In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

import numpy as np
from tqdm import tqdm

from datasets import load_from_disk

In [2]:
train_val_data = load_from_disk("processed_bird_data")

train_data = train_val_data["train"]
val_data = train_val_data["validation"]

print("Train samples:", len(train_data))
print("Validation samples:", len(val_data))

Train samples: 3337
Validation samples: 589


In [3]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),     
    transforms.RandomRotation(10),         
    transforms.ColorJitter(),              
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [4]:
class BirdTrainDataset(Dataset):
    def __init__(self, ds, transform=None):
        self.ds = ds
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        img = item["image"]
        label = item["label"]

        if self.transform:
            img = self.transform(img)

        return img, label

In [5]:
train_loader = DataLoader(
    BirdTrainDataset(train_data, train_transform),
    batch_size=32,
    shuffle=True,
)

val_loader = DataLoader(
    BirdTrainDataset(val_data, val_transform),
    batch_size=32,
    shuffle=False,
)

In [6]:
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.downsample = None
        if stride != 1 or in_ch != out_ch:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        identity = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.downsample:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class ResNetScratch(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 7, stride=2, padding=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
        )

        self.layer1 = BasicBlock(32, 64, stride=2)
        self.layer2 = BasicBlock(64, 128, stride=2)
        self.layer3 = BasicBlock(128, 256, stride=2)
        self.layer4 = BasicBlock(256, 256, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return self.fc(x)

In [7]:
EPOCHS = 40
NUM_CLASSES = 200
model = ResNetScratch(NUM_CLASSES)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS) 

In [8]:
def train_one_epoch(epoch):
    model.train()
    total_loss = 0
    correct = 0
    samples = 0

    for batch_idx, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        samples += imgs.size(0)

        if batch_idx % 20 == 0:
            print(f"[Epoch {epoch}] Batch {batch_idx}/{len(train_loader)} loss={loss.item():.4f}")

    return total_loss/samples, correct/samples


def evaluate():
    model.eval()
    total_loss = 0
    correct = 0
    samples = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)

            loss = criterion(logits, labels)
            preds = logits.argmax(1)

            total_loss += loss.item() * imgs.size(0)
            correct += (preds == labels).sum().item()
            samples += imgs.size(0)

    return total_loss/samples, correct/samples

In [9]:
best_val_acc = 0.0

for epoch in range(1, EPOCHS+1):
    print(f"\nEpoch {epoch}/{EPOCHS}")

    train_loss, train_acc = train_one_epoch(epoch)
    val_loss, val_acc = evaluate()

    scheduler.step()

    print(f"Train: loss={train_loss:.4f}, acc={train_acc:.4f}")
    print(f"Val: loss={val_loss:.4f}, acc={val_acc:.4f}")


    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "resnet_best.pth")
        print("Best model saved")


Epoch 1/40
[Epoch 1] Batch 0/105 loss=5.2894
[Epoch 1] Batch 20/105 loss=5.2952
[Epoch 1] Batch 40/105 loss=5.1810
[Epoch 1] Batch 60/105 loss=5.1811
[Epoch 1] Batch 80/105 loss=5.2426
[Epoch 1] Batch 100/105 loss=5.2743
Train: loss=5.2207, acc=0.0177
Val: loss=5.1718, acc=0.0204
Best model saved

Epoch 2/40
[Epoch 2] Batch 0/105 loss=4.9603
[Epoch 2] Batch 20/105 loss=4.9334
[Epoch 2] Batch 40/105 loss=5.0072
[Epoch 2] Batch 60/105 loss=4.7196
[Epoch 2] Batch 80/105 loss=4.8391
[Epoch 2] Batch 100/105 loss=4.9605
Train: loss=4.9373, acc=0.0372
Val: loss=5.0442, acc=0.0153

Epoch 3/40
[Epoch 3] Batch 0/105 loss=4.9634
[Epoch 3] Batch 20/105 loss=4.7199
[Epoch 3] Batch 40/105 loss=4.8598
[Epoch 3] Batch 60/105 loss=4.6079
[Epoch 3] Batch 80/105 loss=4.8840
[Epoch 3] Batch 100/105 loss=4.7018
Train: loss=4.7714, acc=0.0509
Val: loss=4.9356, acc=0.0357
Best model saved

Epoch 4/40
[Epoch 4] Batch 0/105 loss=4.5196
[Epoch 4] Batch 20/105 loss=4.6518
[Epoch 4] Batch 40/105 loss=4.8059
[Epo

In [11]:
# TODO: test data (fix hyperparameters first)