In [None]:
# Chakma Numerals Recognition System
# Hybrid CNN-Transformer with Encoder-Decoder Architecture and Explainable AI
# With ResNet50 backbone

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import math
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from PIL import Image
import cv2
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_fscore_support
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset, random_split
from einops import rearrange, reduce

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# Google Drive results directory
RESULTS_DIR = "/content/drive/MyDrive/_Chakma_Numbers_Results_ResNet50_"
os.makedirs(RESULTS_DIR, exist_ok=True)

# ---------------------------------------------------------------------------
# Positional Encoding for Transformer
# ---------------------------------------------------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

# ---------------------------------------------------------------------------
# CBAM Attention
# ---------------------------------------------------------------------------
class CBAMAttention(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super().__init__()
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction_ratio, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction_ratio, channels, 1),
            nn.Sigmoid()
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Channel attention
        channel_att = self.channel_attention(x)
        x_channel = x * channel_att

        # Spatial attention
        avg_out = torch.mean(x_channel, dim=1, keepdim=True)
        max_out, _ = torch.max(x_channel, dim=1, keepdim=True)
        spatial = torch.cat([avg_out, max_out], dim=1)
        spatial_att = self.spatial_attention(spatial)

        return x_channel * spatial_att, spatial_att

# ---------------------------------------------------------------------------
# Enhanced Hybrid CNN-Transformer Model with Encoder-Decoder
# ---------------------------------------------------------------------------
class ChakmaNumeralClassifier(nn.Module):
    def __init__(self, num_classes, d_model=256, nhead=8, num_encoder_layers=3,
                 num_decoder_layers=3, max_seq_length=10, mode='classification'):
        super().__init__()
        self.mode = mode
        self.max_seq_length = max_seq_length
        self.num_classes = num_classes
        self.attention_weights = []  # Store attention weights for visualization

        # CNN Backbone (ResNet50) - Updated from ResNet18
        backbone = models.resnet50(pretrained=False)
        backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.feature_extractor = nn.Sequential(*list(backbone.children())[:-2])

        # CBAM Attention - Updated channel size for ResNet50 (2048 channels)
        self.cbam = CBAMAttention(2048)

        # Adaptive pooling
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Encoder Projection - Updated input size for ResNet50 (2048 channels)
        self.encoder_proj = nn.Linear(2048, d_model)
        self.encoder_pos_encoding = PositionalEncoding(d_model)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=512, dropout=0.1, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)

        # Register hooks to capture attention weights
        for layer in self.transformer_encoder.layers:
            # Check if the submodule is MultiheadAttention before registering
            if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, nn.MultiheadAttention):
                 layer.self_attn.register_forward_hook(self._get_attention_hook())

        # Mode-specific components
        if mode == 'classification':
            # Classification Head
            self.fc = nn.Linear(d_model, num_classes)
        elif mode == 'sequence':
            # Decoder components for sequence recognition
            self.tgt_embedding = nn.Embedding(num_classes + 1, d_model)  # +1 for SOS/EOS tokens
            self.decoder_pos_encoding = PositionalEncoding(d_model)

            decoder_layer = nn.TransformerDecoderLayer(
                d_model=d_model, nhead=nhead, dim_feedforward=512, dropout=0.1, batch_first=True
            )
            self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)
            self.fc_out = nn.Linear(d_model, num_classes + 1)  # +1 for EOS token

            # optional CTC (left defined but not used here for simplicity)
            self.ctc_loss = nn.CTCLoss(blank=num_classes)

    def _get_attention_hook(self):
        def hook(module, input, output):
            # Store attention weights for visualization
            # Check if output is a tuple with at least two elements and the second element is a tensor
            if isinstance(output, tuple) and len(output) > 1 and isinstance(output[1], torch.Tensor):
                self.attention_weights.append(output[1].detach().cpu())
            else:
                # Append None if attention weights are not available or not in expected format
                self.attention_weights.append(None)
        return hook

    def forward(self, x, targets=None):
        self.attention_weights = []  # Reset attention weights

        # Feature extraction (shared)
        features = self.feature_extractor(x)
        features, spatial_att = self.cbam(features)  # Get spatial attention from CBAM
        features_pooled = self.adaptive_pool(features)
        features_pooled = features_pooled.view(features_pooled.size(0), -1)  # [B, 2048]

        # Project to transformer dimension and add positional enc
        encoded = self.encoder_proj(features_pooled).unsqueeze(1)  # [B, 1, d_model]
        encoded = self.encoder_pos_encoding(encoded)

        # Memory via transformer encoder
        memory = self.transformer_encoder(encoded)  # [B, seq_len=1, d_model]

        if self.mode == 'classification':
            output = memory.squeeze(1)  # [B, d_model]
            return self.fc(output), spatial_att  # Return both output and spatial attention

        elif self.mode == 'sequence':
            batch_size = x.size(0)

            if targets is not None:
                sos_tokens = torch.zeros(batch_size, 1, dtype=torch.long, device=x.device)
                decoder_input = torch.cat([sos_tokens, targets], dim=1)
                tgt_embed = self.tgt_embedding(decoder_input)
                tgt_embed = self.decoder_pos_encoding(tgt_embed)
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(decoder_input.size(1)).to(x.device)
                output = self.transformer_decoder(tgt_embed, memory, tgt_mask=tgt_mask)
                return self.fc_out(output), spatial_att
            else:
                # Greedy decode
                decoded = torch.zeros(batch_size, self.max_seq_length, dtype=torch.long, device=x.device)
                output_tokens = torch.zeros(batch_size, 1, dtype=torch.long, device=x.device)

                for t in range(self.max_seq_length):
                    tgt_embed = self.tgt_embedding(output_tokens)
                    tgt_embed = self.decoder_pos_encoding(tgt_embed)
                    tgt_mask = nn.Transformer.generate_square_subsequent_mask(output_tokens.size(1)).to(x.device)
                    output = self.transformer_decoder(tgt_embed, memory, tgt_mask=tgt_mask)
                    output = self.fc_out(output[:, -1, :])
                    next_token = output.argmax(-1)
                    decoded[:, t] = next_token
                    output_tokens = torch.cat([output_tokens, next_token.unsqueeze(1)], dim=1)

                return decoded, spatial_att

# ---------------------------------------------------------------------------
# Explainable AI Tools
# ---------------------------------------------------------------------------
class XAITools:
    @staticmethod
    def grad_cam(model, input_tensor, target_layer, target_class=None):
        """Simplified Grad-CAM implementation using input gradients as an approximation."""
        model.eval()

        # Ensure input tensor requires gradients
        input_batch = input_tensor.unsqueeze(0).to(device)
        input_batch.requires_grad = True

        output, _ = model(input_batch)

        if target_class is None:
            target_class = output.argmax(dim=1).item()

        # Backward pass to get gradients w.r.t. input
        model.zero_grad()
        output[0, target_class].backward()

        # Use input gradients as an approximate importance map
        if input_batch.grad is not None:
            # Average gradients across channels
            cam = torch.mean(torch.abs(input_batch.grad[0]), dim=0)
            # Resize CAM to original image size
            cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0),
                                size=input_tensor.shape[1:],
                                mode='bilinear',
                                align_corners=False).squeeze()
            # Normalize CAM
            cam = cam - cam.min()
            cam = cam / (cam.max() + 1e-8) # Add small epsilon to avoid division by zero
            return cam.cpu().numpy(), target_class

        return None, target_class

    @staticmethod
    def visualize_attention(model, input_tensor):
        """Visualize attention maps from CBAM and transformer."""
        model.eval()
        with torch.no_grad():
            # Get spatial attention from CBAM
            _, spatial_att = model(input_tensor.unsqueeze(0).to(device))

            if spatial_att is not None:
                # Process spatial attention
                spatial_att = spatial_att.squeeze().cpu().numpy()

                # Resize to match input image size
                h, w = input_tensor.shape[1], input_tensor.shape[2]
                att_mask = cv2.resize(spatial_att, (w, h))

                # Normalize
                att_mask = (att_mask - att_mask.min()) / (att_mask.max() - att_mask.min() + 1e-8)

                # Create attention map by applying mask to original image
                img_np = input_tensor.squeeze().cpu().numpy()
                att_map = att_mask * img_np

                return att_mask, att_map
            else:
                return None, None

    @staticmethod
    def feature_importance_analysis(model, dataloader, class_idx, num_samples=10):
        """Analyze which input pixels/features are important for class_idx"""
        model.eval()
        important_features = []

        cnt = 0
        for images, labels in dataloader:
            if cnt >= num_samples:
                break
            images = images.to(device)
            images.requires_grad = True
            outputs, _ = model(images)
            if outputs.dim() == 1:
                outputs = outputs.unsqueeze(0)
            model.zero_grad()
            # Sum output logits for the class of interest across batch
            target_scores = outputs[:, class_idx].sum()
            target_scores.backward(retain_graph=False)

            # gradient magnitude averaged across channels
            if images.grad is not None:
                grad_abs = torch.mean(torch.abs(images.grad), dim=1)  # [B, H, W]
                important_features.append(grad_abs.detach().cpu().numpy())
            images.requires_grad = False
            cnt += images.size(0)

        if len(important_features) == 0:
            return None

        return np.mean(np.concatenate(important_features, axis=0), axis=0)  # [H, W]

# ---------------------------------------------------------------------------
# Enhanced Dataset with Sequence Support
# ---------------------------------------------------------------------------
class EnhancedChakmaDataset(Dataset):
    def __init__(self, root_dir, transform=None, mode='classification', seq_length=5):
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.mode = mode
        self.seq_length = seq_length
        self.samples = []
        self.sequences = []

        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            for img in os.listdir(cls_dir):
                if img.lower().endswith((".png", ".jpg", ".jpeg", ".bmp")):
                    self.samples.append((os.path.join(cls_dir, img), self.class_to_idx[cls]))

        if mode == 'sequence':
            self._create_sequences()

        self.transform = transform

    def _create_sequences(self):
        from collections import defaultdict
        file_groups = defaultdict(list)
        for path, label in self.samples:
            base_name = os.path.basename(path).split('_')[0]
            file_groups[base_name].append((path, label))
        for base_name, samples in file_groups.items():
            if len(samples) >= self.seq_length:
                samples.sort(key=lambda x: x[0])
                for i in range(0, len(samples) - self.seq_length + 1):
                    sequence = samples[i:i + self.seq_length]
                    self.sequences.append(sequence)

    def __len__(self):
        return len(self.samples) if self.mode == 'classification' else len(self.sequences)

    def __getitem__(self, idx):
        if self.mode == 'classification':
            path, label = self.samples[idx]
            img = Image.open(path).convert("L")
            if self.transform:
                img = self.transform(img)
            return img, label
        else:
            sequence = self.sequences[idx]
            images = []
            labels = []
            for path, label in sequence:
                img = Image.open(path).convert("L")
                if self.transform:
                    img = self.transform(img)
                images.append(img)
                labels.append(label)
            images = torch.stack(images)  # [seq_len, C, H, W]
            labels = torch.tensor(labels, dtype=torch.long)
            return images, labels

# ---------------------------------------------------------------------------
# Training and Evaluation Functions
# ---------------------------------------------------------------------------
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=20, mode='classification'):
    # Add progress monitoring
    import time
    start_time = time.time()
    print(f"[INFO] Training started at {time.strftime('%Y-%m-%d %H:%M:%S')}")

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    model.to(device)

    # For storing best and worst models
    best_model_state = None
    best_val_acc = 0.0
    worst_model_state = None
    worst_val_acc = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0.0, 0, 0
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for batch in train_bar:
            if mode == 'classification':
                imgs, labels = batch
                imgs, labels = imgs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs, _ = model(imgs)
                loss = criterion(outputs, labels)
            else:
                imgs, labels = batch
                imgs, labels = imgs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs, _ = model(imgs, labels[:, :-1])
                loss = criterion(outputs.view(-1, outputs.size(-1)), labels[:, 1:].contiguous().view(-1))
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            if mode == 'classification':
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                train_bar.set_postfix(loss=loss.item(), acc=100.0 * correct / total)
            else:
                train_bar.set_postfix(loss=loss.item())

        avg_train_loss = train_loss / max(1, len(train_loader))
        train_acc = 100.0 * correct / total if mode == 'classification' else 0.0
        # Add checkpointing
        if epoch % 5 == 0:  # Save every 5 epochs
            checkpoint_path = os.path.join(RESULTS_DIR, f"checkpoint_epoch_{epoch}.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_train_loss,
            }, checkpoint_path)
            print(f"[INFO] Checkpoint saved at epoch {epoch}")

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]")
            for batch in val_bar:
                if mode == 'classification':
                    imgs, labels = batch
                    imgs, labels = imgs.to(device), labels.to(device)
                    outputs, _ = model(imgs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
                    val_bar.set_postfix(loss=loss.item(), acc=100.0 * correct / total)
                else:
                    imgs, labels = batch
                    imgs, labels = imgs.to(device), labels.to(device)
                    outputs, _ = model(imgs)
                    loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
                    val_loss += loss.item()
                    val_bar.set_postfix(loss=loss.item())

        avg_val_loss = val_loss / max(1, len(val_loader))
        val_acc = 100.0 * correct / total if mode == 'classification' else 0.0

        # Update best and worst models
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()

        if val_acc < worst_val_acc:
            worst_val_acc = val_acc
            worst_model_state = model.state_dict().copy()

        if scheduler is not None:
            # scheduler expects a scalar to step on
            try:
                scheduler.step(avg_val_loss)
            except Exception:
                scheduler.step()

        history["train_loss"].append(avg_train_loss)
        history["val_loss"].append(avg_val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

        print(f"Epoch {epoch+1}: Train Loss {avg_train_loss:.4f}, Val Loss {avg_val_loss:.4f}, "
              f"Train Acc {train_acc:.2f}%, Val Acc {val_acc:.2f}%")

    return model, history, best_model_state, worst_model_state

# ---------------------------------------------------------------------------
# Visualization & XAI
# ---------------------------------------------------------------------------
def plot_curves(history, mode='classification'):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history["train_loss"], label="Train")
    plt.plot(history["val_loss"], label="Validation")
    plt.title("Loss"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()

    if mode == 'classification':
        plt.subplot(1, 2, 2)
        plt.plot(history["train_acc"], label="Train")
        plt.plot(history["val_acc"], label="Validation")
        plt.title("Accuracy"); plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.legend()

    path = os.path.join(RESULTS_DIR, "training_curves.png")
    plt.savefig(path, dpi=300); plt.close()
    print(f"[INFO] Training curves saved to {path}")

def plot_roc_curve(model, test_loader, dataset, num_classes):
    """Generate ROC curves for all classes"""
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Generating ROC data"):
            images = images.to(device)
            outputs, _ = model(images)
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_labels.append(labels.numpy())

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Binarize the labels for multiclass ROC
    y_test_bin = label_binarize(all_labels, classes=range(num_classes))

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], all_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plot all ROC curves
    plt.figure(figsize=(10, 8))
    colors = ['blue', 'red', 'green', 'orange', 'purple',
              'brown', 'pink', 'gray', 'olive', 'cyan']

    for i, color in zip(range(num_classes), colors):
        if i < len(dataset.classes):
            plt.plot(fpr[i], tpr[i], color=color, lw=2,
                     label='{0} (AUC = {1:0.2f})'.format(dataset.classes[i], roc_auc[i]))

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for ResNet50-based Chakma Numeral Classification')
    plt.legend(loc="lower right")

    # Save the plot
    path = os.path.join(RESULTS_DIR, "roc_curve.png")
    plt.savefig(path, dpi=300)
    plt.close()
    print(f"[INFO] ROC curve saved to {path}")

    return roc_auc

def visualize_xai_results(model, test_loader, dataset, num_examples=3):
    model.to(device)
    model.eval()
    xai_tools = XAITools()

    samples = []
    for i, (images, labels) in enumerate(test_loader):
        if i >= num_examples:
            break
        # If batch, pick first sample of the batch
        samples.append((images[0], labels[0]))

    fig, axes = plt.subplots(num_examples, 5, figsize=(25, 5 * num_examples))
    if num_examples == 1:
        axes = axes.reshape(1, -1)

    for i, (image, label) in enumerate(samples):
        img_np = image.squeeze().cpu().numpy()
        axes[i, 0].imshow(img_np, cmap='gray')
        axes[i, 0].set_title(f'Original (True: {dataset.classes[label]})')
        axes[i, 0].axis('off')

        # Choose a convolutional layer as target_layer for Grad-CAM.
        target_layer = None
        for module in model.feature_extractor.modules():
            if isinstance(module, nn.Conv2d):
                target_layer = module  # last conv found will be used
        if target_layer is None:
            axes[i, 1].set_title('Grad-CAM not available')
            axes[i, 1].axis('off')
        else:
            cam, pred_class = xai_tools.grad_cam(model, image, target_layer)
            if cam is not None:
                axes[i, 1].imshow(img_np, cmap='gray')
                axes[i, 1].imshow(cam, cmap='jet', alpha=0.5)
                axes[i, 1].set_title(f'Grad-CAM (Pred: {dataset.classes[pred_class]})')
            else:
                axes[i, 1].set_title('Grad-CAM failed')
            axes[i, 1].axis('off')

        # Attention visualization
        try:
            att_mask, att_map = xai_tools.visualize_attention(model, image)
            if att_mask is not None:
                axes[i, 2].imshow(att_mask, cmap='viridis')
                axes[i, 2].set_title('Attention Mask (CBAM)')
                axes[i, 2].axis('off')

                axes[i, 3].imshow(att_map, cmap='gray')
                axes[i, 3].set_title('Attention Map (CBAM)')
                axes[i, 3].axis('off')
            else:
                axes[i, 2].set_title('CBAM Attention Unavailable')
                axes[i, 2].axis('off')
                axes[i, 3].set_title('CBAM Attention Unavailable')
                axes[i, 3].axis('off')

        except Exception as e:
             axes[i, 2].set_title(f'CBAM Att Error: {str(e)[:30]}...')
             axes[i, 2].axis('off')
             axes[i, 3].set_title(f'CBAM Att Error: {str(e)[:30]}...')
             axes[i, 3].axis('off')

        # Class probabilities bar chart
        with torch.no_grad():
            image_dev = image.unsqueeze(0).to(device)
            out, _ = model(image_dev)
            probs = F.softmax(out, dim=1)[0].cpu().numpy()
            # Plot top-k probabilities to avoid clutter if many classes
            classes = dataset.classes
            axes[i, 4].barh(classes, probs)
            axes[i, 4].set_title('Class Probabilities')
            axes[i, 4].set_xlim(0, 1)

    plt.tight_layout()
    path = os.path.join(RESULTS_DIR, "xai_visualizations.png")
    plt.savefig(path, dpi=300); plt.close()
    print(f"[INFO] XAI visualizations saved to {path}")

# ---------------------------------------------------------------------------
# Main Execution
# ---------------------------------------------------------------------------
def main():
    torch.manual_seed(42)
    np.random.seed(42)
    # Add memory management
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("[INFO] Cleared GPU memory")
    else:
        print("[INFO] Using CPU, no GPU memory to clear")

    MODE = 'classification'  # or 'sequence'
    root_dir = "/content/drive/MyDrive/Chakma Numerals"

    if not os.path.isdir(root_dir):
        print(f"[ERROR] Directory not found: {root_dir}")
        print("Please ensure Google Drive is mounted and the directory path is correct.")
        return

    train_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomRotation(5),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    val_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    dataset = EnhancedChakmaDataset(root_dir, transform=val_transform, mode=MODE)

    if len(dataset) == 0:
        print("[ERROR] Dataset is empty. Check dataset path and structure.")
        return

    # Class weights if classification
    class_weights = None
    if MODE == 'classification':
        labels = [label for _, label in dataset.samples]
        classes_unique = np.unique(labels)
        class_weights_arr = compute_class_weight('balanced', classes=classes_unique, y=labels)
        # Build weight vector aligned with dataset.class ordering
        # compute_class_weight returns weights for classes_unique; need to map to full range
        full_weights = np.ones(len(dataset.classes), dtype=float)
        for i, c in enumerate(classes_unique):
            full_weights[c] = class_weights_arr[i]
        class_weights = torch.tensor(full_weights, dtype=torch.float).to(device)
        print(f"[INFO] Class weights: {full_weights}")

    # Split dataset
    n = len(dataset)
    train_size = int(0.7 * n)
    val_size = int(0.15 * n)
    test_size = n - train_size - val_size
    train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))
    # apply augmentation to train
    train_ds.dataset.transform = train_transform

    batch_size = 4 if MODE == 'sequence' else 8
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)

    print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}, Test samples: {len(test_ds)}")

    model = ChakmaNumeralClassifier(
        num_classes=len(dataset.classes),
        d_model=64,
        num_encoder_layers=1,
        num_decoder_layers=1,
        mode=MODE
    ).to(device)

    if MODE == 'classification':
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

    model, history, best_model_state, worst_model_state = train_model(
        model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=10, mode=MODE
    )

    # Save training curves
    plot_curves(history, mode=MODE)

    # XAI visualizations
    visualize_xai_results(model, test_loader, dataset)

    # Test / Metrics
    model.eval()
    all_preds, all_labels = [], []
    all_probs = []  # For ROC curve
    with torch.no_grad():
        for batch in test_loader:
            if MODE == 'classification':
                imgs, labels = batch
                imgs = imgs.to(device)
                outputs, _ = model(imgs)
                probs = F.softmax(outputs, dim=1)
                preds = outputs.argmax(dim=1).cpu().numpy()
                all_preds.extend(preds.tolist())
                all_labels.extend(labels.numpy().tolist())
                all_probs.extend(probs.cpu().numpy())
            else:
                imgs, labels = batch
                imgs = imgs.to(device)
                outputs, _ = model(imgs)
                for i in range(outputs.size(0)):
                    seq_pred = outputs[i].cpu().numpy()
                    seq_true = labels[i].numpy()
                    all_preds.append(seq_pred)
                    all_labels.append(seq_true)

    if MODE == 'classification':
        acc = np.mean(np.array(all_preds) == np.array(all_labels))
        print(f"[RESULT] Test Accuracy: {acc:.4f}")

        # Confusion matrix
        cm = confusion_matrix(all_labels, all_preds)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=dataset.classes, yticklabels=dataset.classes)
        plt.title("Confusion Matrix"); plt.ylabel("True"); plt.xlabel("Predicted")
        plt.xticks(rotation=45); plt.yticks(rotation=0)
        cm_path = os.path.join(RESULTS_DIR, "confusion_matrix.png")
        plt.savefig(cm_path, dpi=300); plt.close()
        print(f"[INFO] Confusion matrix saved to {cm_path}")

        # Classification report (precision, recall, f1)
        report = classification_report(all_labels, all_preds, target_names=dataset.classes, digits=4)
        print("[INFO] Classification Report:\n", report)
        report_txt = os.path.join(RESULTS_DIR, "classification_report.txt")
        with open(report_txt, "w") as f:
            f.write(report)
        # Also save as CSV with precision/recall/f1/support columns
        precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_preds, labels=range(len(dataset.classes)))
        metrics_df = pd.DataFrame({
            "class": dataset.classes,
            "precision": precision,
            "recall": recall,
            "f1-score": f1,
            "support": support
        })
        metrics_csv = os.path.join(RESULTS_DIR, "classification_report.csv")
        metrics_df.to_csv(metrics_csv, index=False)
        print(f"[INFO] Classification report saved to {report_txt} and {metrics_csv}")

        # Generate ROC curves
        roc_auc = plot_roc_curve(model, test_loader, dataset, len(dataset.classes))

        # Print AUC values
        print("\n[INFO] AUC values for each class:")
        for i, class_name in enumerate(dataset.classes):
            print(f"  - {class_name} (AUC = {roc_auc[i]:.2f})")

    # Save model checkpoint
    model_path = os.path.join(RESULTS_DIR, "chakma_numeral_classifier.pth")
    torch.save({
        "model_state_dict": model.state_dict(),
        "class_to_idx": dataset.class_to_idx,
        "classes": dataset.classes,
        "mode": MODE
    }, model_path)
    print(f"[INFO] Model saved to {model_path}")


if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[INFO] Using device: cuda
[INFO] Cleared GPU memory
[INFO] Class weights: [0.99983333 0.99983333 0.99983333 1.0015025  0.99983333 0.99983333
 0.99983333 0.99983333 0.99983333 0.99983333]
Train samples: 4199, Val samples: 899, Test samples: 901
[INFO] Training started at 2025-09-12 13:49:45


Epoch 1/10 [Train]: 100%|██████████| 525/525 [44:27<00:00,  5.08s/it, acc=56.8, loss=0.589]


[INFO] Checkpoint saved at epoch 0


Epoch 1/10 [Val]: 100%|██████████| 113/113 [09:10<00:00,  4.87s/it, acc=78.4, loss=0.054]


Epoch 1: Train Loss 1.1400, Val Loss 0.5687, Train Acc 56.82%, Val Acc 78.42%


Epoch 2/10 [Train]: 100%|██████████| 525/525 [00:38<00:00, 13.81it/s, acc=87.4, loss=0.265]
Epoch 2/10 [Val]: 100%|██████████| 113/113 [00:04<00:00, 26.29it/s, acc=94.1, loss=0.0145]


Epoch 2: Train Loss 0.3720, Val Loss 0.2007, Train Acc 87.38%, Val Acc 94.10%


Epoch 3/10 [Train]: 100%|██████████| 525/525 [00:37<00:00, 13.82it/s, acc=92.6, loss=0.0986]
Epoch 3/10 [Val]: 100%|██████████| 113/113 [00:04<00:00, 24.96it/s, acc=88.7, loss=0.0137]


Epoch 3: Train Loss 0.2428, Val Loss 0.3175, Train Acc 92.62%, Val Acc 88.65%


Epoch 4/10 [Train]: 100%|██████████| 525/525 [00:37<00:00, 14.05it/s, acc=95.4, loss=0.261]
Epoch 4/10 [Val]: 100%|██████████| 113/113 [00:04<00:00, 25.04it/s, acc=92.5, loss=0.00485]


Epoch 4: Train Loss 0.1612, Val Loss 0.2434, Train Acc 95.36%, Val Acc 92.55%


Epoch 5/10 [Train]: 100%|██████████| 525/525 [00:37<00:00, 13.88it/s, acc=95.4, loss=0.00745]
Epoch 5/10 [Val]: 100%|██████████| 113/113 [00:04<00:00, 26.62it/s, acc=97, loss=0.0112]


Epoch 5: Train Loss 0.1506, Val Loss 0.1174, Train Acc 95.38%, Val Acc 97.00%


Epoch 6/10 [Train]: 100%|██████████| 525/525 [00:38<00:00, 13.76it/s, acc=95.9, loss=0.0338]


[INFO] Checkpoint saved at epoch 5


Epoch 6/10 [Val]: 100%|██████████| 113/113 [00:05<00:00, 21.92it/s, acc=95.9, loss=0.00449]


Epoch 6: Train Loss 0.1322, Val Loss 0.1519, Train Acc 95.93%, Val Acc 95.88%


Epoch 7/10 [Train]: 100%|██████████| 525/525 [00:38<00:00, 13.70it/s, acc=97.3, loss=0.011]
Epoch 7/10 [Val]: 100%|██████████| 113/113 [00:04<00:00, 24.31it/s, acc=97.4, loss=0.00407]


Epoch 7: Train Loss 0.0920, Val Loss 0.0990, Train Acc 97.31%, Val Acc 97.44%


Epoch 8/10 [Train]: 100%|██████████| 525/525 [00:37<00:00, 13.89it/s, acc=98, loss=0.0119]
Epoch 8/10 [Val]: 100%|██████████| 113/113 [00:04<00:00, 26.54it/s, acc=96.2, loss=0.00315]


Epoch 8: Train Loss 0.0688, Val Loss 0.1392, Train Acc 97.98%, Val Acc 96.22%


Epoch 9/10 [Train]: 100%|██████████| 525/525 [00:37<00:00, 13.92it/s, acc=97.3, loss=0.0035]
Epoch 9/10 [Val]: 100%|██████████| 113/113 [00:04<00:00, 26.52it/s, acc=97.4, loss=0.00262]


Epoch 9: Train Loss 0.0925, Val Loss 0.1144, Train Acc 97.29%, Val Acc 97.44%


Epoch 10/10 [Train]: 100%|██████████| 525/525 [00:37<00:00, 13.93it/s, acc=97.4, loss=0.0412]
Epoch 10/10 [Val]: 100%|██████████| 113/113 [00:04<00:00, 26.38it/s, acc=96.3, loss=0.0278]


Epoch 10: Train Loss 0.0906, Val Loss 0.1426, Train Acc 97.40%, Val Acc 96.33%
[INFO] Training curves saved to /content/drive/MyDrive/_Chakma_Numbers_Results_ResNet50_/training_curves.png
[INFO] XAI visualizations saved to /content/drive/MyDrive/_Chakma_Numbers_Results_ResNet50_/xai_visualizations.png
[RESULT] Test Accuracy: 0.9401
[INFO] Confusion matrix saved to /content/drive/MyDrive/_Chakma_Numbers_Results_ResNet50_/confusion_matrix.png
[INFO] Classification Report:
               precision    recall  f1-score   support

       eight     0.9651    0.9881    0.9765        84
        five     0.9118    0.8611    0.8857       108
        four     1.0000    0.9059    0.9506        85
        nine     0.8796    1.0000    0.9360        95
         one     0.9659    0.8854    0.9239        96
       seven     0.9072    0.9778    0.9412        90
         six     0.9062    0.9560    0.9305        91
       three     0.9625    0.9872    0.9747        78
         two     0.9870    0.9500    

Generating ROC data: 100%|██████████| 113/113 [00:04<00:00, 24.42it/s]


[INFO] ROC curve saved to /content/drive/MyDrive/_Chakma_Numbers_Results_ResNet50_/roc_curve.png

[INFO] AUC values for each class:
  - eight (AUC = 1.00)
  - five (AUC = 0.99)
  - four (AUC = 1.00)
  - nine (AUC = 1.00)
  - one (AUC = 1.00)
  - seven (AUC = 1.00)
  - six (AUC = 1.00)
  - three (AUC = 1.00)
  - two (AUC = 1.00)
  - zero (AUC = 0.99)
[INFO] Model saved to /content/drive/MyDrive/_Chakma_Numbers_Results_ResNet50_/chakma_numeral_classifier.pth
