In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LEARNING_RATE = 0.001
EPOCHS = 10
IMG_DIR = 'data/processed_pngs'
TEMP = 4.0
ALPHA = 0.3

In [3]:
class RSNADataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['filename']
        img_path = os.path.join(self.img_dir, img_name)
        label = int(self.data.iloc[idx]['Target'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

In [4]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = RSNADataset('data/train_split.csv', IMG_DIR, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [5]:
teacher = models.resnet50(weights=None)
num_ftrs = teacher.fc.in_features
teacher.fc = nn.Linear(num_ftrs, 2)
teacher.load_state_dict(torch.load("teacher_resnet50.pth"))
teacher = teacher.to(DEVICE)
teacher.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:
for param in teacher.parameters():
    param.requires_grad = False

In [7]:
student = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1)
num_ftrs = student.classifier[3].in_features
student.classifier[3] = nn.Linear(num_ftrs, 2)
student = student.to(DEVICE)

In [8]:
optimizer = optim.Adam(student.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    student.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        
        student_outputs = student(inputs)
        
        with torch.no_grad():
            teacher_outputs = teacher(inputs)
            
        soft_targets = F.softmax(teacher_outputs / TEMP, dim=1)
        soft_prob = F.log_softmax(student_outputs / TEMP, dim=1)
        
        loss_soft = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (TEMP ** 2)
        loss_hard = F.cross_entropy(student_outputs, labels)
        
        loss = (ALPHA * loss_hard) + ((1 - ALPHA) * loss_soft)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(student_outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print(f"Epoch {epoch+1} - Distillation Loss: {running_loss/len(train_loader):.4f} - Acc: {100*correct/total:.2f}%")

100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [03:52<00:00,  5.74it/s]


Epoch 1 - Distillation Loss: 0.1682 - Acc: 82.91%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [03:54<00:00,  5.70it/s]


Epoch 2 - Distillation Loss: 0.1454 - Acc: 83.87%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [03:49<00:00,  5.82it/s]


Epoch 3 - Distillation Loss: 0.1387 - Acc: 83.96%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [03:53<00:00,  5.72it/s]


Epoch 4 - Distillation Loss: 0.1348 - Acc: 84.04%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [03:50<00:00,  5.79it/s]


Epoch 5 - Distillation Loss: 0.1322 - Acc: 84.26%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [04:02<00:00,  5.51it/s]


Epoch 6 - Distillation Loss: 0.1302 - Acc: 84.50%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [04:08<00:00,  5.37it/s]


Epoch 7 - Distillation Loss: 0.1282 - Acc: 84.65%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [04:07<00:00,  5.40it/s]


Epoch 8 - Distillation Loss: 0.1263 - Acc: 84.78%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [04:00<00:00,  5.54it/s]


Epoch 9 - Distillation Loss: 0.1242 - Acc: 85.02%


100%|██████████████████████████████████████████████████████████████████████████████| 1335/1335 [03:52<00:00,  5.75it/s]

Epoch 10 - Distillation Loss: 0.1217 - Acc: 85.30%





In [9]:
torch.save(student.state_dict(), "nanoray_student.pth")