In [1]:
# ======================================================================
# Merge HAM10000 + ISIC2019 + Dermnet datasets (non-cancer) + split
# ======================================================================

import os
from glob import glob
import pandas as pd
from sklearn.model_selection import train_test_split

# -----------------------------
# 1) Base paths
# -----------------------------
HAM_PART1_PATH = r"E:\HAM10000_images_part_1"
HAM_PART2_PATH = r"E:\HAM10000_images_part_2"
HAM_META_PATH  = r"C:\Users\achim\OneDrive\Documents\HAM10000_metadata.csv"

ISIC_TRAIN_PATH = r"E:\ISIC_2019_Training_Input\train"
ISIC_TEST_PATH  = r"E:\ISIC_2019_Training_Input\test"
ISIC_FULL_PATH  = r"E:\ISIC_2019_Training_Input\ISIC_2019_Training_Input"
ISIC_META_PATH  = r"C:\Users\achim\OneDrive\Documents\ISIC_2019_Training_Metadata.csv"
ISIC_GT_PATH    = r"C:\Users\achim\OneDrive\Documents\ISIC_2019_Training_GroundTruth.csv"

DERM_TRAIN_PATH = r"E:\train"
DERM_TEST_PATH  = r"E:\test"

MASTER_CSV_PATH = "master_skin_dataset_non_cancer.csv"

# -----------------------------
# 2) Load HAM10000
# -----------------------------
def load_ham():
    meta = pd.read_csv(HAM_META_PATH)
    all_ham_images = glob(os.path.join(HAM_PART1_PATH, "*.jpg")) + glob(os.path.join(HAM_PART2_PATH, "*.jpg"))
    img_dict = {os.path.splitext(os.path.basename(p))[0]: p for p in all_ham_images}

    meta['image_path'] = meta['image_id'].map(img_dict)
    df = meta[['image_path','dx']].copy()
    df = df[df['image_path'].notna()]
    df.rename(columns={'dx':'original_label'}, inplace=True)
    print(f"HAM10000 images loaded: {len(df)}")
    return df

# -----------------------------
# 3) Load ISIC2019
# -----------------------------
def load_isic():
    meta = pd.read_csv(ISIC_META_PATH)
    gt   = pd.read_csv(ISIC_GT_PATH)

    # Merge meta + GT on 'image'
    df = meta.merge(gt, on='image', how='left')

    # Winner-takes-all label for ISIC 2019
    label_cols = ['MEL','NV','BCC','AK','BKL','DF','VASC','SCC']
    df['original_label'] = df[label_cols].idxmax(axis=1)

    # Map image filenames to actual paths
    all_isic_images = []
    for p in [ISIC_FULL_PATH, ISIC_TRAIN_PATH, ISIC_TEST_PATH]:
        all_isic_images.extend(glob(os.path.join(p, "*.jpg")))
    img_dict = {os.path.splitext(os.path.basename(p))[0]: p for p in all_isic_images}
    df['image_path'] = df['image'].map(img_dict)

    df = df[['image_path','original_label']].copy()
    df = df[df['image_path'].notna()]
    print(f"ISIC2019 images loaded: {len(df)}")
    return df

# -----------------------------
# 4) Load Dermnet
# -----------------------------
def load_dermnet():
    rows = []
    def collect(split_dir):
        for cls in sorted(os.listdir(split_dir)):
            cpath = os.path.join(split_dir, cls)
            if not os.path.isdir(cpath): 
                continue
            files = []
            for ext in ("*.jpg","*.jpeg","*.png","*.JPG","*.JPEG","*.PNG"):
                files.extend(glob(os.path.join(cpath, ext)))
            for fp in files:
                rows.append((fp, cls))
    collect(DERM_TRAIN_PATH)
    collect(DERM_TEST_PATH)
    df = pd.DataFrame(rows, columns=['image_path','original_label'])
    print(f"Dermnet images loaded: {len(df)}")
    return df

# -----------------------------
# 5) Merge all datasets
# -----------------------------
ham_df    = load_ham()
isic_df   = load_isic()
dermnet_df= load_dermnet()

master_df = pd.concat([ham_df,isic_df,dermnet_df],ignore_index=True)
print(f"Total images before filtering: {len(master_df)}")

# -----------------------------
# 6) Define unified mapping & cancer labels
# -----------------------------
UNIFIED_LABEL_MAPPING = {
    'BKL': 'Benign Keratosis',
    'DF': 'Dermatofibroma',
    'NV': 'Nevus',
    'VASC': 'Vascular Lesion',
    # DermNet
    'Acne and Rosacea Photos': 'Acne and Rosacea',
    'Atopic Dermatitis Photos': 'Eczema',
    'Bullous Disease Photos': 'Bullous Disease',
    'Cellulitis Impetigo and other Bacterial Infections': 'Bacterial Infections',
    'Eczema Photos': 'Eczema',
    'Exanthems and Drug Eruptions': 'Drug Eruptions',
    'Hair Loss Photos Alopecia and other Hair Diseases': 'Hair Disorders',
    'Herpes HPV and other STDs Photos': 'STDs and Viral Infections',
    'Light Diseases and Disorders of Pigmentation': 'Pigmentation Disorders',
    'Lupus and other Connective Tissue diseases': 'Connective Tissue Diseases',
    'Nail Fungus and other Nail Disease': 'Nail Disorders',
    'Poison Ivy Photos and other Contact Dermatitis': 'Contact Dermatitis',
    'Psoriasis pictures Lichen Planus and related diseases': 'Psoriasis and Lichen Planus',
    'Scabies Lyme Disease and other Infestations and Bites': 'Infestations and Bites',
    'Seborrheic Keratoses and other Benign Tumors': 'Benign Tumors',
    'Systemic Disease': 'Systemic Disease',
    'Tinea Ringworm Candidiasis and other Fungal Infections': 'Fungal Infections',
    'Urticaria Hives': 'Urticaria',
    'Vascular Tumors': 'Vascular Tumors',
    'Vasculitis Photos': 'Vasculitis',
    'Warts Molluscum and other Viral Infections': 'Viral Infections',
}

CANCER_LABELS_TO_EXCLUDE = {'MEL','BCC','SCC','AK',
                            'Melanoma Skin Cancer Nevi and Moles',
                            'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions'}

# -----------------------------
# 7) Remove cancer data & map unified labels
# -----------------------------
cancer_mask = master_df['original_label'].isin(CANCER_LABELS_TO_EXCLUDE)
master_df = master_df[~cancer_mask].copy()
master_df['unified_label'] = master_df['original_label'].map(UNIFIED_LABEL_MAPPING)

# Remove any rows not mapped
master_df = master_df[master_df['unified_label'].notna()]

# Assign numeric label_id
master_df['label_id'] = master_df['unified_label'].astype('category').cat.codes

# Verify all image files exist
master_df = master_df[master_df['image_path'].apply(os.path.exists)]

# Save master CSV
master_df[['image_path','unified_label','label_id']].to_csv(MASTER_CSV_PATH, index=False)
print(f"Master CSV saved: {MASTER_CSV_PATH}")
print(f"Total images after non-cancer filtering: {len(master_df)}")

# -----------------------------
# 8) Train/Test split
# -----------------------------
train_df, test_df = train_test_split(master_df, test_size=0.2, stratify=master_df['label_id'], random_state=42)
train_df.to_csv("train_dataset.csv", index=False)
test_df.to_csv("test_dataset.csv", index=False)
print(f"Train images: {len(train_df)}, Test images: {len(test_df)}")


HAM10000 images loaded: 5000
ISIC2019 images loaded: 25331
Dermnet images loaded: 39118
Total images before filtering: 69449
Master CSV saved: master_skin_dataset_non_cancer.csv
Total images after non-cancer filtering: 51077
Train images: 40861, Test images: 10216


In [2]:
# ELSA + Adaptive Vision Transformer Model
# Optimized for RTX 3070 GPU

# =========================
# Install Dependencies (Run this first)
# =========================
# !pip install torch torchvision timm matplotlib pandas scikit-learn pillow

# =========================
# Imports
# =========================
import os
import time
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Function
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from PIL import Image

from timm.layers import DropPath, trunc_normal_, Mlp
from timm.utils import accuracy

# Set device - Your RTX 3070
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# =========================
# Custom Dataset for CSV files
# =========================
class CustomImageDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['image_path']
        label = self.df.iloc[idx]['label_id']
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# =========================
# ELSA Implementation
# =========================
class ELSAFunctionCUDA(Function):
    @staticmethod
    def forward(ctx, features, ghost_mul, ghost_add, h_attn,
                kernel_size=5, dilation=1, stride=1, version=''):
        # CPU fallback implementation for compatibility
        B, C, H, W = features.shape
        _pad = kernel_size // 2 * dilation
        features_unfolded = F.unfold(
            features, kernel_size=kernel_size, dilation=dilation, padding=_pad, stride=stride) \
            .reshape(B, C, kernel_size ** 2, H * W)

        if ghost_mul is not None:
            ghost_mul = ghost_mul.reshape(B, C, kernel_size ** 2, 1)
        if ghost_add is not None:
            ghost_add = ghost_add.reshape(B, C, kernel_size ** 2, 1)

        h_attn = h_attn.reshape(B, 1, kernel_size ** 2, H * W)

        # Compute filters
        if ghost_mul is not None and ghost_add is not None:
            filters = ghost_mul * h_attn + ghost_add
        elif ghost_mul is not None:
            filters = ghost_mul * h_attn
        elif ghost_add is not None:
            filters = h_attn + ghost_add
        else:
            filters = h_attn

        return (features_unfolded * filters).sum(2).reshape(B, C, H, W)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None, None, None, None, None, None

def elsa_op(features, ghost_mul, ghost_add, h_attn, lam, gamma,
            kernel_size=5, dilation=1, stride=1, version=''):

    _B, _C = features.shape[:2]
    ks = kernel_size

    if ghost_mul is not None:
        ghost_mul = ghost_mul ** lam if lam != 0 else None
    if ghost_add is not None:
        ghost_add = ghost_add * gamma if gamma != 0 else None

    B, C, H, W = features.shape
    _pad = kernel_size // 2 * dilation
    features_unfolded = F.unfold(
        features, kernel_size=kernel_size, dilation=dilation, padding=_pad, stride=stride) \
        .reshape(B, C, kernel_size ** 2, H * W)

    if ghost_mul is not None:
        ghost_mul = ghost_mul.reshape(B, C, kernel_size ** 2, 1)
    if ghost_add is not None:
        ghost_add = ghost_add.reshape(B, C, kernel_size ** 2, 1)

    h_attn = h_attn.reshape(B, 1, kernel_size ** 2, H * W)

    # Compute filters
    if ghost_mul is not None and ghost_add is not None:
        filters = ghost_mul * h_attn + ghost_add
    elif ghost_mul is not None:
        filters = ghost_mul * h_attn
    elif ghost_add is not None:
        filters = h_attn + ghost_add
    else:
        filters = h_attn

    return (features_unfolded * filters).sum(2).reshape(B, C, H, W)

class ELSA(nn.Module):
    """Enhanced Local Self-Attention"""
    def __init__(self, dim, num_heads, dim_qk=None, dim_v=None, kernel_size=5,
                 stride=1, dilation=1, qkv_bias=False, qk_scale=None,
                 attn_drop=0., proj_drop=0., group_width=8, groups=1, lam=1,
                 gamma=1, **kwargs):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.dim_qk = dim_qk or self.dim // 3 * 2
        self.dim_v = dim_v or dim
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation

        head_dim = self.dim_v // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        if self.dim_qk % group_width != 0:
            self.dim_qk = math.ceil(float(self.dim_qk) / group_width) * group_width

        self.group_width = group_width
        self.groups = groups
        self.lam = lam
        self.gamma = gamma

        self.pre_proj = nn.Conv2d(dim, self.dim_qk * 2 + self.dim_v, 1, bias=qkv_bias)
        self.attn = nn.Sequential(
            nn.Conv2d(self.dim_qk, self.dim_qk, kernel_size, padding=(kernel_size // 2)*dilation,
                      dilation=dilation, groups=self.dim_qk // group_width),
            nn.GELU(),
            nn.Conv2d(self.dim_qk, kernel_size ** 2 * num_heads, 1, groups=groups))

        if self.lam != 0 and self.gamma != 0:
            ghost_mul = torch.randn(1, 1, self.dim_v, kernel_size, kernel_size)
            ghost_add = torch.zeros(1, 1, self.dim_v, kernel_size, kernel_size)
            trunc_normal_(ghost_add, std=.02)
            self.ghost_head = nn.Parameter(torch.cat((ghost_mul, ghost_add), dim=0), requires_grad=True)
        elif self.lam == 0 and self.gamma != 0:
            ghost_add = torch.zeros(1, self.dim_v, kernel_size, kernel_size)
            trunc_normal_(ghost_add, std=.02)
            self.ghost_head = nn.Parameter(ghost_add, requires_grad=True)
        elif self.lam != 0 and self.gamma == 0:
            ghost_mul = torch.randn(1, self.dim_v, kernel_size, kernel_size)
            self.ghost_head = nn.Parameter(ghost_mul, requires_grad=True)
        else:
            self.ghost_head = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.post_proj = nn.Linear(self.dim_v, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, mask=None):
        B, H, W, _ = x.shape
        C = self.dim_v
        ks = self.kernel_size
        G = self.num_heads
        x = x.permute(0, 3, 1, 2)  # B, C, H, W

        qkv = self.pre_proj(x)
        q, k, v = torch.split(qkv, (self.dim_qk, self.dim_qk, self.dim_v), dim=1)
        hadamard_product = q * k * self.scale

        if self.stride > 1:
            hadamard_product = F.avg_pool2d(hadamard_product, self.stride)

        h_attn = self.attn(hadamard_product)
        v = v.reshape(B * G, C // G, H, W)
        h_attn = h_attn.reshape(B * G, -1, H, W).softmax(1)
        h_attn = self.attn_drop(h_attn)

        ghost_mul = None
        ghost_add = None
        if self.lam != 0 and self.gamma != 0:
            gh = self.ghost_head.expand(2, B, C, ks, ks).reshape(2, B * G, C // G, ks, ks)
            ghost_mul, ghost_add = gh[0], gh[1]
        elif self.lam == 0 and self.gamma != 0:
            ghost_add = self.ghost_head.expand(B, C, ks, ks).reshape(B * G, C // G, ks, ks)
        elif self.lam != 0 and self.gamma == 0:
            ghost_mul = self.ghost_head.expand(B, C, ks, ks).reshape(B * G, C // G, ks, ks)

        x = elsa_op(v, ghost_mul, ghost_add, h_attn, self.lam, self.gamma,
                    self.kernel_size, self.dilation, self.stride)
        x = x.reshape(B, C, H // self.stride, W // self.stride)
        x = self.post_proj(x.permute(0, 2, 3, 1))  # B, H, W, C
        x = self.proj_drop(x)
        return x

class ELSABlock(nn.Module):
    """ELSA block: ELSA + MLP"""
    def __init__(self, dim, kernel_size, stride=1, num_heads=1, mlp_ratio=3.,
                 drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, 
                 norm_layer=nn.LayerNorm, qkv_bias=False, qk_scale=1, 
                 dim_qk=None, dim_v=None, lam=1, gamma=1, dilation=1, 
                 group_width=8, groups=1, **kwargs):
        super().__init__()
        assert stride == 1
        self.dim = dim
        self.norm1 = norm_layer(dim)
        self.attn = ELSA(dim, num_heads, dim_qk=dim_qk, dim_v=dim_v, 
                         kernel_size=kernel_size, stride=stride, dilation=dilation,
                         qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
                         group_width=group_width, groups=groups, lam=lam, gamma=gamma)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
                       act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

# =========================
# Vision Transformer Components
# =========================
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=384):
        super().__init__()
        assert img_size % patch_size == 0
        self.num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)                 # (B, C, H/ps, W/ps)
        x = x.flatten(2).transpose(1, 2) # (B, N, C)
        return x

class AdaptiveAttention(nn.Module):
    def __init__(self, dim, num_heads=6, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 ada_head=False, head_select_tau=5.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.ada_head = ada_head
        self.head_select_tau = head_select_tau

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        if ada_head:
            self.head_select = nn.Linear(dim, num_heads)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_scores = (q @ k.transpose(-2, -1)) * self.scale

        head_policy = None
        if self.ada_head:
            cls_embed = x[:, 0]
            logits = self.head_select(cls_embed)
            head_policy = F.gumbel_softmax(logits / self.head_select_tau, hard=True, dim=-1)
            attn_scores = attn_scores * head_policy.unsqueeze(-1).unsqueeze(-1)

        attn = attn_scores.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out, head_policy

class AdaptiveBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., 
                 attn_drop=0., drop_path=0., ada_head=False, head_select_tau=5.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = AdaptiveAttention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, 
            proj_drop=drop, ada_head=ada_head, head_select_tau=head_select_tau
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)

        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        attn_in = self.norm1(x)
        attn_out, head_policy = self.attn(attn_in)
        x = x + self.drop_path(attn_out)
        
        mlp_in = self.norm2(x)
        x = x + self.drop_path(self.mlp(mlp_in))
        
        return x, head_policy

# =========================
# Main Model: ELSA + Adaptive ViT
# =========================
class AdaptiveViTWithELSA(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=384, depth=8, num_heads=6, mlp_ratio=4.,
                 drop_rate=0.1, drop_path_rate=0.1, ada_head=True, ada_layer=True,
                 head_select_tau=5.0, layer_select_tau=5.0, use_elsa=True, 
                 elsa_kernel_size=5, elsa_num_heads=6, elsa_mlp_ratio=3.0, 
                 elsa_lam=1.0, elsa_gamma=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.depth = depth
        self.ada_head = ada_head
        self.ada_layer = ada_layer
        self.layer_select_tau = layer_select_tau
        self.use_elsa = use_elsa

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # ELSA block after patch embedding
        if use_elsa:
            self.elsa_block = ELSABlock(
                dim=embed_dim, kernel_size=elsa_kernel_size, num_heads=elsa_num_heads,
                mlp_ratio=elsa_mlp_ratio, drop=drop_rate, attn_drop=0.0,
                drop_path=0.0, lam=elsa_lam, gamma=elsa_gamma
            )

        # Adaptive Vision Transformer blocks
        dpr = torch.linspace(0, drop_path_rate, steps=depth).tolist()
        self.blocks = nn.ModuleList([
            AdaptiveBlock(
                embed_dim, num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, 
                drop=drop_rate, attn_drop=0.0, drop_path=dpr[i], ada_head=ada_head,
                head_select_tau=head_select_tau
            ) for i in range(depth)
        ])

        if ada_layer:
            self.layer_select = nn.Linear(embed_dim, depth)

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Apply ELSA after patch embedding
        if self.use_elsa:
            patch_dim = int(math.sqrt(x.shape[1] - 1))
            cls_token = x[:, 0:1, :]
            patch_tokens = x[:, 1:, :]
            
            # Reshape to spatial format for ELSA
            patch_tokens = patch_tokens.reshape(B, patch_dim, patch_dim, self.embed_dim)
            patch_tokens = self.elsa_block(patch_tokens)
            patch_tokens = patch_tokens.reshape(B, patch_dim * patch_dim, self.embed_dim)
            
            # Concatenate CLS token back
            x = torch.cat([cls_token, patch_tokens], dim=1)

        head_policies = []
        layer_policy = None

        if self.ada_layer:
            with torch.no_grad():
                logits = self.layer_select(x[:, 0])
                layer_policy = F.gumbel_softmax(logits / self.layer_select_tau, hard=True, dim=-1)

        for i, blk in enumerate(self.blocks):
            if self.ada_layer and layer_policy is not None:
                if (layer_policy[:, i].sum() == 0):
                    head_policies.append(None)
                    continue
            x, h_pol = blk(x)
            head_policies.append(h_pol)

        x = self.norm(x)
        return x[:, 0], head_policies, layer_policy

    def forward(self, x):
        feats, head_policies, layer_policy = self.forward_features(x)
        logits = self.head(feats)

        head_select = None
        if self.ada_head and any(p is not None for p in head_policies):
            valid = [p for p in head_policies if p is not None]
            if len(valid) > 0:
                head_select = torch.stack(valid, dim=1).mean(dim=1)

        return logits, head_select, layer_policy

# =========================
# Loss Function
# =========================
class AdaptiveEfficiencyLoss(nn.Module):
    def __init__(self, class_counts=None, smoothing=0.1):
        super().__init__()
        if class_counts is not None:
            class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float32)
            class_weights = class_weights / class_weights.sum()
            self.ce = nn.CrossEntropyLoss(weight=class_weights.to(device), label_smoothing=smoothing)
        else:
            self.ce = nn.CrossEntropyLoss(label_smoothing=smoothing)

    def forward(self, logits, targets):
        return self.ce(logits, targets)

# =========================
# Data Transforms
# =========================
def get_transforms(img_size=224):
    train_transform = T.Compose([
        T.Resize((img_size, img_size)),
        T.RandomRotation(15),
        T.RandomHorizontalFlip(p=0.5),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        T.ToTensor(),
        T.RandomErasing(p=0.2),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    
    val_transform = T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    
    return train_transform, val_transform

# =========================
# Training Functions
# =========================
def train_one_epoch(model, loader, optimizer, loss_fn, device, epoch, scheduler=None, log_interval=50):
    model.train()
    losses, top1s = [], []
    head_ratios, layer_ratios = [], []

    for i, (x, y) in enumerate(loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        
        logits, head_policy, layer_policy = model(x)
        loss = loss_fn(logits, y)

        if head_policy is not None:
            head_ratios.append(head_policy.mean().item())
        if layer_policy is not None:
            layer_ratios.append(layer_policy.float().mean().item())

        loss.backward()
        optimizer.step()

        if scheduler:
            scheduler.step()

        acc1, _ = accuracy(logits, y, topk=(1, 5))
        losses.append(loss.item())
        top1s.append(acc1.item())

        if (i + 1) % log_interval == 0:
            print(f"Epoch {epoch} | Step {i+1}/{len(loader)} | "
                  f"Loss {np.mean(losses):.4f} | Acc@1 {np.mean(top1s):.2f}%")

    return {
        "loss": float(np.mean(losses)),
        "acc1": float(np.mean(top1s)),
        "head": float(np.mean(head_ratios)) if head_ratios else 0.0,
        "layer": float(np.mean(layer_ratios)) if layer_ratios else 0.0,
    }

@torch.no_grad()
def validate(model, loader, loss_fn, device):
    model.eval()
    losses, top1s = [], []
    head_ratios, layer_ratios = [], []

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits, head_policy, layer_policy = model(x)
        loss = loss_fn(logits, y)

        if head_policy is not None:
            head_ratios.append(head_policy.mean().item())
        if layer_policy is not None:
            layer_ratios.append(layer_policy.float().mean().item())

        acc1, _ = accuracy(logits, y, topk=(1, 5))
        losses.append(loss.item())
        top1s.append(acc1.item())

    return {
        "val_loss": float(np.mean(losses)),
        "val_acc1": float(np.mean(top1s)),
        "val_head": float(np.mean(head_ratios)) if head_ratios else 0.0,
        "val_layer": float(np.mean(layer_ratios)) if layer_ratios else 0.0,
    }

def plot_curves(history, save_path="training_curves.png"):
    epochs = len(history['train']['loss'])
    x = range(1, epochs + 1)

    plt.figure(figsize=(14, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(x, history['train']['loss'], label='Train Loss', marker='o')
    plt.plot(x, history['val']['loss'], label='Val Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curves')
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(x, history['train']['acc'], label='Train Acc@1', marker='o')
    plt.plot(x, history['val']['acc'], label='Val Acc@1', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Accuracy Curves')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches='tight')
    plt.show()

# =========================
# Configuration Class
# =========================
class Config:
    # Data
    train_csv = "train_dataset.csv"
    test_csv = "test_dataset.csv"
    img_size = 224
    batch_size = 8  # Reduced for RTX 3070 8GB VRAM
    num_workers = 4

    # Model architecture
    patch_size = 16
    embed_dim = 384
    depth = 8
    num_heads = 6
    mlp_ratio = 4.0
    drop_rate = 0.1
    drop_path_rate = 0.1

    # ELSA parameters
    use_elsa = True
    elsa_kernel_size = 5
    elsa_num_heads = 6
    elsa_mlp_ratio = 3.0
    elsa_lam = 1.0
    elsa_gamma = 1.0

    # Training
    epochs = 25
    lr = 3e-4
    weight_decay = 0.05
    max_lr = 1e-3
    
    # Save paths
    ckpt_path = "best_elsa_adaptive_vit.pth"
    plot_path = "training_curves.png"

# Initialize config
config = Config()

print("✅ Model code loaded successfully!")
print("📝 Next steps:")
print("1. Make sure your train_dataset.csv and test_dataset.csv are in the current directory")
print("2. Run the training code below")

Using device: cuda
GPU: NVIDIA GeForce RTX 3070 Laptop GPU
GPU Memory: 8.0 GB
✅ Model code loaded successfully!
📝 Next steps:
1. Make sure your train_dataset.csv and test_dataset.csv are in the current directory
2. Run the training code below


In [4]:
# =========================
# Training Script for ELSA + Adaptive Vision Transformer
# Optimized for RTX 3070 GPU
# =========================

def main():
    """Main training function"""
    
    # Check if CSV files exist
    if not os.path.exists(config.train_csv):
        print(f"❌ Error: {config.train_csv} not found!")
        print("Please make sure your train_dataset.csv is in the current directory")
        return
    
    if not os.path.exists(config.test_csv):
        print(f"❌ Error: {config.test_csv} not found!")
        print("Please make sure your test_dataset.csv is in the current directory")
        return
    
    print("📂 Loading datasets...")
    
    # Get data transforms
    train_transform, val_transform = get_transforms(config.img_size)
    
    # Create datasets
    train_dataset = CustomImageDataset(config.train_csv, transform=train_transform)
    val_dataset = CustomImageDataset(config.test_csv, transform=val_transform)
    
    print(f"📊 Dataset Info:")
    print(f"   Training samples: {len(train_dataset)}")
    print(f"   Validation samples: {len(val_dataset)}")
    
    # Get number of classes from the dataset
    train_df = pd.read_csv(config.train_csv)
    num_classes = train_df['label_id'].nunique()
    class_counts = train_df['label_id'].value_counts().sort_index().values
    
    print(f"   Number of classes: {num_classes}")
    print(f"   Class distribution: {class_counts}")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True,
        num_workers=config.num_workers, 
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.batch_size, 
        shuffle=False,
        num_workers=config.num_workers, 
        pin_memory=True
    )
    
    print(f"🔧 Building model...")
    
    # Create model
    model = AdaptiveViTWithELSA(
        img_size=config.img_size,
        patch_size=config.patch_size,
        num_classes=num_classes,
        embed_dim=config.embed_dim,
        depth=config.depth,
        num_heads=config.num_heads,
        mlp_ratio=config.mlp_ratio,
        drop_rate=config.drop_rate,
        drop_path_rate=config.drop_path_rate,
        use_elsa=config.use_elsa,
        elsa_kernel_size=config.elsa_kernel_size,
        elsa_num_heads=config.elsa_num_heads,
        elsa_mlp_ratio=config.elsa_mlp_ratio,
        elsa_lam=config.elsa_lam,
        elsa_gamma=config.elsa_gamma
    ).to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"🏗️  Model Info:")
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    print(f"   Model size: ~{total_params * 4 / 1024**2:.1f} MB")
    
    # Create optimizer and scheduler
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=config.lr, 
        weight_decay=config.weight_decay
    )
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=config.max_lr, 
        steps_per_epoch=len(train_loader), 
        epochs=config.epochs
    )
    
    # Create loss function with class balancing
    loss_fn = AdaptiveEfficiencyLoss(class_counts=class_counts.tolist())
    
    print(f"🚀 Starting training...")
    print(f"   Device: {device}")
    print(f"   Batch size: {config.batch_size}")
    print(f"   Epochs: {config.epochs}")
    print(f"   Learning rate: {config.lr}")
    print(f"   Max learning rate: {config.max_lr}")
    
    # Training history
    history = {
        "train": {"loss": [], "acc": []},
        "val": {"loss": [], "acc": []},
    }
    
    best_acc = 0.0
    start_time = time.time()
    
    # Training loop
    for epoch in range(1, config.epochs + 1):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch}/{config.epochs}")
        print('='*50)
        
        # Train
        train_info = train_one_epoch(
            model, train_loader, optimizer, loss_fn, device, epoch, scheduler
        )
        
        # Validate
        val_info = validate(model, val_loader, loss_fn, device)
        
        # Print epoch results
        print(f"\n📊 Epoch {epoch} Results:")
        print(f"   Train Loss: {train_info['loss']:.4f} | Train Acc@1: {train_info['acc1']:.2f}%")
        print(f"   Val Loss:   {val_info['val_loss']:.4f} | Val Acc@1:   {val_info['val_acc1']:.2f}%")
        
        if train_info['head'] > 0:
            print(f"   Head Selection: {train_info['head']:.3f} | Layer Selection: {train_info['layer']:.3f}")
        
        # Update history
        history['train']['loss'].append(train_info['loss'])
        history['train']['acc'].append(train_info['acc1'])
        history['val']['loss'].append(val_info['val_loss'])
        history['val']['acc'].append(val_info['val_acc1'])
        
        # Save best model
        if val_info['val_acc1'] > best_acc:
            best_acc = val_info['val_acc1']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_acc': best_acc,
                'config': config,
            }, config.ckpt_path)
            print(f"   💾 Saved new best model! (Acc@1: {best_acc:.2f}%)")
        
        # GPU memory info
        if torch.cuda.is_available():
            memory_used = torch.cuda.max_memory_allocated() / 1024**3
            print(f"   🔥 GPU Memory: {memory_used:.2f} GB / 8.0 GB")
    
    # Training completed
    total_time = time.time() - start_time
    print(f"\n🎉 Training completed!")
    print(f"   Total time: {total_time/3600:.2f} hours")
    print(f"   Best validation accuracy: {best_acc:.2f}%")
    print(f"   Model saved to: {config.ckpt_path}")
    
    # Plot training curves
    print(f"\n📈 Plotting training curves...")
    plot_curves(history, config.plot_path)
    print(f"   Curves saved to: {config.plot_path}")
    
    return model, history, best_acc

def test_model():
    """Test the trained model"""
    if not os.path.exists(config.ckpt_path):
        print(f"❌ No saved model found at {config.ckpt_path}")
        return
    
    print("🔍 Loading best model for testing...")
    
    # Load the saved model
    checkpoint = torch.load(config.ckpt_path)
    
    # Get number of classes
    train_df = pd.read_csv(config.train_csv)
    num_classes = train_df['label_id'].nunique()
    
    # Recreate model
    model = AdaptiveViTWithELSA(
        img_size=config.img_size,
        patch_size=config.patch_size,
        num_classes=num_classes,
        embed_dim=config.embed_dim,
        depth=config.depth,
        num_heads=config.num_heads,
        mlp_ratio=config.mlp_ratio,
        drop_rate=config.drop_rate,
        drop_path_rate=config.drop_path_rate,
        use_elsa=config.use_elsa,
        elsa_kernel_size=config.elsa_kernel_size,
        elsa_num_heads=config.elsa_num_heads,
        elsa_mlp_ratio=config.elsa_mlp_ratio,
        elsa_lam=config.elsa_lam,
        elsa_gamma=config.elsa_gamma
    ).to(device)
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Test on validation set
    _, val_transform = get_transforms(config.img_size)
    val_dataset = CustomImageDataset(config.test_csv, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    
    # Simple test
    loss_fn = AdaptiveEfficiencyLoss()
    test_results = validate(model, val_loader, loss_fn, device)
    
    print(f"🎯 Test Results:")
    print(f"   Test Accuracy: {test_results['val_acc1']:.2f}%")
    print(f"   Test Loss: {test_results['val_loss']:.4f}")
    
    return model

def predict_single_image(image_path, model=None):
    """Predict a single image"""
    if model is None:
        model = test_model()
    
    if model is None:
        return None
    
    # Load and preprocess image
    _, val_transform = get_transforms(config.img_size)
    
    try:
        image = Image.open(image_path).convert('RGB')
        image_tensor = val_transform(image).unsqueeze(0).to(device)
        
        model.eval()
        with torch.no_grad():
            logits, _, _ = model(image_tensor)
            probabilities = F.softmax(logits, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0][predicted_class].item()
        
        print(f"🔮 Prediction for {image_path}:")
        print(f"   Predicted class: {predicted_class}")
        print(f"   Confidence: {confidence:.4f}")
        
        return predicted_class, confidence
        
    except Exception as e:
        print(f"❌ Error predicting image {image_path}: {e}")
        return None

# =========================
# Quick GPU Test
# =========================
def gpu_test():
    """Quick test to ensure GPU is working"""
    print("🧪 GPU Test:")
    print(f"   CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"   GPU name: {torch.cuda.get_device_name(0)}")
        print(f"   GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        
        # Simple tensor operations
        x = torch.randn(1000, 1000).cuda()
        y = torch.randn(1000, 1000).cuda()
        z = torch.matmul(x, y)
        
        print(f"   ✅ GPU computation successful!")
        print(f"   Tensor device: {z.device}")
        
        # Clear memory
        del x, y, z
        torch.cuda.empty_cache()
    else:
        print("   ❌ GPU not available - will use CPU (very slow)")

# Run GPU test
gpu_test()

print("\n" + "="*60)
print("🚀 ELSA + Adaptive Vision Transformer Ready!")
print("="*60)
print("\n📋 Available functions:")
print("   • main() - Start training")
print("   • test_model() - Test the trained model") 
print("   • predict_single_image(path) - Predict single image")
print("   • gpu_test() - Test GPU functionality")
print("\n💡 To start training, run: main()")

🧪 GPU Test:
   CUDA available: True
   GPU name: NVIDIA GeForce RTX 3070 Laptop GPU
   GPU memory: 8.0 GB
   ✅ GPU computation successful!
   Tensor device: cuda:0

🚀 ELSA + Adaptive Vision Transformer Ready!

📋 Available functions:
   • main() - Start training
   • test_model() - Test the trained model
   • predict_single_image(path) - Predict single image
   • gpu_test() - Test GPU functionality

💡 To start training, run: main()


In [None]:
main()

📂 Loading datasets...
📊 Dataset Info:
   Training samples: 28068
   Validation samples: 7018
   Number of classes: 20
   Class distribution: [1843  578 2742  898  840  520  808 3449 2600  478  862 2081 1138 2811
  811 1213  424  965  834 2173]
🔧 Building model...
🏗️  Model Info:
   Total parameters: 16,085,986
   Trainable parameters: 16,085,986
   Model size: ~61.4 MB
🚀 Starting training...
   Device: cuda
   Batch size: 8
   Epochs: 20
   Learning rate: 0.0003
   Max learning rate: 0.001

Epoch 1/20
