In [None]:
import timm
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 5

In [None]:
"""
Model set-up
"""

model = model = timm.create_model("deit_small_patch16_224", pretrained=True, num_classes=num_classes) # 22M params

num_blocks_unfreeze = 2 # for partial ft

strategy = "head_only"  
# strategy = "partial"   
# strategy = "full"         

if strategy == "full":
    for p in model.parameters():
        p.requires_grad = True

if strategy == "head_only":
    for name, p in model.named_parameters():
        p.requires_grad = "head" in name

if strategy == "partial":
    for block in model.blocks[-num_blocks_unfreeze:]:
        for p in block.parameters():
            p.requires_grad = True
    for p in model.classifier.parameters():
        p.requires_grad = True

model = model.to(device)

In [None]:
"""
Criterion, optimizer, epochs, scheduler
"""

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-4,
    weight_decay=1e-4,
)
num_epochs = 30
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

In [None]:
"""
Train function
"""

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:

        # Prep images and labels
        images = images.repeat(1, 3, 1, 1) if images.size(1) == 1 # change size, channels, normalize
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update metrics
        total_loss += loss.item() * labels.size(0)
        correct += (logits.argmax(1) == labels).sum().item()
        total += labels.size(0)

    return (total_loss/total), (correct/total)
    

In [None]:
"""
Eval function
"""

def eval_epoch(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:

            # Prep images and labels
            images = adapt_batch_for_deit(images.to(device))
            labels = labels.to(device)

            # Forward pass
            logits = model(images)
            loss = criterion(logits, labels)

            # Update metrics
            total_loss += loss.item() * labels.size(0)
            correct += (logits.argmax(1) == labels).sum().item()
            total += labels.size(0)

    return (total_loss/total), (correct/total)

In [None]:
best_val_acc = 0.0
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = eval_epoch(model, val_loader, criterion, device)
    scheduler.step()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_deit.pt")
    print(f"Epoch {epoch+1}/{num_epochs}  train_loss={train_loss:.4f}  train_acc={train_acc:.4f}  val_loss={val_loss:.4f}  val_acc={val_acc:.4f}")
print(f"Best val accuracy: {best_val_acc:.4f}")

NameError: name 'num_epochs' is not defined