# <div style="text-align: center; color: cyan">Train</div>

## <div style="text-align: center; color: lime">Imports</div>

In [2]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.datasets import load_iris

## <div style="text-align: center; color: lime">Load the data and make the model </div>

In [3]:
iris = load_iris()

In [4]:
class IRISDataset(Dataset):
    def __init__(self, data, target):
        super().__init__()
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = torch.tensor(self.data[idx]).to(torch.float)
        target = torch.tensor(self.target[idx])
        return data, target

iris_dataset = IRISDataset(iris.data, iris.target)

In [5]:
g1 = torch.Generator().manual_seed(20)
train_data, val_data, test_data = random_split(iris_dataset, [0.7, 0.2, 0.1], g1)

train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
val_loader = DataLoader(val_data, batch_size=10, shuffle=False)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

In [6]:
class IRISClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(4, 16),
            nn.Linear(16, 8),
            nn.Linear(8, 3),
        )

    def forward(self, x):
        return self.layers(x)


model = IRISClassifier()

## <div style="text-align: center; color: lime">Train the model</div>

In [7]:
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

In [8]:
model.train()

for batch_of_data, batch_of_target in train_loader:
    optimizer.zero_grad()

    logits = model(batch_of_data)

    loss = loss_fn(logits, batch_of_target)
    print(f"loss: {loss.item()}")

    loss.backward()

    optimizer.step()


loss: 1.0928857326507568
loss: 1.250089168548584
loss: 1.3968603610992432
loss: 1.048239827156067
loss: 1.3411037921905518
loss: 0.9701024889945984
loss: 1.0000779628753662
loss: 0.9390740394592285
loss: 0.9554198384284973
loss: 0.9382257461547852
loss: 1.0476295948028564


## <div style="text-align: center; color: lime">Evaluate the model</div>

In [9]:
model.eval()

with torch.inference_mode():
    total_loss = 0

    for batch_of_data, batch_of_target in val_loader:
        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

    print(f"average_loss: {total_loss / len(val_loader)}")

average_loss: 0.9457583427429199


In [10]:
model.eval()

with torch.inference_mode():
    total_loss = 0
    total_correct = 0

    for batch_of_data, batch_of_target in val_loader:
        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        predictions = logits.argmax(dim=1)
        total_correct += predictions.eq(batch_of_target).sum().item()

    print(f"average_loss: {total_loss / len(val_loader)}")
    print(f"accuracy: {total_correct / len(val_loader.dataset)}")


average_loss: 0.9457583427429199
accuracy: 0.5333333333333333


## <div style="text-align: center; color: lime">make train_step and val_step</div>

In [11]:
def train_step():
    model.train()

    total_loss = 0

    for batch_of_data, batch_of_target in train_loader:
        optimizer.zero_grad()

        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        loss.backward()

        optimizer.step()

    print(f"training average_loss: {total_loss / len(train_loader)}")


In [12]:
def val_step():
    model.eval()

    with torch.inference_mode():
        total_loss = 0
        total_correct = 0

        for batch_of_data, batch_of_target in val_loader:
            logits = model(batch_of_data)

            loss = loss_fn(logits, batch_of_target)
            total_loss += loss.item()

            predictions = logits.argmax(dim=1)
            total_correct += predictions.eq(batch_of_target).sum().item()

        print(f"validation average_loss: {total_loss / len(val_loader)}")
        print(f"validation accuracy: {total_correct / len(val_loader.dataset)}")


In [13]:
train_step()

training average_loss: 0.9776997186920859


In [14]:
val_step()

validation average_loss: 0.9329464634259542
validation accuracy: 0.7666666666666667


## <div style="text-align: center; color: lime">Epoch</div>

In [15]:
model = IRISClassifier()

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

for epoch in range(5):
    print("-" * 20)
    print(f"epoch: {epoch}")
    train_step()
    val_step()

--------------------
epoch: 0
training average_loss: 1.015959923917597
validation average_loss: 0.9808167219161987
validation accuracy: 0.6333333333333333
--------------------
epoch: 1
training average_loss: 0.972325016151775
validation average_loss: 0.9715630809466044
validation accuracy: 0.6333333333333333
--------------------
epoch: 2
training average_loss: 0.9439240585673939
validation average_loss: 0.9235706726710001
validation accuracy: 0.6333333333333333
--------------------
epoch: 3
training average_loss: 0.9132903164083307
validation average_loss: 0.8667411605517069
validation accuracy: 0.5333333333333333
--------------------
epoch: 4
training average_loss: 0.8846907778219744
validation average_loss: 0.816445509592692
validation accuracy: 0.4666666666666667


## <div style="text-align: center; color: lime">Run on Accelerator</div>

In [16]:
if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()
else:
    device = "cpu"

print(device)

mps


In [17]:
def train_step():
    model.train()

    total_loss = 0

    for batch_of_data, batch_of_target in train_loader:
        batch_of_data = batch_of_data.to(device)
        batch_of_target = batch_of_target.to(device)

        optimizer.zero_grad()

        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        loss.backward()

        optimizer.step()

    print(f"training average_loss: {total_loss / len(train_loader)}")


In [18]:
def val_step():
    model.eval()

    with torch.inference_mode():
        total_loss = 0
        total_correct = 0

        for batch_of_data, batch_of_target in val_loader:
            batch_of_data = batch_of_data.to(device)
            batch_of_target = batch_of_target.to(device)

            logits = model(batch_of_data)

            loss = loss_fn(logits, batch_of_target)
            total_loss += loss.item()

            predictions = logits.argmax(dim=1)
            total_correct += predictions.eq(batch_of_target).sum().item()

        print(f"validation average_loss: {total_loss / len(val_loader)}")
        print(f"validation accuracy: {total_correct / len(val_loader.dataset)}")



In [19]:
model = IRISClassifier()
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

for epoch in range(5):
    print("-" * 20)
    print(f"epoch: {epoch}")
    train_step()
    val_step()


--------------------
epoch: 0
training average_loss: 1.2106390487064014
validation average_loss: 1.2599918444951375
validation accuracy: 0.2
--------------------
epoch: 1
training average_loss: 1.1163438666950574
validation average_loss: 1.135883132616679
validation accuracy: 0.5
--------------------
epoch: 2
training average_loss: 1.0798369212584062
validation average_loss: 1.0811514457066853
validation accuracy: 0.36666666666666664
--------------------
epoch: 3
training average_loss: 1.0335172794081948
validation average_loss: 1.0472945372263591
validation accuracy: 0.3333333333333333
--------------------
epoch: 4
training average_loss: 1.0019696192307905
validation average_loss: 0.9941175182660421
validation accuracy: 0.36666666666666664


## <div style="text-align: center; color: lime">Save and load our model</div>

In [20]:
torch.save(model.state_dict(), "model.pth")

In [21]:
new_model = IRISClassifier()

weights = torch.load("model.pth")

new_model.load_state_dict(weights)

new_model = new_model.to(device)

In [22]:
for key in new_model.state_dict().keys():
    if key not in model.state_dict().keys():
        print(f"Key {key} not in model.state_dict()")
        break

    if not torch.allclose(new_model.state_dict()[key], model.state_dict()[key]):
        print("Values are different")
        break


<div style="text-align: center">

<div>
    @LiterallyTheOne — PhD Candidate in Artificial Intelligence
</div>

<a style="margin: 1em" href="https://literallytheone.github.io">
https://literallytheone.github.io
</a>

</div>
