In [1]:
import os
import random
import json
import pickle
import warnings
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import cv2
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets.folder import default_loader
from torchvision.utils import make_grid, save_image
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, label_binarize
from tqdm import tqdm
import timm

# from mamba_vision import mamba_vision_T as create_model
from prediction_saver import PredictionArraySaver

warnings.filterwarnings("ignore")
sns.set()

In [2]:
# -----------------------------
# Config
# -----------------------------
BATCH_SIZE = 32
NUM_WORKERS = 4
EPOCHS = 100
LR = 1e-4
PATIENCE = 10
MAX_MISCLASSIFIED_TO_SAVE = 20
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

RESULTS_DIR = Path("results")
RESULTS_DIR.mkdir(exist_ok=True)
NEW_RESULTS_DIR = Path("final_results")
NEW_RESULTS_DIR.mkdir(exist_ok=True)

# Update model list as needed (timm model names)
MODEL_NAMES = [
#    "mobilevit_s",
#    "mobilenetv3_small_100",
#    "mobilenetv3_large_100",
#    "tf_efficientnet_b0",
#    "densenet201",
#    "inception_v3"
]

In [3]:
# -----------------------------
# Dataset class
# -----------------------------
class XrayDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.loader = default_loader

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['image_path']
        label = int(self.df.iloc[idx]['label_encoded'])
        image = self.loader(img_path).convert("RGB")
    
        if self.transform:
            view1 = self.transform(image)
            view2 = self.transform(image)
            return (view1, view2), label
        else:
            return (image, image), label

               


In [4]:
# -----------------------------
# Grad-CAM
# -----------------------------
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, inp, out):
            self.activations = out.detach()
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()
        self.hook_handles.append(self.target_layer.register_forward_hook(forward_hook))
        self.hook_handles.append(self.target_layer.register_backward_hook(backward_hook))

    def generate(self, input_tensor, target_class):
        self.model.zero_grad()
        output = self.model(input_tensor)
        loss = output[:, target_class].sum()
        loss.backward(retain_graph=True)
        grads = self.gradients
        acts = self.activations
        weights = grads.mean(dim=(2,3), keepdim=True)
        cam = (weights * acts).sum(dim=1, keepdim=True)
        cam = torch.relu(cam)
        cam = cam.squeeze().cpu().numpy()
        cam = (cam - cam.min())/(cam.max()-cam.min()+1e-8)
        return cam

    def remove_hooks(self):
        for h in self.hook_handles:
            h.remove()

def overlay_gradcam(img_pil, mask):
    img = np.array(img_pil.resize((224, 224))).astype(np.uint8)

    # Normalize mask
    mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
    mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)  # normalize 0-1
    mask = np.uint8(255 * mask)

    # Apply colormap (makes it 3-channel)
    heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

    # Ensure both are same shape & dtype
    if heatmap.shape != img.shape:
        heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))

    overlay = cv2.addWeighted(img, 0.5, heatmap, 0.5, 0)

    orig = Image.fromarray(img)
    overlay = Image.fromarray(overlay)
    return orig, overlay


def generate_and_save_gradcam_samples(model, test_df, all_labels, all_preds, save_dir, classes, max_correct=20, max_misclassified=20):
    gradcam_dir = save_dir / "gradcam_samples"
    gradcam_dir.mkdir(exist_ok=True)

    model.eval()
    #target_layer = model.backbone.features[-1]
    target_layer = model.fu4	

    cam = GradCAM(model, target_layer)

    mis_idx = np.where(np.array(all_labels)!=np.array(all_preds))[0]
    cor_idx = np.where(np.array(all_labels)==np.array(all_preds))[0]
    sel_idx = []
    if len(mis_idx)>0:
        sel_idx.extend(np.random.choice(mis_idx, min(max_misclassified,len(mis_idx)), replace=False))
    if len(cor_idx)>0:
        sel_idx.extend(np.random.choice(cor_idx, min(max_correct,len(cor_idx)), replace=False))

    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
    
    indices = list(range(len(test_df)))
    
    for i in indices:
        row = test_df.iloc[i]
        img_pil = Image.open(row['image_path']).convert("RGB").resize((224,224))
        tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
        pred_class = all_preds[i]
        true_class = all_labels[i]
        mask = cam.generate(tensor, pred_class)
        orig, overlay = overlay_gradcam(img_pil, mask)
        # side by side
        combined = np.hstack([orig, overlay])
        combined_pil = Image.fromarray(combined)
        draw = ImageDraw.Draw(combined_pil)
        text = f"True: {classes[true_class]} | Pred: {classes[pred_class]}"
        try:
            font = ImageFont.truetype("arial.ttf", 18)
        except:
            font = ImageFont.load_default()
        draw.text((5,5), text, fill=(255,0,0), font=font)
        out_path = gradcam_dir / f"{i:04d}_T-{true_class}_P-{pred_class}.png"
        combined_pil.save(out_path)
    cam.remove_hooks()
    print(f"‚úÖ Saved Grad-CAM samples to {gradcam_dir}")
	
    '''for count,i in enumerate(sel_idx):
        row = test_df.iloc[i]
        img_pil = Image.open(row['image_path']).convert("RGB").resize((224,224))
        tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
        pred_class = all_preds[i]
        true_class = all_labels[i]
        mask = cam.generate(tensor, pred_class)
        orig, overlay = overlay_gradcam(img_pil, mask)
        # side by side
        combined = np.hstack([orig, overlay])
        combined_pil = Image.fromarray(combined)
        draw = ImageDraw.Draw(combined_pil)
        text = f"True: {classes[true_class]} | Pred: {classes[pred_class]}"
        try:
            font = ImageFont.truetype("arial.ttf", 18)
        except:
            font = ImageFont.load_default()
        draw.text((5,5), text, fill=(255,0,0), font=font)
        out_path = gradcam_dir / f"{count:04d}_T-{classes[true_class]}_P-{classes[pred_class]}.png"
        combined_pil.save(out_path)
    cam.remove_hooks()
    print(f"‚úÖ Saved Grad-CAM samples to {gradcam_dir}")'''


In [5]:
# -----------------------------
# Utility functions
# -----------------------------
def print_class_distribution(df, split_name):
    counter = Counter(df['label'])
    print(f"üìä {split_name} split distribution (total {len(df)}):")
    for cls, count in counter.items():
        print(f"   {cls}: {count}")
    print()

def save_classification_report_text(report_text, path):
    with open(path, "w") as f:
        f.write(report_text)

def save_metrics_json(metrics: dict, path):
    with open(path, "w") as f:
        json.dump(metrics, f, indent=2)

def plot_and_save_confusion_matrix(cm, classes, path):
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig(path)
    plt.close()

def plot_and_save_prediction_distribution(preds, classes, path):
    counts = pd.Series(preds).value_counts().sort_index()
    # ensure all classes are present in plot
    counts = counts.reindex(range(len(classes)), fill_value=0)
    plt.figure(figsize=(6,4))
    plt.bar(classes, counts.values)
    plt.title("Predicted Class Distribution")
    plt.xlabel("Class")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(path)
    plt.close()

def plot_and_save_loss_accuracy(history, save_dir):
    # loss
    plt.figure(figsize=(8,5))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss Curve")
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_dir / "loss_curve.png")
    plt.close()

    # accuracy (val acc)
    if 'val_acc' in history:
        plt.figure(figsize=(8,5))
        plt.plot(history['val_acc'], label='Val Accuracy')
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.title("Validation Accuracy")
        plt.legend()
        plt.tight_layout()
        plt.savefig(save_dir / "accuracy_curve.png")
        plt.close()

def save_misclassified_samples(test_df, all_labels, all_preds, save_path, classes, max_samples=20, img_size=(224,224)):

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    mis_idx = np.where(all_labels != all_preds)[0]
    correct_idx = np.where(all_labels == all_preds)[0]

    selected_idx = []
    if len(mis_idx) > 0:
        selected_idx.extend(np.random.choice(mis_idx, size=min(max_samples//2, len(mis_idx)), replace=False))
    if len(correct_idx) > 0:
        selected_idx.extend(np.random.choice(correct_idx, size=min(max_samples//2, len(correct_idx)), replace=False))

    if not selected_idx:
        print("‚ö†Ô∏è No samples to display.")
        return

    tensors = []
    for i in selected_idx:
        img_path = test_df.iloc[i]['image_path']
        true_lbl = classes[all_labels[i]]
        pred_lbl = classes[all_preds[i]]

        img = default_loader(img_path).convert("RGB")
        img = img.resize(img_size)

        # add text (True | Pred)
        img_pil = img.copy()
        draw = ImageDraw.Draw(img_pil)
        try:
            font = ImageFont.truetype("arial.ttf", 16)
        except:
            font = ImageFont.load_default()
        text = f"T: {true_lbl} | P: {pred_lbl}"
        draw.text((5, 5), text, fill=(255, 0, 0), font=font)

        tensor = transforms.ToTensor()(img_pil)  # normalized 0..1
        tensors.append(tensor)

    grid = make_grid(tensors, nrow=5, normalize=True)
    save_image(grid, save_path)
    print(f"‚úÖ Saved samples(misclassified+correct) to {save_path}")

def plot_and_save_roc(all_labels, all_probs, classes, save_path):
    num_classes = len(classes)
    if num_classes == 2:
        # binary
        fpr, tpr, _ = roc_curve(all_labels, np.array(all_probs)[:, 1])
        roc_auc = auc(fpr, tpr)
        plt.figure(figsize=(6,6))
        plt.plot(fpr, tpr, label=f"ROC (AUC = {roc_auc:.2f})")
        plt.plot([0,1],[0,1],'k--', lw=2)
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("ROC Curve")
        plt.legend(loc="lower right")
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        return {"binary_auc": float(roc_auc)}
    else:
        y_true_bin = label_binarize(all_labels, classes=list(range(num_classes)))
        probs_np = np.array(all_probs)
        fpr = dict(); tpr = dict(); roc_auc = dict()
        for i in range(num_classes):
            try:
                fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], probs_np[:, i])
                roc_auc[i] = auc(fpr[i], tpr[i])
            except ValueError:
                fpr[i], tpr[i], roc_auc[i] = None, None, None

        # micro-average
        try:
            fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), probs_np.ravel())
            roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
        except ValueError:
            fpr["micro"], tpr["micro"], roc_auc["micro"] = None, None, None

        plt.figure(figsize=(8,6))
        if roc_auc["micro"] is not None:
            plt.plot(fpr["micro"], tpr["micro"],
                     label=f"micro-average (AUC = {roc_auc['micro']:.2f})",
                     color="deeppink", linestyle=":", linewidth=4)
        colors = ["blue", "green", "orange", "purple", "cyan"]
        for i, color in zip(range(num_classes), colors):
            if fpr[i] is None:
                continue
            plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f"{classes[i]} (AUC = {roc_auc[i]:.2f})")
        plt.plot([0,1],[0,1],'k--', lw=2)
        plt.xlim([0.0,1.0])
        plt.ylim([0.0,1.05])
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("Multi-class ROC")
        plt.legend(loc="lower right")
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()

        # return per class aucs and micro
        aucs = { (classes[i] if i in roc_auc else str(i)) : (float(roc_auc[i]) if roc_auc[i] is not None else None)
                 for i in range(num_classes) }
        aucs["micro"] = float(roc_auc["micro"]) if roc_auc["micro"] is not None else None
        return aucs

In [9]:
import torch.nn.functional as F
# -----------------------------
# Loss function
# -----------------------------
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
        self.cosine = nn.CosineSimilarity(dim=-1)

    def forward(self, z_i, z_j):
        # z_i, z_j: embeddings of two augmented views of the same batch
        batch_size = z_i.size(0)

        # Normalize
        z_i = F.normalize(z_i, dim=-1)
        z_j = F.normalize(z_j, dim=-1)

        # Similarity matrix
        representations = torch.cat([z_i, z_j], dim=0)       # (2N, D)
        sim_matrix = F.cosine_similarity(
            representations.unsqueeze(1), 
            representations.unsqueeze(0), 
            dim=-1
        ) / self.temperature

        # Mask to remove self-similarity
        mask = torch.eye(2*batch_size, device=z_i.device).bool()
        sim_matrix = sim_matrix.masked_fill(mask, -9e15)

        # Positive pairs: i-th sample in z_i with i-th sample in z_j
        positives = torch.cat([torch.arange(batch_size, 2*batch_size),
                               torch.arange(0, batch_size)]).to(z_i.device)

        loss = F.cross_entropy(sim_matrix, positives)
        return loss


# -----------------------------
# Training & evaluation
# -----------------------------
def train_and_eval_model(model_name, train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes):
    print(f"\n=== ‚ñ∂ Starting training for model: {model_name} ===")
    save_dir = RESULTS_DIR / model_name
    save_dir.mkdir(parents=True, exist_ok=True)

    # create model
    '''try:
        model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
    except Exception as e:
        print(f"Failed to create model {model_name} with pretrained weights: {e}")
        print("Trying without pretrained weights...")
        model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
    model = model.to(DEVICE)'''
    
    model = create_model(num_classes=num_classes).to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=LR)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3, min_lr=1e-6)

    best_val_loss = float("inf")
    patience_counter = 0

    history = {"train_loss": [], "val_loss": [], "val_acc": []}
    
    criterion = nn.CrossEntropyLoss()
    contrastive_loss = ContrastiveLoss(temperature=0.5)
    lambda_contrastive = 0.1
    use_contrastive = False

    for epoch in range(1, EPOCHS+1):
        # --- train epoch ---
        model.train()
        running_loss = 0.0
        for (inputs, aug_inputs), labels in tqdm(train_loader, desc=f"{model_name} Epoch {epoch}/{EPOCHS} [Train]"):
            inputs, aug_inputs, labels = inputs.to(DEVICE), aug_inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            cls_loss = criterion(outputs, labels)

            if use_contrastive:
                embeddings = model.forward_features(inputs)
                aug_embeddings = model.forward_features(aug_inputs)
                contr_loss = contrastive_loss(embeddings, aug_embeddings)
                loss = cls_loss + lambda_contrastive * contr_loss
            else:
                loss = cls_loss

            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_train_loss = running_loss / len(train_loader.dataset)

        # --- validate epoch ---
        model.eval()
        running_val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for (inputs, aug_inputs), labels in tqdm(val_loader, desc=f"{model_name} Epoch {epoch}/{EPOCHS} [Val]"):
                inputs, aug_inputs, labels = inputs.to(DEVICE), aug_inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_val_loss += loss.item() * inputs.size(0)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        epoch_val_acc = correct / total if total > 0 else 0.0

        history["train_loss"].append(epoch_train_loss)
        history["val_loss"].append(epoch_val_loss)
        history["val_acc"].append(epoch_val_acc)

        print(f"[{model_name}] Epoch {epoch}/{EPOCHS} ‚Äî train_loss: {epoch_train_loss:.4f}, val_loss: {epoch_val_loss:.4f}, val_acc: {epoch_val_acc:.4f}")

        scheduler.step(epoch_val_loss)

        # early stopping
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            torch.save(model.state_dict(), save_dir / "best_model.pth")
            patience_counter = 0
            print(f"  ‚Ü≥ New best val loss: {best_val_loss:.4f} (model saved)")
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"  ‚èπ Early stopping triggered after {PATIENCE} epochs without improvement.")
                break

    # save history
    with open(save_dir / "history.pkl", "wb") as f:
        pickle.dump(history, f)
    plot_and_save_loss_accuracy(history, save_dir)

    # --- Testing ---
    print(f"üîç Starting testing for {model_name}")
    model.load_state_dict(torch.load(save_dir / "best_model.pth", map_location=DEVICE))
    model.eval()
    save_dir = NEW_RESULTS_DIR / model_name
    save_dir.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), save_dir / "best_model.pth")

    all_labels, all_preds, all_probs = [], [], []
    with torch.no_grad():
        for (inputs, aug_inputs), labels in tqdm(test_loader, desc=f"{model_name} [Test]"):
            inputs = inputs.to(DEVICE)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)

    # confusion matrix & classification report
    cm = confusion_matrix(all_labels, all_preds)
    clf_report_text = classification_report(all_labels, all_preds, target_names=classes)
    clf_report_dict = classification_report(all_labels, all_preds, target_names=classes, output_dict=True)

    # ROC & AUC
    if num_classes == 2:
        aucs = plot_and_save_roc(all_labels, all_probs, classes, save_dir / "roc_curve.png")
        macro_auc = aucs.get("binary_auc", None)
        micro_auc = macro_auc
        per_class_auc = {classes[1]: float(macro_auc)} if macro_auc is not None else {}
    else:
        per_class_auc = plot_and_save_roc(all_labels, all_probs, classes, save_dir / "roc_curve.png")
        micro_auc = per_class_auc.get("micro", None)
        macro_auc = np.nanmean([v for k,v in per_class_auc.items() if k != "micro" and v is not None]) if per_class_auc else None

    # save confusion matrix
    plot_and_save_confusion_matrix(cm, classes, save_dir / "confusion_matrix.png")

    # save prediction distribution
    plot_and_save_prediction_distribution(all_preds, classes, save_dir / "prediction_distribution.png")

    # save misclassified samples
    save_misclassified_samples(
        test_df.reset_index(drop=True),
        all_labels,
        all_preds,
        save_dir / "samples.png",
        classes,
        max_samples=MAX_MISCLASSIFIED_TO_SAVE
    )

    # save classification report and metrics
    with open(save_dir / "classification_report.txt", "w") as f:
        f.write(clf_report_text)

    metrics = {
        "classification_report": clf_report_dict,
        "per_class_auc": per_class_auc,
        "macro_auc": float(macro_auc) if macro_auc is not None and not np.isnan(macro_auc) else None,
        "micro_auc": float(micro_auc) if micro_auc is not None and not np.isnan(micro_auc) else None,
        "confusion_matrix": cm.tolist()
    }
    save_metrics_json(metrics, save_dir / "metrics.json")

    saver = PredictionArraySaver(save_dir)
    saver.save_predictions(all_labels, all_preds, all_probs, classes, model_name)

    print(f"üìä Testing complete for {model_name}, results saved in {save_dir}")
    print("-----------------------------------------------------------")
    
    # print("Running gradcam")
    # generate_and_save_gradcam_samples(model, test_df, all_labels, all_preds, save_dir, classes)

In [11]:
    # -----------------------------
    # Main
    # -----------------------------
    
    train_df = pd.read_csv("train_dataset.csv")
    val_df = pd.read_csv("val_dataset.csv")
    test_df = pd.read_csv("test_dataset.csv")

    # Encode labels
    le = LabelEncoder()
    train_df['label_encoded'] = le.fit_transform(train_df['label'])
    val_df['label_encoded'] = le.transform(val_df['label'])
    test_df['label_encoded'] = le.transform(test_df['label'])
    classes = list(le.classes_)
    num_classes = len(classes)
    print("üî¢ Label mapping:", dict(zip(le.classes_, le.transform(le.classes_))))

    # print summary
    print(f"üìå Number of classes: {num_classes}")
    print_class_distribution(train_df, "Train")
    print_class_distribution(val_df, "Validation")
    print_class_distribution(test_df, "Test")

    # transforms
    train_transform = transforms.Compose([
    	transforms.Resize((224,224)),
    	transforms.RandomHorizontalFlip(p=0.5),                # HF
       	transforms.RandomRotation(degrees=15),                 # ROT
    	transforms.ToTensor(),
    	transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    # datasets and loaders
    train_dataset = XrayDataset(train_df, transform=train_transform)
    val_dataset = XrayDataset(val_df, transform=test_transform)
    test_dataset = XrayDataset(test_df, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    # loop models
    '''for model_name in MODEL_NAMES:
        train_and_eval_model(model_name, train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)'''

    # from MedMamba import get_medmamba_s as create_model
    # train_and_eval_model("MedMamba_small", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    # from MedMamba import get_medmamba_t as create_model
    # train_and_eval_model("MedMamba_tiny", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    # from MedMamba import get_medmamba_b as create_model
    # train_and_eval_model("MedMamba_base", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    # from mamba_vision import mamba_vision_T as create_model
    # train_and_eval_model("mamba_vision_tiny", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    # from mamba_vision import mamba_vision_B as create_model
    # train_and_eval_model("mamba_vision_base", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    # from mamba_vision import mamba_vision_S as create_model
    # train_and_eval_model("mamba_vision_small", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    # from hifuse_model import HiFuse_Tiny as create_model
    # train_and_eval_model("HiFuse_Tiny", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    # from hifuse_model import HiFuse_Base as create_model
    # train_and_eval_model("HiFuse_Base", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    # from hifuse_model import HiFuse_Small as create_model
    # train_and_eval_model("HiFuse_Small", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    from VMamba.vmamba import vanilla_vmamba_small as create_model
    train_and_eval_model("VMamba_small", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    from VMamba.vmamba import vanilla_vmamba_tiny as create_model
    train_and_eval_model("VMamba_tiny", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    from VMamba.vmamba import vanilla_vmamba_base as create_model
    train_and_eval_model("VMamba_base", train_loader, val_loader, test_loader, train_df, val_df, test_df, num_classes, classes)

    print("üèÅ All models processed!")

üî¢ Label mapping: {'COVID': np.int64(0), 'NORMAL': np.int64(1), 'PNEUMONIA': np.int64(2)}
üìå Number of classes: 3
   COVID: 325


=== ‚ñ∂ Starting training for model: mamba_vision_tiny ===
üîç Starting testing for mamba_vision_tiny


mamba_vision_tiny [Test]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 33/33 [00:01<00:00, 16.61it/s]


‚úÖ Saved samples(misclassified+correct) to final_results/mamba_vision_tiny/samples.png

üíæ Saving prediction arrays for: mamba_vision_tiny
‚úÖ Saved raw arrays:
   - all_labels.npy: shape (1046,)
   - all_preds.npy: shape (1046,)
   - all_probs.npy: shape (1046, 1000)
‚úÖ Saved confusion_matrix.npy: shape (3, 3)
‚úÖ Saved per_class_metrics.json and .npy
‚úÖ Saved roc_data.json and .npy
‚úÖ Saved pr_data.json and .npy
‚úÖ Saved aggregate_metrics.json and .npy
‚úÖ Saved metadata.json

üìä Testing complete for mamba_vision_tiny, results saved in final_results/mamba_vision_tiny
-----------------------------------------------------------
üèÅ All models processed!
