In [None]:
# Standard library
import json
import os

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import torch.nn as nn

from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, vit_l_16, vit_b_32, vit_l_32
from torchvision.transforms import Resize, Compose, ToTensor
from tqdm import tqdm
from datasets import ClassLabel

from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix

from data import load_food

from torchvision import transforms
from PIL import Image

from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam import GradCAM

from config import OUTPUT_DIR

# Set random seeds
np.random.seed(0)
torch.manual_seed(0)

In [None]:
# importing data
train_ds, val_ds = load_food.load_food()

In [None]:
# constants
NUM_CLASS = 20
IMAGE_SIZE = 224

# creating label to id and id to label mappings
labels = train_ds.features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
    
# selected classes
selected_classes = ['ramen', 'carrot_cake', 'beef_carpaccio', 'strawberry_shortcake', 'escargots', 'donuts', 'croque_madame', 'cheese_plate', 'caprese_salad', 'sashimi', 'oysters', 'caesar_salad', 'pho', 'hot_and_sour_soup', 'beef_tartare', 'creme_brulee', 'cup_cakes', 'miso_soup', 'pork_chop', 'paella']

# remapping to new indices
def map_to_new_label(image, old_mapping, new_class_names):
    old_name = old_mapping[str(image["label"])]
    new_id = new_class_names.index(old_name)
    image["label"] = new_id
    return image

def transform_fn(batch):
    batch["image"] = [map_to_new_label(img, old_mapping=id2label, new_class_names=selected_classes) for img in batch["image"]]
    return batch

train_ds = train_ds.with_transform(transform_fn)
val_ds = val_ds.with_transform(transform_fn)
new_label_feature = ClassLabel(names=selected_classes)
train_ds = train_ds.cast_column("label", new_label_feature)
val_ds = val_ds.cast_column("label", new_label_feature)

# image transformation
train_transforms = Compose([
    Resize((IMAGE_SIZE, IMAGE_SIZE)),
    ToTensor()
])

val_transforms = Compose([
    Resize((IMAGE_SIZE, IMAGE_SIZE)),
    ToTensor()
])

def map_fn(examples):
    examples["pixel_values"] = [train_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

train_ds_transformed = train_ds.with_transform(map_fn)
val_ds_transformed = val_ds.with_transform(map_fn)

In [None]:
# configuring the device used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

def get_vit_architecture(architecture_name):
  if architecture_name == "vit_b_16":
    return vit_b_16
  elif architecture_name == "vit_l_16":
    return vit_l_16
  elif architecture_name == "vit_b_32":
    return vit_b_32
  elif architecture_name == "vit_l_32":
    return vit_l_32
  else:
    raise ValueError(f"Unknown architecture name: {architecture_name}")

# the vision transformer model
def get_custom_vit(custom_head: nn.Sequential, architecture_name: str, weights=None, device=device):
  print(f"\nArchitecture: {architecture_name}; Weights: {weights}\n")
  architecture = get_vit_architecture(architecture_name)
  if weights is None:
    vit_model = architecture()
  else:
    vit_model = architecture(weights=weights)
  vit_model.heads.head = custom_head
  model = vit_model
  model.to(device)

  for param in vit_model.parameters():
      param.requires_grad = False
  for param in vit_model.heads.head.parameters():
      param.requires_grad = True

  return model

In [None]:
# running and validating the model
def run_model(model, train_loader, val_loader, hyperparameters, device=device):
  optimizer = hyperparameters['optimizer'](model.parameters(), lr=hyperparameters['lr'])
  loss_fct = CrossEntropyLoss()
  history = []
  n_epochs = hyperparameters['n_epochs']
  for epoch in range(n_epochs):
    history_item = {}
    history_item['epoch'] = epoch

    # Training loop
    model.train()
    train_correct, train_total = 0, 0
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} in training"):
      optimizer.zero_grad()

      train_x = batch["pixel_values"]
      train_y = batch["label"]
      train_x, train_y = train_x.to(device), train_y.to(device)
      train_y_hat = model(train_x)
      loss = loss_fct(train_y_hat, train_y)

      train_loss += loss.item()/ len(train_loader)
      train_correct += torch.sum(torch.argmax(train_y_hat, dim=1) == train_y).item()
      train_total += len(train_x)
      loss.backward()
      optimizer.step()

    print(f"Epoch {epoch + 1}/{n_epochs} loss: {train_loss:.3f}")
    print(f"Training accuracy: {train_correct / train_total * 100:.2f}%")
    history_item['train_loss'] = train_loss
    history_item['train_accuracy'] = train_correct / train_total * 100

    # Validation loop
    model.eval()
    with torch.no_grad():
      val_correct, val_total = 0, 0
      val_loss = 0.0
      for batch in tqdm(val_loader, desc="Validation"):
          val_x = batch["pixel_values"]
          val_y = batch["label"]
          val_x, val_y = val_x.to(device), val_y.to(device)
          val_y_hat = model(val_x)
          loss = loss_fct(val_y_hat, val_y)
          val_loss += loss.item()/ len(val_loader)

          val_correct += torch.sum(torch.argmax(val_y_hat, dim=1) == val_y).item()
          val_total += len(val_x)
      print(f"Validation loss: {val_loss:.3f}")
      print(f"Validation accuracy: {val_correct / val_total * 100:.2f}%")
      history_item['val_loss'] = val_loss
      history_item['val_accuracy'] = val_correct / val_total * 100
      history.append(history_item)

  return model, history

In [None]:
def get_predictions(model, dataloader, device):
    print('\nGetting predictions...\n')
    model.to(device)
    model.eval()
    all_preds = []
    all_labels = []
    total_val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch_x = batch["pixel_values"]
            batch_y = batch["label"]
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            batch_y_hat = model(batch_x)
            loss = CrossEntropyLoss()(batch_y_hat, batch_y)
            total_val_loss += loss.item()

            preds = torch.argmax(batch_y_hat, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())

    avg_val_loss = total_val_loss / len(dataloader)
    return all_preds, all_labels, avg_val_loss, total_val_loss

def get_metrics(all_preds, all_labels):
    print(f'\nCalculating metrics...')
    # Compute metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')

    # Print results
    print(f"Accuracy:              {accuracy:.4f}")
    print(f"Precision (macro):     {precision:.4f}")
    print(f"Recall (macro):        {recall:.4f}")
    print(f"F1-score (macro):      {f1:.4f}\n")

def plot_confusion_matrix(all_labels, all_preds, class_names=id2label.values()):
    print(f'\nPlotting confusion matrix...')
    cm = confusion_matrix(all_labels, all_preds)
    np.fill_diagonal(cm, 0)

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
        vmax=21
    )
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Vision Transformer")
    plt.tight_layout()
    plt.show()
    
def run_evaluation(model, dataloader, metrics=True, confusion_matrix=False, device=device):
    all_preds, all_labels, avg_val_loss, total_val_loss = get_predictions(model, dataloader, device)
    print(f'\nRunning evaluation...\n')

    # Compute metrics
    if metrics:
      get_metrics(all_preds, all_labels)

    # Plotting confusion matrix
    if confusion_matrix:
      plot_confusion_matrix(all_labels, all_preds, class_names=id2label)

In [None]:
def save_model_and_history(experiment_name, model, history):
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    output_pth = os.path.join(OUTPUT_DIR, f'experiments/torch_model_{experiment_name}.pth')
    torch.save(model.state_dict(), output_pth)
    print(f"Saved model to: {output_pth}")
    
    output_json = os.path.join(OUTPUT_DIR, f"experiments/training_history_{experiment_name}.json")
    with open(output_json, "w") as f:
      json.dump(history, f, indent=2)
    print(f"Saved history to: {output_json}")

In [None]:
hyperparameters = {
    'batch_size': 16,
    'n_epochs': 10,
    'lr': 0.001,
    'optimizer': AdamW
}

train_loader = DataLoader(train_ds_transformed, shuffle=True, batch_size=hyperparameters['batch_size'])
val_loader = DataLoader(val_ds_transformed, shuffle=False, batch_size=hyperparameters['batch_size'])

In [None]:
top_layers = nn.Sequential(
    nn.Linear(768, NUM_CLASS)
  )
model = get_custom_vit(custom_head=top_layers, architecture_name="vit_b_16", weights='IMAGENET1K_V1')
model, history = run_model(model, train_loader, val_loader, hyperparameters)
save_model_and_history("vit_base_16_cls_epoch10", model, history)

In [None]:
top_layers = nn.Sequential(
    nn.Linear(768, NUM_CLASS)
)
model2 = get_custom_vit(custom_head=top_layers, architecture_name="vit_b_16", weights=None)
model2.load_state_dict(torch.load(OUTPUT_DIR + '/experiments/torch_model_vit_base_16_cls_epoch10.pth'))
run_evaluation(model2, val_loader)
all_labels, all_preds, _, _ = get_predictions(model2, val_loader, device)
plot_confusion_matrix(all_labels, all_preds)

In [None]:
plot_confusion_matrix(all_labels, all_preds)

In [None]:
from config import GRAD_CAM_DIR

def reshape_transform(tensor, height=14, width=14):
    if isinstance(tensor, tuple):
        tensor = tensor[0]
    tensor = tensor[:, 1:, :]
    B, _, D = tensor.shape
    tensor = tensor.reshape(B, height, width, D).permute(0, 3, 1, 2)
    return tensor


def visualize_attention(
    model,
    atten_model_name,
    image_folder=None,
    image_idxes=None,
    layer_indices=None,
    label_names=None,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    show_only_overlay=False
):
    model.to(device)
    model.eval()
    target_layers = [model.encoder.layers[i].ln_1 for i in layer_indices]

    # Ensure gradients are enabled for target layers
    for layer in target_layers:
        for param in layer.parameters():
            param.requires_grad = True

    # Image transform
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    row_images = []

    image_paths = []
    if image_folder:
        image_paths = [os.path.join(image_folder, fname) for fname in sorted(os.listdir(image_folder)) if fname.endswith(".jpg") or fname.endswith(".jpeg") or fname.endswith(".png")]
        if image_idxes is not None:
            image_paths = [image_paths[i] for i in image_idxes]

    for idx, image_path in enumerate(image_paths):
        image = Image.open(image_path).convert("RGB")
        img_tensor = transform(image)
        img_np = np.array(image.resize((224, 224))).astype(np.float32) / 255.0

        input_tensor = img_tensor.unsqueeze(0).to(device)
        input_tensor.requires_grad = True

        with torch.no_grad():
            pred_class_idx = model(input_tensor).argmax(dim=1).item()

        print(f"Processing {os.path.basename(image_path)} | Pred = {label_names[pred_class_idx]}")

        cams_this_row = []
        last_grayscale_cam = None

        for layer in target_layers:
            cam = GradCAM(model=model, target_layers=[layer], reshape_transform=reshape_transform)
            grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(pred_class_idx)])[0]
            cams_this_row.append(grayscale_cam)
            last_grayscale_cam = grayscale_cam

        overlay_img = show_cam_on_image(img_np, last_grayscale_cam, use_rgb=True, image_weight=0.6)
        cams_this_row.append(overlay_img)
        row_images.append((cams_this_row, pred_class_idx))

    cols = 2
    rows = int(np.ceil(len(row_images) * len(layer_indices) / cols)) + 1

    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows), squeeze=False)
    axes = axes.flatten()

    img_counter = 0
    for cams_this_row, pred in row_images:
        predicted = label_names[pred]
        images_to_show = [cams_this_row[-1]] if show_only_overlay else cams_this_row

        for img in images_to_show:
            if img_counter >= len(axes):
                break
            ax = axes[img_counter]
            if img.ndim == 2:
                ax.imshow(img, cmap='jet', interpolation='nearest')
            else:
                ax.imshow(img, interpolation='nearest')

            ax.axis('off')
            ax.set_aspect('equal')
            ax.set_title(f"Pred: {predicted}", fontsize=10)
            img_counter += 1

    for ax in axes[img_counter:]:
        fig.delaxes(ax)

    os.makedirs(GRAD_CAM_DIR, exist_ok=True)
    plt.tight_layout()
    plt.savefig(f"{GRAD_CAM_DIR}/{atten_model_name}.png", dpi=150)
    plt.close()
    print(f"Saved outputs to gradcam_outputs/{atten_model_name}.png")

In [None]:
from config import IMAGE_DIR

# Configuration
label_names = val_ds.features['label'].names

# Using ViT Base 16
atten_model_name = "vit_b_16"
atten_model_path = "experiments/torch_model_vit_base_16_cls_epoch10.pth"
top_layers = nn.Sequential(
    nn.Linear(768, NUM_CLASS)
  )
model2 = get_custom_vit(custom_head=top_layers, architecture_name=atten_model_name, weights=None)
model2.load_state_dict(torch.load(OUTPUT_DIR+"/"+atten_model_path))
visualize_attention(
    model=model2,
    atten_model_name="vit_b_16",
    image_folder=IMAGE_DIR,
    layer_indices=[-1],
    label_names=label_names,
    show_only_overlay=True
)