In [None]:
# ============================================================
# CELL 1: Mount Drive & Setup (SAFE VERSION)
# ============================================================

from google.colab import drive
import os

# Mount drive
drive.mount('/content/drive')

# Working directory
work_dir = '/content/drive/MyDrive/SkinDiseaseProject'
os.makedirs(work_dir, exist_ok=True)
os.chdir(work_dir)

print(f"‚úÖ Working directory: {work_dir}")
print("‚úÖ Checkpoints are SAFE - not deleted!")


Mounted at /content/drive
‚úÖ Working directory: /content/drive/MyDrive/SkinDiseaseProject
‚úÖ Checkpoints are SAFE - not deleted!


In [None]:
# ============================================================
# CELL 2: Download Datasets (Skip if already exist)
# ============================================================

import os
import kagglehub

work_dir = '/content/drive/MyDrive/SkinDiseaseProject'

# Check if datasets already exist
dermnet_exists = os.path.exists('dermnet_path.txt')
mgmitesh_exists = os.path.exists('mgmitesh_path.txt')
ismailpromus_exists = os.path.exists('ismailpromus_path.txt')

if dermnet_exists and mgmitesh_exists and ismailpromus_exists:
    print("‚úÖ All dataset paths already exist - skipping download!")
    print("‚úÖ Checkpoints are safe!")

    # Read existing paths
    with open('dermnet_path.txt', 'r') as f:
        print(f"   DermNet: {f.read().strip()}")
    with open('mgmitesh_path.txt', 'r') as f:
        print(f"   Mgmitesh: {f.read().strip()}")
    with open('ismailpromus_path.txt', 'r') as f:
        print(f"   Ismailpromus: {f.read().strip()}")
else:
    print("üì• Downloading datasets...")

    # Download DermNet
    if not dermnet_exists:
        dermnet_path = kagglehub.dataset_download('shubhamgoel27/dermnet')
        with open('dermnet_path.txt', 'w') as f:
            f.write(dermnet_path)
        print(f"‚úÖ DermNet: {dermnet_path}")

    # Download Mgmitesh
    if not mgmitesh_exists:
        mgmitesh_path = kagglehub.dataset_download('mgmitesh/skin-disease-detection-dataset')
        with open('mgmitesh_path.txt', 'w') as f:
            f.write(mgmitesh_path)
        print(f"‚úÖ Mgmitesh: {mgmitesh_path}")

    # Download Ismailpromus
    if not ismailpromus_exists:
        ismailpromus_path = kagglehub.dataset_download('ismailpromus/skin-diseases-image-dataset')
        with open('ismailpromus_path.txt', 'w') as f:
            f.write(ismailpromus_path)
        print(f"‚úÖ Ismailpromus: {ismailpromus_path}")

    print("\n‚úÖ All datasets ready!")


‚úÖ All dataset paths already exist - skipping download!
‚úÖ Checkpoints are safe!
   DermNet: /kaggle/input/dermnet
   Mgmitesh: /root/.cache/kagglehub/datasets/mgmitesh/skin-disease-detection-dataset/versions/1
   Ismailpromus: /root/.cache/kagglehub/datasets/ismailpromus/skin-diseases-image-dataset/versions/1


In [None]:
# Download the Mgmitesh dataset using kagglehub

import kagglehub
import os

# Define the dataset handle
mgmitesh_handle = "mgmitesh/skin-disease-detection-dataset"
work_dir = '/content/drive/MyDrive/SkinDiseaseProject' # Ensure work_dir is defined
mgmitesh_path_file = os.path.join(work_dir, 'mgmitesh_path.txt')

print(f"üì• Attempting to download {mgmitesh_handle} using kagglehub...")

try:
    # Use kagglehub.dataset_download to download and get the path
    local_path = kagglehub.dataset_download(mgmitesh_handle)
    print(f"‚úÖ Download of {mgmitesh_handle} seems successful.")
    print(f"Path to dataset files: {local_path}")

    # Save the actual local path to the text file
    with open(mgmitesh_path_file, 'w') as f:
        f.write(local_path)
    print(f"‚úÖ Updated mgmitesh_path.txt with new path: {local_path}")

    # Verify if the path exists
    if not os.path.exists(local_path):
         print(f"‚ùå Verified path does NOT exist: {local_path}")


except Exception as e:
    print(f"‚ùå Error downloading {mgmitesh_handle}: {e}")
    # Optionally, remove the path file if download failed
    if os.path.exists(mgmitesh_path_file):
        os.remove(mgmitesh_path_file)
        print(f"Removed {mgmitesh_path_file} as download failed.")

üì• Attempting to download mgmitesh/skin-disease-detection-dataset using kagglehub...
Using Colab cache for faster access to the 'skin-disease-detection-dataset' dataset.
‚úÖ Download of mgmitesh/skin-disease-detection-dataset seems successful.
Path to dataset files: /kaggle/input/skin-disease-detection-dataset
‚úÖ Updated mgmitesh_path.txt with new path: /kaggle/input/skin-disease-detection-dataset


Now that we've attempted to download the Mgmitesh dataset using the Kaggle API, let's rerun the dataset processing cell to see if it can find and process the data from the new location.

In [None]:
# ============================================================
# COMPLETE DATASET SETUP - Downloads, Labels, CSVs
# ============================================================

import os
import shutil
import kagglehub
import pandas as pd
from sklearn.model_selection import train_test_split
from google.colab import drive

# Mount Drive
drive.mount('/content/drive')
work_dir = '/content/drive/MyDrive/SkinDiseaseProject'
os.makedirs(work_dir, exist_ok=True)
os.chdir(work_dir)

print("="*80)
print("üì• DOWNLOADING ALL DATASETS")
print("="*80)

# Delete old path files to force fresh download
path_files = ['dermnet_path.txt', 'mgmitesh_path.txt', 'ismailpromus_path.txt']
for file in path_files:
    # Use absolute path
    abs_file_path = os.path.join(work_dir, file)
    if os.path.exists(abs_file_path):
        os.remove(abs_file_path)
        print(f"üóëÔ∏è  Removed old {file}")

print("\nüì• Downloading datasets (this may take 10-15 minutes)....\n")

# Download Dataset 1: DermNet
print("1Ô∏è‚É£  Downloading DermNet...")
dermnet_path = kagglehub.dataset_download('shubhamgoel27/dermnet')
# Use absolute path
with open(os.path.join(work_dir, 'dermnet_path.txt'), 'w') as f:
    f.write(dermnet_path)
print(f"   ‚úÖ DermNet saved to: {dermnet_path}")

# Download Dataset 2: Mgmitesh
print("\n2Ô∏è‚É£  Downloading Mgmitesh Skin Disease Dataset...")
mgmitesh_path = kagglehub.dataset_download('mgmitesh/skin-disease-detection-dataset')
# Use absolute path
with open(os.path.join(work_dir, 'mgmitesh_path.txt'), 'w') as f:
    f.write(mgmitesh_path)
print(f"   ‚úÖ Mgmitesh saved to: {mgmitesh_path}")

# Download Dataset 3: Ismailpromus
print("\n3Ô∏è‚É£  Downloading Ismailpromus Skin Diseases Dataset...")
ismailpromus_path = kagglehub.dataset_download('ismailpromus/skin-diseases-image-dataset')
# Use absolute path
with open(os.path.join(work_dir, 'ismailpromus_path.txt'), 'w') as f:
    f.write(ismailpromus_path)
print(f"   ‚úÖ Ismailpromus saved to: {ismailpromus_path}")

print("\n" + "="*80)
print("‚úÖ ALL DATASETS DOWNLOADED")
print("="*80)

# ============================================================
# LABEL ALL IMAGES AND CREATE TRAIN/TEST CSVs
# ============================================================

all_data = []

# ============================================================================
# Process DermNet Dataset
# ============================================================================
print(f"\nüìÇ Processing DermNet dataset from: {dermnet_path}")
for split in ['train', 'test']:
    split_path = os.path.join(dermnet_path, split)
    if os.path.exists(split_path):
        disease_folders = os.listdir(split_path)
        for disease in disease_folders:
            disease_path = os.path.join(split_path, disease)
            if os.path.isdir(disease_path):
                images = [f for f in os.listdir(disease_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                for img in images:
                    img_path = os.path.join(disease_path, img)
                    all_data.append({
                        'image_path': img_path,
                        'label': disease,
                        'dataset_source': 'DermNet'
                    })
    else:
         print(f"‚ùå DermNet split path does not exist: {split_path}")


print(f"‚úÖ DermNet: {len([d for d in all_data if d['dataset_source'] == 'DermNet']):,} images")

# ============================================================================
# Process Mgmitesh Dataset
# ============================================================================
print(f"\nüìÇ Processing Mgmitesh dataset from: {mgmitesh_path}")
mgmitesh_start = len(all_data)
for split in ['train', 'val']:
    split_path = os.path.join(mgmitesh_path, split)
    if os.path.exists(split_path):
        disease_folders = os.listdir(split_path)
        for disease in disease_folders:
            disease_path = os.path.join(split_path, disease)
            if os.path.isdir(disease_path):
                images = [f for f in os.listdir(disease_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                for img in images:
                    img_path = os.path.join(disease_path, img)
                    all_data.append({
                        'image_path': img_path,
                        'label': disease,
                        'dataset_source': 'Mgmitesh'
                    })
    else:
         print(f"‚ùå Mgmitesh split path does not exist: {split_path}")

print(f"‚úÖ Mgmitesh: {len(all_data) - mgmitesh_start:,} images")

# ============================================================================
# Process Ismailpromus Dataset
# ============================================================================
print(f"\nüìÇ Processing Ismailpromus dataset from: {ismailpromus_path}")
ismailpromus_start = len(all_data)
img_classes_path = os.path.join(ismailpromus_path, 'IMG_CLASSES')
if os.path.exists(img_classes_path):
    disease_folders = os.listdir(img_classes_path)
    for disease in disease_folders:
        disease_path = os.path.join(img_classes_path, disease)
        if os.path.isdir(disease_path):
            images = [f for f in os.listdir(disease_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            for img in images:
                img_path = os.path.join(disease_path, img)
                all_data.append({
                    'image_path': img_path,
                    'label': disease,
                    'dataset_source': 'Ismailpromus'
                })
else:
    print(f"‚ùå Ismailpromus IMG_CLASSES path does not exist: {img_classes_path}")


print(f"‚úÖ Ismailpromus: {len(all_data) - ismailpromus_start:,} images")

# Create DataFrame
df = pd.DataFrame(all_data)

print(f"\n‚úÖ Total images collected: {len(df):,}")
print(f"‚úÖ Total unique labels: {df['label'].nunique() if 'label' in df.columns else 0}")

# Check if DataFrame is empty
if df.empty:
    print("\n‚ùå No images found in any dataset paths. Please check paths and contents.")
else:
    # ============================================================================
    # STEP 2: Standardize Labels
    # ============================================================================
    print("\n" + "="*80)
    print("STEP 2: STANDARDIZING DISEASE LABELS")
    print("="*80)

    label_mapping = {
        # Fungal infections
        'Tinea Ringworm Candidiasis and other Fungal Infections': 'Fungal Infections',
        'Ringworm': 'Fungal Infections',
        'ringworm': 'Fungal Infections',
        '9. Tinea Ringworm Candidiasis and other Fungal Infections - 1.7k': 'Fungal Infections',
        'Nail Fungus and other Nail Disease': 'Nail Fungus',
        'Nail Fungus': 'Nail Fungus',

        # Viral infections
        'Warts Molluscum and other Viral Infections': 'Viral Skin Infections',
        '10. Warts Molluscum and other Viral Infections - 2103': 'Viral Skin Infections',
        'Warts': 'Viral Skin Infections',
        'Chickenpox': 'Chickenpox',
        'chickenpox': 'Chickenpox',
        'Herpes HPV and other STDs Photos': 'Herpes and STDs',

        # Bacterial infections
        'Cellulitis Impetigo and other Bacterial Infections': 'Bacterial Skin Infections',
        'Cellulitis': 'Bacterial Skin Infections',

        # Inflammatory conditions
        '1. Eczema 1677': 'Eczema',
        'Eczema Photos': 'Eczema',
        'Eczema': 'Eczema',
        'Dyshidrotic Eczema': 'Dyshidrotic Eczema',
        '3. Atopic Dermatitis - 1.25k': 'Atopic Dermatitis',
        'Atopic Dermatitis Photos': 'Atopic Dermatitis',
        'Atopic Dermatitis': 'Atopic Dermatitis',
        '7. Psoriasis pictures Lichen Planus and related diseases - 2k': 'Psoriasis and Lichen Planus',
        'Psoriasis pictures Lichen Planus and related diseases': 'Psoriasis and Lichen Planus',
        'Psoriasis': 'Psoriasis and Lichen Planus',
        'Bullous Disease Photos': 'Bullous Disease',
        'Lupus and other Connective Tissue diseases': 'Lupus and Connective Tissue Disease',
        'Vasculitis Photos': 'Vasculitis',
        'Poison Ivy Photos and other Contact Dermatitis': 'Contact Dermatitis',
        'Urticaria Hives': 'Urticaria',

        # Other conditions
        'Acne and Rosacea Photos': 'Acne and Rosacea',
        'Acne': 'Acne and Rosacea',
        'acne': 'Acne and Rosacea',
        'Light Diseases and Disorders of Pigmentation': 'Pigmentation Disorders',
        'Scabies Lyme Disease and other Infestations and Bites': 'Scabies and Infestations',
        'Exanthems and Drug Eruptions': 'Drug Eruptions',
        'Hair Loss Photos Alopecia and other Hair Diseases': 'Alopecia and Hair Loss',
        'Systemic Disease': 'Systemic Disease',
    }

    df['label_standardized'] = df['label'].map(label_mapping).fillna(df['label'])

    # ============================================================================
    # STEP 3: Filter Out Non-Clinical Diseases
    # ============================================================================
    print("\n" + "="*80)
    print("STEP 3: FILTERING FOR CLINICAL DISEASES ONLY")
    print("="*80)

    # Exclude cancer, benign tumors, and normal skin
    exclude_diseases = [
        '5. Melanocytic Nevi (NV) - 7970',
        '4. Basal Cell Carcinoma (BCC) 3323',
        '2. Melanoma 15.75k',
        'Melanoma',
        'Basal Cell Carcinoma',
        'Actinic Keratosis',
        'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions',
        'Squamous Cell Carcinoma',
        'Melanoma Skin Cancer Nevi and Moles',
        'Nevus',
        '6. Benign Keratosis-like Lesions (BKL) 2624',
        '8. Seborrheic Keratoses and other Benign Tumors - 1.8k',
        'Seborrheic Keratoses',
        'Seborrheic Keratosis',
        'Pigmented Benign Keratosis',
        'Dermato Fibroma',
        'Vascular Lesion',
        'Vascular Tumors',
        'Normal Skin',
    ]

    df_filtered = df[~df['label_standardized'].isin(exclude_diseases)].copy()

    print(f"‚úÖ Before filtering: {len(df):,} images")
    print(f"‚úÖ After filtering: {len(df_filtered):,} images")
    print(f"‚úÖ Removed: {len(df) - len(df_filtered):,} images (cancer, benign, normal)")
    print(f"‚úÖ Final unique diseases: {df_filtered['label_standardized'].nunique()}")

    # ============================================================================
    # STEP 4: Create 80-20 Stratified Train-Test Split
    # ============================================================================
    print("\n" + "="*80)
    print("STEP 4: CREATING 80-20 STRATIFIED TRAIN-TEST SPLIT")
    print("="*80)

    # Perform stratified split
    if df_filtered.empty or df_filtered['label_standardized'].nunique() < 2:
        print("\n‚ùå Not enough data or unique classes for stratified split after filtering.")
    else:
        train_df, test_df = train_test_split(
            df_filtered,
            test_size=0.2,
            stratify=df_filtered['label_standardized'],
            random_state=42
        )

        # Rename column for final CSV
        train_df = train_df[['image_path', 'label_standardized', 'dataset_source']].copy()
        test_df = test_df[['image_path', 'label_standardized', 'dataset_source']].copy()
        train_df = train_df.rename(columns={'label_standardized': 'label'})
        test_df = test_df.rename(columns={'label_standardized': 'label'})

        print(f"\n‚úÖ Training set: {len(train_df):,} images ({len(train_df)/len(df_filtered)*100:.1f}%)")
        print(f"‚úÖ Test set: {len(test_df):,} images ({len(test_df)/len(df_filtered)*100:.1f}%)")

        # Save CSVs
        train_csv_path = '/content/drive/MyDrive/SkinDiseaseProject/train_dataset.csv'
        test_csv_path = '/content/drive/MyDrive/SkinDiseaseProject/test_dataset.csv'

        train_df.to_csv(train_csv_path, index=False)
        test_df.to_csv(test_csv_path, index=False)

        print(f"\n‚úÖ Saved: {train_csv_path}")
        print(f"‚úÖ Saved: {test_csv_path}")

        # ============================================================================
        # STEP 5: Display Statistics
        # ============================================================================
        print("\n" + "="*80)
        print("FINAL DATASET STATISTICS")
        print("="*80)

        # Count by category
        # Check if train_df and test_df have 'label' column before combining value counts
        if 'label' in train_df.columns and 'label' in test_df.columns:
            disease_counts = train_df['label'].value_counts().add(test_df['label'].value_counts(), fill_value=0)

            fungal_diseases = ['Fungal Infections', 'Nail Fungus']
            viral_diseases = ['Chickenpox', 'Viral Skin Infections', 'Herpes and STDs']
            bacterial_diseases = ['Bacterial Skin Infections']
            inflammatory_diseases = ['Eczema', 'Dyshidrotic Eczema', 'Atopic Dermatitis',
                                    'Psoriasis and Lichen Planus', 'Bullous Disease',
                                    'Lupus and other Connective Tissue Disease', 'Vasculitis',
                                    'Contact Dermatitis', 'Urticaria']
            other_diseases = ['Acne and Rosacea', 'Systemic Disease', 'Pigmentation Disorders',
                            'Scabies and Infestations', 'Drug Eruptions', 'Alopecia and Hair Loss']

            def count_category(disease_list):
                return sum(disease_counts.get(d, 0) for d in disease_list)

            print(f"\n{'Category':<40} {'Images':>10} {'%':>8}")
            print("="*60)
            print(f"{'FUNGAL INFECTIONS':<40} {count_category(fungal_diseases):>10,} {count_category(fungal_diseases)/len(df_filtered)*100:>7.2f}%")
            print(f"{'VIRAL INFECTIONS':<40} {count_category(viral_diseases):>10,} {count_category(viral_diseases)/len(df_filtered)*100:>7.2f}%")
            print(f"{'BACTERIAL INFECTIONS':<40} {count_category(bacterial_diseases):>10,} {count_category(bacterial_diseases)/len(df_filtered)*100:>7.2f}%")
            print(f"{'INFLAMMATORY/AUTOIMMUNE':<40} {count_category(inflammatory_diseases):>10,} {count_category(inflammatory_diseases)/len(df_filtered)*100:>7.2f}%")
            print(f"{'OTHER CONDITIONS':<40} {count_category(other_diseases):>10,} {count_category(other_diseases)/len(df_filtered)*100:>7.2f}%")
            print("="*60)
            print(f"{'TOTAL':<40} {len(df_filtered):>10,} {100.0:>7.2f}%")

            # Display top diseases
            print("\n" + "="*80)
            print("TOP 15 DISEASES BY IMAGE COUNT")
            print("="*80)
            print(f"\n{'Disease':<50} {'Train':>8} {'Test':>8} {'Total':>8}")
            print("="*80)

            for disease in disease_counts.head(15).index:
                train_count = (train_df['label'] == disease).sum()
                test_count = (test_df['label'] == disease).sum()
                total = train_count + test_count
                print(f"{disease:<50} {train_count:>8,} {test_count:>8,} {total:>8,}")

            print("\n" + "="*80)
            print("‚úÖ DATASET PREPARATION COMPLETE!")
            print("="*80)
            print("\nüìÅ Files ready for training:")
            print(f"   1. {train_csv_path}")
            print(f"   2. {test_csv_path}")
            print("\nüí° Next step: Change runtime to T4 GPU and start training!")
        else:
             print("\n‚ùå Train/Test DataFrames are empty. Cannot display statistics.")

    # Ensure train_csv_path and test_csv_path are defined even if an error occurs
    if 'train_csv_path' not in locals():
        train_csv_path = '/content/drive/MyDrive/SkinDiseaseProject/train_dataset.csv'
    if 'test_csv_path' not in locals():
        test_csv_path = '/content/drive/MyDrive/SkinDiseaseProject/test_dataset.csv'

    # Add a message indicating if CSVs were saved
    if os.path.exists(train_csv_path) and os.path.exists(test_csv_path):
        print("\n‚úÖ Train and test CSVs are ready.")
    else:
        print("\n‚ùå Train and test CSVs were NOT successfully created.")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
üì• DOWNLOADING ALL DATASETS
üóëÔ∏è  Removed old dermnet_path.txt
üóëÔ∏è  Removed old mgmitesh_path.txt

üì• Downloading datasets (this may take 10-15 minutes)....

1Ô∏è‚É£  Downloading DermNet...
Using Colab cache for faster access to the 'dermnet' dataset.
   ‚úÖ DermNet saved to: /kaggle/input/dermnet

2Ô∏è‚É£  Downloading Mgmitesh Skin Disease Dataset...
Using Colab cache for faster access to the 'skin-disease-detection-dataset' dataset.
   ‚úÖ Mgmitesh saved to: /kaggle/input/skin-disease-detection-dataset

3Ô∏è‚É£  Downloading Ismailpromus Skin Diseases Dataset...
Using Colab cache for faster access to the 'skin-diseases-image-dataset' dataset.
   ‚úÖ Ismailpromus saved to: /kaggle/input/skin-diseases-image-dataset

‚úÖ ALL DATASETS DOWNLOADED

üìÇ Processing DermNet dataset from: /kaggle/input/dermnet
‚úÖ DermNet: 19,559 images

üìÇ Processing Mg

Let's inspect the downloaded dataset paths to understand where the files are located and why they are not being processed correctly.

In [None]:
import os

# Read dataset paths from the files created in the first cell
try:
    with open('dermnet_path.txt', 'r') as f:
        dermnet_path = f.read().strip()
    with open('mgmitesh_path.txt', 'r') as f:
        mgmitesh_path = f.read().strip()
    with open('ismailpromus_path.txt', 'r') as f:
        ismailpromus_path = f.read().strip()

    print(f"DermNet path: {dermnet_path}")
    print(f"Mgmitesh path: {mgmitesh_path}")
    print(f"Ismailpromus path: {ismailpromus_path}")

    print("\nInspecting contents:")

    # Function to list contents with a limit
    def list_dir_limited(path, limit=10):
        if os.path.exists(path):
            print(f"\nContents of {path} (first {limit} items):")
            try:
                items = os.listdir(path)
                for i, item in enumerate(items[:limit]):
                    print(f"- {item}")
                if len(items) > limit:
                    print(f"... and {len(items) - limit} more items")
            except Exception as e:
                print(f"Could not list contents: {e}")
        else:
            print(f"\nPath does not exist: {path}")

    list_dir_limited(dermnet_path)
    list_dir_limited(mgmitesh_path)
    list_dir_limited(ismailpromus_path)

except FileNotFoundError:
    print("Error: Dataset path files not found. Please run the first cell to download datasets and create path files.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

DermNet path: /kaggle/input/dermnet
Mgmitesh path: /kaggle/input/skin-disease-detection-dataset
Ismailpromus path: /kaggle/input/skin-diseases-image-dataset

Inspecting contents:

Contents of /kaggle/input/dermnet (first 10 items):
- test
- train

Contents of /kaggle/input/skin-disease-detection-dataset (first 10 items):
- val
- train

Contents of /kaggle/input/skin-diseases-image-dataset (first 10 items):
- IMG_CLASSES


In [None]:
import os
import time
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Function
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from PIL import Image
from torch.utils.data import Subset
from timm.layers import DropPath, trunc_normal_, Mlp
from timm.utils import accuracy

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")


Using device: cuda
GPU: Tesla T4
GPU Memory: 14.7 GB


In [None]:
# ============================================================
# Image Preprocessing Module - FIXED FOR SIZE CONSISTENCY
# ============================================================

import torchvision.transforms as T
from PIL import Image

class SimpleMedicalImagePreprocessor:
    """Minimal preprocessing for dermatology images - FIXED"""
    def __init__(self, img_size=224):
        self.img_size = img_size
        self.normalize_stats = {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225]
        }

    def get_train_transforms(self):
        """Training transforms with data augmentation"""
        return T.Compose([
            T.Resize((self.img_size, self.img_size)),  # ‚úÖ FIXED - force exact size
            T.RandomHorizontalFlip(p=0.5),
            T.RandomRotation(20),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            T.ToTensor(),
            T.Normalize(
                mean=self.normalize_stats['mean'],
                std=self.normalize_stats['std']
            ),
        ])

    def get_val_transforms(self):
        """Validation transforms - no augmentation"""
        return T.Compose([
            T.Resize((self.img_size, self.img_size)),  # ‚úÖ FIXED - force exact size
            T.ToTensor(),
            T.Normalize(
                mean=self.normalize_stats['mean'],
                std=self.normalize_stats['std']
            ),
        ])

print("‚úÖ Image preprocessing module loaded (FIXED)")


‚úÖ Image preprocessing module loaded (FIXED)


In [None]:
# ============================================================
# Custom Dataset Class - FIXED FOR SIZE CONSISTENCY
# ============================================================

import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import os

class CustomImageDataset(Dataset):
    """Enhanced dataset with size validation - FIXED"""
    def __init__(self, csv_file, transform=None, img_size=224, validate=False):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        self.img_size = img_size

        # Build label map
        unique_labels = sorted(self.df['label'].unique())
        self.label2idx = {label: idx for idx, label in enumerate(unique_labels)}
        self.idx2label = {idx: label for label, idx in self.label2idx.items()}

        # Optional validation
        if validate:
            self._validate_images()

    def _validate_images(self):
        """Check if image files exist"""
        print(f"üîç Validating image paths and integrity for {len(self.df)} samples...")
        invalid_rows = []
        for idx, row in self.df.iterrows():
            if not os.path.exists(row['image_path']):
                invalid_rows.append(idx)

        if invalid_rows:
            print(f"‚ö†Ô∏è Warning: {len(invalid_rows)} invalid image paths found")
            self.df = self.df.drop(invalid_rows).reset_index(drop=True)

        print(f"‚úÖ Validation complete. {len(self.df)} valid images remaining.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['image_path']
        label = self.df.iloc[idx]['label']

        try:
            # Load image
            image = Image.open(img_path).convert('RGB')

            # ‚úÖ CRITICAL FIX: Force resize BEFORE transform
            if image.size != (self.img_size, self.img_size):
                image = image.resize((self.img_size, self.img_size), Image.BILINEAR)

            # Apply transforms
            if self.transform:
                image = self.transform(image)

            # Convert label to index
            label_idx = self.label2idx[label]
            label_tensor = torch.tensor(label_idx, dtype=torch.long)

            return image, label_tensor

        except Exception as e:
            print(f"‚ùå Error loading image {img_path}: {e}")
            # Return blank image on error
            blank = torch.zeros(3, self.img_size, self.img_size)
            return blank, torch.tensor(0, dtype=torch.long)

print("‚úÖ Custom dataset class loaded (FIXED)")


‚úÖ Custom dataset class loaded (FIXED)


In [None]:
# ============================================================
# Cell 4: Enhanced Medical Components
# ============================================================

class LaplaceConv2d(nn.Module):
    """Laplacian edge detection for boundary analysis"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False)
        # Initialize with Laplacian kernel
        laplacian_kernel = torch.tensor([[[[-1, -1, -1],
                                           [-1,  8, -1],
                                           [-1, -1, -1]]]], dtype=torch.float32)
        self.conv.weight.data = laplacian_kernel.repeat(out_channels, in_channels, 1, 1)

    def forward(self, x):
        return self.conv(x)

class MedicalGhostHead(nn.Module):
    """Enhanced Ghost Head for Medical Imaging"""
    def __init__(self, channels, kernel_size, num_heads=6, medical_patterns=True):
        super().__init__()
        self.medical_patterns = medical_patterns
        self.channels = channels
        self.kernel_size = kernel_size
        self.num_heads = num_heads
        self.k_elems = kernel_size ** 2

        # Ghost params shaped to align with h_attn
        self.ghost_mul = nn.Parameter(torch.randn(1, num_heads, self.k_elems, 1))
        self.ghost_add = nn.Parameter(torch.zeros(1, num_heads, self.k_elems, 1))
        trunc_normal_(self.ghost_add, std=0.02)

        if medical_patterns:
            # Partition heads into groups (texture / boundary / color)
            h_per_group = max(1, num_heads // 3)
            groups = []
            start = 0
            while start < num_heads:
                end = min(start + h_per_group, num_heads)
                groups.append((start, end))
                start = end
            self.pattern_groups = groups

            # Create per-group modulation tensors
            self.group_modulations = nn.ParameterList()
            for (s, e) in groups:
                g_size = e - s
                self.group_modulations.append(
                    nn.Parameter(torch.randn(1, g_size, self.k_elems, 1))
                )

            self.medical_fusion = nn.Conv2d(channels, channels, 1)

    def forward(self, h_attn, lam=1.0, gamma=1.0):
        """
        h_attn: (B, num_heads, k_elems, HW)
        returns: same shape
        """
        B, Hn, K2, HW = h_attn.shape
        assert Hn == self.num_heads and K2 == self.k_elems

        # Apply scaling
        ghost_mul = self.ghost_mul ** lam if lam != 0 else None
        ghost_add = self.ghost_add * gamma if gamma != 0 else None

        # Broadcasted computation
        if ghost_mul is not None and ghost_add is not None:
            enhanced_attn = ghost_mul * h_attn + ghost_add
        elif ghost_mul is not None:
            enhanced_attn = ghost_mul * h_attn
        elif ghost_add is not None:
            enhanced_attn = h_attn + ghost_add
        else:
            enhanced_attn = h_attn

        # Medical pattern enhancement
        if self.medical_patterns and len(self.pattern_groups) > 0:
            modulated = []
            for idx, (s, e) in enumerate(self.pattern_groups):
                h_part = enhanced_attn[:, s:e, :, :]
                mod = self.group_modulations[idx]
                h_mod = h_part * mod
                modulated.append(h_mod)

            medical_enhanced = torch.cat(modulated, dim=1)
            combined = enhanced_attn + 0.3 * medical_enhanced
            return combined

        return enhanced_attn

print("‚úÖ Medical components loaded (LaplaceConv2d, MedicalGhostHead)")


‚úÖ Medical components loaded (LaplaceConv2d, MedicalGhostHead)


In [None]:
# ============================================================
# Cell 5: PathoScaleSA (COMPLETELY FIXED)
# ============================================================

class PathoScaleSA(nn.Module):
    """Pathology-guided Multi-Scale Self-Attention for Medical Imaging"""
    def __init__(self, dim, num_heads=6, kernel_sizes=[3, 5, 7], medical_prior=True,
                 cross_scale_fusion=True, pathology_guided=True, qkv_bias=False,
                 attn_drop=0., proj_drop=0., lam=1.0, gamma=1.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.kernel_sizes = kernel_sizes
        self.medical_prior = medical_prior
        self.cross_scale_fusion = cross_scale_fusion
        self.pathology_guided = pathology_guided
        self.lam = lam
        self.gamma = gamma
        self.scale = (dim // num_heads) ** -0.5

        # Multi-scale ELSA modules
        self.multi_scale_attn = nn.ModuleList()
        for ks in kernel_sizes:
            self.multi_scale_attn.append(nn.ModuleDict({
                'qkv': nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias),
                'attn_gen': nn.Sequential(
                    nn.Conv2d(dim, dim, ks, padding=ks//2, groups=num_heads),
                    nn.GELU(),
                    nn.Conv2d(dim, ks**2 * num_heads, 1)
                ),
                'ghost_head': MedicalGhostHead(dim, ks, num_heads, medical_patterns=True)
            }))

        # Medical prior attention
        if medical_prior:
            self.medical_gate = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(dim, dim//4),
                nn.ReLU(),
                nn.Linear(dim//4, len(kernel_sizes)),
                nn.Sigmoid()
            )

        # Cross-scale fusion
        if cross_scale_fusion:
            self.fusion_conv = nn.Conv2d(dim * len(kernel_sizes), dim, 1)
            self.fusion_norm = nn.LayerNorm(dim)

        # Pathology-guided attention branches - FIXED
        if pathology_guided:
            self.texture_branch = nn.Conv2d(dim, dim//2, 3, padding=1)
            self.color_branch = nn.Conv2d(dim, dim//2, 1)
            self.boundary_branch = LaplaceConv2d(dim, dim//2)
            # Total pathology features: 3 * (dim//2) = dim * 1.5
            # Combined with fused: dim + dim*1.5 = dim * 2.5
            pathology_dim = 3 * (dim // 2)  # Total from 3 branches
            self.pathology_fusion = nn.Conv2d(dim + pathology_dim, dim, 1)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, H, W, C = x.shape
        x_2d = x.permute(0, 3, 1, 2)  # B, C, H, W

        scale_outputs = []

        # Process each scale
        for i, (ks, scale_module) in enumerate(zip(self.kernel_sizes, self.multi_scale_attn)):
            qkv = scale_module['qkv'](x_2d)
            q, k, v = torch.chunk(qkv, 3, dim=1)

            # Hadamard product for attention generation
            hadamard_product = q * k * self.scale
            h_attn = scale_module['attn_gen'](hadamard_product)

            # Reshape attention
            h_attn = h_attn.reshape(B, self.num_heads, ks**2, H * W)
            h_attn = h_attn.reshape(B * self.num_heads, ks**2, H * W)
            h_attn = h_attn.softmax(dim=1)
            h_attn = self.attn_drop(h_attn)

            # Apply ghost head enhancement
            h_attn_reshaped = h_attn.reshape(B, self.num_heads, ks**2, H * W)
            enhanced_attn = scale_module['ghost_head'](h_attn_reshaped, self.lam, self.gamma)
            enhanced_attn = enhanced_attn.reshape(B * self.num_heads, ks**2, H * W)

            # Apply attention to values using unfold
            v_unfolded = F.unfold(v, kernel_size=ks, padding=ks//2, stride=1)
            v_unfolded = v_unfolded.reshape(B, C, ks**2, H * W)
            v_unfolded = v_unfolded.reshape(B * self.num_heads, C // self.num_heads, ks**2, H * W)

            # Weighted sum
            attended_v = torch.einsum('bchw,bhw->bcw', v_unfolded, enhanced_attn)
            attended_v = attended_v.reshape(B, C, H, W)

            scale_outputs.append(attended_v)

        # Medical prior weighting
        if self.medical_prior:
            scale_weights = self.medical_gate(x_2d)  # (B, num_scales)
            scale_weights = scale_weights.unsqueeze(-1).unsqueeze(-1)  # (B, num_scales, 1, 1)
            scale_outputs = [out * scale_weights[:, i:i+1, :, :]
                           for i, out in enumerate(scale_outputs)]

        # Cross-scale fusion
        if self.cross_scale_fusion:
            fused = torch.cat(scale_outputs, dim=1)
            fused = self.fusion_conv(fused)
        else:
            fused = sum(scale_outputs) / len(scale_outputs)

        # Pathology-guided enhancement
        if self.pathology_guided:
            texture_feat = self.texture_branch(fused)  # dim//2
            color_feat = self.color_branch(fused)      # dim//2
            boundary_feat = self.boundary_branch(fused) # dim//2

            # Concatenate all pathology features (total: 3 * dim//2)
            pathology_feat = torch.cat([texture_feat, color_feat, boundary_feat], dim=1)
            # Concatenate with fused (total: dim + 3*dim//2)
            combined_feat = torch.cat([fused, pathology_feat], dim=1)
            fused = self.pathology_fusion(combined_feat)

        # Back to token format
        output = fused.permute(0, 2, 3, 1)  # B, H, W, C
        output = self.proj(output)
        output = self.proj_drop(output)

        return output

class PathoScaleSABlock(nn.Module):
    """Complete PathoScaleSA Block with MLP"""
    def __init__(self, dim, num_heads=6, kernel_sizes=[3, 5, 7], mlp_ratio=3.,
                 drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm, medical_prior=True, cross_scale_fusion=True,
                 pathology_guided=True, lam=1.0, gamma=1.0):
        super().__init__()
        self.dim = dim
        self.norm1 = norm_layer(dim)

        self.attn = PathoScaleSA(
            dim=dim, num_heads=num_heads, kernel_sizes=kernel_sizes,
            medical_prior=medical_prior, cross_scale_fusion=cross_scale_fusion,
            pathology_guided=pathology_guided, attn_drop=attn_drop,
            proj_drop=drop, lam=lam, gamma=gamma
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                      act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

print("‚úÖ PathoScaleSA module loaded (COMPLETELY FIXED)")


‚úÖ PathoScaleSA module loaded (COMPLETELY FIXED)


In [None]:
# ============================================================
# Cell 6: Standard ELSA Implementation
# ============================================================

class ELSAFunctionCUDA(Function):
    @staticmethod
    def forward(ctx, features, ghost_mul, ghost_add, h_attn,
                kernel_size=5, dilation=1, stride=1, version=''):
        B, C, H, W = features.shape
        _pad = kernel_size // 2 * dilation
        features_unfolded = F.unfold(
            features, kernel_size=kernel_size, dilation=dilation, padding=_pad, stride=stride) \
            .reshape(B, C, kernel_size ** 2, H * W)

        if ghost_mul is not None:
            ghost_mul = ghost_mul.reshape(B, C, kernel_size ** 2, 1)
        if ghost_add is not None:
            ghost_add = ghost_add.reshape(B, C, kernel_size ** 2, 1)

        h_attn = h_attn.reshape(B, 1, kernel_size ** 2, H * W)

        # Compute filters
        if ghost_mul is not None and ghost_add is not None:
            filters = ghost_mul * h_attn + ghost_add
        elif ghost_mul is not None:
            filters = ghost_mul * h_attn
        elif ghost_add is not None:
            filters = h_attn + ghost_add
        else:
            filters = h_attn

        return (features_unfolded * filters).sum(2).reshape(B, C, H, W)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None, None, None, None, None, None

def elsa_op(features, ghost_mul, ghost_add, h_attn, lam, gamma,
            kernel_size=5, dilation=1, stride=1, version=''):
    _B, _C = features.shape[:2]

    if ghost_mul is not None:
        ghost_mul = ghost_mul ** lam if lam != 0 else None
    if ghost_add is not None:
        ghost_add = ghost_add * gamma if gamma != 0 else None

    B, C, H, W = features.shape
    _pad = kernel_size // 2 * dilation
    features_unfolded = F.unfold(
        features, kernel_size=kernel_size, dilation=dilation, padding=_pad, stride=stride) \
        .reshape(B, C, kernel_size ** 2, H * W)

    if ghost_mul is not None:
        ghost_mul = ghost_mul.reshape(B, C, kernel_size ** 2, 1)
    if ghost_add is not None:
        ghost_add = ghost_add.reshape(B, C, kernel_size ** 2, 1)

    h_attn = h_attn.reshape(B, 1, kernel_size ** 2, H * W)

    # Compute filters
    if ghost_mul is not None and ghost_add is not None:
        filters = ghost_mul * h_attn + ghost_add
    elif ghost_mul is not None:
        filters = ghost_mul * h_attn
    elif ghost_add is not None:
        filters = h_attn + ghost_add
    else:
        filters = h_attn

    return (features_unfolded * filters).sum(2).reshape(B, C, H, W)

class ELSA(nn.Module):
    """Standard Enhanced Local Self-Attention"""
    def __init__(self, dim, num_heads, dim_qk=None, dim_v=None, kernel_size=5,
                 stride=1, dilation=1, qkv_bias=False, qk_scale=None,
                 attn_drop=0., proj_drop=0., group_width=8, groups=1, lam=1,
                 gamma=1, **kwargs):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.dim_qk = dim_qk or self.dim // 3 * 2
        self.dim_v = dim_v or dim
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation

        head_dim = self.dim_v // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        if self.dim_qk % group_width != 0:
            self.dim_qk = math.ceil(float(self.dim_qk) / group_width) * group_width

        self.group_width = group_width
        self.groups = groups
        self.lam = lam
        self.gamma = gamma

        self.pre_proj = nn.Conv2d(dim, self.dim_qk * 2 + self.dim_v, 1, bias=qkv_bias)
        self.attn = nn.Sequential(
            nn.Conv2d(self.dim_qk, self.dim_qk, kernel_size, padding=(kernel_size // 2)*dilation,
                      dilation=dilation, groups=self.dim_qk // group_width),
            nn.GELU(),
            nn.Conv2d(self.dim_qk, kernel_size ** 2 * num_heads, 1, groups=groups))

        if self.lam != 0 and self.gamma != 0:
            ghost_mul = torch.randn(1, 1, self.dim_v, kernel_size, kernel_size)
            ghost_add = torch.zeros(1, 1, self.dim_v, kernel_size, kernel_size)
            trunc_normal_(ghost_add, std=.02)
            self.ghost_head = nn.Parameter(torch.cat((ghost_mul, ghost_add), dim=0), requires_grad=True)
        elif self.lam == 0 and self.gamma != 0:
            ghost_add = torch.zeros(1, self.dim_v, kernel_size, kernel_size)
            trunc_normal_(ghost_add, std=.02)
            self.ghost_head = nn.Parameter(ghost_add, requires_grad=True)
        elif self.lam != 0 and self.gamma == 0:
            ghost_mul = torch.randn(1, self.dim_v, kernel_size, kernel_size)
            self.ghost_head = nn.Parameter(ghost_mul, requires_grad=True)
        else:
            self.ghost_head = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.post_proj = nn.Linear(self.dim_v, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, mask=None):
        B, H, W, _ = x.shape
        C = self.dim_v
        ks = self.kernel_size
        G = self.num_heads
        x = x.permute(0, 3, 1, 2)  # B, C, H, W

        qkv = self.pre_proj(x)
        q, k, v = torch.split(qkv, (self.dim_qk, self.dim_qk, self.dim_v), dim=1)
        hadamard_product = q * k * self.scale

        if self.stride > 1:
            hadamard_product = F.avg_pool2d(hadamard_product, self.stride)

        h_attn = self.attn(hadamard_product)
        v = v.reshape(B * G, C // G, H, W)
        h_attn = h_attn.reshape(B * G, -1, H, W).softmax(1)
        h_attn = self.attn_drop(h_attn)

        ghost_mul = None
        ghost_add = None
        if self.lam != 0 and self.gamma != 0:
            gh = self.ghost_head.expand(2, B, C, ks, ks).reshape(2, B * G, C // G, ks, ks)
            ghost_mul, ghost_add = gh[0], gh[1]
        elif self.lam == 0 and self.gamma != 0:
            ghost_add = self.ghost_head.expand(B, C, ks, ks).reshape(B * G, C // G, ks, ks)
        elif self.lam != 0 and self.gamma == 0:
            ghost_mul = self.ghost_head.expand(B, C, ks, ks).reshape(B * G, C // G, ks, ks)

        x = elsa_op(v, ghost_mul, ghost_add, h_attn, self.lam, self.gamma,
                    self.kernel_size, self.dilation, self.stride)
        x = x.reshape(B, C, H // self.stride, W // self.stride)
        x = self.post_proj(x.permute(0, 2, 3, 1))  # B, H, W, C
        x = self.proj_drop(x)
        return x

class ELSABlock(nn.Module):
    """Standard ELSA block: ELSA + MLP"""
    def __init__(self, dim, kernel_size, stride=1, num_heads=1, mlp_ratio=3.,
                 drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm, qkv_bias=False, qk_scale=1,
                 dim_qk=None, dim_v=None, lam=1, gamma=1, dilation=1,
                 group_width=8, groups=1, **kwargs):
        super().__init__()
        assert stride == 1
        self.dim = dim
        self.norm1 = norm_layer(dim)
        self.attn = ELSA(dim, num_heads, dim_qk=dim_qk, dim_v=dim_v,
                         kernel_size=kernel_size, stride=stride, dilation=dilation,
                         qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
                         group_width=group_width, groups=groups, lam=lam, gamma=gamma)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

print("‚úÖ Standard ELSA implementation loaded")


‚úÖ Standard ELSA implementation loaded


In [None]:
# ============================================================
# Cell 7: Vision Transformer Base Components
# ============================================================

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=384):
        super().__init__()
        assert img_size % patch_size == 0
        self.num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, C, H/ps, W/ps)
        x = x.flatten(2).transpose(1, 2)  # (B, N, C)
        return x

class AdaptiveAttention(nn.Module):
    def __init__(self, dim, num_heads=6, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 ada_head=False, head_select_tau=5.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.ada_head = ada_head
        self.head_select_tau = head_select_tau

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        if ada_head:
            self.head_select = nn.Linear(dim, num_heads)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_scores = (q @ k.transpose(-2, -1)) * self.scale

        head_policy = None
        if self.ada_head:
            cls_embed = x[:, 0]
            logits = self.head_select(cls_embed)
            head_policy = F.gumbel_softmax(logits / self.head_select_tau, hard=True, dim=-1)
            attn_scores = attn_scores * head_policy.unsqueeze(-1).unsqueeze(-1)

        attn = attn_scores.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out, head_policy

class AdaptiveBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0.,
                 attn_drop=0., drop_path=0., ada_head=False, head_select_tau=5.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = AdaptiveAttention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop,
            proj_drop=drop, ada_head=ada_head, head_select_tau=head_select_tau
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)

        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        attn_in = self.norm1(x)
        attn_out, head_policy = self.attn(attn_in)
        x = x + self.drop_path(attn_out)

        mlp_in = self.norm2(x)
        x = x + self.drop_path(self.mlp(mlp_in))

        return x, head_policy

print("‚úÖ ViT base components loaded")


‚úÖ ViT base components loaded


In [None]:
# ============================================================
# Cell 8: Medical Adaptive ViT Main Model
# ============================================================

class MedicalAdaptiveViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=384, depth=8, num_heads=6, mlp_ratio=4.,
                 drop_rate=0.1, drop_path_rate=0.1, ada_head=False, ada_layer=False,
                 head_select_tau=5.0, layer_select_tau=5.0,
                 use_pathoscale=True, use_standard_elsa=False,
                 pathoscale_kernel_sizes=[3, 5, 7], pathoscale_num_heads=6,
                 pathoscale_mlp_ratio=3.0, pathoscale_lam=1.0, pathoscale_gamma=1.0,
                 medical_prior=True, cross_scale_fusion=True, pathology_guided=True,
                 # Standard ELSA parameters (for backward compatibility)
                 elsa_kernel_size=5, elsa_num_heads=6, elsa_mlp_ratio=3.0,
                 elsa_lam=1.0, elsa_gamma=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.depth = depth
        self.ada_head = ada_head
        self.ada_layer = ada_layer
        self.layer_select_tau = layer_select_tau
        self.use_pathoscale = use_pathoscale
        self.use_standard_elsa = use_standard_elsa

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # Enhanced Medical ELSA block after patch embedding
        if use_pathoscale:
            self.pathoscale_block = PathoScaleSABlock(
                dim=embed_dim, num_heads=pathoscale_num_heads,
                kernel_sizes=pathoscale_kernel_sizes, mlp_ratio=pathoscale_mlp_ratio,
                drop=drop_rate, attn_drop=0.0, drop_path=0.0,
                medical_prior=medical_prior, cross_scale_fusion=cross_scale_fusion,
                pathology_guided=pathology_guided, lam=pathoscale_lam, gamma=pathoscale_gamma
            )
        elif use_standard_elsa:
            # Standard ELSA block for comparison
            self.elsa_block = ELSABlock(
                dim=embed_dim, kernel_size=elsa_kernel_size, num_heads=elsa_num_heads,
                mlp_ratio=elsa_mlp_ratio, drop=drop_rate, attn_drop=0.0,
                drop_path=0.0, lam=elsa_lam, gamma=elsa_gamma
            )

        # Adaptive Vision Transformer blocks
        dpr = torch.linspace(0, drop_path_rate, steps=depth).tolist()
        self.blocks = nn.ModuleList([
            AdaptiveBlock(
                embed_dim, num_heads, mlp_ratio=mlp_ratio, qkv_bias=True,
                drop=drop_rate, attn_drop=0.0, drop_path=dpr[i], ada_head=ada_head,
                head_select_tau=head_select_tau
            ) for i in range(depth)
        ])

        if ada_layer:
            self.layer_select = nn.Linear(embed_dim, depth)

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Apply Enhanced Medical ELSA or Standard ELSA after patch embedding
        if self.use_pathoscale:
            patch_dim = int(math.sqrt(x.shape[1] - 1))
            cls_token = x[:, 0:1, :]
            patch_tokens = x[:, 1:, :]

            # Reshape to spatial format for PathoScaleSA
            patch_tokens = patch_tokens.reshape(B, patch_dim, patch_dim, self.embed_dim)
            patch_tokens = self.pathoscale_block(patch_tokens)
            patch_tokens = patch_tokens.reshape(B, patch_dim * patch_dim, self.embed_dim)

            # Concatenate CLS token back
            x = torch.cat([cls_token, patch_tokens], dim=1)

        elif self.use_standard_elsa:
            patch_dim = int(math.sqrt(x.shape[1] - 1))
            cls_token = x[:, 0:1, :]
            patch_tokens = x[:, 1:, :]

            # Reshape to spatial format for Standard ELSA
            patch_tokens = patch_tokens.reshape(B, patch_dim, patch_dim, self.embed_dim)
            patch_tokens = self.elsa_block(patch_tokens)
            patch_tokens = patch_tokens.reshape(B, patch_dim * patch_dim, self.embed_dim)

            # Concatenate CLS token back
            x = torch.cat([cls_token, patch_tokens], dim=1)

        head_policies = []
        layer_policy = None

        if self.ada_layer:
            with torch.no_grad():
                logits = self.layer_select(x[:, 0])
                layer_policy = F.gumbel_softmax(logits / self.layer_select_tau, hard=True, dim=-1)

        for i, blk in enumerate(self.blocks):
            if self.ada_layer and layer_policy is not None:
                if (layer_policy[:, i].sum() == 0):
                    head_policies.append(None)
                    continue
            x, h_pol = blk(x)
            head_policies.append(h_pol)

        x = self.norm(x)
        return x[:, 0], head_policies, layer_policy

    def forward(self, x):
        feats, head_policies, layer_policy = self.forward_features(x)
        logits = self.head(feats)

        head_select = None
        if self.ada_head and any(p is not None for p in head_policies):
            valid = [p for p in head_policies if p is not None]
            if len(valid) > 0:
                head_select = torch.stack(valid, dim=1).mean(dim=1)

        return logits, head_select, layer_policy

print("‚úÖ MedicalAdaptiveViT model loaded")


‚úÖ MedicalAdaptiveViT model loaded


In [None]:
# ============================================================
# Step 1: Update Config (Replace your Cell 10)
# ============================================================

from dataclasses import dataclass
import os
import torch

@dataclass
class Config:
    train_csv = "/content/drive/MyDrive/SkinDiseaseProject/train_dataset.csv"
    test_csv = "/content/drive/MyDrive/SkinDiseaseProject/test_dataset.csv"

    img_size = 224
    batch_size = 24
    num_workers = 2
    pin_memory = True

    # Model
    patch_size = 16
    embed_dim = 384
    depth = 8
    num_heads = 6
    mlp_ratio = 4.0
    drop_rate = 0.1
    drop_path_rate = 0.1

    # PathoScaleSA
    use_pathoscale = True
    use_standard_elsa = False
    pathoscale_kernel_sizes = [3, 5, 7]
    pathoscale_num_heads = 6
    pathoscale_mlp_ratio = 3.0
    pathoscale_lam = 1.0
    pathoscale_gamma = 1.0

    # üéØ MEDICAL FEATURES: ENABLED with STABLE LR
    medical_prior = True
    cross_scale_fusion = True
    pathology_guided = False

    # Standard ELSA
    elsa_kernel_size = 5
    elsa_num_heads = 6
    elsa_mlp_ratio = 3.0
    elsa_lam = 1.0
    elsa_gamma = 1.0

    # üîß FIXED TRAINING PARAMETERS
    epochs = 67
    lr = 5e-5 # Base LR
    base_lr = 5e-5 # Add base_lr to config
    weight_decay = 0.05
    max_lr = 1.5e-4

    gradient_accumulation_steps = 6
    max_grad_norm = 1.0
    use_amp = True

    # Checkpointing
    save_every = 5 # Added save_every attribute
    ckpt_dir = "/content/drive/MyDrive/SkinDiseaseProject/checkpoints"
    best_ckpt_path = "/content/drive/MyDrive/SkinDiseaseProject/best_model_stable.pth"

    # Logging
    log_every = 300
    plot_path = "/content/drive/MyDrive/SkinDiseaseProject/training_curves.png"
    results_dir = "/content/drive/MyDrive/SkinDiseaseProject/results"

    patience = 7
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def __post_init__(self):
        import os
        os.makedirs(self.ckpt_dir, exist_ok=True)
        os.makedirs(self.results_dir, exist_ok=True)


config = Config()
config.__post_init__()

print("="*60)
print("STABLE CONFIG WITH MEDICAL FEATURES")
print("="*60)
print(f"‚úÖ Medical features: ALL ENABLED")
print(f"‚úÖ Max LR: {config.max_lr} (reduced from 0.0006)")
print(f"‚úÖ Gradient clipping: {config.max_grad_norm} (stronger)")
print(f"‚úÖ Starting fresh training with stable parameters")
print(f"‚úÖ Device set to: {config.device}")
print("="*60)

STABLE CONFIG WITH MEDICAL FEATURES
‚úÖ Medical features: ALL ENABLED
‚úÖ Max LR: 0.00015 (reduced from 0.0006)
‚úÖ Gradient clipping: 1.0 (stronger)
‚úÖ Starting fresh training with stable parameters
‚úÖ Device set to: cuda


In [None]:
import torch

if torch.cuda.is_available():
    print("‚úÖ GPU is available!")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
else:
    print("‚ùå GPU is NOT available.")
    print("Go to Runtime > Change runtime type and select a GPU as the hardware accelerator.")

‚úÖ GPU is available!
GPU Name: Tesla T4


In [None]:
# ============================================================
# CELL: Fix Scheduler Checkpoint Issue - Add This BEFORE Training
# ============================================================

import os
import shutil

print("="*80)
print("üîß FIXING SCHEDULER COMPATIBILITY ISSUE")
print("="*80)

checkpoint_dir = '/content/drive/MyDrive/SkinDiseaseProject/checkpoints'
backup_dir = '/content/drive/MyDrive/SkinDiseaseProject/checkpoints_old_scheduler_backup'

# Backup old checkpoints
if os.path.exists(checkpoint_dir):
    checkpoint_files = os.listdir(checkpoint_dir)
    if checkpoint_files:
        print(f"\nüì¶ Found {len(checkpoint_files)} checkpoint files")

        # Create backup
        if not os.path.exists(backup_dir):
            shutil.copytree(checkpoint_dir, backup_dir)
            print(f"‚úÖ Backed up old checkpoints to: {backup_dir}")
        else:
            print(f"‚ÑπÔ∏è  Backup already exists at: {backup_dir}")

        # Clear current checkpoints
        shutil.rmtree(checkpoint_dir)
        os.makedirs(checkpoint_dir, exist_ok=True)
        print(f"üóëÔ∏è  Cleared checkpoint directory for fresh training with CyclicLR")
    else:
        print("‚ÑπÔ∏è  No existing checkpoints found")
else:
    os.makedirs(checkpoint_dir, exist_ok=True)
    print("‚úÖ Created fresh checkpoint directory")

print("\n" + "="*80)
print("‚úÖ READY TO START TRAINING WITH NEW LEARNING RATE SCHEDULE")
print("="*80)
print("\nüéØ Target Configuration:")
print(f"   Base LR: 1e-5")
print(f"   Max LR: 8e-4")
print(f"   Scheduler: CyclicLR (Triangular)")
print(f"   Batch Size: 24")
print(f"   Gradient Accumulation: 6 steps")
print(f"   Effective Batch Size: 144")
print(f"   Epochs: 50")
print(f"\nüéØ Expected Improvement: 70% ‚Üí 78-85% accuracy")
print("="*80)


üîß FIXING SCHEDULER COMPATIBILITY ISSUE

üì¶ Found 2 checkpoint files
‚ÑπÔ∏è  Backup already exists at: /content/drive/MyDrive/SkinDiseaseProject/checkpoints_old_scheduler_backup
üóëÔ∏è  Cleared checkpoint directory for fresh training with CyclicLR

‚úÖ READY TO START TRAINING WITH NEW LEARNING RATE SCHEDULE

üéØ Target Configuration:
   Base LR: 1e-5
   Max LR: 8e-4
   Scheduler: CyclicLR (Triangular)
   Batch Size: 24
   Gradient Accumulation: 6 steps
   Effective Batch Size: 144
   Epochs: 50

üéØ Expected Improvement: 70% ‚Üí 78-85% accuracy


In [None]:
# ============================================================
# CELL: Training Utilities with Auto-Save & Resume
# ============================================================

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import time
import torch.serialization # Import serialization module

def train_one_epoch(model, train_loader, optimizer, scheduler, scaler, epoch, config):
    """Training loop for one epoch with NaN protection (non-blocking)"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    nan_count = 0

    optimizer.zero_grad()

    for batch_idx, (images, labels) in enumerate(train_loader):
        # Skip batch if collate_fn returned None
        if images is None:
            print(f"‚ö†Ô∏è Skipped empty batch at Epoch {epoch}, Batch {batch_idx}")
            continue

        images, labels = images.to(config.device), labels.to(config.device)

        # Use autocast and scaler only if AMP is enabled and device is CUDA
        if config.use_amp and config.device.startswith('cuda'):
            with torch.amp.autocast('cuda', enabled=config.use_amp):
                logits, local_attn, global_attn = model(images)
                loss = F.cross_entropy(logits, labels)
            loss = loss / config.gradient_accumulation_steps
            # NaN detection - skip batch instead of stopping
            if torch.isnan(loss):
                print(f"‚ö†Ô∏è NaN at Epoch {epoch}, Batch {batch_idx}, skipping...")
                nan_count += 1
                optimizer.zero_grad()
                continue

            scaler.scale(loss).backward()

            if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
        else: # Standard training loop for CPU or AMP disabled
            logits, local_attn, global_attn = model(images)
            loss = F.cross_entropy(logits, labels)
            loss = loss / config.gradient_accumulation_steps
            # NaN detection - skip batch instead of stopping
            if torch.isnan(loss):
                print(f"‚ö†Ô∏è NaN at Epoch {epoch}, Batch {batch_idx}, skipping...")
                nan_count += 1
                optimizer.zero_grad()
                continue

            loss.backward()
            if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
              scaler.unscale_(optimizer)
              torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
              scaler.step(optimizer)
              scaler.update()
              optimizer.zero_grad()
              scheduler.step()  # ‚Üê This is correct for CyclicLR (per batch)


        total_loss += loss.item() * config.gradient_accumulation_steps
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        if (batch_idx + 1) % config.log_every == 0:
            acc = 100. * correct / total if total > 0 else 0
            avg_loss = total_loss / (batch_idx + 1)
            print(f"  Epoch [{epoch}/{config.epochs}] "
                  f"Batch [{batch_idx+1}/{len(train_loader)}] "
                  f"Loss: {avg_loss:.4f} "
                  f"Acc: {acc:.2f}% "
                  f"LR: {optimizer.param_groups[0]['lr']:.6f}")

        if batch_idx % 100 == 0 and config.device.startswith('cuda'):
            torch.cuda.empty_cache()

    epoch_loss = total_loss / len(train_loader) if len(train_loader) > 0 else 0
    epoch_acc = 100. * correct / total if total > 0 else 0

    if nan_count > 0:
        print(f"‚ö†Ô∏è {nan_count} NaN batches skipped (training continues)")

    return epoch_loss, epoch_acc


def validate(model, val_loader, device):
    """Validation function"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            # Skip batch if collate_fn returned None
            if images is None:
                continue

            images, labels = images.to(device), labels.to(device)
            logits, _, _ = model(images)
            loss = F.cross_entropy(logits, labels)
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_loss = total_loss / len(val_loader) if len(val_loader) > 0 else 0
    val_acc = 100. * correct / total if total > 0 else 0
    return val_loss, val_acc


def plot_training_curves(history, save_path):
    """Plot training curves"""
    epochs = range(1, len(history['train_loss']) + 1)
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].plot(epochs, history['train_loss'], 'b-o', label='Train Loss', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], 'r-s', label='Val Loss', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Loss', fontsize=12, fontweight='bold')
    axes[0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(alpha=0.3)

    axes[1].plot(epochs, history['train_acc'], 'b-o', label='Train Acc', linewidth=2)
    axes[1].plot(epochs, history['val_acc'], 'r-s', label='Val Acc', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[1].set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
    axes[1].set_title('Training & Validation Accuracy', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=11)
    axes[1].grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úÖ Curves saved: {save_path}")


def save_checkpoint(model, optimizer, scheduler, epoch, val_acc, history, config, label_names, is_best=False):
    """Save checkpoint for resuming"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_acc': val_acc,
        'history': history,
        'label_names': label_names,
        'config': config
    }

    # Save latest checkpoint
    latest_path = f"{config.ckpt_dir}/latest_checkpoint.pth"
    torch.save(checkpoint, latest_path)

    # Save best model
    if is_best:
        torch.save(checkpoint, config.best_ckpt_path)
        print(f"‚úÖ Best model saved: {val_acc:.2f}%")

    # Save periodic checkpoint
    if epoch % config.save_every == 0:
        epoch_path = f"{config.ckpt_dir}/epoch_{epoch}.pth"
        torch.save(checkpoint, epoch_path)
        print(f"üíæ Checkpoint saved: epoch_{epoch}.pth")


def load_checkpoint_if_exists(config):
    """Load latest checkpoint if available"""
    latest_path = f"{config.ckpt_dir}/latest_checkpoint.pth"

    if os.path.exists(latest_path):
        print(f"‚úÖ Found checkpoint: {latest_path}")
        print("üîÑ Resuming training...")
        # Allowlist the Config class for safe loading
        torch.serialization.add_safe_globals([Config])
        return torch.load(latest_path, map_location=config.device)
    else:
        print("üí° No checkpoint found - starting fresh")
        return None


print("="*60)
print("‚úÖ Auto-resume training utilities loaded!")
print("="*60)
print("Features:")
print("  ‚úÖ NaN handling (skips, doesn't stop)")
print("  ‚úÖ Auto-save every epoch")
print("  ‚úÖ Auto-resume if interrupted")
print("  ‚úÖ Periodic checkpoints")
print("="*60)

‚úÖ Auto-resume training utilities loaded!
Features:
  ‚úÖ NaN handling (skips, doesn't stop)
  ‚úÖ Auto-save every epoch
  ‚úÖ Auto-resume if interrupted
  ‚úÖ Periodic checkpoints


In [None]:
# ============================================================
# CELL: Main Training with Auto-Resume (COMPLETE VERSION)
# ============================================================
from torch.optim.lr_scheduler import CyclicLR
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import time
import torch.serialization

def main_training_continuous():
    """Training that survives disconnections and resumes automatically"""

    config = Config()

    # Inform the user about the device being used
    if torch.cuda.is_available():
        print(f"‚úÖ GPU is available! Using device: {config.device}")
        print(f"GPU Name: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\n")
    else:
        print(f"‚ùå GPU is NOT available. Using device: {config.device}\n")

    # Load data
    train_df = pd.read_csv(config.train_csv)
    test_df = pd.read_csv(config.test_csv)
    num_classes = train_df['label'].nunique()
    label_names = sorted(train_df['label'].unique())

    print(f"üìä Classes: {num_classes}, Train: {len(train_df):,}, Test: {len(test_df):,}\n")

    # Create datasets
    preprocessor = SimpleMedicalImagePreprocessor(img_size=config.img_size)
    train_transform = preprocessor.get_train_transforms()
    val_transform = preprocessor.get_val_transforms()

    # Custom collate function
    def collate_fn(batch):
        batch = [(img, label) for img, label in batch if img is not None and label is not None]
        if not batch:
            return None, None
        return torch.utils.data.dataloader.default_collate(batch)

    train_dataset = CustomImageDataset(config.train_csv, transform=train_transform)
    val_dataset = CustomImageDataset(config.test_csv, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                             shuffle=True, num_workers=config.num_workers,
                             pin_memory=config.pin_memory, drop_last=True,
                             collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size,
                           shuffle=False, num_workers=config.num_workers,
                           pin_memory=config.pin_memory,
                           collate_fn=collate_fn)

    # Create model
    print("üèóÔ∏è Building model...")
    model = MedicalAdaptiveViT(
        img_size=config.img_size, patch_size=config.patch_size,
        num_classes=num_classes, embed_dim=config.embed_dim,
        depth=config.depth, num_heads=config.num_heads,
        mlp_ratio=config.mlp_ratio, drop_rate=config.drop_rate,
        drop_path_rate=config.drop_path_rate,
        use_pathoscale=config.use_pathoscale,
        pathoscale_kernel_sizes=config.pathoscale_kernel_sizes,
        pathoscale_num_heads=config.pathoscale_num_heads,
        medical_prior=config.medical_prior,
        cross_scale_fusion=config.cross_scale_fusion,
        pathology_guided=config.pathology_guided
    ).to(config.device)

    print(f"‚úÖ Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M\n")

    # Calculate steps per epoch
    steps_per_epoch = len(train_loader) // config.gradient_accumulation_steps

    # Optimizer & scheduler
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.base_lr,
        weight_decay=config.weight_decay,
        betas=(0.9, 0.999)
    )

    scheduler = CyclicLR(
        optimizer,
        base_lr=config.base_lr,
        max_lr=config.max_lr,
        step_size_up=steps_per_epoch * 2,
        mode='triangular',
        cycle_momentum=False
    )

    # Initialize GradScaler
    scaler = None
    if config.use_amp and config.device.startswith('cuda'):
        scaler = torch.amp.GradScaler(enabled=True)

    # ============================================================
    # TRY TO RESUME FROM CHECKPOINT
    # ============================================================
    checkpoint_path = os.path.join(config.ckpt_dir, 'latest_checkpoint.pth')
    start_epoch = 1
    best_val_acc = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'learning_rates': []}

    if os.path.exists(checkpoint_path):
        try:
            print(f"‚úÖ Found checkpoint: {checkpoint_path}")
            print("üîÑ Resuming training...")

            # Add safe globals for Config class
            torch.serialization.add_safe_globals([Config])

            checkpoint = torch.load(checkpoint_path, map_location=config.device, weights_only=False)

            # Load model state
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            # Load scheduler if exists
            if 'scheduler_state_dict' in checkpoint:
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            start_epoch = checkpoint['epoch'] + 1
            best_val_acc = checkpoint.get('best_val_acc', checkpoint.get('val_acc', 0))
            history = checkpoint.get('history', history)

            print(f"‚úÖ Resumed from epoch {checkpoint['epoch']}")
            print(f"üìä Best accuracy so far: {best_val_acc:.2f}%\n")

        except Exception as e:
            print(f"‚ö†Ô∏è Could not load checkpoint: {e}")
            print("üí° Starting training from scratch (epoch 1).\n")
            start_epoch = 1
            best_val_acc = 0
            history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'learning_rates': []}
    else:
        print("üí° No checkpoint found - starting fresh from epoch 1.\n")

    # ============================================================
    # TRAINING LOOP
    # ============================================================
    print("üöÄ Starting training...\n")
    print("="*80)

    start_time = time.time()
    patience_counter = 0
    epoch = start_epoch - 1

    try:
        for epoch in range(start_epoch, config.epochs + 1):
            print(f"\n{'='*80}")
            print(f"EPOCH {epoch}/{config.epochs}")
            print(f"{'='*80}")

            # Train
            train_loss, train_acc = train_one_epoch(
                model, train_loader, optimizer, scheduler, scaler, epoch, config
            )

            # Validate
            val_loss, val_acc = validate(model, val_loader, config.device)

            # Update history
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['learning_rates'].append(optimizer.param_groups[0]['lr'])

            # Print summary
            print(f"\n{'='*80}")
            print(f"üìä Epoch {epoch} Summary:")
            print(f"{'='*80}")
            print(f"Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
            print(f"Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%")
            print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
            print(f"{'='*80}\n")

            # Save checkpoint
            is_best = val_acc > best_val_acc
            if is_best:
                best_val_acc = val_acc
                patience_counter = 0
            else:
                patience_counter += 1

            save_checkpoint(model, optimizer, scheduler, epoch, val_acc,
                          history, config, label_names, is_best)

            # Plot progress
            if epoch % 5 == 0 or is_best:
                plot_training_curves(history, config.plot_path)

            # Early stopping
            if patience_counter >= config.patience:
                print(f"‚ö†Ô∏è Early stopping at epoch {epoch}")
                break

            torch.cuda.empty_cache()

    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Training interrupted by user")
    except Exception as e:
        print(f"\n‚ùå Error: {e}")
        print("üíæ Checkpoint saved - can resume later")
    finally:
        # Final save
        if 'val_acc' not in locals():
            val_acc = best_val_acc

        save_checkpoint(model, optimizer, scheduler, epoch, val_acc,
                       history, config, label_names, False)

    total_time = time.time() - start_time
    print(f"\n{'='*80}")
    print(f"‚úÖ Training completed in {total_time/3600:.2f} hours")
    print(f"‚úÖ Best accuracy: {best_val_acc:.2f}%")
    print(f"{'='*80}\n")

    plot_training_curves(history, config.plot_path)

    return model, history, label_names


In [None]:
# ============================================================
# CELL: Complete Evaluation with All Presentation Visuals
# ============================================================

import seaborn as sns
from sklearn.metrics import (confusion_matrix, classification_report,
                            roc_curve, auc, precision_recall_curve)
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt

def evaluate_and_visualize():
    """Generate all visualizations for presentation"""

    config = Config()

    # Load best model
    print("üîç Loading best trained model...")
    checkpoint = torch.load(config.best_ckpt_path, map_location=config.device)

    # Get label names
    train_df = pd.read_csv(config.train_csv)
    label_names = sorted(train_df['label'].unique())
    num_classes = len(label_names)

    # Recreate model
    model = MedicalAdaptiveViT(
        img_size=config.img_size,
        patch_size=config.patch_size,
        num_classes=num_classes,
        embed_dim=config.embed_dim,
        depth=config.depth,
        num_heads=config.num_heads,
        mlp_ratio=config.mlp_ratio,
        drop_rate=config.drop_rate,
        drop_path_rate=config.drop_path_rate,
        use_pathoscale=config.use_pathoscale,
        pathoscale_kernel_sizes=config.pathoscale_kernel_sizes,
        pathoscale_num_heads=config.pathoscale_num_heads,
        medical_prior=config.medical_prior,
        cross_scale_fusion=config.cross_scale_fusion,
        pathology_guided=config.pathology_guided
    ).to(config.device)

    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Load test data
    # Use SimpleMedicalImagePreprocessor to get transforms
    preprocessor = SimpleMedicalImagePreprocessor(img_size=config.img_size)
    val_transform = preprocessor.get_val_transforms()

    val_dataset = CustomImageDataset(config.test_csv, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size,
                           shuffle=False, num_workers=config.num_workers)

    # Get predictions
    print("üìä Generating predictions...")
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for images, labels in val_loader:
            # Skip batch if collate_fn returned None
            if images is None:
                continue
            images = images.to(config.device)
            logits, _, _ = model(images)
            probs = F.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # Create results directory
    os.makedirs(config.results_dir, exist_ok=True)

    print("\nüé® Generating visualizations...\n")

    # 1. Confusion Matrix
    print("1Ô∏è‚É£ Creating confusion matrix...")
    plot_confusion_matrix_pretty(all_labels, all_preds, label_names,
                                 f"{config.results_dir}/confusion_matrix.png")

    # 2. Per-class metrics
    print("2Ô∏è‚É£ Creating per-class accuracy...")
    plot_per_class_metrics(all_labels, all_preds, label_names,
                           f"{config.results_dir}/per_class_accuracy.png")

    # 3. Classification report
    print("3Ô∏è‚É£ Saving classification report...")
    save_classification_report(all_labels, all_preds, label_names,
                               f"{config.results_dir}/classification_report.txt")

    # 4. ROC curves
    print("4Ô∏è‚É£ Creating ROC curves...")
    plot_roc_curves_multiclass(all_labels, all_probs, label_names,
                               f"{config.results_dir}/roc_curves.png")

    # 5. Top-k accuracy
    print("5Ô∏è‚É£ Creating top-k accuracy...")
    plot_topk_accuracy(all_labels, all_probs,
                      f"{config.results_dir}/topk_accuracy.png")

    # 6. Confidence distribution
    print("6Ô∏è‚É£ Creating confidence distribution...")
    plot_confidence_analysis(all_probs, all_preds, all_labels,
                            f"{config.results_dir}/confidence_distribution.png")

    # 7. Sample predictions
    print("7Ô∏è‚É£ Creating sample predictions...")
    plot_sample_predictions(model, val_dataset, config.device, label_names,
                           f"{config.results_dir}/sample_predictions.png")

    # 8. Training history (if available)
    if 'history' in checkpoint:
        print("8Ô∏è‚É£ Creating training history plots...")
        plot_detailed_history(checkpoint['history'],
                             f"{config.results_dir}/training_history.png")

    # Calculate overall metrics
    accuracy = (all_preds == all_labels).mean() * 100

    print(f"\n{'='*80}")
    print("üìà FINAL TEST RESULTS")
    print(f"{'='*80}")
    print(f"Overall Accuracy: {accuracy:.2f}%")
    print(f"Total Samples: {len(all_labels):,}")
    print(f"Number of Classes: {num_classes}")
    print(f"Best Validation Accuracy: {checkpoint['val_acc']:.2f}%")
    print(f"{'='*80}\n")

    print("‚úÖ All visualizations saved to:", config.results_dir)
    print("\nGenerated files:")
    print("  1. confusion_matrix.png")
    print("  2. per_class_accuracy.png")
    print("  3. classification_report.txt")
    print("  4. roc_curves.png")
    print("  5. topk_accuracy.png")
    print("  6. confidence_distribution.png")
    print("  7. sample_predictions.png")
    print("  8. training_history.png")
    print("  9. training_curves.png (saved during training)")

    return all_preds, all_labels, all_probs

def evaluate_and_visualize_comprehensive():
    """Enhanced evaluation with all metrics"""

    # ... [Keep your existing setup code] ...

    # Get predictions (same as before)
    all_preds, all_labels, all_probs = get_predictions(model, val_loader, config.device)

    print("\nüé® Generating comprehensive visualizations...\n")

    # Original visualizations
    print("1Ô∏è‚É£ Creating confusion matrix...")
    plot_confusion_matrix_pretty(all_labels, all_preds, label_names,
                                f"{config.results_dir}/confusion_matrix.png")

    print("2Ô∏è‚É£ Creating per-class accuracy...")
    plot_per_class_metrics(all_labels, all_preds, label_names,
                          f"{config.results_dir}/per_class_accuracy.png")

    # NEW VISUALIZATIONS
    print("3Ô∏è‚É£ Creating F1/Precision/Recall breakdown...")
    plot_f1_scores_per_class(all_labels, all_preds, label_names,
                            f"{config.results_dir}/f1_precision_recall.png")

    print("4Ô∏è‚É£ Creating macro/micro/weighted comparison...")
    plot_macro_micro_weighted_metrics(all_labels, all_preds, all_probs,
                                     f"{config.results_dir}/averaging_strategies.png")

    print("5Ô∏è‚É£ Creating Precision-Recall curves...")
    plot_precision_recall_curves(all_labels, all_probs, label_names,
                                f"{config.results_dir}/precision_recall_curves.png")

    print("6Ô∏è‚É£ Creating calibration analysis...")
    plot_calibration_curve(all_labels, all_probs,
                          f"{config.results_dir}/calibration_curve.png")

    print("7Ô∏è‚É£ Creating class imbalance analysis...")
    plot_class_imbalance_analysis(all_labels, all_preds, label_names,
                                  f"{config.results_dir}/class_imbalance_analysis.png")

    print("8Ô∏è‚É£ Creating misclassification heatmap...")
    plot_misclassification_matrix(all_labels, all_preds, label_names,
                                  f"{config.results_dir}/misclassification_heatmap.png")

    print("9Ô∏è‚É£ Creating MCC and balanced metrics...")
    plot_matthews_correlation_coefficient(all_labels, all_preds,
                                         f"{config.results_dir}/mcc_balanced_metrics.png")

    # ... [Keep your existing visualizations: ROC, top-k, confidence, samples, history] ...

    print("\n‚úÖ All visualizations complete!")

# Helper visualization functions
def plot_confusion_matrix_pretty(y_true, y_pred, class_names, save_path):
    """Beautiful confusion matrix visualization"""
    cm = confusion_matrix(y_true, y_pred)
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    fig, axes = plt.subplots(1, 2, figsize=(22, 9))

    # Raw counts
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'}, linewidths=0.5)
    axes[0].set_title('Confusion Matrix - Raw Counts', fontsize=16, fontweight='bold', pad=20)
    axes[0].set_xlabel('Predicted Label', fontsize=13, fontweight='bold')
    axes[0].set_ylabel('True Label', fontsize=13, fontweight='bold')
    plt.setp(axes[0].get_xticklabels(), rotation=45, ha='right', fontsize=9)
    plt.setp(axes[0].get_yticklabels(), rotation=0, fontsize=9)

    # Normalized
    sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Greens', ax=axes[1],
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Proportion'}, linewidths=0.5)
    axes[1].set_title('Confusion Matrix - Normalized', fontsize=16, fontweight='bold', pad=20)
    axes[1].set_xlabel('Predicted Label', fontsize=13, fontweight='bold')
    axes[1].set_ylabel('True Label', fontsize=13, fontweight='bold')
    plt.setp(axes[1].get_xticklabels(), rotation=45, ha='right', fontsize=9)
    plt.setp(axes[1].get_yticklabels(), rotation=0, fontsize=9)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


def plot_per_class_metrics(y_true, y_pred, class_names, save_path):
    """Per-class accuracy with color coding"""
    cm = confusion_matrix(y_true, y_pred)
    per_class_acc = cm.diagonal() / cm.sum(axis=1)

    # Sort by accuracy
    sorted_idx = np.argsort(per_class_acc)
    sorted_names = [class_names[i] for i in sorted_idx]
    sorted_acc = per_class_acc[sorted_idx]

    fig, ax = plt.subplots(figsize=(12, max(10, len(class_names) * 0.45)))
    colors = plt.cm.RdYlGn(sorted_acc)

    bars = ax.barh(range(len(sorted_names)), sorted_acc * 100,
                   color=colors, edgecolor='black', linewidth=1.2)
    ax.set_yticks(range(len(sorted_names)))
    ax.set_yticklabels(sorted_names, fontsize=10)
    ax.set_xlabel('Accuracy (%)', fontsize=14, fontweight='bold')
    ax.set_title('Per-Class Accuracy', fontsize=16, fontweight='bold', pad=20)
    ax.set_xlim(0, 100)
    ax.grid(axis='x', alpha=0.3, linestyle='--')

    # Add percentage labels
    for i, (bar, acc) in enumerate(zip(bars, sorted_acc)):
        ax.text(acc * 100 + 1.5, i, f'{acc*100:.1f}%',
               va='center', fontsize=10, fontweight='bold')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


def save_classification_report(y_true, y_pred, class_names, save_path):
    """Detailed classification report"""
    report = classification_report(y_true, y_pred,
                                   target_names=class_names, digits=4)

    with open(save_path, 'w') as f:
        f.write("="*100 + "\n")
        f.write("DETAILED CLASSIFICATION REPORT\n")
        f.write("="*100 + "\n\n")
        f.write(report)
        f.write("\n" + "="*100 + "\n")


def plot_roc_curves_multiclass(y_true, y_probs, class_names, save_path, max_classes=10):
    """ROC curves for top classes"""
    n_classes = len(class_names)
    y_true_bin = label_binarize(y_true, classes=range(n_classes))

    # Select top classes
    class_counts = np.bincount(y_true)
    top_classes = np.argsort(class_counts)[-max_classes:]

    fig, ax = plt.subplots(figsize=(12, 9))

    colors = plt.cm.tab10(np.linspace(0, 1, max_classes))

    for idx, i in enumerate(top_classes):
        fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
        roc_auc = auc(fpr, tpr)
        ax.plot(fpr, tpr, lw=2.5, color=colors[idx],
               label=f'{class_names[i][:25]} (AUC={roc_auc:.3f})')

    ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random (AUC=0.5)')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate', fontsize=13, fontweight='bold')
    ax.set_ylabel('True Positive Rate', fontsize=13, fontweight='bold')
    ax.set_title(f'ROC Curves - Top {max_classes} Classes',
                fontsize=16, fontweight='bold', pad=20)
    ax.legend(loc='lower right', fontsize=9, framealpha=0.9)
    ax.grid(alpha=0.3, linestyle='--')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


def plot_topk_accuracy(y_true, y_probs, save_path):
    """Top-k accuracy visualization"""
    k_values = [1, 3, 5, 10]
    accuracies = []

    for k in k_values:
        if k > y_probs.shape[1]:
            k = y_probs.shape[1]
        top_k_preds = np.argsort(y_probs, axis=1)[:, -k:]
        correct = np.any(top_k_preds == y_true[:, None], axis=1)
        acc = correct.mean() * 100
        accuracies.append(acc)

    fig, ax = plt.subplots(figsize=(10, 7))
    colors = ['#3498db', '#2ecc71', '#f39c12', '#e74c3c']
    bars = ax.bar([f'Top-{k}' for k in k_values], accuracies,
                  color=colors, edgecolor='black', linewidth=1.5, width=0.6)

    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'{acc:.2f}%', ha='center', va='bottom',
                fontsize=14, fontweight='bold')

    ax.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
    ax.set_title('Top-K Accuracy', fontsize=16, fontweight='bold', pad=20)
    ax.set_ylim(0, 105)
    ax.grid(axis='y', alpha=0.3, linestyle='--')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


def plot_confidence_analysis(y_probs, y_preds, y_true, save_path):
    """Prediction confidence analysis"""
    max_probs = np.max(y_probs, axis=1)
    correct = (y_preds == y_true)

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Histogram
    axes[0].hist(max_probs[correct], bins=40, alpha=0.7,
                label='Correct', color='green', edgecolor='black')
    axes[0].hist(max_probs[~correct], bins=40, alpha=0.7,
                label='Incorrect', color='red', edgecolor='black')
    axes[0].set_xlabel('Prediction Confidence', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Count', fontsize=12, fontweight='bold')
    axes[0].set_title('Confidence Distribution', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(alpha=0.3)

    # Box plot
    data = [max_probs[correct], max_probs[~correct]]
    box = axes[1].boxplot(data, labels=['Correct', 'Incorrect'],
                         patch_artist=True, widths=0.5)
    box['boxes'][0].set_facecolor('lightgreen')
    box['boxes'][1].set_facecolor('lightcoral')
    axes[1].set_ylabel('Prediction Confidence', fontsize=12, fontweight='bold')
    axes[1].set_title('Confidence by Correctness', fontsize=14, fontweight='bold')
    axes[1].grid(alpha=0.3, axis='y')

    # Confidence bins
    bins = [0, 0.5, 0.7, 0.8, 0.9, 1.0]
    bin_labels = ['0-50%', '50-70%', '70-80%', '80-90%', '90-100%']
    correct_counts = []
    total_counts = []

    for i in range(len(bins)-1):
        mask = (max_probs >= bins[i]) & (max_probs < bins[i+1])
        if i == len(bins)-2:  # Include 1.0 in last bin
            mask = (max_probs >= bins[i]) & (max_probs <= bins[i+1])
        correct_counts.append(correct[mask].sum())
        total_counts.append(mask.sum())

    accuracies = [c/t*100 if t > 0 else 0 for c, t in zip(correct_counts, total_counts)]

    axes[2].bar(bin_labels, accuracies, color='skyblue', edgecolor='black', linewidth=1.5)
    axes[2].set_xlabel('Confidence Range', fontsize=12, fontweight='bold')
    axes[2].set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
    axes[2].set_title('Accuracy by Confidence Range', fontsize=14, fontweight='bold')
    axes[2].grid(alpha=0.3, axis='y')

    for i, (acc, count) in enumerate(zip(accuracies, total_counts)):
        axes[2].text(i, acc + 2, f'{acc:.1f}%\n(n={count})',
                    ha='center', fontsize=9)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


def plot_sample_predictions(model, dataset, device, label_names, save_path, num_samples=16):
    """Visualize sample predictions"""
    model.eval()

    indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)

    fig, axes = plt.subplots(4, 4, figsize=(18, 18))
    axes = axes.flatten()

    for idx, sample_idx in enumerate(indices):
        image, true_label = dataset[sample_idx]

        with torch.no_grad():
            image_batch = image.unsqueeze(0).to(device)
            logits, _, _ = model(image_batch)
            probs = F.softmax(logits, dim=1)
            pred_label = torch.argmax(probs).item()
            confidence = probs[0, pred_label].item()

        # Denormalize
        img_display = image.cpu().numpy().transpose(1, 2, 0)
        img_display = img_display * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_display = np.clip(img_display, 0, 1)

        axes[idx].imshow(img_display)
        axes[idx].axis('off')

        color = 'green' if pred_label == true_label else 'red'
        true_name = label_names[true_label][:22]
        pred_name = label_names[pred_label][:22]

        title = f"True: {true_name}\nPred: {pred_name}\nConf: {confidence:.2%}"
        axes[idx].set_title(title, fontsize=10, color=color, fontweight='bold', pad=10)

    plt.suptitle('Sample Predictions (Green=Correct, Red=Incorrect)',
                fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


def plot_detailed_history(history, save_path):
    """Detailed training history plots"""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    epochs = range(1, len(history['train_loss']) + 1)

    # Loss curves
    axes[0, 0].plot(epochs, history['train_loss'], 'b-o', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-s', label='Val Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[0, 0].set_ylabel('Loss', fontsize=12, fontweight='bold')
    axes[0, 0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
    axes[0, 0].legend(fontsize=11)
    axes[0, 0].grid(alpha=0.3)

    # Accuracy curves
    axes[0, 1].plot(epochs, history['train_acc'], 'b-o', label='Train Acc', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-s', label='Val Acc', linewidth=2)
    axes[0, 1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[0, 1].set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
    axes[0, 1].set_title('Training & Validation Accuracy', fontsize=14, fontweight='bold')
    axes[0, 1].legend(fontsize=11)
    axes[0, 1].grid(alpha=0.3)

    # Learning rate
    axes[1, 0].plot(epochs, history['learning_rates'], 'g-', linewidth=2)
    axes[1, 0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[1, 0].set_ylabel('Learning Rate', fontsize=12, fontweight='bold')
    axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[1, 0].set_yscale('log')
    axes[1, 0].grid(alpha=0.3)

    # Overfitting analysis
    gap = np.array(history['train_acc']) - np.array(history['val_acc'])
    axes[1, 1].plot(epochs, gap, 'm-', linewidth=2)
    axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.5)
    axes[1, 1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[1, 1].set_ylabel('Accuracy Gap (%)', fontsize=12, fontweight='bold')
    axes[1, 1].set_title('Train-Val Accuracy Gap (Overfitting Indicator)',
                        fontsize=14, fontweight='bold')
    axes[1, 1].grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


print("‚úÖ Evaluation and visualization functions loaded!")
print("üí° After training completes, run: evaluate_and_visualize()")

‚úÖ Evaluation and visualization functions loaded!
üí° After training completes, run: evaluate_and_visualize()


In [None]:
# ============================================================
# JUST RUN THIS - IT WILL AUTO-RESUME!
# ============================================================

print("="*80)
print("üöÄ RESUMING TRAINING")
print("="*80)

result = main_training_continuous()


üöÄ RESUMING TRAINING
‚úÖ GPU is available! Using device: cuda
GPU Name: Tesla T4
GPU Memory: 15.83 GB

üìä Classes: 22, Train: 33,531, Test: 8,383

üèóÔ∏è Building model...
‚úÖ Parameters: 20.10M

‚úÖ Found checkpoint: /content/drive/MyDrive/SkinDiseaseProject/checkpoints/latest_checkpoint.pth
üîÑ Resuming training...
‚úÖ Resumed from epoch 8
üìä Best accuracy so far: 49.59%

üöÄ Starting training...


EPOCH 9/50
  Epoch [9/50] Batch [300/1397] Loss: 1.6873 Acc: 48.53% LR: 0.000061
  Epoch [9/50] Batch [600/1397] Loss: 1.6926 Acc: 48.35% LR: 0.000072
  Epoch [9/50] Batch [900/1397] Loss: 1.7027 Acc: 48.03% LR: 0.000082
  Epoch [9/50] Batch [1200/1397] Loss: 1.7045 Acc: 47.96% LR: 0.000093

üìä Epoch 9 Summary:
Train: Loss=1.7072, Acc=47.84%
Val:   Loss=1.6688, Acc=49.45%
LR: 0.000100


EPOCH 10/50
  Epoch [10/50] Batch [300/1397] Loss: 1.7019 Acc: 48.10% LR: 0.000111
  Epoch [10/50] Batch [600/1397] Loss: 1.7099 Acc: 47.79% LR: 0.000122
  Epoch [10/50] Batch [900/1397] Loss: 1.

In [None]:
evaluate_and_visualize()

üîç Loading best trained model...
üìä Generating predictions...

üé® Generating visualizations...

1Ô∏è‚É£ Creating confusion matrix...
2Ô∏è‚É£ Creating per-class accuracy...
3Ô∏è‚É£ Saving classification report...
4Ô∏è‚É£ Creating ROC curves...
5Ô∏è‚É£ Creating top-k accuracy...
6Ô∏è‚É£ Creating confidence distribution...


  box = axes[1].boxplot(data, labels=['Correct', 'Incorrect'],


7Ô∏è‚É£ Creating sample predictions...
8Ô∏è‚É£ Creating training history plots...

üìà FINAL TEST RESULTS
Overall Accuracy: 74.89%
Total Samples: 8,383
Number of Classes: 22
Best Validation Accuracy: 74.89%

‚úÖ All visualizations saved to: /content/drive/MyDrive/SkinDiseaseProject/results

Generated files:
  1. confusion_matrix.png
  2. per_class_accuracy.png
  3. classification_report.txt
  4. roc_curves.png
  5. topk_accuracy.png
  6. confidence_distribution.png
  7. sample_predictions.png
  8. training_history.png
  9. training_curves.png (saved during training)


(array([ 0, 10, 12, ..., 17,  5,  4]),
 array([ 0, 10, 12, ..., 17,  5, 11]),
 array([[9.99857783e-01, 1.54017471e-06, 4.05175088e-05, ...,
         5.60537217e-07, 2.45184947e-06, 4.57607130e-05],
        [1.00097841e-05, 3.88609209e-07, 1.65369420e-06, ...,
         4.52972188e-07, 1.16377706e-07, 5.13406940e-06],
        [9.03824039e-05, 5.58994303e-04, 6.23662127e-05, ...,
         3.70370144e-05, 2.72501184e-05, 3.28606693e-05],
        ...,
        [5.91997057e-02, 1.77766255e-03, 1.51823973e-04, ...,
         6.12493721e-04, 6.75037736e-05, 3.75010632e-02],
        [5.74609487e-07, 1.30308422e-07, 4.02181115e-08, ...,
         2.67205223e-07, 1.09924407e-08, 7.36031254e-08],
        [2.50629466e-02, 4.05517220e-03, 3.92646855e-03, ...,
         6.57733483e-03, 7.08894106e-03, 4.97573940e-03]], dtype=float32))

In [None]:
# ============================================================
# CHECK FOR SAVED CHECKPOINTS
# ============================================================

import os

checkpoint_dir = '/content/drive/MyDrive/SkinDiseaseProject/checkpoints'

print("="*60)
print("üîç CHECKING FOR SAVED CHECKPOINTS")
print("="*60)

if os.path.exists(checkpoint_dir):
    files = os.listdir(checkpoint_dir)
    if files:
        print(f"\n‚úÖ Found {len(files)} file(s):")
        for f in files:
            file_path = os.path.join(checkpoint_dir, f)
            size_mb = os.path.getsize(file_path) / (1024*1024)
            print(f"   üìÅ {f} ({size_mb:.1f} MB)")

        # Check best checkpoint
        best_ckpt = os.path.join(checkpoint_dir, 'best_checkpoint.pth')
        if os.path.exists(best_ckpt):
            import torch
            ckpt = torch.load(best_ckpt, map_location='cpu')
            print(f"\n‚úÖ Best checkpoint found!")
            print(f"   Epoch: {ckpt.get('epoch', '?')}")
            print(f"   Accuracy: {ckpt.get('val_acc', 0):.2f}%")
    else:
        print("\n‚ùå Checkpoint directory is empty")
else:
    print("\n‚ùå Checkpoint directory doesn't exist")

# Check for backups
backup_dir = '/content/drive/MyDrive/SkinDiseaseProject/checkpoints_backup'
if os.path.exists(backup_dir):
    print(f"\n‚úÖ Found backup directory: {backup_dir}")
    backup_files = os.listdir(backup_dir)
    print(f"   Contains {len(backup_files)} file(s)")


üîç CHECKING FOR SAVED CHECKPOINTS

‚úÖ Found 11 file(s):
   üìÅ best_checkpoint.pth (226.8 MB)
   üìÅ epoch_10.pth (226.8 MB)
   üìÅ epoch_15.pth (226.8 MB)
   üìÅ epoch_20.pth (226.8 MB)
   üìÅ epoch_25.pth (226.8 MB)
   üìÅ epoch_30.pth (226.8 MB)
   üìÅ epoch_35.pth (226.8 MB)
   üìÅ epoch_40.pth (226.8 MB)
   üìÅ epoch_45.pth (226.8 MB)
   üìÅ epoch_50.pth (226.8 MB)
   üìÅ latest_checkpoint.pth (226.8 MB)

‚úÖ Best checkpoint found!
   Epoch: 8
   Accuracy: 49.59%


In [None]:
result = main_training_continuous()

‚úÖ GPU is available! Using device: cuda
GPU Name: Tesla T4
GPU Memory: 15.83 GB

üìä Classes: 22, Train: 33,531, Test: 8,383

üèóÔ∏è Building model...
‚úÖ Parameters: 20.10M

‚úÖ Found checkpoint: /content/drive/MyDrive/SkinDiseaseProject/checkpoints/latest_checkpoint.pth
üîÑ Resuming training...
‚úÖ Resumed from epoch 52
üìä Best accuracy so far: 70.40%

üöÄ Starting training...


EPOCH 53/67
  Epoch [53/67] Batch [300/1397] Loss: 0.6322 Acc: 78.53% LR: 0.000135
  Epoch [53/67] Batch [600/1397] Loss: 0.6257 Acc: 78.41% LR: 0.000125
  Epoch [53/67] Batch [900/1397] Loss: 0.6193 Acc: 78.83% LR: 0.000114
  Epoch [53/67] Batch [1200/1397] Loss: 0.6180 Acc: 79.00% LR: 0.000103

üìä Epoch 53 Summary:
Train: Loss=0.6169, Acc=79.01%
Val:   Loss=0.9984, Acc=72.98%
LR: 0.000096

‚úÖ Best model saved: 72.98%
‚úÖ Curves saved: /content/drive/MyDrive/SkinDiseaseProject/training_curves.png

EPOCH 54/67
  Epoch [54/67] Batch [300/1397] Loss: 0.5137 Acc: 82.31% LR: 0.000085
  Epoch [54/67] Bat

In [None]:
import torch
import torchvision.transforms as T

def tta_predict(model, image, device='cuda'):
    """Test-Time Augmentation for better accuracy"""
    model.eval()

    # Define augmentations
    augmentations = [
        T.Lambda(lambda x: x),  # Original
        T.RandomHorizontalFlip(p=1.0),  # Flip horizontal
        T.RandomVerticalFlip(p=1.0),    # Flip vertical
        T.Lambda(lambda x: torch.rot90(x, 1, [1, 2])),  # Rotate 90¬∞
        T.Lambda(lambda x: torch.rot90(x, 2, [1, 2])),  # Rotate 180¬∞
    ]

    predictions = []

    with torch.no_grad():
        for aug in augmentations:
            # Apply augmentation
            aug_image = aug(image).unsqueeze(0).to(device)

            # Get prediction
            logits, _, _ = model(aug_image)
            probs = torch.softmax(logits, dim=1)
            predictions.append(probs)

    # Average predictions
    avg_probs = torch.mean(torch.cat(predictions, dim=0), dim=0)

    return avg_probs
# Usage during testing
image = load_test_image('test.jpg')
final_prediction = tta_predict(model, image)
predicted_class = torch.argmax(final_prediction)



NameError: name 'load_test_image' is not defined