In [46]:
import os
import torch
import torchvision
from torch import nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from dataset import CustomImageFolder
import tqdm
import json

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [62]:
BATCH_SIZE = 32

transform = torchvision.models.ViT_B_16_Weights.DEFAULT.transforms()

train_dataloader = torch.utils.data.DataLoader(
    CustomImageFolder(root_dir="./dataset/train", transform=transform),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)

val_dataloader = torch.utils.data.DataLoader(
    CustomImageFolder(root_dir="./dataset/val", transform=transform),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)

In [63]:
# Start from vit_b_16 model weights
model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)

# Freeze all the params
for parameter in model.parameters():
    parameter.requires_grad = False

# Modify model head to include num classes
model.heads = nn.Linear(in_features=768, out_features=len(train_dataloader.dataset.classes))

# Move model to GPU (if available)
model = model.to(device)

In [64]:
# Create optimizer and loss function
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

In [65]:
def epoch_step(model, dataloader, device, validation=False):
    # Put model in train mode
    if validation:
        model.eval()
    else:
        model.train()

    # Setup train loss and train accuracy values
    epoch_loss, epoch_acc = 0, 0

    # Loop through data loader data batches
    for batch, (X, y) in enumerate(tqdm.tqdm(dataloader)):
        # Send data to target device
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Calculate  and accumulate loss
        loss = loss_fn(y_pred, y)
        epoch_loss += loss.item() 

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

        # Calculate and accumulate accuracy metric across all batches
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        epoch_acc += (y_pred_class == y).sum().item()/len(y_pred)

    # Adjust metrics to get average loss and accuracy per batch 
    epoch_loss = epoch_loss / len(dataloader)
    epoch_acc = epoch_acc / len(dataloader)

    return epoch_loss, epoch_acc

In [None]:
!rm -rf weights
!mkdir -p weights

history = {
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": []
}

NUM_EPOCHS = 20
best_acc = 0.0

for epoch in range(NUM_EPOCHS):
    # Train 1 epoch
    train_loss, train_acc = epoch_step(
        model=model,
        dataloader=train_dataloader,
        device=device
    )
    
    # Val after 1 epoch
    val_loss, val_acc = epoch_step(
        model=model,
        dataloader=val_dataloader,
        device=device
    )

    # Print out what's happening
    print(
      f"Epoch: {epoch+1} | "
      f"train_loss: {train_loss:.4f} | "
      f"train_acc: {train_acc:.4f} | "
      f"val_loss: {val_loss:.4f} | "
      f"val_acc: {val_acc:.4f}"
    )
    
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)
    
    # write training history to disk
    with open("training_history.json", "w") as f:
        f.write(json.dumps(history))
    
    # save latest model
    torch.save(model.state_dict(), 'weights/model_latest.pth')
    
    # save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'weights/model_best.pth')

In [67]:
train_dataloader.dataset.classes

['camouflage', 'normal']