In [49]:
import yaml
import math

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch
from torchvision.transforms import functional as F
import torch.nn as nn

from torch.amp import autocast, GradScaler
import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched

from vit import ViTBackbone, ClassificationHead

In [50]:
with open("configs.yaml", "r", encoding="utf-8") as f:
    cfg = yaml.safe_load(f)

cfg_model = cfg['model']
cfg_train = cfg['training']

device = torch.device("cuda")

### Estimate Mean and Std

In [51]:
# def compute_mean_std(dataset):
#     loader = DataLoader(dataset, batch_size=1024, num_workers=16, shuffle=False)
#     mean = 0
#     std = 0
#     nb_samples = 0

#     for batch_idx, (data, _) in enumerate(loader):
#         print(f"\rProcessed {batch_idx+1}/{len(loader)} batches", end='')
#         batch_samples = data.size(0)
#         data = data.view(batch_samples, data.size(1), -1)
#         mean += data.mean(2).sum(0)
#         std += data.std(2).sum(0)
#         nb_samples += batch_samples

#     mean /= nb_samples
#     std  /= nb_samples
#     return mean, std

# root_dir = 'data'

# def crop_to_square(img):
#     w, h = img.size
#     side = min(w, h)
#     return F.center_crop(img, side)

# stats_transform = transforms.Compose([
#     transforms.Grayscale(num_output_channels=1),
#     transforms.Lambda(crop_to_square),
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])

# full_dataset_for_stats = datasets.ImageFolder(root=root_dir, transform=stats_transform)
# mean, std = compute_mean_std(full_dataset_for_stats)
# print('Dataset mean:', mean)
# print('Dataset std: ', std)


In [52]:
# Processed 71/71 batchesDataset mean: tensor([0.5442])
# Dataset std:  tensor([0.1946])

### Create Dataset

In [53]:
mean = 0.5443
std = 0.1946

data_path = 'data'
batch_size = cfg_train["batch_size"]
num_workers = 16

In [54]:
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [55]:
full_dataset = datasets.ImageFolder(
    root=data_path, transform=train_transform
)

total_size = len(full_dataset)
train_size = int(0.8 * total_size)
test_size  = total_size - train_size

train_dataset, test_dataset = random_split(
    full_dataset, [train_size, test_size],
    generator=torch.Generator().manual_seed(42)
)
test_dataset.dataset.transform = test_transform

In [56]:
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, 
    shuffle=True,  num_workers=num_workers
)
test_loader  = DataLoader(
    test_dataset,  batch_size=batch_size,
    shuffle=False, num_workers=num_workers
)

### Model

In [57]:
vit_backbone = ViTBackbone(
    in_channels=cfg_model["vit"]["in_channels"],
    embedding_dim=cfg_model["vit"]["embedding_dim"],
    patch_size=cfg_model["vit"]["patch_size"],
    max_patch_num=cfg_model["vit"]["max_patch_num"],
    L=cfg_model["vit"]["depth"],
    n_heads=cfg_model["vit"]["n_heads"],
    mlp_size=cfg_model["vit"]["mlp_size"]
)
classifier = ClassificationHead(
    embedding_dim=cfg_model["vit"]["embedding_dim"],
    n_classes=cfg_model["cls"]["n_classes"]
)

model = nn.Sequential(
    vit_backbone,
    classifier
)
model = model.to(device)

### Train

In [58]:
def get_lr_lambda(num_epochs, warmup_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        else:
            progress = (epoch - warmup_epochs) / (num_epochs - warmup_epochs)
            return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    return lr_lambda

In [59]:
optimizer = optim.Adam(
    model.parameters(), 
    lr=float(cfg_train["lr"])
)
scheduler = lr_sched.LambdaLR(
    optimizer, 
    lr_lambda=get_lr_lambda(
        num_epochs=cfg_train["num_epochs"],
        warmup_epochs=cfg_train["warmup_epochs"]
    )
)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler(device=cfg_train["device"]) 

In [60]:
for epoch in range(cfg_train["num_epochs"]):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader, 1):
        print(f"\rProcessed {batch_idx}/{len(train_loader)} batches", end='')

        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        with autocast(device_type=device.type):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    scheduler.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    print(
        f"\rEpoch [{epoch+1}/{cfg_train['num_epochs']}], "
        f"Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%"
    )

print("Training finished.")

Epoch [1/90], Loss: 1.2293, Test Accuracy: 52.75%
Epoch [2/90], Loss: 1.0645, Test Accuracy: 59.83%
Epoch [3/90], Loss: 0.9049, Test Accuracy: 67.12%
Epoch [4/90], Loss: 0.7421, Test Accuracy: 72.54%
Epoch [5/90], Loss: 0.6506, Test Accuracy: 76.36%
Epoch [6/90], Loss: 0.5799, Test Accuracy: 78.52%
Epoch [7/90], Loss: 0.5465, Test Accuracy: 80.05%
Epoch [8/90], Loss: 0.5026, Test Accuracy: 79.60%
Epoch [9/90], Loss: 0.4777, Test Accuracy: 82.33%
Epoch [10/90], Loss: 0.4414, Test Accuracy: 83.46%
Epoch [11/90], Loss: 0.4086, Test Accuracy: 83.36%
Epoch [12/90], Loss: 0.3818, Test Accuracy: 85.08%
Epoch [13/90], Loss: 0.3601, Test Accuracy: 84.09%
Epoch [14/90], Loss: 0.3407, Test Accuracy: 86.33%
Epoch [15/90], Loss: 0.3283, Test Accuracy: 86.63%
Epoch [16/90], Loss: 0.3058, Test Accuracy: 87.67%
Epoch [17/90], Loss: 0.2865, Test Accuracy: 87.14%
Epoch [18/90], Loss: 0.2769, Test Accuracy: 87.39%
Epoch [19/90], Loss: 0.2579, Test Accuracy: 87.99%
Epoch [20/90], Loss: 0.2473, Test Accura

In [61]:
torch.save(
    vit_backbone.state_dict(),
    cfg["model_weights"]["vit_backbone"]
)
torch.save(
    classifier.state_dict(), 
    cfg["model_weights"]["classifier"]
)
print('ViT weights are saved')

ViT weights are saved
