# <div style="text-align: center; color: cyan">Train</div>

## <div style="text-align: center; color: lime">Imports</div>

In [1]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.datasets import load_iris

## <div style="text-align: center; color: lime">Load the data and make the model </div>

In [2]:
iris = load_iris()

data = torch.tensor(iris.data).to(torch.float)
target = torch.tensor(iris.target)

In [3]:
class IRISDataset(Dataset):
    def __init__(self, data, target):
        super().__init__()
        self.data = data
        self.target = target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]


iris_dataset = IRISDataset(data, target)

In [4]:
g1 = torch.Generator().manual_seed(20)
train_data, val_data, test_data = random_split(iris_dataset, [0.7, 0.2, 0.1], g1)

train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
val_loader = DataLoader(val_data, batch_size=10, shuffle=False)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

In [5]:
class IRISClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(4, 16),
            nn.Linear(16, 8),
            nn.Linear(8, 3),
        )

    def forward(self, x):
        return self.layers(x)


model = IRISClassifier()

## <div style="text-align: center; color: lime">Train the model</div>

In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

In [7]:
model.train()

for batch_of_data, batch_of_target in train_loader:
    optimizer.zero_grad()

    logits = model(batch_of_data)

    loss = loss_fn(logits, batch_of_target)
    print(f"loss: {loss.item()}")

    loss.backward()

    optimizer.step()


loss: 0.99995356798172
loss: 0.9977426528930664
loss: 0.9675296545028687
loss: 0.966225266456604
loss: 0.944222629070282
loss: 0.9690997004508972
loss: 1.0804362297058105
loss: 0.9989019632339478
loss: 0.9921015501022339
loss: 0.9486616253852844
loss: 1.0657466650009155


## <div style="text-align: center; color: lime">Evaluate the model</div>

In [11]:
model.eval()

with torch.inference_mode():
    total_loss = 0

    for batch_of_data, batch_of_target in val_loader:
        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

    print(f"average_loss: {total_loss / len(val_loader)}")

average_loss: 1.0044949253400166


In [13]:
model.eval()

with torch.inference_mode():
    total_loss = 0
    total_correct = 0

    for batch_of_data, batch_of_target in val_loader:
        logits = model(batch_of_data)

        loss = loss_fn(logits, batch_of_target)
        total_loss += loss.item()

        predictions = logits.argmax(dim=1)
        total_correct += predictions.eq(batch_of_target).sum().item()

    print(f"average_loss: {total_loss / len(val_loader)}")
    print(f"accuracy: {total_correct / len(val_loader.dataset)}")


average_loss: 1.0044949253400166
accuracy: 0.6333333333333333


## <div style="text-align: center; color: lime">Run on Accelerator</div>

<div style="text-align: center">

<div>
    @LiterallyTheOne — PhD Candidate in Artificial Intelligence
</div>

<a style="margin: 1em" href="https://literallytheone.github.io">
https://literallytheone.github.io
</a>

</div>
