In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
from PIL import ImageFile

# Allow loading truncated images safely
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Hyperparameters
device = torch.device("cuda")
batch_size = 32
num_epochs = 10
learning_rate = 1e-4
num_classes = 18  # number of landmark classes

# Model + Classifier
model = DELG(pretrained=True, use_global=True, use_local=False).to(device)  # local not needed for classification
classifier = nn.Linear(2048, num_classes).to(device)  # classification head

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=learning_rate)


for epoch in range(num_epochs):
    # Training
    model.train()
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
    
    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        feats = model(imgs)['global']           # [B, 2048]
        outputs = classifier(feats)             # [B, num_classes]
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        loop.set_postfix(loss=running_loss/total, acc=100.*correct/total)

    # Validation
    model.eval()
    classifier.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0.0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            feats = model(imgs)['global']
            outputs = classifier(feats)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * imgs.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_acc = 100.*val_correct/val_total
    val_loss_avg = val_loss / val_total
    print(f"Epoch [{epoch+1}/{num_epochs}] | Validation Loss: {val_loss_avg:.4f} | Validation Accuracy: {val_acc:.2f}%\n")
