In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

import numpy as np
from tqdm import tqdm

from datasets import load_from_disk

In [2]:
train_val_data = load_from_disk("processed_bird_data")

train_data = train_val_data["train"]
val_data = train_val_data["validation"]

print("Train samples:", len(train_data))
print("Validation samples:", len(val_data))

Train samples: 3337
Validation samples: 589


In [3]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), 
    
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    
    transforms.ToTensor(),

    # leaving this for resnet
    transforms.Normalize(mean=mean, std=std) 
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [4]:
class BirdTrainDataset(Dataset):
    def __init__(self, ds, transform=None):
        self.ds = ds
        self.transform = transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        img = item["image"]
        label = item["label"]

        if self.transform:
            img = self.transform(img)

        return img, label

In [5]:
train_loader = DataLoader(
    BirdTrainDataset(train_data, train_transform),
    batch_size=32,
    shuffle=True,
)

val_loader = DataLoader(
    BirdTrainDataset(val_data, val_transform),
    batch_size=32,
    shuffle=False,
)

In [6]:
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.downsample = None
        if stride != 1 or in_ch != out_ch:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        identity = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.downsample:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class ResNetScratch(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 7, stride=2, padding=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
        )

        self.layer1 = BasicBlock(32, 64, stride=2)
        self.layer2 = BasicBlock(64, 128, stride=2)
        self.layer3 = BasicBlock(128, 256, stride=2)
        self.layer4 = BasicBlock(256, 256, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return self.fc(x)

In [7]:
EPOCHS = 40
NUM_CLASSES = 200
model = ResNetScratch(NUM_CLASSES)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS) 

In [8]:
def train_one_epoch(epoch):
    model.train()
    total_loss = 0
    correct = 0
    samples = 0

    for batch_idx, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        samples += imgs.size(0)

        if batch_idx % 20 == 0:
            print(f"[Epoch {epoch}] Batch {batch_idx}/{len(train_loader)} loss={loss.item():.4f}")

    return total_loss/samples, correct/samples


def evaluate():
    model.eval()
    total_loss = 0
    correct = 0
    samples = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)

            loss = criterion(logits, labels)
            preds = logits.argmax(1)

            total_loss += loss.item() * imgs.size(0)
            correct += (preds == labels).sum().item()
            samples += imgs.size(0)

    return total_loss/samples, correct/samples

In [9]:
best_val_acc = 0.0

for epoch in range(1, EPOCHS+1):
    print(f"\nEpoch {epoch}/{EPOCHS}")

    train_loss, train_acc = train_one_epoch(epoch)
    val_loss, val_acc = evaluate()

    scheduler.step()

    print(f"Train: loss={train_loss:.4f}, acc={train_acc:.4f}")
    print(f"Val: loss={val_loss:.4f}, acc={val_acc:.4f}")


    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "resnet_best.pth")
        print("Best model saved")


Epoch 1/40
[Epoch 1] Batch 0/105 loss=5.3501
[Epoch 1] Batch 20/105 loss=5.2675
[Epoch 1] Batch 40/105 loss=5.1752
[Epoch 1] Batch 60/105 loss=5.2846
[Epoch 1] Batch 80/105 loss=5.1627
[Epoch 1] Batch 100/105 loss=5.2126
Train: loss=5.2327, acc=0.0195
Val: loss=5.1505, acc=0.0102
Best model saved

Epoch 2/40
[Epoch 2] Batch 0/105 loss=4.8716
[Epoch 2] Batch 20/105 loss=5.1182
[Epoch 2] Batch 40/105 loss=4.9582
[Epoch 2] Batch 60/105 loss=5.0459
[Epoch 2] Batch 80/105 loss=4.7887
[Epoch 2] Batch 100/105 loss=4.9115
Train: loss=4.9770, acc=0.0303
Val: loss=5.0099, acc=0.0238
Best model saved

Epoch 3/40
[Epoch 3] Batch 0/105 loss=4.8307
[Epoch 3] Batch 20/105 loss=4.9549
[Epoch 3] Batch 40/105 loss=4.8280
[Epoch 3] Batch 60/105 loss=4.8361
[Epoch 3] Batch 80/105 loss=4.7757
[Epoch 3] Batch 100/105 loss=4.6994
Train: loss=4.8072, acc=0.0441
Val: loss=4.9521, acc=0.0238

Epoch 4/40
[Epoch 4] Batch 0/105 loss=4.4807
[Epoch 4] Batch 20/105 loss=4.6739
[Epoch 4] Batch 40/105 loss=4.4446
[Epo

In [11]:
# TODO: test data (fix hyperparameters first)

In [10]:
import pandas as pd
import torch
import os
from datasets import load_from_disk
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

TEST_DATA_PATH = "processed_bird_test_data"
WEIGHTS_PATH = "resnet_best.pth"
OUTPUT_FILENAME = "submission_resnet.csv"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading test data from {TEST_DATA_PATH}...")
try:
    dataset_raw = load_from_disk(TEST_DATA_PATH)
    if isinstance(dataset_raw, dict) and "test" in dataset_raw:
        test_ds = dataset_raw["test"]
    else:
        test_ds = dataset_raw
    if "id" in test_ds.column_names:
        submission_ids = list(test_ds["id"])
        print(f"Found {len(submission_ids)} IDs in dataset.")
    else:
        print("Warning: 'id' column not found. Creating sequential IDs (0..N).")
        submission_ids = list(range(len(test_ds)))

except Exception as e:
    print(f"Error loading dataset: {e}")
    raise e

class SimpleTestDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        img = self.hf_dataset[idx]["image"]
        img = img.convert("RGB")
        
        if self.transform:
            img = self.transform(img)
        return img

test_loader = DataLoader(
    SimpleTestDataset(test_ds, transform=val_transform),
    batch_size=32,
    shuffle=False, 
    num_workers=0
)

print("Initializing model...")
model = ResNetScratch(NUM_CLASSES) 
model.to(DEVICE)

if os.path.exists(WEIGHTS_PATH):
    print(f"Loading weights from {WEIGHTS_PATH}...")
    model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
else:
    print(f"Weights file {WEIGHTS_PATH} not found!")

model.eval()
all_preds = []

print("Running prediction...")
with torch.no_grad():
    for imgs in tqdm(test_loader, desc="Predicting"):
        imgs = imgs.to(DEVICE)
        
        logits = model(imgs)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        all_preds.extend(preds)

if len(all_preds) != len(submission_ids):
    print(f"ERROR: Mismatch! IDs: {len(submission_ids)} vs Preds: {len(all_preds)}")
else:
    df = pd.DataFrame({
        "id": submission_ids,
        "label": all_preds
    })
    
    df.to_csv(OUTPUT_FILENAME, index=False)
    print("\n" + "="*40)
    print(f"Saved {OUTPUT_FILENAME}")

Loading test data from processed_bird_test_data...
Found 4000 IDs in dataset.
Initializing model...
Loading weights from resnet_best.pth...
Running prediction...


Predicting: 100%|████████████████████████████████████████████████████████████████████| 125/125 [01:12<00:00,  1.72it/s]


Saved submission_resnet.csv



