In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LEARNING_RATE = 0.001
EPOCHS = 10
IMG_DIR = 'data/processed_pngs'

In [3]:
class RSNADataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['filename']
        img_path = os.path.join(self.img_dir, img_name)
        label = int(self.data.iloc[idx]['Target'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

In [4]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [5]:
train_dataset = RSNADataset('data/train_split.csv', IMG_DIR, transform=train_transforms)
val_dataset = RSNADataset('data/val_split.csv', IMG_DIR, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [6]:
teacher_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
num_ftrs = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(num_ftrs, 2)
teacher_model = teacher_model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher_model.parameters(), lr=LEARNING_RATE)

In [7]:
for epoch in range(EPOCHS):
    teacher_model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = teacher_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print(f"Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f} - Acc: {100*correct/total:.2f}%")

100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [09:39<00:00,  2.31it/s]


Epoch 1 - Loss: 0.4378 - Acc: 79.94%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:08<00:00,  3.12it/s]


Epoch 2 - Loss: 0.4077 - Acc: 81.65%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:26<00:00,  2.99it/s]


Epoch 3 - Loss: 0.3950 - Acc: 82.24%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:18<00:00,  3.05it/s]


Epoch 4 - Loss: 0.3923 - Acc: 82.44%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:17<00:00,  3.05it/s]


Epoch 5 - Loss: 0.3895 - Acc: 82.56%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:16<00:00,  3.06it/s]


Epoch 6 - Loss: 0.3847 - Acc: 82.57%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:19<00:00,  3.04it/s]


Epoch 7 - Loss: 0.3860 - Acc: 82.39%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:17<00:00,  3.05it/s]


Epoch 8 - Loss: 0.3792 - Acc: 82.97%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:20<00:00,  3.03it/s]


Epoch 9 - Loss: 0.3757 - Acc: 82.77%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [07:21<00:00,  3.02it/s]

Epoch 10 - Loss: 0.3730 - Acc: 83.16%





In [8]:
torch.save(teacher_model.state_dict(), "teacher_resnet50.pth")