# Training

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import random
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

random.seed(42)

In [None]:
IMAGE_SIZE = 224

train_transform = transforms.Compose([
    transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load the datasets using ImageFolder
dataset = ImageFolder(
    root="../complete-the-look-dataset/items/train-removed",
    transform=train_transform,
)

train_idx, valid_idx = train_test_split(np.arange(len(dataset)), test_size=0.2, random_state=42,
                                        shuffle=True, stratify=dataset.targets)

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, valid_idx)

model = torchvision.models.vit_b_32(weights=torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1)

for param in model.parameters():
    param.requires_grad = False

in_features = model.heads.head.in_features
model.heads = nn.Sequential(
    nn.Linear(in_features, 1024),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(256, 13)
)

# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "mps"

# Move model to the device
model.to(device)

In [None]:
# Define batch size
batch_size = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# Initialize the ViT model
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Define the loss function
criterion = nn.CrossEntropyLoss()

In [None]:
best_val_loss = float('inf')  # Initialize with a very large value
best_checkpoint_path = None
num_epochs = 20
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

model.train()
for epoch in range(num_epochs):
    # Training
    epoch_train_losses = []
    train_acc = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_train_losses.append(loss.item())
        y_pred_class = torch.argmax(torch.softmax(outputs, dim=1),dim=1)
        train_acc += (y_pred_class == labels).sum().item() / len(outputs)
    
    train_acc /= len(train_loader)
    train_accuracies.append(train_acc)
    train_loss = np.mean(epoch_train_losses)
    train_losses.append(train_loss)

    model.eval()
    epoch_val_losses = []
    val_acc = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            epoch_val_losses.append(loss.item())
            y_pred_class=torch.argmax(outputs, dim=1)
            val_acc += ((y_pred_class == labels).sum().item() / len(outputs))
        
        val_acc /= len(val_loader)
        val_accuracies.append(val_acc)
        val_loss = np.mean(epoch_val_losses)
        val_losses.append(val_loss)


        # Save the model checkpoint if validation loss improves
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            best_checkpoint_path = os.path.join("../checkpoints/items-classification-vit", 
                                                f'best_model_{current_datetime}_epoch_{epoch + 1}.pt')
            torch.save(model.state_dict(), best_checkpoint_path)
            print(f"Best checkpoint saved for epoch {epoch + 1} at {best_checkpoint_path}")

    print(f"Epoch {epoch + 1}, train_loss = {train_loss:.4f}, train_accuracy = {train_acc * 100:.2f}%, val_loss = {val_loss:.4f}, val_accuracy = {val_acc * 100:.2f}%")

In [None]:
# Plot the loss curves, accuracy curves and roc curves in a same figure
fig, ax = plt.subplots(1, 2, figsize=(15, 5))

ax[0].plot(train_losses, label="Training loss")
ax[0].plot(val_losses, label="Validation loss")
ax[0].legend()
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")

ax[1].plot(train_accuracies, label="Training accuracy")
ax[1].plot(val_accuracies, label="Validation accuracy")
ax[1].legend()
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")

# Evaluation

In [None]:
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.models.vision_transformer import VisionTransformer
import random
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

random.seed(42)

In [None]:
IMAGE_SIZE = 224

best_checkpoint_path = "../checkpoints/items-classification-vit/best_model_2024-03-28_20-32-21_epoch_19.pt"

test_transform = transforms.Compose([
    transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_dataset = ImageFolder(
    root="../complete-the-look-dataset/items/test",
    transform=test_transform,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

model = torchvision.models.vit_b_32(weights=torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1)

for param in model.parameters():
    param.requires_grad = False

in_features = model.heads.head.in_features
model.heads = nn.Sequential(
    nn.Linear(in_features, 1024),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(256, 13)
)

device = "mps"
# Load the best checkpoint
model.load_state_dict(torch.load(best_checkpoint_path))
model.to(device)

In [None]:
# Function to get all predictions for the test set
def get_all_predictions(model, test_loader):
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return all_preds, all_labels

# Get all predictions
all_preds, all_labels = get_all_predictions(model, test_loader)

# Calculate confusion matrix
conf_matrix = confusion_matrix(all_labels, all_preds)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()