# STL10 Image Classification (PyTorch)


This notebook trains:
**ResNet18 (transfer learning)**

Deliverables included:
- Training/validation accuracy & loss plots
- Confusion matrix
- Sample predictions
- Saved model weights

**Dataset:** `STL10` (10 classes, 96x96 images). Labeled splits used: `train` (5k) and `test` (8k).


In [None]:
# Environment & imports
import os, time, math, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from torchvision.datasets import STL10
from torchvision import models
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import itertools

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 1) Data Loading
- STL10 images are **96x96**.
- For **Simple CNN**, we keep 96x96.
- For **ResNet18**, we'll **resize to 224x224** later.


In [None]:
mean_96 = (0.4467, 0.4398, 0.4066)
std_96  = (0.2241, 0.2215, 0.2239)

train_tf_96 = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean_96, std_96),
])

test_tf_96 = T.Compose([
    T.ToTensor(),
    T.Normalize(mean_96, std_96),
])

data_root = './data'

train_ds_96 = STL10(root=data_root, split='train', download=True, transform=train_tf_96)
test_ds_96  = STL10(root=data_root, split='test', download=True, transform=test_tf_96)

batch_size = 64
train_loader_96 = DataLoader(train_ds_96, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader_96  = DataLoader(test_ds_96, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

class_names = train_ds_96.classes
num_classes = len(class_names)
num_classes, class_names


In [None]:
def denorm(img, mean=mean_96, std=std_96):
    img = img.clone()
    for c in range(3):
        img[c] = img[c]*std[c] + mean[c]
    return img.clamp(0,1)

images, labels = next(iter(train_loader_96))
plt.figure()
grid = torchvision.utils.make_grid(denorm(images[:16]))
plt.imshow(np.transpose(grid.numpy(), (1,2,0)))
plt.title('Sample training images')
plt.axis('off')
plt.show()


### MAking a training function

In [None]:
def train_one_epoch(model, loader, criteria, optimizer):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criteria(out, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()*x.size(0)
        _, pred = out.max(1)
        correct += pred.eq(y).sum().item()
        total += y.size(0)
    return running_loss/total, correct/total

### Creating an Evaluation function

In [None]:
def evaluate(model, loader, criteria):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criteria(out, y)
            running_loss += loss.item()*x.size(0)
            _, pred = out.max(1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
            all_preds.append(pred.cpu())
            all_labels.append(y.cpu())
    avg_loss = running_loss/total
    acc = correct/total
    return avg_loss, acc, torch.cat(all_preds), torch.cat(all_labels)


## Resnet

In [None]:
train_tf_224 = T.Compose([
    T.Resize((224,224)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

test_tf_224 = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

train_ds_224 = STL10(root=data_root, split='train', download=True, transform=train_tf_224)
test_ds_224  = STL10(root=data_root, split='test', download=True, transform=test_tf_224)

train_loader_224 = DataLoader(train_ds_224, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader_224  = DataLoader(test_ds_224, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
in_features = resnet.fc.in_features
resnet.fc = nn.Linear(in_features, num_classes)
resnet = resnet.to(device)

criterion_res = nn.CrossEntropyLoss()
optimizer_res = optim.Adam(resnet.parameters(), lr=1e-4)

epochs_res = 6  # quick fine-tune
history_res = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}
best_acc_res, best_state_res = 0.0, None

for ep in range(1, epochs_res+1):
    tl, ta = train_one_epoch(resnet, train_loader_224, criterion_res, optimizer_res)
    vl, va, _, _ = evaluate(resnet, test_loader_224, criterion_res)
    history_res['train_loss'].append(tl)
    history_res['train_acc'].append(ta)
    history_res['val_loss'].append(vl)
    history_res['val_acc'].append(va)
    print(f"[ResNet18] Epoch {ep}/{epochs_res}: train_loss={tl:.4f} train_acc={ta:.4f} | val_loss={vl:.4f} val_acc={va:.4f}")
    if va > best_acc_res:
        best_acc_res = va
        best_state_res = resnet.state_dict()

if best_state_res is not None:
    resnet.load_state_dict(best_state_res)
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(resnet.state_dict(), 'checkpoints/stl10_resnet18.pt')
    print('Best ResNet18 saved to checkpoints/stl10_resnet18.pt')


### Accuracy and loss curves

In [None]:
if 'history_res' in globals() and len(history_res['train_acc'])>0:
    plt.figure()
    plt.plot(history_res['train_acc'], label='train')
    plt.plot(history_res['val_acc'], label='val')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy over epochs (ResNet18)')
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(history_res['train_loss'], label='train')
    plt.plot(history_res['val_loss'], label='val')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss over epochs (ResNet18)')
    plt.legend()
    plt.show()
else:
    print('There is a problem with resnet. Fix it')


### Confusion matrix

In [None]:
if 'resnet' in globals():
    vl, va, preds_res, labels_res = evaluate(resnet, test_loader_224, criterion_res)
    cm_res = confusion_matrix(labels_res.numpy(), preds_res.numpy())
    plt.figure(figsize=(6,6))
    plt.imshow(cm_res, interpolation='nearest')
    plt.title('Confusion Matrix (ResNet18)')
    plt.colorbar()
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, class_names, rotation=45, ha='right')
    plt.yticks(tick_marks, class_names)
    thresh = cm_res.max() / 2.
    for i, j in itertools.product(range(cm_res.shape[0]), range(cm_res.shape[1])):
        plt.text(j, i, cm_res[i, j], horizontalalignment="center",
                 color="white" if cm_res[i, j] > thresh else "black")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()
    print('Classification Report (ResNet18):')
    print(classification_report(labels_res.numpy(), preds_res.numpy(), target_names=class_names))
else:
    print('There is a problem with resnet. Fix it')
