# <div style="text-align: center; color: cyan">Train</div>

## <div style="text-align: center; color: lime">Imports</div>

In [1]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.datasets import load_iris

## <div style="text-align: center; color: lime">Load the data and make the model </div>

In [2]:
iris = load_iris()

In [3]:
class IRISDataset(Dataset):
    def __init__(self, data, target):
        super().__init__()
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = torch.tensor(self.data[idx]).to(torch.float)
        target = torch.tensor(self.target[idx])
        return data, target

iris_dataset = IRISDataset(iris.data, iris.target)

In [4]:
g1 = torch.Generator().manual_seed(20)
train_data, val_data, test_data = random_split(iris_dataset, [0.7, 0.2, 0.1], g1)

train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
val_loader = DataLoader(val_data, batch_size=10, shuffle=False)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

In [5]:
class IRISClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(4, 16),
            nn.Linear(16, 8),
            nn.Linear(8, 3),
        )

    def forward(self, x):
        return self.layers(x)


model = IRISClassifier()

## <div style="text-align: center; color: lime">Train the model</div>

In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

In [7]:
model.train()

for batch_of_data, batch_of_target in train_loader:
    optimizer.zero_grad()

    logits = model(batch_of_data)

    loss = loss_fn(logits, batch_of_target)
    print(f"loss: {loss.item()}")

    loss.backward()

    optimizer.step()


loss: 1.1131248474121094
loss: 1.2602981328964233
loss: 1.1842542886734009
loss: 1.1304521560668945
loss: 1.1646679639816284
loss: 1.1598676443099976
loss: 1.036192536354065
loss: 1.1502324342727661
loss: 0.9869096875190735
loss: 1.1336429119110107
loss: 1.0729477405548096


## <div style="text-align: center; color: lime">Evaluate the model</div>

In [8]:
model.eval()

with torch.inference_mode():
    total_loss = 0

    for batch_of_data, batch_of_target in val_loader:
        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

    print(f"average_loss: {total_loss / len(val_loader)}")

average_loss: 1.1142856280008953


In [9]:
model.eval()

with torch.inference_mode():
    total_loss = 0
    total_correct = 0

    for batch_of_data, batch_of_target in val_loader:
        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        predictions = logits.argmax(dim=1)
        total_correct += predictions.eq(batch_of_target).sum().item()

    print(f"average_loss: {total_loss / len(val_loader)}")
    print(f"accuracy: {total_correct / len(val_loader.dataset)}")


average_loss: 1.1142856280008953
accuracy: 0.23333333333333334


## <div style="text-align: center; color: lime">make train_step and val_step</div>

In [10]:
def train_step():
    model.train()

    total_loss = 0

    for batch_of_data, batch_of_target in train_loader:
        optimizer.zero_grad()

        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        loss.backward()

        optimizer.step()

    print(f"training average_loss: {total_loss / len(train_loader)}")


In [11]:
def val_step():
    model.eval()

    with torch.inference_mode():
        total_loss = 0
        total_correct = 0

        for batch_of_data, batch_of_target in val_loader:
            logits = model(batch_of_data)

            loss = loss_fn(logits, batch_of_target)
            total_loss += loss.item()

            predictions = logits.argmax(dim=1)
            total_correct += predictions.eq(batch_of_target).sum().item()

        print(f"validation average_loss: {total_loss / len(val_loader)}")
        print(f"validation accuracy: {total_correct / len(val_loader.dataset)}")


In [12]:
train_step()

training average_loss: 1.0586650588295676


In [13]:
val_step()

validation average_loss: 1.1128065983454387
validation accuracy: 0.13333333333333333


## <div style="text-align: center; color: lime">Epoch</div>

In [14]:
model = IRISClassifier()

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

for epoch in range(5):
    print("-" * 20)
    print(f"epoch: {epoch}")
    train_step()
    val_step()

--------------------
epoch: 0
training average_loss: 1.0262293327938428
validation average_loss: 1.0352061986923218
validation accuracy: 0.26666666666666666
--------------------
epoch: 1
training average_loss: 0.9824423031373457
validation average_loss: 1.0101556380589802
validation accuracy: 0.3
--------------------
epoch: 2
training average_loss: 0.9455421566963196
validation average_loss: 0.9778478145599365
validation accuracy: 0.26666666666666666
--------------------
epoch: 3
training average_loss: 0.9181961200454019
validation average_loss: 0.9337190985679626
validation accuracy: 0.5666666666666667
--------------------
epoch: 4
training average_loss: 0.8796669190580194
validation average_loss: 0.8617650866508484
validation accuracy: 0.8


## <div style="text-align: center; color: lime">Run on Accelerator</div>

In [15]:
if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()
else:
    device = "cpu"

print(device)

mps


In [16]:
def train_step():
    model.train()

    total_loss = 0

    for batch_of_data, batch_of_target in train_loader:
        batch_of_data = batch_of_data.to(device)
        batch_of_target = batch_of_target.to(device)

        optimizer.zero_grad()

        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        loss.backward()

        optimizer.step()

    print(f"training average_loss: {total_loss / len(train_loader)}")


In [17]:
def val_step():
    model.eval()

    with torch.inference_mode():
        total_loss = 0
        total_correct = 0

        for batch_of_data, batch_of_target in val_loader:
            batch_of_data = batch_of_data.to(device)
            batch_of_target = batch_of_target.to(device)

            logits = model(batch_of_data)

            loss = loss_fn(logits, batch_of_target)
            total_loss += loss.item()

            predictions = logits.argmax(dim=1)
            total_correct += predictions.eq(batch_of_target).sum().item()

        print(f"validation average_loss: {total_loss / len(val_loader)}")
        print(f"validation accuracy: {total_correct / len(val_loader.dataset)}")



In [18]:
model = IRISClassifier()
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

for epoch in range(5):
    print("-" * 20)
    print(f"epoch: {epoch}")
    train_step()
    val_step()


--------------------
epoch: 0
training average_loss: 1.1079630093141035
validation average_loss: 1.1306960185368855
validation accuracy: 0.36666666666666664
--------------------
epoch: 1
training average_loss: 1.0650406967509876
validation average_loss: 1.0661596457163494
validation accuracy: 0.36666666666666664
--------------------
epoch: 2
training average_loss: 1.0330196131359448
validation average_loss: 1.0052913228670757
validation accuracy: 0.36666666666666664
--------------------
epoch: 3
training average_loss: 0.9949019930579446
validation average_loss: 0.9664387504259745
validation accuracy: 0.43333333333333335
--------------------
epoch: 4
training average_loss: 0.9623267108743842
validation average_loss: 0.905125896135966
validation accuracy: 0.36666666666666664


## <div style="text-align: center; color: lime">Save and load our model</div>

In [19]:
torch.save(model.state_dict(), "model.pth")

In [20]:
new_model = IRISClassifier()

weights = torch.load("model.pth")

new_model.load_state_dict(weights)

new_model = new_model.to(device)

In [21]:
for key in new_model.state_dict().keys():
    if key not in model.state_dict().keys():
        print(f"Key {key} not in model.state_dict()")
        break

    if not torch.allclose(new_model.state_dict()[key], model.state_dict()[key]):
        print("Values are different")
        break


<div style="text-align: center">

<div>
    @LiterallyTheOne — PhD Candidate in Artificial Intelligence
</div>

<a style="margin: 1em" href="https://literallytheone.github.io">
https://literallytheone.github.io
</a>

</div>
