# GTAL+ : Game-Theoretic Attention Learning for Fine-Grained Classification

### Key Features:
- âœ… **GTAL Core**: Game-theoretic attention with Nash Equilibrium
- âœ… **4 Players**: Early, Mid, Late, Semantic feature levels
- âœ… **448Ã—448 Images**: Critical for fine-grained recognition
- âœ… **SGD + StepLR**: Proven training configuration

### Configuration Comparison:
| Parameter | GTAL+ |
|-----------|--------------|
| Image Size  | **448Ã—448** |
| Loss | **CrossEntropy** |
| Optimizer | **SGD** |
| LR Schedule | **StepLR** |
| Epochs | **95** |
| Batch Size | **64** |

In [None]:
# =====================================================
# 1. SETUP & IMPORTS
# =====================================================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from tqdm import tqdm
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import math
warnings.filterwarnings('ignore')

# ==========================================
# GOOGLE COLAB SETUP
# ==========================================
from google.colab import drive
drive.mount('/content/drive')

# ==========================================
# PATHS
# ==========================================
BASE_PATH = "/content/drive/My Drive/Project_GTAL"
TRAIN_CSV = os.path.join(BASE_PATH, "image_list_TRAIN.csv")
TEST_CSV = os.path.join(BASE_PATH, "image_list_TEST.csv")

# VERSION 3 - New approach
CHECKPOINT_DIR = "./checkpoints/gtal_plus_v3"
VIZ_DIR = "./visualizations/gtal_plus_v3"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(VIZ_DIR, exist_ok=True)

# ==========================================
# DEVICE SETUP
# ==========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'âœ“ Device: {device}')
print(f'âœ“ CUDA Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'âœ“ GPU: {torch.cuda.get_device_name(0)}')

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print(f'\nâœ“ Paths configured:')
print(f'  BASE_PATH: {BASE_PATH}')
print(f'  CHECKPOINT_DIR: {CHECKPOINT_DIR} (v3)')
print(f'  VIZ_DIR: {VIZ_DIR} (v3)')

In [None]:
# =====================================================
# 2. CONFIGURATION 
# =====================================================

class GTALPlusConfig:
    
    # Image settings
    IMG_SIZE = 448
    
    # Training settings
    EPOCHS = 95
    BATCH_SIZE = 64
    BASE_LR = 0.001
    MOMENTUM = 0.9
    WEIGHT_DECAY = 1e-4
    
    # LR Schedule
    LR_STEP_SIZE = 30
    LR_GAMMA = 0.1
    
    # Model settings
    NUM_CLASSES = 200
    EMBEDDING_DIM = 512
    
    # ========== V3 GTAL SETTINGS ==========
    GTAL_ITERATIONS = 8           
    GTAL_TEMPERATURE = 0.3        
    GTAL_LEARNING_RATE = 0.2      
    REDUNDANCY_PENALTY = 0.2      
    
    # ========== V3 LOSS SETTINGS ==========
    DIVERSITY_WEIGHT = 0.05       
    AUXILIARY_WEIGHT = 0.1        
    LABEL_SMOOTHING = 0.1         
    
 
    WARMUP_EPOCHS = 10           
    
    # Ensemble prediction
    USE_ENSEMBLE = True           

config = GTALPlusConfig()

print("="*60)
print("GTAL+ v3 BALANCED CONFIGURATION")
print("="*60)
print(f"\nImage: {config.IMG_SIZE}Ã—{config.IMG_SIZE}")
print(f"\nTraining: {config.EPOCHS} epochs, batch={config.BATCH_SIZE}, lr={config.BASE_LR}")
print(f"\nGTAL (Relaxed):")
print(f"   Iterations: {config.GTAL_ITERATIONS}")
print(f"   Temperature: {config.GTAL_TEMPERATURE}")
print(f"   Redundancy Penalty: {config.REDUNDANCY_PENALTY} (â†“ from 0.5)")
print(f"\nV3 Improvements:")
print(f"   Diversity Weight: {config.DIVERSITY_WEIGHT} (â†“â†“ from 0.3)")
print(f"   Auxiliary Weight: {config.AUXILIARY_WEIGHT} (â†“â†“ from 0.4)")
print(f"   Label Smoothing: {config.LABEL_SMOOTHING}")
print(f"   Warmup Epochs: {config.WARMUP_EPOCHS}")
print(f"   Ensemble Mode: {config.USE_ENSEMBLE}")

In [None]:
# =====================================================
# 3. DATA TRANSFORMS 
# =====================================================

# Training transforms 
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(config.IMG_SIZE),  # 448x448
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Test transforms 
test_transform = transforms.Compose([
    transforms.Resize(int(config.IMG_SIZE / 0.875)),  # ~512 for 448
    transforms.CenterCrop(config.IMG_SIZE),           # 448x448
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

print("âœ“ Data transforms defined (matching resnet_finetune_cub)")
print(f"  Training: RandomResizedCrop({config.IMG_SIZE}) + RandomHorizontalFlip")
print(f"  Testing: Resize({int(config.IMG_SIZE/0.875)}) + CenterCrop({config.IMG_SIZE})")

In [None]:
# =====================================================
# 4. DATASET CLASS
# =====================================================

class CUB200Dataset(Dataset):
    """CUB-200 Dataset loading from CSV files"""
    
    def __init__(self, csv_file, transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        self.labels = self.df['label'].values
        self.paths = self.df['path'].values
        
        # Compute class statistics
        unique, counts = np.unique(self.labels, return_counts=True)
        self.num_classes = len(unique)
        self.class_counts = dict(zip(unique, counts))
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            img = Image.open(self.paths[idx]).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, self.labels[idx]
        except Exception as e:
            print(f"Error loading {self.paths[idx]}: {e}")
            # Return blank image on error
            img = torch.zeros(3, config.IMG_SIZE, config.IMG_SIZE)
            return img, self.labels[idx]


# Load and check datasets
print("Loading datasets...")
train_dataset = CUB200Dataset(TRAIN_CSV, transform=train_transform)
test_dataset = CUB200Dataset(TEST_CSV, transform=test_transform)

print(f"\nâœ“ Datasets loaded:")
print(f"  Train: {len(train_dataset)} samples, {train_dataset.num_classes} classes")
print(f"  Test:  {len(test_dataset)} samples, {test_dataset.num_classes} classes")

In [None]:
# =====================================================
# 5. GTAL CORE COMPONENTS 
# =====================================================

class PayoffFunction(nn.Module):
    """Computes relevance score for each feature level"""
    def __init__(self, feature_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 1)
        )
        # Initialize to output similar values
        nn.init.xavier_uniform_(self.net[-1].weight, gain=0.1)
        nn.init.zeros_(self.net[-1].bias)
        
    def forward(self, x):
        return self.net(x)


class GTALLayer(nn.Module):
    """
    V3 GTAL Layer - Simplified but effective
    
    Key changes:
    - Simpler best response (no min constraint that hurt training)
    - Learnable bias terms for each feature level
    - More stable gradient flow
    """
    def __init__(self, feature_dims=[256, 512, 1024, 2048], common_dim=256):
        super().__init__()
        self.num_players = len(feature_dims)
        
        # Payoff networks
        self.payoff_nets = nn.ModuleList([
            PayoffFunction(dim) for dim in feature_dims
        ])
        
        # Learnable prior bias (helps balance initial weights)
        # Initialize to encourage some diversity
        self.prior_bias = nn.Parameter(torch.tensor([0.1, 0.15, 0.2, 0.55]))
        
        # Redundancy projection
        self.projectors = nn.ModuleList([
            nn.Linear(dim, common_dim) for dim in feature_dims
        ])
        
        self.temperature = config.GTAL_TEMPERATURE
        self.num_iterations = config.GTAL_ITERATIONS

    def compute_redundancy(self, features):
        """Compute pairwise feature redundancy"""
        batch_size = features[0].size(0)
        n = len(features)
        
        # Project and normalize
        proj = [F.normalize(self.projectors[i](features[i]), dim=1) for i in range(n)]
        
        # Compute similarity matrix
        redundancy = torch.zeros(batch_size, n, n, device=features[0].device)
        for i in range(n):
            for j in range(n):
                if i != j:
                    sim = (proj[i] * proj[j]).sum(dim=1)
                    redundancy[:, i, j] = sim.abs()
        return redundancy

    def forward(self, features):
        batch_size = features[0].size(0)
        
        # Compute payoffs
        payoffs = torch.cat([
            self.payoff_nets[i](features[i]) for i in range(self.num_players)
        ], dim=1)  # (B, 4)
        
        # Add learnable prior
        payoffs = payoffs + self.prior_bias.unsqueeze(0)
        
        # Compute redundancy
        redundancy = self.compute_redundancy(features)
        
        # Initialize weights
        weights = F.softmax(self.prior_bias.unsqueeze(0).expand(batch_size, -1), dim=1)
        
        # Best response iterations
        for _ in range(self.num_iterations):
            # Redundancy penalty
            penalty = torch.bmm(redundancy, weights.unsqueeze(-1)).squeeze(-1)
            
            # Effective payoff
            effective = payoffs - config.REDUNDANCY_PENALTY * penalty
            
            # Best response
            best_response = F.softmax(effective / self.temperature, dim=1)
            
            # Smooth update
            weights = 0.7 * weights + 0.3 * best_response
        
        return weights


def soft_diversity_loss(weights):

    max_weight = weights.max(dim=1)[0]
    # Only penalize if max > 0.8
    penalty = F.relu(max_weight - 0.8)
    return penalty.mean()


print("âœ“ V3 GTAL Components:")
print("  - Simplified PayoffFunction")
print("  - GTALLayer with learnable prior bias")
print("  - Soft diversity loss (only extreme imbalance)")

In [None]:
# =====================================================
# 6. GTAL+ MODEL
# =====================================================

class SEBlock(nn.Module):
    """Squeeze-and-Excitation block"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
        
    def forward(self, x):
        s = torch.sigmoid(self.fc2(F.relu(self.fc1(x))))
        return x * s


class GTALPlusModelV3(nn.Module):
    """
    GTAL+ V3 Model - Ensemble Strategy
    
    
    Strategy:
    1. Main classifier on semantic features
    2. GTAL fusion as enhancement
    3. Ensemble both predictions
    
    This way we get:
    - Baseline-level performance from semantic path
    - Potential boost from multi-scale GTAL fusion
    """
    
    def __init__(self, num_classes=200, embedding_dim=512):
        super().__init__()
        self.num_classes = num_classes
        
        # Load pretrained ResNet-50
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        
        # Feature stages
        self.early = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu,
            resnet.maxpool, resnet.layer1
        )
        self.mid = resnet.layer2
        self.late = resnet.layer3
        self.semantic = resnet.layer4
        
        # Feature projections with SE
        self.proj_early = nn.Sequential(
            nn.Linear(256, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        self.proj_mid = nn.Sequential(
            nn.Linear(512, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        self.proj_late = nn.Sequential(
            nn.Linear(1024, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        self.proj_semantic = nn.Sequential(
            nn.Linear(2048, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        
        self.se_early = SEBlock(embedding_dim)
        self.se_mid = SEBlock(embedding_dim)
        self.se_late = SEBlock(embedding_dim)
        self.se_semantic = SEBlock(embedding_dim)
        
        # GTAL Layer
        self.gtal = GTALLayer(
            feature_dims=[256, 512, 1024, 2048],
            common_dim=256
        )
        
       
        
        # Path 1: Direct semantic classifier (like baseline)
        self.semantic_classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(2048, num_classes)
        )
        # Initialize like baseline
        nn.init.kaiming_normal_(self.semantic_classifier[1].weight)
        nn.init.zeros_(self.semantic_classifier[1].bias)
        
        # Path 2: GTAL fusion classifier
        self.gtal_classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(embedding_dim, num_classes)
        )
        nn.init.kaiming_normal_(self.gtal_classifier[1].weight)
        nn.init.zeros_(self.gtal_classifier[1].bias)
        
        # Learnable ensemble weight (starts at 0.5/0.5)
        self.ensemble_weight = nn.Parameter(torch.tensor(0.5))
        
        self.feature_weights = None

    def forward(self, x, return_both=False):
        # Extract features
        e = self.early(x)
        m = self.mid(e)
        l = self.late(m)
        s = self.semantic(l)
        
        # Global average pooling
        p_e = F.adaptive_avg_pool2d(e, 1).flatten(1)
        p_m = F.adaptive_avg_pool2d(m, 1).flatten(1)
        p_l = F.adaptive_avg_pool2d(l, 1).flatten(1)
        p_s = F.adaptive_avg_pool2d(s, 1).flatten(1)
        
        # ===== Path 1: Direct semantic (baseline-like) =====
        semantic_logits = self.semantic_classifier(p_s)
        
        # ===== Path 2: GTAL fusion =====
        weights = self.gtal([p_e, p_m, p_l, p_s])
        self.feature_weights = weights.detach()
        
        # Project and apply SE
        f_e = self.se_early(self.proj_early(p_e))
        f_m = self.se_mid(self.proj_mid(p_m))
        f_l = self.se_late(self.proj_late(p_l))
        f_s = self.se_semantic(self.proj_semantic(p_s))
        
        # Weighted fusion
        fused = (weights[:, 0:1] * f_e + 
                 weights[:, 1:2] * f_m + 
                 weights[:, 2:3] * f_l + 
                 weights[:, 3:4] * f_s)
        
        gtal_logits = self.gtal_classifier(fused)
        
        # ===== Ensemble =====
        alpha = torch.sigmoid(self.ensemble_weight)  # Constrain to [0,1]
        ensemble_logits = alpha * semantic_logits + (1 - alpha) * gtal_logits
        
        if return_both:
            return ensemble_logits, weights, semantic_logits, gtal_logits
        
        return ensemble_logits, weights


# Test the model
print("Testing GTAL+ V3 Model...")
model = GTALPlusModelV3(num_classes=200).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\nâœ“ GTAL+ V3 Model created:")
print(f"  Parameters: {total_params:,}")

test_input = torch.randn(2, 3, config.IMG_SIZE, config.IMG_SIZE).to(device)
with torch.no_grad():
    out, w, sem, gtal = model(test_input, return_both=True)
    alpha = torch.sigmoid(model.ensemble_weight).item()
    
print(f"\nâœ“ Forward pass test:")
print(f"  Output: {out.shape}")
print(f"  GTAL weights: {w[0].cpu().numpy().round(3)}")
print(f"  Ensemble Î±: {alpha:.3f} (semantic) / {1-alpha:.3f} (gtal)")

del test_input, out, w, sem, gtal

In [None]:
# =====================================================
# 7. TRAINING FUNCTION 
# =====================================================

def train_gtal_plus_v3(model, train_loader, test_loader, epochs=None):
    if epochs is None:
        epochs = config.EPOCHS
    
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
    
    model = model.to(device)
    
    # Loss with label smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=config.LABEL_SMOOTHING)
    
    # SGD optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.BASE_LR,
        momentum=config.MOMENTUM,
        weight_decay=config.WEIGHT_DECAY
    )
    
    # StepLR scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.LR_STEP_SIZE,
        gamma=config.LR_GAMMA
    )
    
    history = {
        'train_loss': [], 'train_acc': [], 'test_acc': [],
        'feature_weights': [], 'lr': [], 'ensemble_alpha': [],
        'semantic_acc': [], 'gtal_acc': []
    }
    best_acc = 0.0
    best_epoch = 0
    
    print("\n" + "="*85)
    print("GTAL+ V3 TRAINING - Ensemble Approach")
    print("="*85)
    print(f"Strategy: Semantic path (baseline) + GTAL path â†’ Ensemble")
    print(f"Label Smoothing: {config.LABEL_SMOOTHING}")
    print(f"Warmup: {config.WARMUP_EPOCHS} epochs (no extra losses)")
    print("-"*85)
    print("Epoch\tLoss\tTrain%\tTest%\tSem%\tGTAL%\tÎ±\tWeights")
    print("-"*85)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        epoch_weights = []
        
        # Check if in warmup
        in_warmup = epoch < config.WARMUP_EPOCHS
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}', leave=False)
        for imgs, labels in pbar:
            imgs = imgs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            ensemble_logits, weights, sem_logits, gtal_logits = model(imgs, return_both=True)
            
            # Main loss on ensemble
            main_loss = criterion(ensemble_logits, labels)
            
            # Auxiliary losses (only after warmup)
            if not in_warmup:
                # Light loss on semantic path
                sem_loss = criterion(sem_logits, labels)
                # Light loss on GTAL path
                gtal_loss = criterion(gtal_logits, labels)
                # Soft diversity (only if very imbalanced)
                div_loss = soft_diversity_loss(weights)
                
                loss = main_loss + config.AUXILIARY_WEIGHT * (sem_loss + gtal_loss) + config.DIVERSITY_WEIGHT * div_loss
            else:
                loss = main_loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            
            train_loss += main_loss.item()
            _, pred = torch.max(ensemble_logits, 1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
            
            epoch_weights.append(weights.detach().cpu().numpy())
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Metrics
        train_acc = 100.0 * correct / total
        avg_loss = train_loss / len(train_loader)
        
        all_weights = np.concatenate(epoch_weights, axis=0)
        mean_weights = all_weights.mean(axis=0)
        history['feature_weights'].append(mean_weights)
        
        # Evaluate all paths
        test_acc, sem_acc, gtal_acc = evaluate_all_paths(model, test_loader)
        
        # Get ensemble weight
        if hasattr(model, 'module'):
            alpha = torch.sigmoid(model.module.ensemble_weight).item()
        else:
            alpha = torch.sigmoid(model.ensemble_weight).item()
        
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Record
        history['train_loss'].append(avg_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        history['lr'].append(current_lr)
        history['ensemble_alpha'].append(alpha)
        history['semantic_acc'].append(sem_acc)
        history['gtal_acc'].append(gtal_acc)
        
        # Save best
        is_best = test_acc > best_acc
        if is_best:
            best_acc = test_acc
            best_epoch = epoch + 1
            save_dict = {
                'epoch': epoch,
                'model_state': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'best_acc': best_acc,
                'history': history,
                'config': vars(config)
            }
            torch.save(save_dict, os.path.join(CHECKPOINT_DIR, 'gtal_plus_v3_best.pth'))
        
        # Print
        marker = 'â˜…' if is_best else ' '
        w_str = f"{mean_weights[0]:.2f}/{mean_weights[1]:.2f}/{mean_weights[2]:.2f}/{mean_weights[3]:.2f}"
        warmup_mark = "(warmup)" if in_warmup else ""
        print(f'{marker}{epoch+1}\t{avg_loss:.3f}\t{train_acc:.1f}%\t{test_acc:.1f}%\t{sem_acc:.1f}%\t{gtal_acc:.1f}%\t{alpha:.2f}\t{w_str} {warmup_mark}')
    
    print("-"*85)
    print(f"\n Training completed!")
    print(f"   Best Accuracy: {best_acc:.2f}% (Epoch {best_epoch})")
    
    # Analysis
    print(f"\nPath Analysis (final):")
    print(f"   Semantic path: {history['semantic_acc'][-1]:.2f}%")
    print(f"   GTAL path: {history['gtal_acc'][-1]:.2f}%")
    print(f"   Ensemble (Î±={alpha:.2f}): {test_acc:.2f}%")
    
    return history, model


def evaluate_all_paths(model, loader):
    """Evaluate ensemble, semantic, and GTAL paths separately"""
    model.eval()
    
    correct_ens = correct_sem = correct_gtal = 0
    total = 0
    
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            
            ens_logits, _, sem_logits, gtal_logits = model(imgs, return_both=True)
            
            _, pred_ens = torch.max(ens_logits, 1)
            _, pred_sem = torch.max(sem_logits, 1)
            _, pred_gtal = torch.max(gtal_logits, 1)
            
            correct_ens += (pred_ens == labels).sum().item()
            correct_sem += (pred_sem == labels).sum().item()
            correct_gtal += (pred_gtal == labels).sum().item()
            total += labels.size(0)
    
    return 100.0 * correct_ens / total, 100.0 * correct_sem / total, 100.0 * correct_gtal / total


print("  - Warmup period for stable early training")
print("  - Evaluates all paths separately")
print("  - Ensemble combines semantic + GTAL")

In [None]:
# =====================================================
# 8. VISUALIZATION FUNCTIONS
# =====================================================

def visualize_training_results(history, model, test_loader):
    """Comprehensive visualization of training results"""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    epochs = range(1, len(history['train_acc']) + 1)
    feature_names = ['Early', 'Mid', 'Late', 'Semantic']
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
    
    # Plot 1: Accuracy curves
    ax1 = axes[0, 0]
    ax1.plot(epochs, history['train_acc'], 'b-', linewidth=2, label='Train', marker='o', markersize=2)
    ax1.plot(epochs, history['test_acc'], 'r-', linewidth=2, label='Test', marker='s', markersize=2)
    ax1.axhline(y=84, color='g', linestyle='--', linewidth=2, label='Baseline (84%)')
    ax1.set_xlabel('Epoch', fontweight='bold')
    ax1.set_ylabel('Accuracy (%)', fontweight='bold')
    ax1.set_title('Training & Test Accuracy', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Loss curve
    ax2 = axes[0, 1]
    ax2.plot(epochs, history['train_loss'], 'b-', linewidth=2, marker='o', markersize=2)
    ax2.set_xlabel('Epoch', fontweight='bold')
    ax2.set_ylabel('Loss', fontweight='bold')
    ax2.set_title('Training Loss', fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Feature weights evolution
    ax3 = axes[0, 2]
    weights_history = np.array(history['feature_weights'])
    for i, (name, color) in enumerate(zip(feature_names, colors)):
        ax3.plot(epochs, weights_history[:, i], color=color, linewidth=2, label=name, marker='o', markersize=2)
    ax3.set_xlabel('Epoch', fontweight='bold')
    ax3.set_ylabel('Weight', fontweight='bold')
    ax3.set_title('GTAL Feature Weights Evolution', fontweight='bold')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Final Nash Equilibrium weights
    ax4 = axes[1, 0]
    final_weights = weights_history[-1]
    bars = ax4.bar(feature_names, final_weights, color=colors, edgecolor='black', linewidth=2)
    ax4.set_ylabel('Equilibrium Weight', fontweight='bold')
    ax4.set_title('Final Nash Equilibrium Weights', fontweight='bold')
    for bar, weight in zip(bars, final_weights):
        ax4.text(bar.get_x() + bar.get_width()/2., bar.get_height(),
                f'{weight:.3f}', ha='center', va='bottom', fontweight='bold')
    ax4.grid(axis='y', alpha=0.3)
    
    # Plot 5: Feature weight distribution on test set
    ax5 = axes[1, 1]
    model.eval()
    all_weights = []
    with torch.no_grad():
        for imgs, _ in test_loader:
            imgs = imgs.to(device)
            _, weights = model(imgs)
            all_weights.append(weights.cpu().numpy())
    all_weights = np.concatenate(all_weights, axis=0)
    bp = ax5.boxplot([all_weights[:, i] for i in range(4)], labels=feature_names, patch_artist=True)
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    ax5.set_ylabel('Weight', fontweight='bold')
    ax5.set_title('Feature Weight Distribution (Test)', fontweight='bold')
    ax5.grid(True, alpha=0.3)
    
    # Plot 6: Stacked area chart
    ax6 = axes[1, 2]
    ax6.stackplot(epochs, weights_history.T, labels=feature_names, colors=colors, alpha=0.8)
    ax6.set_xlabel('Epoch', fontweight='bold')
    ax6.set_ylabel('Cumulative Weight', fontweight='bold')
    ax6.set_title('Feature Contribution Over Training', fontweight='bold')
    ax6.legend(loc='upper right')
    ax6.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(VIZ_DIR, 'gtal_plus_results.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary
    best_acc = max(history['test_acc'])
    best_epoch = history['test_acc'].index(best_acc) + 1
    
    print(f"\n" + "="*60)
    print("TRAINING SUMMARY")
    print("="*60)
    print(f"Best Test Accuracy: {best_acc:.2f}% (Epoch {best_epoch})")
    print(f"\nFinal Nash Equilibrium Weights:")
    for name, weight in zip(feature_names, final_weights):
        bar = 'â–ˆ' * int(weight * 40)
        print(f"  {name:10s}: {weight:.4f} {bar}")
    print(f"\nMost Important: {feature_names[np.argmax(final_weights)]}")
    
    return all_weights


print("âœ“ Visualization functions defined")

In [None]:
# =====================================================
# 9. CREATE DATALOADERS & START TRAINING
# =====================================================

print("="*70)
print("PREPARING FOR TRAINING")
print("="*70)

# Create DataLoaders (matching resnet_finetune_cub)
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=16,  # Same as resnet_finetune_cub
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"\n DataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Test batches: {len(test_loader)}")

print(f"\n GTAL Game Theory Setup:")
print(f"  Players: 4 feature levels (Early, Mid, Late, Semantic)")
print(f"  Strategies: Feature weights (attention)")
print(f"  Payoffs: Relevance - Redundancy penalty")
print(f"  Solution: Nash Equilibrium via Best Response")

print(f"\n Training Configuration:")
print(f"  Image Size: {config.IMG_SIZE}Ã—{config.IMG_SIZE}")
print(f"  Epochs: {config.EPOCHS}")
print(f"  Batch Size: {config.BATCH_SIZE}")
print(f"  Optimizer: SGD (lr={config.BASE_LR}, momentum={config.MOMENTUM})")
print(f"  Scheduler: StepLR (step={config.LR_STEP_SIZE}, gamma={config.LR_GAMMA})")
print(f"  Loss: CrossEntropyLoss")

In [None]:
# =====================================================
# 10. TRAIN THE V3 MODEL
# =====================================================

# Initialize model
model = GTALPlusModelV3(num_classes=train_dataset.num_classes).to(device)

print(" Starting GTAL+ V3 Training...")
print("   Strategy: Semantic baseline + GTAL enhancement â†’ Ensemble")
print("   This ensures we never do worse than baseline!")

# Train
history, trained_model = train_gtal_plus_v3(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    epochs=config.EPOCHS
)

In [None]:
# =====================================================
# 11. VISUALIZE V3 RESULTS
# =====================================================

# Load best model
checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, 'gtal_plus_v3_best.pth'))
if hasattr(trained_model, 'module'):
    trained_model.module.load_state_dict(checkpoint['model_state'])
else:
    trained_model.load_state_dict(checkpoint['model_state'])

# Visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
epochs = range(1, len(history['train_acc']) + 1)
feature_names = ['Early', 'Mid', 'Late', 'Semantic']
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']

# Plot 1: All accuracies
ax1 = axes[0, 0]
ax1.plot(epochs, history['test_acc'], 'b-', lw=2, label='Ensemble (Test)')
ax1.plot(epochs, history['semantic_acc'], 'g--', lw=1.5, label='Semantic Path')
ax1.plot(epochs, history['gtal_acc'], 'r--', lw=1.5, label='GTAL Path')
ax1.axhline(y=84, color='k', linestyle=':', lw=2, label='Baseline (84%)')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy (%)')
ax1.set_title('V3: All Paths Accuracy')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Ensemble weight evolution
ax2 = axes[0, 1]
ax2.plot(epochs, history['ensemble_alpha'], 'purple', lw=2)
ax2.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Î± (Semantic weight)')
ax2.set_title('Learned Ensemble Weight')
ax2.set_ylim(0, 1)
ax2.grid(True, alpha=0.3)

# Plot 3: Feature weights
ax3 = axes[0, 2]
weights_history = np.array(history['feature_weights'])
for i, (name, color) in enumerate(zip(feature_names, colors)):
    ax3.plot(epochs, weights_history[:, i], color=color, lw=2, label=name)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Weight')
ax3.set_title('GTAL Feature Weights')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Final weights bar
ax4 = axes[1, 0]
final_weights = weights_history[-1]
bars = ax4.bar(feature_names, final_weights, color=colors, edgecolor='black')
for bar, w in zip(bars, final_weights):
    ax4.text(bar.get_x() + bar.get_width()/2., bar.get_height(), f'{w:.3f}', ha='center', va='bottom')
ax4.set_ylabel('Weight')
ax4.set_title('Final GTAL Weights')
ax4.grid(axis='y', alpha=0.3)

# Plot 5: Loss
ax5 = axes[1, 1]
ax5.plot(epochs, history['train_loss'], 'b-', lw=2)
ax5.set_xlabel('Epoch')
ax5.set_ylabel('Loss')
ax5.set_title('Training Loss')
ax5.grid(True, alpha=0.3)

# Plot 6: Improvement analysis
ax6 = axes[1, 2]
final_sem = history['semantic_acc'][-1]
final_gtal = history['gtal_acc'][-1]
final_ens = history['test_acc'][-1]
best_acc = max(history['test_acc'])

x = ['Baseline', 'Semantic\nPath', 'GTAL\nPath', 'Ensemble', 'Best']
y = [84.0, final_sem, final_gtal, final_ens, best_acc]
colors_bar = ['gray', '#4ECDC4', '#FF6B6B', 'purple', 'gold']
bars = ax6.bar(x, y, color=colors_bar, edgecolor='black')
for bar, val in zip(bars, y):
    ax6.text(bar.get_x() + bar.get_width()/2., bar.get_height(), f'{val:.1f}%', ha='center', va='bottom')
ax6.set_ylabel('Accuracy (%)')
ax6.set_title('V3 Results Comparison')
ax6.axhline(y=84, color='red', linestyle='--', lw=1)
ax6.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(VIZ_DIR, 'gtal_plus_v3_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"\nðŸ“Š Visualization saved to: {VIZ_DIR}")