In [61]:
import yaml
import math

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch
from torchvision.transforms import functional as F
import torch.nn as nn

import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched

from vit import ViTBackbone, ViTWithExtraFeatures, ClassificationHead
from feature_extractors import LBPExtractor, HOGExtractor, SIFTExtractor

In [62]:
with open("configs.yaml", "r", encoding="utf-8") as f:
    cfg = yaml.safe_load(f)

cfg_model = cfg['model']
cfg_train = cfg['training']

device = torch.device("cuda")

### Create Dataset

In [63]:
data_path = 'data'
batch_size = cfg_train["batch_size"]
num_workers = 16

In [64]:
channel_mean = [0.5524, 0.5288, 0.5107]
channel_std  = [0.0956, 0.0773, 0.0465]

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
])


In [65]:
full_dataset = datasets.ImageFolder(
    root=data_path, transform=train_transform
)

total_size = len(full_dataset)
train_size = int(0.025 * total_size)
test_size  = total_size - train_size

train_dataset, test_dataset = random_split(
    full_dataset, [train_size, test_size],
    generator=torch.Generator().manual_seed(42)
)
test_dataset.dataset.transform = test_transform

In [66]:
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, 
    shuffle=True,  num_workers=num_workers
)
test_loader  = DataLoader(
    test_dataset,  batch_size=batch_size,
    shuffle=False, num_workers=num_workers
)

### Model

In [67]:
vit_backbone = ViTBackbone(
    in_channels=cfg_model["vit"]["in_channels"],
    embedding_dim=cfg_model["vit"]["embedding_dim"],
    patch_size=cfg_model["vit"]["patch_size"],
    max_patch_num=cfg_model["vit"]["max_patch_num"],
    L=cfg_model["vit"]["depth"],
    n_heads=cfg_model["vit"]["n_heads"],
    mlp_size=cfg_model["vit"]["mlp_size"]
)
extractors = nn.ModuleList([
    LBPExtractor(),
    HOGExtractor(),
    SIFTExtractor()
])

model = ViTWithExtraFeatures(
    vit_backbone=vit_backbone,
    feature_extractors=extractors,
    n_classes=cfg_model["cls"]["n_classes"]
)
model = model.to(device)

### Train (with Feature Engineering)

In [68]:
def get_lr_lambda(num_epochs, warmup_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        else:
            progress = (epoch - warmup_epochs) / (num_epochs - warmup_epochs)
            return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    return lr_lambda

In [69]:
optimizer = optim.Adam(
    model.parameters(), 
    lr=float(cfg_train["lr"])
)
scheduler = lr_sched.LambdaLR(
    optimizer, 
    lr_lambda=get_lr_lambda(
        num_epochs=cfg_train["num_epochs"],
        warmup_epochs=cfg_train["warmup_epochs"]
    )
)
criterion = nn.CrossEntropyLoss()

In [70]:
for epoch in range(cfg_train["num_epochs"]):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader, 1):
        print(f"\rProcessed {batch_idx}/{len(train_loader)} batches", end='')

        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    scheduler.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            _, predicted = torch.max(outputs, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    print(
        f"\rEpoch [{epoch+1}/{cfg_train['num_epochs']}], "
        f"Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%"
    )

print("Training finished.")

Epoch [1/30], Loss: 1.5276, Test Accuracy: 49.40%
Epoch [2/30], Loss: 1.3962, Test Accuracy: 65.86%
Epoch [3/30], Loss: 1.3164, Test Accuracy: 66.42%
Epoch [4/30], Loss: 1.3401, Test Accuracy: 50.88%
Epoch [5/30], Loss: 1.3325, Test Accuracy: 51.58%
Epoch [6/30], Loss: 1.2917, Test Accuracy: 59.82%
Epoch [7/30], Loss: 1.2509, Test Accuracy: 65.61%
Epoch [8/30], Loss: 1.2485, Test Accuracy: 65.02%
Epoch [9/30], Loss: 1.2294, Test Accuracy: 62.62%
Epoch [10/30], Loss: 1.1839, Test Accuracy: 60.96%
Epoch [11/30], Loss: 1.1413, Test Accuracy: 62.28%
Epoch [12/30], Loss: 1.1064, Test Accuracy: 64.91%
Epoch [13/30], Loss: 1.0767, Test Accuracy: 67.15%
Epoch [14/30], Loss: 1.0509, Test Accuracy: 67.66%
Epoch [15/30], Loss: 1.0246, Test Accuracy: 67.54%
Epoch [16/30], Loss: 0.9950, Test Accuracy: 67.21%
Epoch [17/30], Loss: 0.9685, Test Accuracy: 66.34%
Epoch [18/30], Loss: 0.9493, Test Accuracy: 66.06%
Epoch [19/30], Loss: 0.9358, Test Accuracy: 66.31%
Epoch [20/30], Loss: 0.9252, Test Accura