In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LEARNING_RATE = 0.001
EPOCHS = 10
IMG_DIR = 'data/processed_pngs'

In [4]:
class RSNADataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['filename']
        img_path = os.path.join(self.img_dir, img_name)
        label = int(self.data.iloc[idx]['Target'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

In [5]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [6]:
train_dataset = RSNADataset('data/train_split.csv', IMG_DIR, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

baseline_student = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1)
num_ftrs = baseline_student.classifier[3].in_features
baseline_student.classifier[3] = nn.Linear(num_ftrs, 2)
baseline_student = baseline_student.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(baseline_student.parameters(), lr=LEARNING_RATE)

In [7]:
for epoch in range(EPOCHS):
    baseline_student.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = baseline_student(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print(f"Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f} - Acc: {100*correct/total:.2f}%")

100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:16<00:00,  9.80it/s]


Epoch 1 - Loss: 0.4035 - Acc: 81.59%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:12<00:00, 10.08it/s]


Epoch 2 - Loss: 0.3771 - Acc: 83.10%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:14<00:00,  9.96it/s]


Epoch 3 - Loss: 0.3701 - Acc: 83.44%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:15<00:00,  9.87it/s]


Epoch 4 - Loss: 0.3597 - Acc: 83.88%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:11<00:00, 10.17it/s]


Epoch 5 - Loss: 0.3544 - Acc: 83.89%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:12<00:00, 10.08it/s]


Epoch 6 - Loss: 0.3434 - Acc: 84.71%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:12<00:00, 10.06it/s]


Epoch 7 - Loss: 0.3327 - Acc: 85.05%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:19<00:00,  9.54it/s]


Epoch 8 - Loss: 0.3256 - Acc: 85.35%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:16<00:00,  9.81it/s]


Epoch 9 - Loss: 0.3155 - Acc: 86.17%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [02:19<00:00,  9.55it/s]

Epoch 10 - Loss: 0.3020 - Acc: 86.65%





In [8]:
torch.save(baseline_student.state_dict(), "baseline_mobilenet.pth")