In [None]:
import os

if os.getcwd().endswith('notebooks'): os.chdir('..')

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from src.model import GeometricLoss, get_model
from src.data import get_data

from config import MODELS_DIR, BATCH_SIZE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 100

In [None]:
train_loader, test_loader = get_data(BATCH_SIZE)

model = get_model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
# reduce LR 10x for 5 epochs
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


In [None]:
criterion_ce = nn.CrossEntropyLoss()
criterion_supcon = GeometricLoss(temperature=0.1)
alpha = 0.5  # contrastive loss weight

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(EPOCHS):
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        features, logits = model(inputs)
        
        loss_ce = criterion_ce(logits, labels)
        loss_sc = criterion_supcon(features, labels)
        loss = loss_ce + (alpha * loss_sc)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = logits.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()
    
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            _, logits = model(inputs)
            
            loss = criterion_ce(logits, labels) # only on cross entropy

            val_loss += loss.item()
            _, predicted = logits.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    epoch_train_acc = 100. * train_correct / train_total
    epoch_val_acc = 100. * val_correct / val_total
    
    history['train_loss'].append(train_loss / len(train_loader))
    history['train_acc'].append(epoch_train_acc)
    history['val_loss'].append(val_loss / len(test_loader))
    history['val_acc'].append(epoch_val_acc)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Acc: {epoch_train_acc:.2f}% | Val Acc: {epoch_val_acc:.2f}%")
    
    scheduler.step()
    
save_path = os.path.join(MODELS_DIR, "resnet18_cifar100.pth")
torch.save({
    'epoch': EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': history,
}, save_path)

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.legend(); plt.title('Loss')

plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
plt.legend(); plt.title('Accuracy')
plt.show()