# <div style="text-align: center; color: cyan">Train</div>

## <div style="text-align: center; color: lime">Imports</div>

In [2]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.datasets import load_iris

## <div style="text-align: center; color: lime">Load the data and make the model </div>

In [3]:
iris = load_iris()

data = torch.tensor(iris.data).to(torch.float)
target = torch.tensor(iris.target)

In [4]:
class IRISDataset(Dataset):
    def __init__(self, data, target):
        super().__init__()
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]


iris_dataset = IRISDataset(data, target)

In [5]:
g1 = torch.Generator().manual_seed(20)
train_data, val_data, test_data = random_split(iris_dataset, [0.7, 0.2, 0.1], g1)

train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
val_loader = DataLoader(val_data, batch_size=10, shuffle=False)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

In [6]:
class IRISClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(4, 16),
            nn.Linear(16, 8),
            nn.Linear(8, 3),
        )

    def forward(self, x):
        return self.layers(x)


model = IRISClassifier()

## <div style="text-align: center; color: lime">Train the model</div>

In [7]:
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

In [8]:
model.train()

for batch_of_data, batch_of_target in train_loader:
    optimizer.zero_grad()

    logits = model(batch_of_data)

    loss = loss_fn(logits, batch_of_target)
    print(f"loss: {loss.item()}")

    loss.backward()

    optimizer.step()


loss: 0.7640460729598999
loss: 0.7941502332687378
loss: 0.8383936882019043
loss: 0.5752513408660889
loss: 0.7075636386871338
loss: 0.8490365147590637
loss: 0.8173547983169556
loss: 0.9585196375846863
loss: 0.7621733546257019
loss: 0.730190634727478
loss: 0.9111936688423157


## <div style="text-align: center; color: lime">Evaluate the model</div>

In [9]:
model.eval()

with torch.inference_mode():
    total_loss = 0

    for batch_of_data, batch_of_target in val_loader:
        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

    print(f"average_loss: {total_loss / len(val_loader)}")

average_loss: 0.704477588335673


In [10]:
model.eval()

with torch.inference_mode():
    total_loss = 0
    total_correct = 0

    for batch_of_data, batch_of_target in val_loader:
        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        predictions = logits.argmax(dim=1)
        total_correct += predictions.eq(batch_of_target).sum().item()

    print(f"average_loss: {total_loss / len(val_loader)}")
    print(f"accuracy: {total_correct / len(val_loader.dataset)}")


average_loss: 0.704477588335673
accuracy: 0.6333333333333333


## <div style="text-align: center; color: lime">make train_step and val_step</div>

In [11]:
def train_step():
    model.train()

    total_loss = 0

    for batch_of_data, batch_of_target in train_loader:
        optimizer.zero_grad()

        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        loss.backward()

        optimizer.step()

    print(f"training average_loss: {total_loss / len(train_loader)}")


In [12]:
def val_step():
    model.eval()

    with torch.inference_mode():
        total_loss = 0
        total_correct = 0

        for batch_of_data, batch_of_target in val_loader:
            logits = model(batch_of_data)

            loss = loss_fn(logits, batch_of_target)
            total_loss += loss.item()

            predictions = logits.argmax(dim=1)
            total_correct += predictions.eq(batch_of_target).sum().item()

        print(f"validation average_loss: {total_loss / len(val_loader)}")
        print(f"validation accuracy: {total_correct / len(val_loader.dataset)}")


In [13]:
train_step()

training average_loss: 0.7337634075771678


In [14]:
val_step()

validation average_loss: 0.6558424830436707
validation accuracy: 0.6333333333333333


## <div style="text-align: center; color: lime">Epoch</div>

In [18]:
model = IRISClassifier()

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

for epoch in range(5):
    print("-" * 20)
    print(f"epoch: {epoch}")
    train_step()
    val_step()

--------------------
epoch: 0
training average_loss: 1.050121784210205
validation average_loss: 1.0271871089935303
validation accuracy: 0.4
--------------------
epoch: 1
training average_loss: 1.014872670173645
validation average_loss: 0.9773193200429281
validation accuracy: 0.36666666666666664
--------------------
epoch: 2
training average_loss: 0.977487173947421
validation average_loss: 0.943575362364451
validation accuracy: 0.5
--------------------
epoch: 3
training average_loss: 0.9405033750967546
validation average_loss: 0.901186982790629
validation accuracy: 0.6
--------------------
epoch: 4
training average_loss: 0.9058771566911177
validation average_loss: 0.8475126425425211
validation accuracy: 0.7666666666666667


## <div style="text-align: center; color: lime">Run on Accelerator</div>

In [19]:
if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()
else:
    device = "cpu"

print(device)

mps


In [20]:
def train_step():
    model.train()

    total_loss = 0

    for batch_of_data, batch_of_target in train_loader:
        batch_of_data = batch_of_data.to(device)
        batch_of_target = batch_of_target.to(device)

        optimizer.zero_grad()

        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        loss.backward()

        optimizer.step()

    print(f"training average_loss: {total_loss / len(train_loader)}")


In [21]:
def val_step():
    model.eval()

    with torch.inference_mode():
        total_loss = 0
        total_correct = 0

        for batch_of_data, batch_of_target in val_loader:
            batch_of_data = batch_of_data.to(device)
            batch_of_target = batch_of_target.to(device)

            logits = model(batch_of_data)

            loss = loss_fn(logits, batch_of_target)
            total_loss += loss.item()

            predictions = logits.argmax(dim=1)
            total_correct += predictions.eq(batch_of_target).sum().item()

        print(f"validation average_loss: {total_loss / len(val_loader)}")
        print(f"validation accuracy: {total_correct / len(val_loader.dataset)}")



In [23]:
model = IRISClassifier()
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

for epoch in range(5):
    print("-" * 20)
    print(f"epoch: {epoch}")
    train_step()
    val_step()


--------------------
epoch: 0
training average_loss: 1.0796789039265027
validation average_loss: 1.094805916150411
validation accuracy: 0.3
--------------------
epoch: 1
training average_loss: 1.0360274802554736
validation average_loss: 1.0365037123362224
validation accuracy: 0.23333333333333334
--------------------
epoch: 2
training average_loss: 0.9956214427947998
validation average_loss: 0.9788729349772135
validation accuracy: 0.7
--------------------
epoch: 3
training average_loss: 0.9541016221046448
validation average_loss: 0.9327105482419332
validation accuracy: 0.6333333333333333
--------------------
epoch: 4
training average_loss: 0.9093180786479603
validation average_loss: 0.8656323750813802
validation accuracy: 0.6666666666666666


## <div style="text-align: center; color: lime">Save and load our model</div>

In [24]:
torch.save(model.state_dict(), "model.pth")

In [37]:
new_model = IRISClassifier()

weights = torch.load("model.pth")

new_model.load_state_dict(weights)

new_model = new_model.to(device)

In [34]:
for key in new_model.state_dict().keys():
    if key not in model.state_dict().keys():
        print(f"Key {key} not in model.state_dict()")
        break

    if not torch.allclose(new_model.state_dict()[key], model.state_dict()[key]):
        print("Values are different")
        break


<div style="text-align: center">

<div>
    @LiterallyTheOne — PhD Candidate in Artificial Intelligence
</div>

<a style="margin: 1em" href="https://literallytheone.github.io">
https://literallytheone.github.io
</a>

</div>
