In [11]:
import yaml
import math

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch
from torchvision.transforms import functional as F
import torch.nn as nn

from torch.amp import autocast, GradScaler
import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched

from vit import ViTBackbone, ClassificationHead

In [12]:
with open("configs.yaml", "r", encoding="utf-8") as f:
    cfg = yaml.safe_load(f)

cfg_model = cfg['model']
cfg_train = cfg['training']

device = torch.device("cuda")

### Create Dataset

In [13]:
data_path = 'data'
batch_size = cfg_train["batch_size"]
num_workers = 16

In [14]:
channel_mean = [0.5524, 0.5288, 0.5107]
channel_std  = [0.0956, 0.0773, 0.0465]

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
])


In [15]:
full_dataset = datasets.ImageFolder(
    root=data_path, transform=train_transform
)

total_size = len(full_dataset)
train_size = int(0.025 * total_size)
test_size  = total_size - train_size

train_dataset, test_dataset = random_split(
    full_dataset, [train_size, test_size],
    generator=torch.Generator().manual_seed(42)
)
test_dataset.dataset.transform = test_transform

In [16]:
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, 
    shuffle=True,  num_workers=num_workers
)
test_loader  = DataLoader(
    test_dataset,  batch_size=batch_size,
    shuffle=False, num_workers=num_workers
)

### Model

In [17]:
vit_backbone = ViTBackbone(
    in_channels=cfg_model["vit"]["in_channels"],
    embedding_dim=cfg_model["vit"]["embedding_dim"],
    patch_size=cfg_model["vit"]["patch_size"],
    max_patch_num=cfg_model["vit"]["max_patch_num"],
    L=cfg_model["vit"]["depth"],
    n_heads=cfg_model["vit"]["n_heads"],
    mlp_size=cfg_model["vit"]["mlp_size"]
)
classifier = ClassificationHead(
    embedding_dim=cfg_model["vit"]["embedding_dim"],
    n_classes=cfg_model["cls"]["n_classes"]
)

model = nn.Sequential(
    vit_backbone,
    classifier
)
model = model.to(device)

### Train

In [18]:
def get_lr_lambda(num_epochs, warmup_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        else:
            progress = (epoch - warmup_epochs) / (num_epochs - warmup_epochs)
            return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    return lr_lambda

In [19]:
optimizer = optim.Adam(
    model.parameters(), 
    lr=float(cfg_train["lr"])
)
scheduler = lr_sched.LambdaLR(
    optimizer, 
    lr_lambda=get_lr_lambda(
        num_epochs=cfg_train["num_epochs"],
        warmup_epochs=cfg_train["warmup_epochs"]
    )
)
criterion = nn.CrossEntropyLoss()

In [20]:
for epoch in range(cfg_train["num_epochs"]):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader, 1):
        print(f"\rProcessed {batch_idx}/{len(train_loader)} batches", end='')

        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    scheduler.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    print(
        f"\rEpoch [{epoch+1}/{cfg_train['num_epochs']}], "
        f"Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%"
    )

print("Training finished.")

Epoch [1/30], Loss: 2.1423, Test Accuracy: 27.25%
Epoch [2/30], Loss: 1.8544, Test Accuracy: 24.11%
Epoch [3/30], Loss: 1.4771, Test Accuracy: 49.40%
Epoch [4/30], Loss: 1.5881, Test Accuracy: 49.40%
Epoch [5/30], Loss: 1.7310, Test Accuracy: 49.40%
Epoch [6/30], Loss: 1.6052, Test Accuracy: 49.40%
Epoch [7/30], Loss: 1.4178, Test Accuracy: 49.37%
Epoch [8/30], Loss: 1.3798, Test Accuracy: 27.47%
Epoch [9/30], Loss: 1.4393, Test Accuracy: 27.47%
Epoch [10/30], Loss: 1.4416, Test Accuracy: 29.43%
Epoch [11/30], Loss: 1.3932, Test Accuracy: 58.11%
Epoch [12/30], Loss: 1.3392, Test Accuracy: 54.47%
Epoch [13/30], Loss: 1.3002, Test Accuracy: 49.82%
Epoch [14/30], Loss: 1.2788, Test Accuracy: 49.43%
Epoch [15/30], Loss: 1.2668, Test Accuracy: 49.40%
Epoch [16/30], Loss: 1.2542, Test Accuracy: 49.40%
Epoch [17/30], Loss: 1.2357, Test Accuracy: 49.45%
Epoch [18/30], Loss: 1.2123, Test Accuracy: 50.13%
Epoch [19/30], Loss: 1.1885, Test Accuracy: 53.32%
Epoch [20/30], Loss: 1.1686, Test Accura