## imports


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, models
import os
import pandas as pd
from PIL import Image
import onnx

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


## dataset class

In [3]:
class PokemonCardDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, csv_path, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        df = pd.read_csv(csv_path)
        self.filenames = sorted(os.listdir(image_dir))
        self.labels = df['label'].values.astype(float)
        assert len(self.filenames) == len(self.labels), "Mismatch images and labels"

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = self.filenames[idx]
        img_path = os.path.join(self.image_dir, img_name)
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = float(self.labels[idx])
        return img, label

# transformers


In [4]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [5]:
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [6]:
full_dataset = PokemonCardDataset("../data/train", "../data/train_labels.csv", train_transforms)
val_size = int(0.2 * len(full_dataset))
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

test_dataset = PokemonCardDataset("../data/test", "../data/test_labels.csv", test_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# model

In [7]:
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False


for param in model.layer4.parameters():
    param.requires_grad = True

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 128),
    nn.ReLU(),
    nn.Dropout(0.5), 
    nn.Linear(128, 1)
)

for name, param in model.named_parameters():
    if "layer4" not in name and "fc" not in name:
        param.requires_grad = False

model = model.to(device)



In [8]:
pos_weight = torch.tensor([train_dataset.dataset.labels.sum() / (len(train_dataset) - train_dataset.dataset.labels.sum())]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)



## training time

In [9]:
EPOCHS = 200
best_val_loss = float('inf')
patience_counter = 0
EARLY_STOPPING_PATIENCE = 3

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.float().to(device)
        optimizer.zero_grad()
        logits = model(images).view(-1)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_train_loss = total_loss / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.float().to(device)
            logits = model(images).view(-1)
            loss = criterion(logits, labels)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    scheduler.step(avg_val_loss)

    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, patience remaining: {3 - patience_counter}")


    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model, "../models/best_model.pth")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

print("Training complete. Best model saved to ../models/best_model.pth")

Epoch 1/200, Train Loss: 1.8183, Val Loss: 1.1987, patience remaining: 3
Epoch 2/200, Train Loss: 0.8998, Val Loss: 1.0872, patience remaining: 3
Epoch 3/200, Train Loss: 0.6797, Val Loss: 1.0042, patience remaining: 3
Epoch 4/200, Train Loss: 0.5208, Val Loss: 0.9203, patience remaining: 3
Epoch 5/200, Train Loss: 0.4547, Val Loss: 0.8966, patience remaining: 3
Epoch 6/200, Train Loss: 0.3418, Val Loss: 1.1617, patience remaining: 3
Epoch 7/200, Train Loss: 0.2574, Val Loss: 0.9876, patience remaining: 2
Epoch 8/200, Train Loss: 0.1903, Val Loss: 1.1586, patience remaining: 1
Early stopping triggered.
Training complete. Best model saved to ../models/best_model.pth
