In [None]:
import os
import csv
import pandas as pd
import numpy as np
from skimage import io
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import math


import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils import data as tdata
from torchvision import utils as vutils, transforms as T
from torchvision.datasets import ImageFolder

from transformers import AutoImageProcessor, CvtForImageClassification
import pytorch_lightning as pl

In [None]:
# defining the number of GPUs because PL makes it very easy for us to
# parallelize training across multiple GPUs
num_devices = torch.cuda.device_count()
num_devices

In [None]:
random_seed = 42
pl.seed_everything(random_seed)

In [None]:
hugging_model_name = "microsoft/cvt-13"
image_processor = AutoImageProcessor.from_pretrained(hugging_model_name)
image_processor

In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split

# Define input and output paths
original_data_path = "/kaggle/input/indian-sign-language-self-creation/isl_dataset"
augmented_data_path = "/kaggle/input/image-augmentation-v3-dataset/augmented_isl_dataset"
output_path = "/kaggle/train-test-split"

# Create output directories
train_dir = os.path.join(output_path, "train")
val_dir = os.path.join(output_path, "val")
test_dir = os.path.join(output_path, "test")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Get class folders from original data
classes = [d for d in os.listdir(original_data_path) if os.path.isdir(os.path.join(original_data_path, d))]

# Process original data for train, validation and test sets
for class_name in classes:
    original_class_path = os.path.join(original_data_path, class_name)
    
    # Skip if not a directory
    if not os.path.isdir(original_class_path):
        continue
    
    # Get all images for this class
    original_images = [os.path.join(original_class_path, img) for img in os.listdir(original_class_path) 
                      if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    # First split original images into temp_train and test (80/20)
    temp_train_images, test_images = train_test_split(original_images, test_size=0.2, random_state=42)
    
    # Further split temp_train into actual train and validation (75/25 of temp_train, which is 60/20 of total)
    train_images, val_images = train_test_split(temp_train_images, test_size=0.25, random_state=42)
    
    # Create class directories in train, val, and test
    os.makedirs(os.path.join(train_dir, class_name), exist_ok=True)
    os.makedirs(os.path.join(val_dir, class_name), exist_ok=True)
    os.makedirs(os.path.join(test_dir, class_name), exist_ok=True)
    
    # Copy original images to train, val, and test directories
    for img in train_images:
        shutil.copy(img, os.path.join(train_dir, class_name))
    
    for img in val_images:
        shutil.copy(img, os.path.join(val_dir, class_name))
    
    for img in test_images:
        shutil.copy(img, os.path.join(test_dir, class_name))
    
    # Add augmented data to training set only
    augmented_class_path = os.path.join(augmented_data_path, class_name)
    if os.path.exists(augmented_class_path) and os.path.isdir(augmented_class_path):
        augmented_images = [os.path.join(augmented_class_path, img) for img in os.listdir(augmented_class_path)
                           if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        # Copy all augmented images to train directory only
        for img in augmented_images:
            shutil.copy(img, os.path.join(train_dir, class_name))

print(f"Data split complete!")
print(f"Training data (original + augmented) is in '{train_dir}'")
print(f"Validation data (original only) is in '{val_dir}'")
print(f"Testing data (original only) is in '{test_dir}'")

In [None]:
kaggle_write_path = '/kaggle/working/'
mean, std = image_processor.image_mean, image_processor.image_std
img_size = image_processor.size

In [None]:
def get_dirs(root_path):
    """
        Returns `(list_classes, classes_to_idx, numpy array mapping path to class_idx)
    """
    if not os.path.exists(root_path):
        raise FileNotFoundError("Folder does not exist")
        
    classes = sorted(os.listdir(root_path))
    classes_to_idx = {c: i for i, c in enumerate(classes)}
    
    all_samples = []
    
    for idx, cl in enumerate(classes):
        path = os.path.join(root_path, cl)
        
        all_files = os.listdir(path)
        
        # add `(path, class_idx)` to all_samples
        all_samples.extend([[p, idx] for p in all_files])

    all_samples = np.array(all_samples)
    
    return classes, classes_to_idx, all_samples
    
# all_classes, classes_to_idx, all_samples = get_dirs(train_path)

In [None]:
# print(f"Number of samples: {len(all_samples)}")
# train_ratio = 0.8
# train_samples, validation_samples = train_test_split(all_samples, shuffle=True, train_size=train_ratio, stratify=all_samples[:, 1], random_state=42)
# print(f"Length of splits: {len(train_samples)}, {len(validation_samples)}")
# train_samples[0], validation_samples[0]

all_classes, classes_to_idx, train_samples = get_dirs(train_dir)
_, _, validation_samples = get_dirs(val_dir)

In [None]:
classes_to_idx

In [None]:
import os
import random
import matplotlib.pyplot as plt
from skimage import io


def visualize_dataset_examples(dataset_path, classes, num_cols=5, title="Dataset Examples"):
    """
    Display one random example from each class in the dataset
    
    Args:
        dataset_path: Path to the dataset (train_dir, val_dir)
        classes: List of class names
        num_cols: Number of columns in the grid
        title: Title for the plot
    """
    # Create a figure
    num_classes = len(classes)
    num_rows = (num_classes + num_cols - 1) // num_cols
    
    plt.figure(figsize=(3*num_cols, 3*num_rows))
    plt.suptitle(title, fontsize=16)
    
    # For each class, display one random image
    for i, class_name in enumerate(classes):
        class_path = os.path.join(dataset_path, class_name)
        
        # Skip if the class directory doesn't exist
        if not os.path.exists(class_path):
            continue
            
        # Get all image files
        images = [img for img in os.listdir(class_path) 
                 if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        if images:
            # Select a random image
            random_img = random.choice(images)
            img_path = os.path.join(class_path, random_img)
            
            # Read and display the image
            img = io.imread(img_path)
            
            plt.subplot(num_rows, num_cols, i+1)
            plt.imshow(img)
            plt.title(class_name, fontsize=10)
            plt.axis('off')
    
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()

In [None]:
print("Visualizing training set examples (one per class)...")
visualize_dataset_examples(train_dir, all_classes, title="Training Set Examples")

In [None]:
print("Visualizing validation set examples (one per class)...")
visualize_dataset_examples(val_dir, all_classes, title="Validation Set Examples")

In [None]:
class CustomDataset(tdata.Dataset):
    def __init__(self, root_path, samples, classes, transform=None, transform_args=[]):
        """_summary_

        Args:
            root_path (_type_): path from where data will be read. Like `ImageFolder`, pass in folder path where folders inside it will be treated as classes.
            samples (_type_): a numpy array that will have rows that map `(image_path, class_idx)`
            classes (_type_): a list of all class names
            transform (_type_, optional): If provided, can be torch transforms or 🤗 image processors. Defaults to None.
            transform_args (list, optional): Is a list of `(argument, value)`. Defaults to [].
        """
        
        self.root_path = root_path
        self.samples = samples
        self.transform = transform
        self.classes = classes
        print(classes)
        self.transform_args = {}
        for key, val in transform_args:
            self.transform_args[key] = val
        
        print(self.transform_args)
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        file_name, target = self.samples[idx]
        target = int(target)
        file_path = os.path.join(self.root_path, self.classes[target], file_name)
        img = io.imread(file_path)
        
        if self.transform is not None:
            img = self.transform(img, **self.transform_args)
        
        return img["pixel_values"][0], target

In [None]:
idx_to_class = {i: d for (i, d) in enumerate(all_classes)}
num_classes = len(all_classes)
str([num_classes, idx_to_class])

In [None]:
transform_args = [("return_tensors", 'pt')]

train_dataset = CustomDataset(
    train_dir, train_samples, all_classes, transform=image_processor, transform_args=transform_args)
valid_dataset = CustomDataset(val_dir, validation_samples, all_classes,
                              transform=image_processor, transform_args=transform_args)
len(train_dataset), len(valid_dataset)

In [None]:
train_dataset[0][0].shape

In [None]:
def denorm(tensor, mean=mean, std=std):
  """This denormalises the tensor so that we can use it to plot normalised images"""
  mean, std = torch.tensor([mean]), torch.tensor([std])
  # print(mean.shape)
  output = std * tensor + mean
  return torch.clamp(output, 0, 1)

In [None]:
def plot_image(tensor, label=None, denormalise=True):
  """tensor is of shape `(h, w, c)`"""
  plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

  if label is not None:
    plt.title(str(label))
  
  if denormalise:
    tensor = denorm(tensor)
    tensor.permute(1, 2, 0)
    
  plt.imshow(tensor)

In [None]:
def plot_grid(tensor, n_row):
  """Input is of shape (batch_size, c, h, w)"""

  grid_tensor = vutils.make_grid(tensor, n_row).permute(1, 2, 0)
  plot_image(grid_tensor, "Grid of Random ISL Images")

In [None]:
img = train_dataset[1][0].permute(1, 2, 0)
plot_image(img)

In [None]:
from torch.utils.data import BatchSampler, SequentialSampler

batch_size = 64

def get_loader(dataset, batch_size, num_workers=0):
    """Simple Utility function to instantiate a `DataLoader`."""
    
    persistent_workers = True if num_workers > 0 else False
    
    sampler = BatchSampler(SequentialSampler(dataset), batch_size, drop_last=False)
    return tdata.DataLoader(dataset, batch_sampler=sampler, 
                            num_workers=num_workers, 
                            # the below flag is necessary for it to not crash
                            persistent_workers=persistent_workers)

train_loader = get_loader(train_dataset, batch_size, 2)
valid_loader = get_loader(valid_dataset, batch_size, 2)

# check if the dataloader works fine by plotting a grid of images
for images, labels in train_loader:
    print(images.shape)
    plot_grid(images, 8)
    break

In [None]:
class CvTTrainingModule(pl.LightningModule):

  def __init__(self, num_labels: int, hugging_model_name: str, lr: int = 1e-6):
    super(CvTTrainingModule, self).__init__()
    self.cvt = CvtForImageClassification.from_pretrained(hugging_model_name, num_labels=num_labels,
                                                         id2label=idx_to_class, label2id=classes_to_idx,
                                                         ignore_mismatched_sizes=True)
    self.lr = lr
    # Freezing early layers initially
    for param in self.cvt.base_model.parameters():
        param.requires_grad = False

  def forward(self, pixel_values):
    outputs = self.cvt(pixel_values)
    return outputs["logits"]

  def common_step(self, batch):
    # images, labels = batch
    # outputs = self.cvt(images)["logits"]

    # l = F.cross_entropy(outputs, labels)
    # preds = torch.argmax(outputs, dim=-1)
    # num_correct = (preds == labels).sum().item()
    # accuracy = num_correct / len(batch)

    # return l, accuracy
    images, labels = batch
    outputs = self.cvt(images)["logits"]

    l = F.cross_entropy(outputs, labels)
    _, preds = torch.max(outputs, dim=1)
    num_correct = (preds == labels).sum().item()
    accuracy = num_correct / labels.size(0)  # Correction here to divide by the size of labels in the batch
    
    return l, accuracy

  def training_step(self, batch, batch_idx):

    l, accuracy = self.common_step(batch)

    self.log("training_loss", l)
    self.log("training_accuracy", accuracy)

    return l

  def validation_step(self, batch, batch_idx):
    l, accuracy = self.common_step(batch)

    self.log("val_loss", l)
    self.log("val_accuracy", accuracy)

    return l

  # def configure_optimizers(self):
  #   optimizer = torch.optim.AdamW(self.parameters(), self.lr)
  #   return optimizer
  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.parameters(), self.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9, verbose=True
    )
    return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}
      
  def unfreeze_model(self):
    for param in self.cvt.base_model.parameters():
        param.requires_grad = True

  def train_dataloader(self):
    return train_loader

  def val_dataloader(self):
    return valid_loader
      
  def save_model(self, save_path: str):
    self.cvt.save_pretrained(save_path, from_pt=True)


# class CvTTrainingModule(pl.LightningModule):

#   def __init__(self, num_labels: int, hugging_model_name: str, lr: float = 5e-5, 
#                weight_decay: float = 1e-2, dropout_rate: float = 0.3, 
#                label_smoothing: float = 0.1, use_scheduler: bool = True):
#     super(CvTTrainingModule, self).__init__()
#     self.cvt = CvtForImageClassification.from_pretrained(hugging_model_name, num_labels=num_labels,
#                                                          id2label=idx_to_class, label2id=classes_to_idx,
#                                                          ignore_mismatched_sizes=True)
    
#     # CvT uses embed_dim, not hidden_size
#     # The last element in embed_dim list is the final dimension
#     feature_dim = self.cvt.config.embed_dim[-1]
    
  #   # Add dropout to classification head
  #   self.dropout = nn.Dropout(dropout_rate)
  #   self.classifier = nn.Linear(feature_dim, num_labels)
  #   self.cvt.classifier = nn.Identity()  # Remove the original classifier
    
  #   # Save hyperparameters
  #   self.lr = lr
  #   self.weight_decay = weight_decay
  #   self.label_smoothing = label_smoothing
  #   self.use_scheduler = use_scheduler
  #   self.save_hyperparameters(ignore=['cvt'])

  # def forward(self, pixel_values):
  #   features = self.cvt(pixel_values).logits  # Get features from the backbone
  #   features = self.dropout(features)  # Apply dropout
  #   logits = self.classifier(features)  # Apply classification head
  #   return logits

  # def common_step(self, batch):
  #   images, labels = batch
    
  #   # Get logits
  #   logits = self(images)
    
  #   # Apply label smoothing to loss
  #   l = F.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
    
  #   # Calculate accuracy
  #   _, preds = torch.max(logits, dim=1)
  #   num_correct = (preds == labels).sum().item()
  #   accuracy = num_correct / labels.size(0)
    
  #   return l, accuracy

  # def training_step(self, batch, batch_idx):
  #   l, accuracy = self.common_step(batch)
    
  #   # Logging
  #   self.log("training_loss", l)
  #   self.log("training_accuracy", accuracy)
    
  #   # Log learning rate
  #   opt = self.optimizers()
  #   if opt is not None:
  #       self.log('learning_rate', opt.param_groups[0]['lr'])
    
  #   return l

  # # In your CvTTrainingModule class:
  # def validation_step(self, batch, batch_idx):
  #   images, labels = batch
  #   outputs = self.cvt(images)["logits"]
    
  #   l = F.cross_entropy(outputs, labels, label_smoothing=self.label_smoothing)
  #   _, preds = torch.max(outputs, dim=1)
  #   num_correct = (preds == labels).sum().item()
  #   accuracy = num_correct / labels.size(0)
    
  #   # THIS IS THE KEY FIX - proper metric logging
  #   self.log("val_loss", l, on_step=False, on_epoch=True, prog_bar=True)
  #   self.log("val_accuracy", accuracy, on_step=False, on_epoch=True, prog_bar=True)
    
  #   return l

  # def configure_optimizers(self):
  #   # Add weight decay for regularization
  #   optimizer = torch.optim.AdamW(
  #       self.parameters(), 
  #       lr=self.lr,
  #       weight_decay=self.weight_decay
  #   )
    
  #   if not self.use_scheduler:
  #       return optimizer
    
  #   # Learning rate scheduler with warmup
  #   train_steps = len(train_loader) * self.trainer.max_epochs
  #   warmup_steps = int(0.1 * train_steps)  # 10% warmup
    
  #   scheduler = {
  #       'scheduler': torch.optim.lr_scheduler.OneCycleLR(
  #           optimizer,
  #           max_lr=self.lr,
  #           total_steps=train_steps,
  #           pct_start=0.1,  # Warmup percentage
  #           div_factor=25,  # Initial lr = max_lr/div_factor
  #           final_div_factor=1000,  # Final lr = initial_lr/final_div_factor
  #       ),
  #       'interval': 'step',
  #       'frequency': 1
  #   }
    
  #   return {'optimizer': optimizer, 'lr_scheduler': scheduler}

  # def on_before_optimizer_step(self, optimizer):
  #   # Gradient clipping to prevent explosion
  #   torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
    
  # def train_dataloader(self):
  #   return train_loader

  # def val_dataloader(self):
  #   return valid_loader
      
  # def save_model(self, save_path: str):
  #   # Create directory if it doesn't exist
  #   os.makedirs(save_path, exist_ok=True)
  #   # Save the entire model
  #   torch.save(self.state_dict(), os.path.join(save_path, "model.pt"))
  #   # Save the backbone for HF compatibility
  #   self.cvt.save_pretrained(save_path, from_pt=True)


In [None]:
class UnfreezeCallback(pl.Callback):
    def __init__(self, unfreeze_at_epoch: int):
        super().__init__()
        self.unfreeze_at_epoch = unfreeze_at_epoch

    def on_train_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch == self.unfreeze_at_epoch:
            pl_module.unfreeze_model()
            print(f"Unfreezing model at epoch {trainer.current_epoch}")

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import glob
from pytorch_lightning.callbacks import Callback

class LearningCurveCallback(Callback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        
    def on_train_epoch_end(self, trainer, pl_module):
        # Store metrics
        metrics = trainer.callback_metrics
        self.train_losses.append(metrics.get('training_loss').item())
        self.val_losses.append(metrics.get('val_loss').item())
        self.train_accuracies.append(metrics.get('training_accuracy').item())
        self.val_accuracies.append(metrics.get('val_accuracy').item())
        
    def plot_learning_curves(self):
        epochs = range(1, len(self.train_losses) + 1)
        
        # Create figure with two subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Plot losses
        ax1.plot(epochs, self.train_losses, 'b-', label='Training Loss')
        ax1.plot(epochs, self.val_losses, 'r-', label='Validation Loss')
        ax1.set_title('Loss Learning Curves')
        ax1.set_xlabel('Epochs')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, linestyle='--', alpha=0.7)
        
        # Plot accuracies
        ax2.plot(epochs, self.train_accuracies, 'b-', label='Training Accuracy')
        ax2.plot(epochs, self.val_accuracies, 'r-', label='Validation Accuracy')
        ax2.set_title('Accuracy Learning Curves')
        ax2.set_xlabel('Epochs')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.grid(True, linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        plt.show()

In [None]:
%load_ext tensorboard
%tensorboard --logdir /kaggle/working/lightning_logs/

In [None]:
# Define paths
custom_checkpoint_path = os.path.join(kaggle_write_path, 'custom.pth')
checkpoint_path = os.path.join(kaggle_write_path, "checkpoints")

learning_curve_callback = LearningCurveCallback()

# Define callbacks
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=checkpoint_path,
    save_top_k=2,  # Save top 2 models with the lowest validation loss
    every_n_train_steps=50,  # Save every 50 steps
    mode="min",  # Minimize the monitored metric
    monitor="val_loss",  # Monitor validation loss
    filename="{epoch:02d}-{val_loss:.2f}",  # Filename format
    verbose=True  # Log saving checkpoints
)

early_stop_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss',  # Monitor validation loss
    patience=15,          # Stop after 5 epochs with no improvement
    verbose=True,        # Log when stopping
    mode='min'           # Minimize validation loss
)

unfreeze_callback = UnfreezeCallback(unfreeze_at_epoch=5) 

# Define model
# num_classes = 35  # Replace with your number of classes
num_classes = 77 # Replace with your number of classes
hugging_model_name = "microsoft/cvt-13"

model = CvTTrainingModule(num_classes, hugging_model_name)

# Trainer arguments
num_devices = 1  # Set the number of GPUs
trainer_args = {
    "accelerator": "gpu",
    "devices": num_devices,
    "strategy": "auto",
    "log_every_n_steps": 5,
    "callbacks": [early_stop_callback, checkpoint_callback, learning_curve_callback, unfreeze_callback],
    "max_epochs": 100,
    "check_val_every_n_epoch": 1,  # Force validation every epoch
    "num_sanity_val_steps": 2,     # Run validation sanity checks
}

# Initialize Trainer
trainer = pl.Trainer(**trainer_args)

# Start training
# trainer.fit(model, ckpt_path=None)
trainer.fit(model, ckpt_path=None)

In [None]:
learning_curve_callback.plot_learning_curves()

In [None]:
model.save_model('/kaggle/working/cvt_model') 

In [None]:
import os

# Define the mapping for class names (1-9 and A-Z)
# classes_to_idx = {str(i): i - 1 for i in range(1, 10)}  # Map '1'-'9' to indices 0-8
# classes_to_idx.update({chr(i): i - ord('A') + 9 for i in range(ord('A'), ord('Z') + 1)})  # Map 'A'-'Z' to indices 9-34

# Define the mapping for class names (A-Z)
# classes_to_idx= {chr(i): i - ord('A') for i in range(ord('A'), ord('Z') + 1)}
test_path = '/kaggle/train-test-split/test'

samples = []
visual_samples = []
# Iterate through subfolders
for class_name in sorted(os.listdir(test_path)):
    class_folder = os.path.join(test_path, class_name)
    if os.path.isdir(class_folder):  # Ensure it's a directory
        idx = classes_to_idx[class_name]  # Map class name to index
        vis_taken = 0
        # Iterate through images in the subfolder
        for f in sorted(os.listdir(class_folder)):
            p = os.path.join(class_folder, f)
            if vis_taken == 0:
                visual_samples.append((p, idx))
                vis_taken = 1
            samples.append((p, idx))

print(f"Total samples: {len(samples)}")
print(f"Visual samples: {len(visual_samples)}")

In [None]:
def perform_predictions(samples, model):
    """samples will be of the form `(path, target)`"""
    num_correct = 0
    
    fig, ax = plt.subplots(len(samples), sharey=True, figsize=(5, 50))
    model.eval()
    for curr_ax, (p, idx) in zip(ax, samples):
        img = io.imread(p)
        proc_img = image_processor(img, return_tensors="pt")["pixel_values"]
        outputs = model(proc_img)
        
        pred = torch.argmax(outputs, -1)
        if pred == idx:
            num_correct += 1
        
        target_class = all_classes[idx]
        curr_ax.title.set_text(f"Actual class: {target_class}, Predicted: {all_classes[pred]}")
        curr_ax.tick_params(bottom=False, left=False, labelleft=False,
                            labelright=False, labelbottom=False)
        curr_ax.imshow(img)
        
    plt.show()
    return num_correct
    
perform_predictions(visual_samples, model)

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score,  confusion_matrix
import torch
import matplotlib.pyplot as plt
import skimage.io as io
from tqdm import tqdm  # Optional, for progress display
from PIL import Image
import seaborn as sns

all_preds = []
all_labels = []

def evaluate_model(samples, model, image_processor):
    """samples will be of the form `(path, target)`"""
    fig, ax = plt.subplots(len(samples), sharey=True, figsize=(5, 50))
    model.eval()
    for curr_ax, (p, idx) in zip(ax, samples):
        img = io.imread(p)
        proc_img = image_processor(img, return_tensors="pt")["pixel_values"]
        outputs = model(proc_img)
        
        pred = torch.argmax(outputs, -1)
        all_preds.append(pred)
        all_labels.append(idx)

    # Calculate accuracy, precision, recall, and F1 score
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    # Print the results
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    return accuracy, precision, recall, f1

In [None]:
evaluate_model(samples, model, image_processor)

In [None]:
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[idx_to_class[i] for i in set(all_labels)], yticklabels=[idx_to_class[i] for i in set(all_labels)])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()

In [None]:
!pip install torchview
import random
from transformers import AutoModel
from torchview import draw_graph

random_sample = random.choice(samples)
image_path_random, label = random_sample

random_image = io.imread(image_path_random)
proc_rand_img = image_processor(random_image, return_tensors="pt")["pixel_values"]
outputs = model(proc_rand_img)

model_graph = draw_graph(model, input_data=proc_rand_img)

model_graph.visual_graph