# Urban Retrofitting Detection using Multi-Modal Temporal Vision Transformer **(Ablation - No Demographic Component)**

This notebook implements a deep learning pipeline for detecting urban retrofitting changes using temporal street view images and demographic data. The model combines Vision Transformer (ViT) features from before/after images **without demographic data** to classify urban changes into five categories.

## Table of Contents
1. Installation and Setup
2. Package Imports
3. System Configuration
4. Dataset and Model Architecture
5. Training and Validation Functions
6. Configuration and Hyperparameters
7. Data Loading and Preprocessing
8. Data Visualization
9. Model Training
10. Evaluation and Results
11. Prediction Export


## 1. Installation and Setup

Install required Python packages for deep learning, computer vision, and geospatial analysis.


In [None]:
# install python dependencies

# %pip install torch torchvision timm pandas geopandas numpy opencv-python scikit-learn matplotlib folium

## 2. Package Imports

Import all necessary libraries for data processing, model training, visualization, and geospatial operations.


In [None]:
# import python packages

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
import torchvision.transforms as T
import timm
import pandas as pd
import geopandas as gpd
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.cluster import KMeans
import folium
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
import gc
from datetime import datetime
import logging
from sklearn.utils.class_weight import compute_class_weight
from PIL import Image

## 3. System Configuration

### 3.1 GPU Availability Check

Check NVIDIA GPU availability and specifications for model training acceleration.


In [None]:
!nvidia-smi

### 3.2 Memory Management

Clear GPU memory and cache to free up resources before starting training.


In [None]:
# clear memory and cache

# del model, optimizer, images, labels, outputs
# del model, optimizer
gc.collect()
torch.cuda.empty_cache()

### 3.3 Device Configuration

Set the computation device (CUDA GPU or CPU) for PyTorch operations.


In [None]:
# set the device for torch backend

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 4. Dataset and Model Architecture

### 4.1 Custom Dataset Class

Define the `UrbanRetrofittingDataset` class that loads temporal image pairs (before/after) and labels. The dataset handles coordinate precision, image loading, and label encoding.


In [None]:
# Structure images, deomgraphic data and labels.

class UrbanRetrofittingDataset(Dataset):
    def __init__(self, excel_file, image_dir, transform=None):
        """
        Args:
            excel_file (str): Path to the Excel file with annotations.
            image_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on an image pair.
        """
        self.data = pd.read_excel(excel_file, dtype={"Latitude": str, "Longitude": str})
        # maintains precision. Do not change this two step converesion.
        self.data = self.data.astype({"Latitude": np.longdouble, "Longitude": np.longdouble})
        self.image_dir = image_dir
        self.transform = transform
        self.label_encoder = LabelEncoder()
        self.data['Change'] = self.label_encoder.fit_transform(self.data['Change'])  # Encode labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        census_block = str(row['ID'])
        lat, lon, direction = row['Latitude'], row['Longitude'], row['Heading']
        start_date, end_date = row['Start'], row['End']
        label = row['Change']

        # Load image pair Do not use f{} formatting to add longdouble number
        start_image_path = os.path.join(self.image_dir, census_block,  str(lat)+"_"+str(lon)+"_"+str(direction), f"{pd.to_datetime(start_date).strftime('%b %Y')}.png")
        end_image_path = os.path.join(self.image_dir, census_block, str(lat)+"_"+str(lon)+"_"+str(direction), f"{pd.to_datetime(end_date).strftime('%b %Y')}.png")
        start_image = Image.open(start_image_path).convert('RGB')
        end_image = Image.open(end_image_path).convert('RGB')

        if self.transform:
            start_image = self.transform(start_image)
            end_image = self.transform(end_image)

        return start_image, end_image, label

### 4.2 Vision Transformer Model Architecture

Implement the `ViTWithSocioEconomic` model that combines:
- Pretrained ViT for extracting image features from start and end images
- Difference features (end - start) to capture temporal changes
- Classification head for 5-class urban retrofitting detection


In [None]:
# Model architecture

class ViTWithSocioEconomic(nn.Module):
    def __init__(self, num_classes):
        super(ViTWithSocioEconomic, self).__init__()
        # Pretrained ViT model from timm library without classification head to get image embeddings.
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
        img_embedding_dim = self.vit.num_features

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(3 * img_embedding_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    # Forward pass
    def forward(self, start_image, end_image):
        start_feat = self.vit(start_image)
        end_feat = self.vit(end_image)
        diff_feat = end_feat - start_feat
        combined = torch.cat([start_feat, end_feat, diff_feat], dim=1)
        return self.classifier(combined)

## 5. Training and Validation Functions

### 5.1 Training Function

Implement the training loop with per-class accuracy and loss tracking. This function processes batches, computes gradients, and updates model parameters while monitoring performance metrics for each class.


In [None]:
# Training function with per-class accuracy and loss tracking

def train_epoch(model, loader, criterion, optimizer, device, num_classes):
    model.train()
    total_samples = 0
    running_loss = 0.0
    correct_total = 0
    
    class_correct = [0] * num_classes
    class_total = [0] * num_classes
    class_losses = [0.0] * num_classes
    
    for start_imgs, end_imgs, labels in loader:
        start_imgs, end_imgs = start_imgs.to(device), end_imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(start_imgs, end_imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        batch_size = labels.size(0)
        running_loss += loss.item() * batch_size
        total_samples += batch_size

        preds = outputs.argmax(dim=1)
        correct_total += (preds == labels).sum().item()

        for cls in range(num_classes):
            cls_mask = labels == cls
            if cls_mask.sum() > 0:
                cls_outputs = outputs[cls_mask]
                cls_labels = labels[cls_mask]
                class_losses[cls] += criterion(cls_outputs, cls_labels).item() * cls_mask.sum().item()
                class_correct[cls] += (cls_outputs.argmax(dim=1) == cls_labels).sum().item()
                class_total[cls] += cls_mask.sum().item()

    epoch_loss = running_loss / total_samples
    overall_acc = 100 * correct_total / total_samples

    class_acc = [100 * class_correct[i] / class_total[i] if class_total[i] else 0 for i in range(num_classes)]
    class_avg_loss = [class_losses[i] / class_total[i] if class_total[i] else 0 for i in range(num_classes)]

    return epoch_loss, overall_acc, class_acc, class_avg_loss

### 5.2 Validation Function

Implement validation with top-2 accuracy metric. This function evaluates the model on validation data, computes per-class metrics, and stores predictions for later analysis. Top-2 accuracy considers a prediction correct if the true label is among the top 2 predicted classes.


In [None]:
# Validation function with top-2 accuracy and per-class metrics

def validate_epoch(model, loader, criterion, device, num_classes):
    model.eval()
    total_samples = 0
    running_loss = 0.0
    correct_total = 0

    class_correct = [0] * num_classes
    class_total = [0] * num_classes
    class_losses = [0.0] * num_classes
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for start_imgs, end_imgs, labels in loader:
            start_imgs, end_imgs = start_imgs.to(device), end_imgs.to(device)
            labels = labels.to(device)

            outputs = model(start_imgs, end_imgs)
            loss = criterion(outputs, labels)

            batch_size = labels.size(0)
            running_loss += loss.item() * batch_size
            total_samples += batch_size

            _, preds_top2 = outputs.topk(2, dim=1)
            correct_top2 = preds_top2.eq(labels.unsqueeze(1)).any(dim=1)
            stored_pred_batch = torch.where(correct_top2, labels, preds_top2[:, 0])

            all_preds.extend(stored_pred_batch.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            correct_total += correct_top2.sum().item()

            for cls in range(num_classes):
                cls_mask = labels == cls
                if cls_mask.sum() > 0:
                    cls_outputs = outputs[cls_mask]
                    cls_labels = labels[cls_mask]
                    class_losses[cls] += criterion(cls_outputs, cls_labels).item() * cls_mask.sum().item()
                    class_correct[cls] += correct_top2[cls_mask].sum().item()
                    class_total[cls] += cls_mask.sum().item()


    epoch_loss = running_loss / total_samples
    overall_acc = 100 * correct_total / total_samples

    class_acc = [100 * class_correct[i] / class_total[i] if class_total[i] else 0 for i in range(num_classes)]
    class_avg_loss = [class_losses[i] / class_total[i] if class_total[i] else 0 for i in range(num_classes)]

    return epoch_loss, overall_acc, class_acc, class_avg_loss, all_preds, all_labels

## 6. Configuration and Hyperparameters

Set up directories for checkpoints and logs, define model name, and configure hyperparameters including batch size, learning rate, number of epochs, and dataset paths.


In [None]:
# create a directory for saving and loading checkpoints.
checkpoint_dir = "ViT_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
MODEL_NAME = "ablated_demographic_component_model"
CHECKPOINT_PATH = f"{checkpoint_dir}/{MODEL_NAME}_epoch_30.pth" 
# directory for storing logs
logs_dir = "logs"
os.makedirs(logs_dir, exist_ok=True)
LOG_FILE = f"{logs_dir}/{MODEL_NAME}.txt"
GRAPH_PATH = f"Model Graphs/{MODEL_NAME}.png"

# Hyperparameters
BATCH_SIZE = 256
NUM_EPOCHS = 30
LEARNING_RATE = 0.0001
WEIGHT_DECAY_RATE = 0.00001
NUM_CLASSES = 5  # Five aspects of urban retrofitting

# Dataset and DataLoader
excel_File = "labels.xlsx"
image_dir = "Streetview_data"

### 6.1 Logging Configuration

Configure logging to write training progress to both a file and console for monitoring and debugging purposes.


In [None]:
# Configure logging to file and console

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        # Append mode
        logging.FileHandler(LOG_FILE, mode='a'),
        # Also print to console
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

## 7. Data Loading and Preprocessing

### 7.1 Data Transformations and Spatial Stratified Split

Load the dataset with appropriate transformations:
- **Training transforms**: Include data augmentation (random flips, rotations, color jitter) to improve generalization
- **Validation transforms**: Only resize and normalize (no augmentation)

Implement spatial stratified sampling to ensure:
- Geographic diversity: Samples are split across spatial clusters
- Class balance: Each cluster maintains proportional class distribution
- Reproducibility: Fixed random seed for consistent splits


In [None]:
# Data Loading

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Data transformations

train_transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Spatial stratified sampling split function 

def spatial_stratified_split(dataset, val_ratio=0.2, n_clusters=10):
    """
    Perform spatial stratified sampling split and return PyTorch-compatible UrbanRetrofittingDataset subsets.
    NOTE: n_clusters = 4 included at least 1 sample per class in each cluster. 
    Args:
        df (pd.DataFrame): Raw dataframe loaded from Excel.
        image_dir (str): Path to the image directory.
        transform (callable): Torchvision transform for images.
        val_ratio (float): Ratio for validation split.
        n_clusters (int): Number of spatial clusters.

    Returns:
        train_dataset (Subset), val_dataset (Subset), label_encoder
    """
    df = dataset.data.copy()
    
    # Ensure correct types
    df = df.dropna(subset=['Latitude', 'Longitude'])

    # Spatial clustering
    coords = df[['Latitude', 'Longitude']].values
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    df['Cluster'] = kmeans.fit_predict(coords)

    # Stratified sampling within each spatial cluster
    train_indices, val_indices = [], []

    for cluster in df['Cluster'].unique():
        cluster_df = df[df['Cluster'] == cluster]
        for label in cluster_df['Change'].unique():
            label_df = cluster_df[cluster_df['Change'] == label]
            n_val = max(1, int(len(label_df) * val_ratio))
            val_sample = label_df.sample(n=n_val, random_state=42)
            train_sample = label_df.drop(val_sample.index)
            train_indices.extend(train_sample.index)
            val_indices.extend(val_sample.index)

    # Reset index to use Subset
    df.reset_index(drop=True, inplace=True)

    dataset.data = df

    # Convert to GeoDataFrame
    gdf = gpd.GeoDataFrame(df, geometry=gpd.points_from_xy(df['Longitude'], df['Latitude']), crs='EPSG:4326')

    return train_indices, val_indices, gdf

logger.info("Data Loading Initiated!")
# initializing dataset twice to use multiple transforms for training and validation.
train_dataset = UrbanRetrofittingDataset(excel_File, image_dir, train_transform)
val_dataset = UrbanRetrofittingDataset(excel_File, image_dir, val_transform)

# use any one to get train and val indices.
train_indices, val_indices, clusters_gdf = spatial_stratified_split(dataset=train_dataset, val_ratio=0.2, n_clusters=20)
_,_,_ = spatial_stratified_split(dataset=val_dataset, val_ratio=0.2, n_clusters=20)

train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)
logger.info("Split data into training and validation.")


### 7.2 Dataset Distribution Logging

Log the distribution of classes in training and validation sets to verify balanced splits and understand data composition.


In [None]:
# Log cluster distribution

all_labels = train_dataset.dataset.data['Change']

def log_distribution(all_labels, dataset, name):
    label_counts = Counter([all_labels[i] for i in dataset.indices])
    label_encoder = dataset.dataset.label_encoder
    logger.info(f"{name} distribution:")
    for label_idx in sorted(label_counts):
        logger.info(f"\t{label_encoder.classes_[label_idx]}: {label_counts[label_idx]}")

log_distribution(all_labels, train_dataset, "Train")
log_distribution(all_labels, val_dataset, "Validation")

### 7.3 Training Set Cluster Analysis

Display the distribution of samples across spatial clusters and classes in the training set. This helps verify that the spatial stratified split maintains class representation in each cluster.


In [None]:
# Summary of clusters and class distribution for training set

train_df = train_dataset.dataset.data.loc[train_indices]
counts = train_df.groupby(['Cluster', 'Change']).size().unstack(fill_value=0)
logger.info("Training Sample counts per class in each cluster:")
logger.info(counts)

### 7.4 Validation Set Cluster Analysis

Display the distribution of samples across spatial clusters and classes in the validation set to ensure proper validation coverage.


In [None]:
# Summary of clusters and class distribution for validation set

val_df = val_dataset.dataset.data.loc[val_indices]
counts = val_df.groupby(['Cluster', 'Change']).size().unstack(fill_value=0)
logger.info("Validation Sample counts per class in each cluster:")
logger.info(counts)

## 8. Data Visualization

### 8.1 Color Mapping Setup

Prepare color schemes for visualizing different urban retrofitting classes on maps. Each class is assigned a distinct color for easy identification.


In [None]:
import matplotlib.cm as cm
import matplotlib.colors as mcolors


# Prepare colors for classes
classes = train_dataset.dataset.label_encoder.classes_
num_classes = len(classes)
# colormap = cm.get_cmap('tab10', num_classes)
# class_colors = {classes[i]: mcolors.rgb2hex(colormap(i)) for i in range(num_classes)}
class_colors = {
    'No Change': "#9caab6",
    're-building': "#8243ab",
    're-capital': '#2e8453',
    're-inhabitation': '#08478c',
    're-transportation': "#b82200"
}

# Get cluster polygons (convex hull)
cluster_polys = clusters_gdf[['Cluster', 'geometry']].dissolve(by='Cluster', as_index=False, aggfunc='first')
cluster_polys['geometry'] = cluster_polys.geometry.convex_hull
cluster_polys = gpd.GeoDataFrame(cluster_polys, geometry="geometry", crs='EPSG:4326')
cluster_polys.rename(columns={0: 'geometry'}, inplace=True)

### 8.2 Interactive Map Creation

Create interactive Folium maps showing:
- Spatial clusters as polygons (convex hulls)
- Sample locations colored by class
- Legend for class identification

This visualization helps understand the geographic distribution of training and validation data.


In [None]:
# Function to build map
def create_folium_cluster_map(data_gdf, class_colors, title):
    m = folium.Map(location=[data_gdf['Latitude'].mean(), data_gdf['Longitude'].mean()], tiles="Cartodb Positron", zoom_start=11, zoom_control=False)
    
    # Plot clusters (polygons)
    for _, row in cluster_polys.iterrows():
        folium.GeoJson(
            row['geometry'],
            style_function=lambda feature, clr=row['Cluster']: {
                'fillColor': "#e7e7e7",
                'color': 'black',
                'weight': 1,
                'fillOpacity': 0.8
            },
            tooltip=f"Cluster {row['Cluster']}"
        ).add_to(m)

    # Add class samples as points
    for _, row in data_gdf.iterrows():
        cls = row['Change']
        color = class_colors[classes[cls]]
        popup = f"Class: {cls}<br>Cluster: {row['Cluster']}"
        folium.CircleMarker(
            location=[row['Latitude'], row['Longitude']],
            radius=2.5,
            color="white",      # Outline color
            weight=0.5,           # Outline thickness
            fill=True,
            fill_color=color,
            fill_opacity=1,
            popup=popup
        ).add_to(m)

    # Add legend
    legend_html = "<div style='position: fixed; top: 20px; right: 20px; z-index: 9999; background-color: white; padding: 10px; border: 1px solid #ccc;'>"
    legend_html += "<b>Retrofit Category</b><br>"
    for cls, color in class_colors.items():
        legend_html += f"<i style='background:{color}; width:10px;height:10px;display:inline-block;border-radius:50%;'></i> {cls}<br>"
    legend_html += "</div>"
    m.get_root().html.add_child(folium.Element(legend_html))
    # # Add the Fullscreen plugin to the map
    # Fullscreen().add_to(m)
    return m

# Generate and save maps
train_map = create_folium_cluster_map(train_df, class_colors, "Training Data Map")
val_map = create_folium_cluster_map(val_df, class_colors, "Validation Data Map")

### 8.3 Training Set Map Visualization

Display the interactive map of training samples with spatial clusters and class-colored points.


In [None]:
# visualize training set map.

train_map

### 8.4 Validation Set Map Visualization

Display the interactive map of validation samples with spatial clusters and class-colored points.


In [None]:
# visualize validation set map. 

val_map

### 8.5 Training Set DataFrame Display

Display the training dataset DataFrame for inspection of data structure and content.


In [None]:
# display traing set dataframe.
train_df

### 8.6 Validation Set DataFrame Display

Display the validation dataset DataFrame for inspection of data structure and content.


In [None]:
# display validation set dataframe.
val_df

### 8.7 Class Weights for Sampling

Compute class weights for the weighted random sampler to address class imbalance during training. These weights ensure that underrepresented classes are sampled more frequently during training.


In [None]:

# Extract labels from training dataset
# train_labels = [sample[2] for sample in train_dataset]  # sample[3] is the label in __getitem__
all_labels = train_dataset.dataset.data['Change'] #defined above already
train_labels = [all_labels[i] for i in train_dataset.indices]
num_classes = len(set(train_labels))

# Compute class sample weights
# class_counts = np.array([train_labels.count(i) for i in range(num_classes)])
# class_weights_sampler = 1. / class_counts

# Class sample weights calculated manually to avoid extreme weights for very rare classes.
class_weights_sampler = np.array([0.0036, 0.0034, 0.0150, 0.0050, 0.0060]) # v21
sample_weights = np.array([class_weights_sampler[label] for label in train_labels])

### 8.8 Weighted Random Sampler

Create a weighted random sampler that uses the computed class weights to balance the training data during batch creation. This helps the model learn from all classes more evenly.


In [None]:
# Weighted sampler for drawing training samples.
sampler = WeightedRandomSampler(weights=sample_weights,
                                 num_samples=len(sample_weights),
                                 replacement=True)

### 8.9 Class Weights for Loss Function

Compute class weights for the loss function to penalize misclassifications of rare classes more heavily. This helps the model focus on learning underrepresented classes during optimization.


In [None]:

# Compute class weights for loss function.
# class_weights = compute_class_weight(
#     class_weight='balanced',
#     classes=np.unique(labels),
#     y=labels
# )

# Manually calculated class weights to avoid extreme weights for very rare classes.
class_weights = np.array([1.66, 7.22, 5.56, 9.22, 9.34]) #v21

weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
logger.info("Class weights used for loss function:")
for label, weight in zip(train_dataset.dataset.label_encoder.classes_, class_weights):
    logger.info(f"\tClass {label}: Weight={weight:.4f}")

### 8.10 Data Loading Workers Configuration

Determine the optimal number of worker processes for parallel data loading. Using multiple workers speeds up data preprocessing and loading during training.


In [None]:
# to enable faster data loading using multiple workers for data loading.

import multiprocessing
num_workers = multiprocessing.cpu_count() // 2
num_workers
# Above result was 48

### 8.11 DataLoader Initialization

Create PyTorch DataLoaders for training and validation sets with:
- Batch size configuration
- Weighted sampling for training (to handle class imbalance)
- Multiple workers for parallel data loading
- Pinned memory for faster GPU transfer


In [None]:
# Data Loaders.

logger.info(f"Batch Size: {BATCH_SIZE}, Workers: {num_workers}, pin_memory: True, LR: {LEARNING_RATE}, Wt. Decay: {WEIGHT_DECAY_RATE}")# this will be a lazy loader. Just initialization happens here. Real work happens inside for loop of data loader.

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=num_workers, pin_memory=True)
logger.info("Training Loader is now ready!")

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)
logger.info("Validation Loader is now ready!")

## 9. Model Training

### 9.1 Training Loop with Checkpointing

Initialize the model, loss function, and optimizer. The training loop:
- Trains the model for specified number of epochs
- Validates after each epoch with top-2 accuracy
- Logs per-class and overall metrics
- Saves checkpoints periodically (after epoch 18) for model recovery
- Supports resuming from checkpoints if training is interrupted


In [None]:
# Inference

# Model, Loss, Optimizer
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTWithSocioEconomic(NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss(weight=weights_tensor, label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY_RATE)

# Load checkpoint if exists
if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    train_acc_per_class = checkpoint['train_acc_per_class']
    val_acc_per_class = checkpoint['val_acc_per_class']
    train_loss_per_class = checkpoint['train_loss_per_class']
    val_loss_per_class = checkpoint['val_loss_per_class']
    overall_train_acc = checkpoint['overall_train_acc']
    overall_val_acc = checkpoint['overall_val_acc']
    overall_train_loss = checkpoint['overall_train_loss']
    overall_val_loss = checkpoint['overall_val_loss']
    
    val_preds = checkpoint['val_preds']
    val_labels = checkpoint['val_labels']
    
    logger.info(f"Loaded checkpoint from epoch {start_epoch}")
else:
    # Initialize lists to store losses and accuracies
    start_epoch = 0
    train_acc_per_class = []
    val_acc_per_class = []
    train_loss_per_class = []
    val_loss_per_class = []
    overall_train_acc = []
    overall_val_acc = []
    overall_train_loss = []
    overall_val_loss = []
    val_preds = None
    val_labels = None

    logger.info("No checkpoint found - starting new training")

for epoch in range(start_epoch, NUM_EPOCHS):
    train_loss, train_overall_acc,  train_class_acc, train_class_loss = train_epoch(model, train_loader, criterion, optimizer, device, NUM_CLASSES)
    val_loss, val_overall_acc, val_class_acc, val_class_loss, val_preds, val_labels  = validate_epoch(model, val_loader, criterion, device, NUM_CLASSES)
    
    train_acc_per_class.append(train_class_acc)
    val_acc_per_class.append(val_class_acc)
    train_loss_per_class.append(train_class_loss)
    val_loss_per_class.append(val_class_loss)
    overall_train_acc.append(train_overall_acc)
    overall_val_acc.append(val_overall_acc)
    overall_train_loss.append(train_loss)
    overall_val_loss.append(val_loss)

    logger.info(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    logger.info(f"Train Accuracy: {train_overall_acc:.2f}%")
    logger.info(f"Train Loss: {train_loss:.4f}")
    for idx, cls in enumerate(train_dataset.dataset.label_encoder.classes_):
        logger.info(f"\tTrain Class '{cls}': Acc={train_class_acc[idx]:.2f}%, Loss={train_class_loss[idx]:.4f}")
    
    logger.info(f"Validation Accuracy: {val_overall_acc:.2f}%")
    logger.info(f"Validation Loss: {val_loss:.4f}")
    for idx, cls in enumerate(train_dataset.dataset.label_encoder.classes_):
        logger.info(f"\tVal Class '{cls}': Acc={val_class_acc[idx]:.2f}%, Loss={val_class_loss[idx]:.4f}")

    # Save checkpoints every 2 epochs
    # if (epoch + 1) % 2 == 0:
    if (epoch + 1) >= 18:
        CHECKPOINT_PATH = os.path.join(checkpoint_dir, f"ablated_demographic_component_model_epoch_{epoch+1}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_acc_per_class': train_acc_per_class,
            'val_acc_per_class': val_acc_per_class,
            'train_loss_per_class': train_loss_per_class,
            'val_loss_per_class': val_loss_per_class,
            'overall_train_acc':overall_train_acc,
            'overall_val_acc':overall_val_acc,
            'overall_train_loss':overall_train_loss,
            'overall_val_loss':overall_val_loss,
            'val_preds': val_preds,
            'val_labels': val_labels
    
        }, CHECKPOINT_PATH)
        logger.info(f"Saved checkpoint to {CHECKPOINT_PATH}")

    logger.info("-"*100)

## 10. Evaluation and Results

### 10.1 Classification Report and Confusion Matrix

Generate detailed classification metrics including:
- Precision, recall, and F1-score for each class
- Confusion matrix visualization showing prediction vs. actual labels
- Overall model performance summary


In [None]:
# Classification Report
classes = val_dataset.dataset.label_encoder.classes_
print(classification_report(val_labels, val_preds, target_names=classes, digits=4))

# Confusion Matrix
conf_matrix = confusion_matrix(val_labels, val_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()