# Multi-class Image Classification 

-- Dependciens

In [49]:
import torch
import torch.nn as nn
from torchmetrics import Accuracy
from torchvision import models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.optim as optim


- Apply Transforms on the dataset (processsing setp).
- We can add data augementation like rotation and flipping but i'll keep it simple.

In [50]:
# transforms
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# load full dataset
dataset = datasets.ImageFolder("data", transform=transform)

- split into train (80%) and val (20%)
- we got 3000 images 1000 each e.g.(1000 cat,1000 dog , 1000 snake)
- simply (3000 * 0.8) and (3000 * 0.2)

In [51]:
train_dataset, val_dataset = random_split(dataset, [2400, 600]) 

In [52]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# test by grabbing one batch
images, labels = next(iter(train_loader))
print(images.shape, labels.shape)

torch.Size([32, 3, 128, 128]) torch.Size([32])


In [53]:

# Load pretrained ResNet18
model = models.resnet18(weights="IMAGENET1K_V1")

# Replace the final fully connected layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 3)  # 3 classes (dogs, cats, snakes)

In [54]:

criterion = nn.CrossEntropyLoss()               # for classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5  

for epoch in range(num_epochs):
    model.train()  # set to training mode
    total_loss = 0.0

    for inputs, labels in train_loader:
        # reset gradients
        optimizer.zero_grad()

        # forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # backward pass + update
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # average loss per epoch
    epoch_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")


Epoch [1/5], Loss: 0.4571
Epoch [2/5], Loss: 0.2224
Epoch [3/5], Loss: 0.1479
Epoch [4/5], Loss: 0.0977
Epoch [5/5], Loss: 0.1039


- As number of epochs increases we tend to improve (decrease) loss
- Now we have trained the model let's go for valdiation 

In [None]:
# --- 5. Training loop with best model saving ---
num_epochs = 10
best_acc = 0.0
val_accuracy = Accuracy(task='multiclass',num_classes=3)
for epoch in range(num_epochs):
    # Training
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)

    # Validation
    model.eval()
    val_accuracy.reset()
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            val_accuracy.update(outputs, labels)

    val_acc = val_accuracy.compute().item() * 100

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {epoch_loss:.4f} "
          f"Val Acc: {val_acc:.2f}%")

    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth") # save the model within' best validation score

Epoch [1/10] Train Loss: 0.1331 Val Acc: 90.17%
Epoch [2/10] Train Loss: 0.0910 Val Acc: 84.00%
Epoch [3/10] Train Loss: 0.0509 Val Acc: 87.50%
Epoch [4/10] Train Loss: 0.0920 Val Acc: 82.67%
Epoch [5/10] Train Loss: 0.0976 Val Acc: 92.50%
Epoch [6/10] Train Loss: 0.0230 Val Acc: 94.17%
Epoch [7/10] Train Loss: 0.0344 Val Acc: 88.83%
Epoch [8/10] Train Loss: 0.0771 Val Acc: 88.83%
Epoch [9/10] Train Loss: 0.0999 Val Acc: 93.83%
Epoch [10/10] Train Loss: 0.0141 Val Acc: 94.17%


- Now we have finished the training and validation.
- The next step would be to put the model under test to see if it can recoginse new images or not
- We uploaded three test cases from the internet to test it on the model (That would be on a sperate juypter notebook called 'test')