In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets
from torchinfo import summary
from torchvision.transforms import ToTensor
from transformers import ViTForImageClassification, ViTImageProcessor
import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from tqdm import tqdm
from PIL import Image


In [None]:
import torch

# Check if GPU is available
if torch.cuda.is_available():
    # Get the current device (GPU)
    device = torch.device("cuda:0")
    # Get the name of the GPU
    gpu_name = torch.cuda.get_device_name(device)
    # Get the total memory of the GPU
    gpu_properties = torch.cuda.get_device_properties(device)
    total_memory = gpu_properties.total_memory / (1024 ** 3)  # Convert bytes to GB
    
    print(f"Using GPU: {gpu_name}")
    print(f"Total GPU memory: {total_memory:.2f} GB")
else:
    print("No GPU available. Using CPU.")
    device = torch.device("cpu")


In [None]:
# Load the pretrained ViT model and processor
model_name = "google/vit-base-patch16-224-in21k"
pretrained_vit = ViTForImageClassification.from_pretrained(model_name).to(device)
processor = ViTImageProcessor.from_pretrained(model_name)

# Modify the classifier head
class_names = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
num_classes = len(class_names)
pretrained_vit.classifier = nn.Linear(in_features=768, out_features=num_classes).to(device)

# Set random seeds for reproducibility
def set_seeds(seed=42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

set_seeds()


In [None]:
# Print a summary using torchinfo
summary(model=pretrained_vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]        
        # Convert Grayscale to 3-channel RGB
        if img.mode == 'L':
            img = np.stack([img] * 3, axis=-1)
            img = Image.fromarray(img)
        
        img = self.processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
        return img, label

def create_dataloaders(data_dir: str, processor: ViTImageProcessor, batch_size: int, validation_split: float=0.2, test_split: float=0.1):
    dataset = datasets.ImageFolder(data_dir)
    class_names = dataset.classes
    dataset_size = len(dataset)
    
    test_size = int(test_split * dataset_size)
    val_size = int(validation_split * dataset_size)
    train_size = dataset_size - val_size - test_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(123)
    )
    
    train_dataset = CustomDataset(train_dataset, processor)
    val_dataset = CustomDataset(val_dataset, processor)
    test_dataset = CustomDataset(test_dataset, processor)
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_dataloader, val_dataloader, test_dataloader, class_names



In [None]:

# Example usage
data_dir = 'Expw-F'
train_dataloader, val_dataloader, test_dataloader, class_names = create_dataloaders(
                                                                                    data_dir,
                                                                                    processor,
                                                                                    batch_size=32
                                                                                    )

# Check the sizes of the datasets
print(f"Training dataset size: {len(train_dataloader.dataset)}")
print(f"Validation dataset size: {len(val_dataloader.dataset)}")
print(f"Test dataset size: {len(test_dataloader.dataset)}")


In [None]:
# Example: iterate over the training dataset
for images, labels in train_dataloader:
    print(images.shape, labels.shape)  # Example output: torch.Size([32, 3, 112, 112]) torch.Size([32])
    break

In [None]:

# Define the mean and standard deviation used for normalization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

# Get a batch of images
image_batch, label_batch = next(iter(train_dataloader))

# Get a single image from the batch
image, label = image_batch[0], label_batch[0]

# Unnormalize the image
image = image * std[:, None, None] + mean[:, None, None]

# View the batch shapes
print(image.shape, label)

# Plot image with matplotlib
plt.imshow(image.permute(1, 2, 0)) # rearrange image dimensions to suit matplotlib [color_channels, height, width] -> [height, width, color_channels]
plt.title(class_names[label])
plt.axis('off') # Turn off axis
plt.show()


In [None]:
# Create optimizer and loss function
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), lr=0.0001)
loss_fn = nn.CrossEntropyLoss()


In [None]:
def train_model(model, train_dataloader, val_dataloader, optimizer, loss_fn, epochs, device, checkpoint_dir="checkpointsADAM"):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    best_accuracy = 0.0

    model.to(device)
    
    for epoch in range(epochs):
        model.train()  # Set model to training mode
        running_train_loss = 0.0
        correct_train = 0
        total_train = 0

        # Training loop with progress bar
        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}", unit='batch') as pbar:
            for images, labels in train_dataloader:
                images, labels = images.to(device), labels.to(device)

                optimizer.zero_grad()  # Zero the gradients
                outputs = model(images)  # Forward pass
                logits = outputs.logits
                loss = loss_fn(logits, labels)  # Compute loss
                loss.backward()  # Backward pass
                optimizer.step()  # Update parameters

                running_train_loss += loss.item()

                # Calculate training accuracy
                _, predicted_train = torch.max(logits, 1)
                total_train += labels.size(0)
                correct_train += (predicted_train == labels).sum().item()

                pbar.set_postfix(loss=running_train_loss / (pbar.n + 1))  # Update the progress bar
                pbar.update(1)  # Update progress bar

        epoch_train_loss = running_train_loss / len(train_dataloader)
        train_losses.append(epoch_train_loss)
        train_accuracy = correct_train / total_train
        train_accuracies.append(train_accuracy)
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")

        # Validation
        model.eval()  # Set model to evaluation mode
        running_val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_dataloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                logits = outputs.logits
                loss = loss_fn(logits, labels)  # Compute validation loss
                running_val_loss += loss.item()

                # Calculate validation accuracy
                _, predicted_val = torch.max(logits, 1)
                total_val += labels.size(0)
                correct_val += (predicted_val == labels).sum().item()

            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_losses.append(epoch_val_loss)
            val_accuracy = correct_val / total_val
            val_accuracies.append(val_accuracy)
            print(f"Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

            # Save checkpoint if validation accuracy improves
            if val_accuracy > best_accuracy:
                best_accuracy = save_checkpoint(model, epoch, val_accuracy, best_accuracy, checkpoint_dir)

    return train_losses, train_accuracies, val_losses, val_accuracies

# Checkpoint function to save the model
def save_checkpoint(model, epoch, accuracy, best_accuracy, checkpoint_dir="checkpointsADAM"):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}_acc_{accuracy:.4f}.pth")
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")
    return accuracy  # Update the best accuracy


In [None]:
# Train the model
train_losses, train_accuracies, val_losses, val_accuracies = train_model(pretrained_vit, train_dataloader, val_dataloader, optimizer, loss_fn, epochs=5, device=device) # type: ignore


In [None]:

# Plotting Loss and Accuracy
def plot_metrics(train_losses, val_accuracies):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 5))

    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Training Loss', color='blue')
    plt.title('Training Loss vs Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.xticks(epochs)
    plt.legend()

    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_accuracies, label='Validation Accuracy', color='orange')
    plt.title('Validation Accuracy vs Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.xticks(epochs)
    plt.legend()

    plt.tight_layout()
    plt.show()

# Call the plot function
plot_metrics(train_losses, val_accuracies)

In [None]:

# Function to plot confusion matrix
def plot_confusion_matrix(model, dataloader, device):
    model.eval()  # Set model to evaluation mode
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))

    # Plot confusion matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.show()

"""
pretrained_vit.load_state_dict(torch.load(best_checkpoint_path))
print(f"Loaded best checkpoint from {best_checkpoint_path}")
"""
# Call the plot function on the validation set
plot_confusion_matrix(pretrained_vit, val_dataloader, device)


In [None]:
import pandas as pd

# Define data to save
data = {
    'epoch': range(1, len(train_losses) + 1),
    'train_loss': train_losses,
    'train_accuracy': train_accuracies,
    'val_loss': val_losses,
    'val_accuracy': val_accuracies
}

# Create DataFrame
df = pd.DataFrame(data)

# Save to CSV
csv_filename = 'training_metrics_ADAM.csv'
df.to_csv(csv_filename, index=False)

print(f"Training metrics saved to {csv_filename}")
