In [None]:
# Step 1: Create output folders for all stages of the adaptive ensemble pipeline

import os

KAGGLE_WORKING = '/kaggle/working'
OUTPUT_ROOT = os.path.join(KAGGLE_WORKING, "outputs")
FEATURES_DIR = os.path.join(OUTPUT_ROOT, "features")           # for extracted regions/crops
SPLITS_DIR = os.path.join(OUTPUT_ROOT, "splits")               # for train/val/test splits
MODELS_DIR = os.path.join(OUTPUT_ROOT, "models")               # trained models
LOGS_DIR = os.path.join(OUTPUT_ROOT, "logs")                   # training logs, CSVs
METRICS_DIR = os.path.join(OUTPUT_ROOT, "metrics")             # evaluation metrics as JSON/CSV
VIS_DIR = os.path.join(OUTPUT_ROOT, "visualizations")          # plots, confusion matrix, etc.
ENSEMBLE_DIR = os.path.join(OUTPUT_ROOT, "ensemble")           # specific directory for ensemble results
INTERP_DIR = os.path.join(OUTPUT_ROOT, "interpretability")     # interpretability visualizations
EXTRACT_VIS_DIR = os.path.join(VIS_DIR, "extracted_regions")   # visualizations of extracted regions

for d in [
    OUTPUT_ROOT, FEATURES_DIR, SPLITS_DIR, MODELS_DIR, LOGS_DIR,
    METRICS_DIR, VIS_DIR, ENSEMBLE_DIR, INTERP_DIR, EXTRACT_VIS_DIR
]:
    os.makedirs(d, exist_ok=True)

# Create test directory (common across all folds)
TEST_DIR = os.path.join(SPLITS_DIR, "test")
os.makedirs(TEST_DIR, exist_ok=True)

print(f"✅ Directory structure created for journal-quality outputs")

In [None]:
# Step 2: Set Global Seed for Reproducibility

import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Step 3: Install essential libraries for the adaptive ensemble

!pip install torch torchvision
!pip install timm transformers roboflow
!pip install grad-cam matplotlib seaborn pandas numpy tqdm scikit-learn

In [None]:
# Step 4: Download the COCO-format segmentation dataset from Roboflow

from roboflow import Roboflow
import os
import json

rf = Roboflow(api_key="Mvt9FCxE4mY6vBy5OG08")  # Replace with your key if needed
project = rf.workspace("urban-lake-wastef").project("another_approach_try")
version = project.version(4)
dataset = version.download("coco-segmentation")

DATA_ROOT = dataset.location
TRAIN_JSON = os.path.join(DATA_ROOT, 'train', '_annotations.coco.json')
IMG_DIR = os.path.join(DATA_ROOT, 'train')

print(f"Dataset downloaded to: {DATA_ROOT}")
print(f"Train JSON: {TRAIN_JSON}")
print(f"Image directory: {IMG_DIR}")

In [None]:
# Step 5: Extract foreground objects from COCO masks and save cropped images in FEATURES_DIR

from PIL import Image, ImageDraw
import numpy as np

with open(TRAIN_JSON) as f:
    ann_data = json.load(f)

cat_map = {c['id']: c['name'] for c in ann_data['categories']}
print(f"Waste categories: {list(cat_map.values())}")

# Track statistics
extracted_count = {cat: 0 for cat in cat_map.values()}
extraction_failures = 0

for ann in ann_data['annotations']:
    try:
        img_info = next(img for img in ann_data['images'] if img['id'] == ann['image_id'])
        img_path = os.path.join(IMG_DIR, img_info['file_name'])
        img = Image.open(img_path).convert('RGB')

        seg = ann['segmentation']
        mask = np.zeros((img_info['height'], img_info['width']), dtype=np.uint8)

        for poly in seg:
            pts = np.array(poly).reshape(-1, 2)
            m = Image.new('L', (img_info['width'], img_info['height']), 0)
            ImageDraw.Draw(m).polygon([tuple(p) for p in pts], outline=1, fill=1)
            mask = np.maximum(mask, np.array(m))

        if mask.sum() < 100:  # Skip very small masks
            extraction_failures += 1
            continue

        region = np.array(img) * mask[:, :, None]
        region_img = Image.fromarray(region)

        label = cat_map[ann['category_id']]
        out_dir = os.path.join(FEATURES_DIR, label)
        os.makedirs(out_dir, exist_ok=True)
        base = os.path.splitext(img_info['file_name'])[0]
        out_path = os.path.join(out_dir, f"{base}_{ann['id']}.png")
        region_img.save(out_path)
        
        extracted_count[label] += 1
    except Exception as e:
        print(f"Error processing annotation {ann['id']}: {e}")
        extraction_failures += 1

print(f"✅ Extracted regions saved to: {FEATURES_DIR}")
print(f"Extraction statistics:")
for category, count in extracted_count.items():
    print(f"  - {category}: {count} images")
print(f"  - Failed extractions: {extraction_failures}")

In [None]:
# Step 6: Visualize 10 original vs masked crops for journal-quality reporting

import matplotlib.pyplot as plt
from PIL import Image
import os
import random

def show_before_after_masked(original_dir, masked_dir, num_samples=10, vis_dir=None):
    """Visualize original images and their extracted foreground regions side-by-side"""
    print(f"\n🔍 Showing {num_samples} samples: original vs extracted region")
    
    classes = sorted([d for d in os.listdir(masked_dir) if os.path.isdir(os.path.join(masked_dir, d))])
    selected_images = []
    
    while len(selected_images) < num_samples and classes:
        # Select a class randomly (weighted by available samples)
        class_sample_counts = {}
        for c in classes:
            class_dir = os.path.join(masked_dir, c)
            if os.path.isdir(class_dir):
                files = [f for f in os.listdir(class_dir) if f.endswith(('.jpg', '.png'))]
                if files:
                    class_sample_counts[c] = len(files)
        
        if not class_sample_counts:
            break
            
        # Weight classes by inverse of sample count to balance visualization
        total_samples = sum(class_sample_counts.values())
        class_weights = {c: 1.0/count for c, count in class_sample_counts.items()}
        weight_sum = sum(class_weights.values())
        class_weights = {c: w/weight_sum for c, w in class_weights.items()}
        
        chosen_class = random.choices(
            list(class_weights.keys()),
            weights=list(class_weights.values()),
            k=1
        )[0]
        
        class_mask_dir = os.path.join(masked_dir, chosen_class)
        mask_files = os.listdir(class_mask_dir)
        if not mask_files:
            continue
            
        chosen_file = random.choice(mask_files)
        
        # Avoid duplicates
        if (chosen_class, chosen_file) not in selected_images:
            selected_images.append((chosen_class, chosen_file))
    
    # Create figure for all samples
    fig, axs = plt.subplots(num_samples, 2, figsize=(12, 4*num_samples))
    
    for idx, (class_name, file_name) in enumerate(selected_images):
        masked_path = os.path.join(masked_dir, class_name, file_name)
        
        # Extract original image name from the masked filename
        # Format: original_name_annotation_id.png
        original_basename = "_".join(file_name.split("_")[:-1])  # Remove last part after underscore
        
        # Try different extensions
        possible_exts = [".jpg", ".png", ".jpeg"]
        original_path = None
        for ext in possible_exts:
            candidate = os.path.join(original_dir, original_basename + ext)
            if os.path.exists(candidate):
                original_path = candidate
                break
        
        if not original_path:
            print(f"⚠️ Original not found for: {original_basename}")
            # Use a placeholder instead
            axs[idx, 0].text(0.5, 0.5, "Original image not found", 
                             ha='center', va='center')
            axs[idx, 0].set_title("Original (Not Found)")
            axs[idx, 0].axis("off")
        else:
            original_img = Image.open(original_path).convert("RGB")
            axs[idx, 0].imshow(original_img)
            axs[idx, 0].set_title("Original")
            axs[idx, 0].axis("off")
        
        # Display masked image
        masked_img = Image.open(masked_path).convert("RGB")
        axs[idx, 1].imshow(masked_img)
        axs[idx, 1].set_title(f"Extracted Region ({class_name})")
        axs[idx, 1].axis("off")
    
    plt.tight_layout()
    
    if vis_dir:
        save_path = os.path.join(vis_dir, "original_vs_extracted.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✅ Visualization saved to: {save_path}")
    
    plt.show()

# Show examples of extracted regions vs original images
show_before_after_masked(
    original_dir=IMG_DIR,
    masked_dir=FEATURES_DIR,
    num_samples=10,
    vis_dir=EXTRACT_VIS_DIR
)

In [None]:
# Step 7: Configure K-Fold Cross-Validation

from sklearn.model_selection import KFold, train_test_split

# K-fold configuration
K_FOLDS = 5
EPOCHS_BASE_MODELS = 200  # Increased from 30
EPOCHS_ENSEMBLE = 50      # Increased from 15
TEST_SPLIT = 0.2          # 20% of data for final test set

print(f"✅ K-Fold Cross-Validation configured with {K_FOLDS} folds")
print(f"✅ Epochs for base models: {EPOCHS_BASE_MODELS}")
print(f"✅ Epochs for ensemble model: {EPOCHS_ENSEMBLE}")

In [None]:
# Step 8: Implement K-Fold Cross-Validation Split

import shutil
import numpy as np
from sklearn.model_selection import KFold, train_test_split

# Dictionary to track class samples for each split
split_counts = {'train': {}, 'val': {}, 'test': {}}
fold_counts = []

# Create split directories
for fold in range(K_FOLDS):
    os.makedirs(os.path.join(SPLITS_DIR, f"fold_{fold}", "train"), exist_ok=True)
    os.makedirs(os.path.join(SPLITS_DIR, f"fold_{fold}", "val"), exist_ok=True)

# Process each class
for class_name in os.listdir(FEATURES_DIR):
    class_path = os.path.join(FEATURES_DIR, class_name)
    if not os.path.isdir(class_path):
        continue
    
    files = [f for f in os.listdir(class_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
    
    if len(files) < K_FOLDS + 1:  # Need at least one sample per fold plus test
        print(f"⚠️ Warning: Class '{class_name}' has only {len(files)} samples. Skipping.")
        continue
    
    # First split out the test set
    train_val_files, test_files = train_test_split(files, test_size=TEST_SPLIT, random_state=42)
    
    # Create test directory for this class
    test_class_dir = os.path.join(SPLITS_DIR, "test", class_name)
    os.makedirs(test_class_dir, exist_ok=True)
    
    # Copy test files
    for f in test_files:
        shutil.copy2(os.path.join(class_path, f), os.path.join(test_class_dir, f))
    
    # Track test counts
    split_counts['test'][class_name] = len(test_files)
    
    # Implement K-fold for the remaining data
    kf = KFold(n_splits=K_FOLDS, shuffle=True, random_state=42)
    
    # Convert to numpy array for indexing
    train_val_files = np.array(train_val_files)
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(train_val_files)):
        train_files = train_val_files[train_idx]
        val_files = train_val_files[val_idx]
        
        # Create directories for this fold
        fold_train_class_dir = os.path.join(SPLITS_DIR, f"fold_{fold}", "train", class_name)
        fold_val_class_dir = os.path.join(SPLITS_DIR, f"fold_{fold}", "val", class_name)
        
        os.makedirs(fold_train_class_dir, exist_ok=True)
        os.makedirs(fold_val_class_dir, exist_ok=True)
        
        # Copy train files
        for f in train_files:
            shutil.copy2(os.path.join(class_path, f), os.path.join(fold_train_class_dir, f))
            
        # Copy val files
        for f in val_files:
            shutil.copy2(os.path.join(class_path, f), os.path.join(fold_val_class_dir, f))
        
        # Track counts for this fold
        if fold == 0:  # Initialize dictionary for first fold
            split_counts['train'][class_name] = len(train_files)
            split_counts['val'][class_name] = len(val_files)
        else:
            # For other folds, we'll store separately
            if fold >= len(fold_counts):
                fold_counts.append({'train': {}, 'val': {}})
            
            fold_counts[fold-1]['train'][class_name] = len(train_files)
            fold_counts[fold-1]['val'][class_name] = len(val_files)

# Visualize class distribution across splits (for first fold and test)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Create DataFrame for visualization
split_df = pd.DataFrame({
    'Train (Fold 0)': pd.Series(split_counts['train']),
    'Validation (Fold 0)': pd.Series(split_counts['val']),
    'Test': pd.Series(split_counts['test'])
}).fillna(0).astype(int)

# Plot
plt.figure(figsize=(12, 6))
split_df.plot(kind='bar', figsize=(12, 6))
plt.title('Class Distribution Across Splits (Fold 0 shown)')
plt.xlabel('Class')
plt.ylabel('Number of Images')
plt.xticks(rotation=45)
plt.tight_layout()

# Save visualization
plt.savefig(os.path.join(VIS_DIR, 'class_distribution_fold0.png'), dpi=300, bbox_inches='tight')
plt.show()

# Save split statistics
split_df.to_csv(os.path.join(METRICS_DIR, 'split_statistics.csv'))

# Calculate total number of samples
train_total = sum(sum(cls.values()) for cls in [split_counts['train']] + [fc['train'] for fc in fold_counts])
val_total = sum(sum(cls.values()) for cls in [split_counts['val']] + [fc['val'] for fc in fold_counts])
test_total = sum(split_counts['test'].values())

print(f"✅ K-Fold Cross-Validation split completed with {K_FOLDS} folds")
print(f"Training samples: ~{train_total // K_FOLDS} per fold")
print(f"Validation samples: ~{val_total // K_FOLDS} per fold")
print(f"Test samples: {test_total}")

In [None]:
# Step 9: Define Dataset and Data Loaders with K-fold support

from torch.utils.data import Dataset, WeightedRandomSampler, DataLoader
from torchvision import transforms
from collections import defaultdict
from PIL import Image

class WasteRegionDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """Dataset for waste regions with class balancing via sample weights"""
        self.samples = []
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.03),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.label2id = {}
        class_counts = defaultdict(int)
        classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.label2id = {c: i for i, c in enumerate(classes)}
        
        for c in classes:
            class_dir = os.path.join(root_dir, c)
            if not os.path.isdir(class_dir):
                continue
            for f in os.listdir(class_dir):
                if f.lower().endswith(('.jpg', '.png', '.jpeg')):
                    self.samples.append((os.path.join(class_dir, f), self.label2id[c]))
                    class_counts[self.label2id[c]] += 1
        
        # Create sample weights for balanced sampling
        if class_counts:  # Only if we have some samples
            self.sample_weights = [1.0 / class_counts[label] for _, label in self.samples]
        else:
            self.sample_weights = []

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img, label

# Create test transform with only deterministic operations
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Function to create dataloaders for a specific fold
def create_fold_dataloaders(fold_idx, batch_size=16):
    """Create train and validation dataloaders for the specified fold"""
    # Paths for this fold
    fold_train_dir = os.path.join(SPLITS_DIR, f"fold_{fold_idx}", "train")
    fold_val_dir = os.path.join(SPLITS_DIR, f"fold_{fold_idx}", "val")
    
    # Create datasets
    fold_train_ds = WasteRegionDataset(fold_train_dir)
    fold_val_ds = WasteRegionDataset(fold_val_dir, transform=test_transform)
    
    # Create dataloaders with weighted sampling for training
    if fold_train_ds.sample_weights:  # Check if we have samples
        train_sampler = WeightedRandomSampler(
            fold_train_ds.sample_weights, 
            len(fold_train_ds.sample_weights), 
            replacement=True
        )
        fold_train_loader = DataLoader(
            fold_train_ds, batch_size=batch_size, 
            sampler=train_sampler, num_workers=2
        )
    else:
        print(f"⚠️ Warning: No training samples found for fold {fold_idx}!")
        fold_train_loader = DataLoader(
            fold_train_ds, batch_size=batch_size, 
            shuffle=True, num_workers=2
        )
    
    fold_val_loader = DataLoader(
        fold_val_ds, batch_size=batch_size, 
        shuffle=False, num_workers=2
    )
    
    return fold_train_loader, fold_val_loader, fold_train_ds.label2id

# Create test dataloader (common for all folds)
test_ds = WasteRegionDataset(os.path.join(SPLITS_DIR, "test"), transform=test_transform)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=2)

# Get number of classes from test dataset
num_classes = len(test_ds.label2id)
class_names = [k for k, v in sorted(test_ds.label2id.items(), key=lambda x: x[1])]

print(f"✅ Dataset and DataLoader configuration for K-fold completed")
print(f"Number of classes: {num_classes}")
print(f"Test samples: {len(test_ds)}")
print("\nClass names:")
for i, name in enumerate(class_names):
    print(f"  - Class {i}: {name}")

In [None]:
# Step 10: Define the Adaptive Weighted Ensemble Model

import torch.nn as nn
import torch.nn.functional as F
from timm import create_model

# Base model loader function
def get_base_model(model_name, num_classes):
    """Load a pretrained model and modify for waste classification"""
    model = create_model(model_name, pretrained=True, num_classes=num_classes)
    return model

# Adaptive Weighted Ensemble Class
class AdaptiveWeightedEnsemble(nn.Module):
    def __init__(self, models, num_classes, base_weights=None):
        """
        Adaptive Weighted Ensemble model that dynamically adjusts weights based on input
        
        Parameters:
        -----------
        models : list of nn.Module
            List of pre-trained models to ensemble
        num_classes : int
            Number of output classes
        base_weights : list, optional
            Initial weights for each model, normalized internally
        """
        super().__init__()
        self.models = nn.ModuleList(models)
        self.n_models = len(models)
        self.n_classes = num_classes
        
        # Initialize base weights
        if base_weights is None:
            self.base_weights = nn.Parameter(torch.ones(self.n_models) / self.n_models)
        else:
            weights = torch.tensor(base_weights, dtype=torch.float32)
            weights = weights / weights.sum()  # Normalize
            self.base_weights = nn.Parameter(weights)
        
        # Water condition adaptation network - learns to adapt weights based on input
        self.condition_encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(16, self.n_models),
            nn.Softmax(dim=0)
        )
    
    def forward(self, x):
        """
        Forward pass computing adaptive weighted ensemble prediction
        
        Parameters:
        -----------
        x : torch.Tensor
            Input image batch of shape [batch_size, channels, height, width]
            
        Returns:
        --------
        torch.Tensor
            Class probabilities of shape [batch_size, n_classes]
        """
        batch_size = x.size(0)
        
        # Get condition-specific weights
        condition_weights = self.condition_encoder(x)  # Shape: [batch_size, n_models]
        
        # Get predictions from all models
        all_outputs = []
        for model in self.models:
            with torch.no_grad():  # Base models are frozen
                model_output = F.softmax(model(x), dim=1)  # [batch_size, n_classes]
                all_outputs.append(model_output)
        
        # Stack all outputs: [n_models, batch_size, n_classes]
        all_outputs = torch.stack(all_outputs)
        
        # Prepare weights for broadcasting - combine base weights with condition weights
        # Reshape to [n_models, batch_size, 1]
        weights = self.base_weights.view(-1, 1, 1) * condition_weights.transpose(0, 1).unsqueeze(-1)
        
        # Normalize weights across models dimension
        weights = weights / weights.sum(dim=0, keepdim=True)
        
        # Apply weights and sum: [batch_size, n_classes]
        weighted_outputs = (all_outputs * weights).sum(dim=0)
        
        return weighted_outputs

print("✅ Adaptive Weighted Ensemble model defined")

In [None]:
# Step 11: Define Training Functions for Individual Models and Ensemble

import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn.utils.class_weight import compute_class_weight
from transformers import get_cosine_schedule_with_warmup
import pandas as pd
import matplotlib.pyplot as plt

def train_single_model(model_name, num_classes, train_loader, val_loader, device, 
                       epochs=100, patience=15, use_mixup=True, label_smoothing=0.1):
    """
    Train a single model and return the trained model and its validation accuracy
    
    Parameters:
    -----------
    model_name : str
        Name of the model architecture to train
    num_classes : int
        Number of output classes
    train_loader : DataLoader
        DataLoader for training data
    val_loader : DataLoader
        DataLoader for validation data
    device : torch.device
        Device to train on (cuda or cpu)
    epochs : int, optional
        Maximum number of epochs to train for
    patience : int, optional
        Early stopping patience (number of epochs without improvement)
    use_mixup : bool, optional
        Whether to use mixup data augmentation
    label_smoothing : float, optional
        Label smoothing factor for loss function
        
    Returns:
    --------
    tuple
        (trained model, best validation accuracy, training history)
    """
    model = get_base_model(model_name, num_classes).to(device)
    
    # Compute class weights for imbalanced dataset
    labels = [label for _, label in train_loader.dataset.samples]
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing)
    optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-3)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=3 * len(train_loader),
        num_training_steps=len(train_loader) * epochs
    )
    
    def mixup_data(x, y, alpha=0.4):
        """Applies mixup augmentation to a batch"""
        lam = np.random.beta(alpha, alpha)
        index = torch.randperm(x.size(0)).to(x.device)
        mixed_x = lam * x + (1 - lam) * x[index]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam
    
    def mixup_criterion(criterion, pred, y_a, y_b, lam):
        """Applies mixup to the loss calculation"""
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
    
    best_val_loss = float('inf')
    best_val_acc = 0.0
    no_improve_epochs = 0
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
    checkpoint_path = os.path.join(MODELS_DIR, f"{model_name}_best.pth")
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            if use_mixup:
                inputs, y_a, y_b, lam = mixup_data(inputs, labels)
                outputs = model(inputs)
                loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
            else:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
            
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_loss /= len(val_loader)
        val_acc = correct / total if total > 0 else 0
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}, Val acc: {val_acc:.4f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_loss = val_loss
            no_improve_epochs = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
            }, checkpoint_path)
            print(f"New best model saved with val_acc: {val_acc:.4f}")
        else:
            no_improve_epochs += 1
        
        # Early stopping
        if no_improve_epochs >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    # Load best model
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Save training history
    pd.DataFrame(history).to_csv(os.path.join(LOGS_DIR, f"{model_name}_history.csv"), index=False)
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['val_acc'], label='Val Accuracy')
    plt.title('Accuracy Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(VIS_DIR, f"{model_name}_training_curves.png"), dpi=300)
    plt.show()
    
    return model, best_val_acc, history

def train_ensemble(ensemble, train_loader, val_loader, device, epochs=50, patience=10, lr=0.001):
    """
    Train the adaptive ensemble (condition encoder and weights only)
    
    Parameters:
    -----------
    ensemble : AdaptiveWeightedEnsemble
        The ensemble model to train
    train_loader : DataLoader
        DataLoader for training data
    val_loader : DataLoader
        DataLoader for validation data
    device : torch.device
        Device to train on
    epochs : int, optional
        Maximum number of epochs to train for
    patience : int, optional
        Early stopping patience
    lr : float, optional
        Learning rate
        
    Returns:
    --------
    tuple
        (trained ensemble, best validation accuracy, training history)
    """
    ensemble = ensemble.to(device)
    
    # Freeze the base models
    for model in ensemble.models:
        for param in model.parameters():
            param.requires_grad = False
    
    # Only optimize the condition encoder and base weights
    params_to_update = list(ensemble.condition_encoder.parameters()) + [ensemble.base_weights]
    optimizer = optim.Adam(params_to_update, lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    best_val_loss = float('inf')
    best_val_acc = 0.0
    no_improve_epochs = 0
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
    
    checkpoint_path = os.path.join(ENSEMBLE_DIR, "adaptive_ensemble_best.pth")
    
    for epoch in range(epochs):
        # Training
        ensemble.train()
        train_loss = 0.0
        
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            outputs = ensemble(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        # Validation
        ensemble.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = ensemble(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_loss /= len(val_loader)
        val_acc = correct / total if total > 0 else 0
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}, Val acc: {val_acc:.4f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_loss = val_loss
            no_improve_epochs = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': ensemble.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
            }, checkpoint_path)
            print(f"New best ensemble saved with val_acc: {val_acc:.4f}")
        else:
            no_improve_epochs += 1
        
        # Early stopping
        if no_improve_epochs >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    # Load best model
    checkpoint = torch.load(checkpoint_path)
    ensemble.load_state_dict(checkpoint['model_state_dict'])
    
    # Save training history
    pd.DataFrame(history).to_csv(os.path.join(LOGS_DIR, "adaptive_ensemble_history.csv"), index=False)
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['val_acc'], label='Val Accuracy')
    plt.title('Accuracy Curve')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(ENSEMBLE_DIR, "adaptive_ensemble_training_curves.png"), dpi=300)
    plt.show()
    
    return ensemble, best_val_acc, history

print("✅ Training functions defined")

In [None]:
# Step 12: Train Base Models with K-Fold Cross-Validation

# Define the model architectures to use
model_architectures = [
    'efficientnet_b0',
    'mobilenetv3_small_100', 
    'resnet18',
    'mobilenetv2_100',
    'efficientnet_b1'
]

# Dictionary to store trained models and accuracies for each fold
fold_models = {arch: [] for arch in model_architectures}
fold_val_accuracies = {arch: [] for arch in model_architectures}
model_histories = {arch: [] for arch in model_architectures}

# Train models for each fold
for fold_idx in range(K_FOLDS):
    print(f"\n{'='*50}")
    print(f"STARTING FOLD {fold_idx+1}/{K_FOLDS}")
    print(f"{'='*50}")
    
    # Create dataloaders for this fold
    train_loader, val_loader, label2id = create_fold_dataloaders(fold_idx)
    
    for arch in model_architectures:
        print(f"\n🔄 Training model: {arch} for fold {fold_idx+1}/{K_FOLDS}")
        
        # Check if model already exists for this fold
        checkpoint_path = os.path.join(MODELS_DIR, f"{arch}_fold{fold_idx}_best.pth")
        
        if os.path.exists(checkpoint_path):
            print(f"Loading existing model from {checkpoint_path}")
            model = get_base_model(arch, num_classes).to(device)
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            val_acc = checkpoint['val_acc']
            
            # Load history if exists
            history_path = os.path.join(LOGS_DIR, f"{arch}_fold{fold_idx}_history.csv")
            if os.path.exists(history_path):
                history = pd.read_csv(history_path).to_dict('list')
            else:
                history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
        else:
            model, val_acc, history = train_single_model(
                model_name=arch,
                num_classes=num_classes,
                train_loader=train_loader,
                val_loader=val_loader,
                device=device,
                epochs=EPOCHS_BASE_MODELS,
                patience=20,  # Increased patience
                use_mixup=True,
                label_smoothing=0.1
            )
            
            # Save model with fold information
            torch.save({
                'model_state_dict': model.state_dict(),
                'val_acc': val_acc,
                'fold': fold_idx
            }, checkpoint_path)
        
        # Store the model and accuracy for this fold
        fold_models[arch].append(model)
        fold_val_accuracies[arch].append(val_acc)
        model_histories[arch].append(history)
        
        print(f"✅ Model {arch} - Fold {fold_idx+1} - Validation accuracy: {val_acc:.4f}")

# Calculate average validation accuracy across folds
avg_val_accuracies = {arch: np.mean(accs) for arch, accs in fold_val_accuracies.items()}

# Create summary dataframe
summary_rows = []
for arch in model_architectures:
    for fold_idx, val_acc in enumerate(fold_val_accuracies[arch]):
        summary_rows.append({
            'Architecture': arch,
            'Fold': fold_idx,
            'Validation_Accuracy': val_acc
        })

model_summary = pd.DataFrame(summary_rows)
model_summary.to_csv(os.path.join(METRICS_DIR, 'k_fold_base_models_summary.csv'), index=False)

# Add average row for each architecture
avg_summary = pd.DataFrame([{
    'Architecture': arch,
    'Fold': 'Average',
    'Validation_Accuracy': avg_val_accuracies[arch]
} for arch in model_architectures])

model_summary = pd.concat([model_summary, avg_summary])
model_summary.to_csv(os.path.join(METRICS_DIR, 'k_fold_base_models_summary.csv'), index=False)

print("\n✅ K-Fold Cross-Validation training complete!")
print("\nAverage validation accuracies across folds:")
for arch, acc in avg_val_accuracies.items():
    print(f"{arch}: {acc:.4f}")

# Visualize average validation accuracies
plt.figure(figsize=(10, 6))
bars = plt.bar(model_architectures, [avg_val_accuracies[arch] for arch in model_architectures])
plt.title('Average Base Model Validation Accuracy Across K-Folds')
plt.xlabel('Model Architecture')
plt.ylabel('Average Validation Accuracy')
plt.xticks(rotation=45)

# Add value annotations
for bar in bars:
    height = bar.get_height()
    plt.annotate(f'{height:.4f}',
               xy=(bar.get_x() + bar.get_width() / 2, height),
               xytext=(0, 3),
               textcoords="offset points",
               ha='center', va='bottom')

plt.tight_layout()
plt.savefig(os.path.join(VIS_DIR, 'k_fold_base_model_accuracies.png'), dpi=300)
plt.show()

# Select the best model from each architecture (based on validation accuracy)
best_models = {}
best_fold_idx = {}

for arch in model_architectures:
    best_fold = np.argmax(fold_val_accuracies[arch])
    best_models[arch] = fold_models[arch][best_fold]
    best_fold_idx[arch] = best_fold
    print(f"Best {arch} model from fold {best_fold+1} with accuracy {fold_val_accuracies[arch][best_fold]:.4f}")

# Use the best models for the ensemble
trained_models = [best_models[arch] for arch in model_architectures]
val_accuracies = [fold_val_accuracies[arch][best_fold_idx[arch]] for arch in model_architectures]

In [None]:
# Step 13: Create and Train Adaptive Weighted Ensemble with K-Fold

# Normalize validation accuracies for initial weights
base_weights = np.array(val_accuracies)
base_weights = base_weights / base_weights.sum()

print(f"Creating ensemble with {len(trained_models)} models")
print(f"Initial weights based on best validation accuracy: {base_weights}")

# Create K-Fold ensembles
fold_ensembles = []
fold_ensemble_accs = []

for fold_idx in range(K_FOLDS):
    print(f"\n{'='*50}")
    print(f"TRAINING ENSEMBLE FOR FOLD {fold_idx+1}/{K_FOLDS}")
    print(f"{'='*50}")
    
    # Create dataloaders for this fold
    train_loader, val_loader, _ = create_fold_dataloaders(fold_idx)
    
    # Create the ensemble with the best models from each architecture
    ensemble = AdaptiveWeightedEnsemble(trained_models, num_classes, base_weights=base_weights)
    
    # Check if ensemble already exists
    checkpoint_path = os.path.join(ENSEMBLE_DIR, f"adaptive_ensemble_fold{fold_idx}_best.pth")
    
    if os.path.exists(checkpoint_path):
        print(f"Loading existing ensemble from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        ensemble.load_state_dict(checkpoint['model_state_dict'])
        ensemble_val_acc = checkpoint['val_acc']
    else:
        # Train the adaptive weights
        ensemble, ensemble_val_acc, ensemble_history = train_ensemble(
            ensemble=ensemble,
            train_loader=train_loader,
            val_loader=val_loader,
            device=device,
            epochs=EPOCHS_ENSEMBLE,
            patience=10,  # Increased patience
            lr=0.001
        )
        
        # Save the final model for this fold
        torch.save({
            'model_state_dict': ensemble.state_dict(),
            'base_weights': ensemble.base_weights.detach().cpu().numpy(),
            'n_models': ensemble.n_models,
            'n_classes': ensemble.n_classes,
            'model_architectures': model_architectures,
            'val_accuracy': ensemble_val_acc,
            'fold': fold_idx
        }, checkpoint_path)
    
    fold_ensembles.append(ensemble)
    fold_ensemble_accs.append(ensemble_val_acc)
    print(f"✅ Adaptive Weighted Ensemble for fold {fold_idx+1} - Validation accuracy: {ensemble_val_acc:.4f}")

# Calculate average ensemble accuracy
avg_ensemble_acc = np.mean(fold_ensemble_accs)
print(f"\nAverage ensemble validation accuracy: {avg_ensemble_acc:.4f}")

# Select the best ensemble
best_ensemble_idx = np.argmax(fold_ensemble_accs)
ensemble = fold_ensembles[best_ensemble_idx]
ensemble_val_acc = fold_ensemble_accs[best_ensemble_idx]

print(f"Best ensemble from fold {best_ensemble_idx+1} with validation accuracy: {ensemble_val_acc:.4f}")

# Save the final best ensemble
final_save_path = os.path.join(ENSEMBLE_DIR, "adaptive_ensemble_final_best.pth")
torch.save({
    'model_state_dict': ensemble.state_dict(),
    'base_weights': ensemble.base_weights.detach().cpu().numpy(),
    'n_models': ensemble.n_models,
    'n_classes': ensemble.n_classes,
    'model_architectures': model_architectures,
    'val_accuracy': ensemble_val_acc,
    'fold': best_ensemble_idx
}, final_save_path)

print(f"Best ensemble saved to {final_save_path}")

# Visualize ensemble accuracies across folds
plt.figure(figsize=(8, 6))
x = np.arange(K_FOLDS)
plt.bar(x, fold_ensemble_accs)
plt.axhline(y=avg_ensemble_acc, color='r', linestyle='--', label=f'Average: {avg_ensemble_acc:.4f}')
plt.xlabel('Fold')
plt.ylabel('Validation Accuracy')
plt.title('Ensemble Validation Accuracy Across Folds')
plt.xticks(x, [f"Fold {i+1}" for i in range(K_FOLDS)])
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(os.path.join(ENSEMBLE_DIR, 'ensemble_k_fold_accuracies.png'), dpi=300)
plt.show()

In [None]:
# Step 14: Evaluate Base Models and Ensemble on Test Set

from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import json

def evaluate_model(model, dataloader, device):
    """Evaluate model performance on the given dataloader"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate accuracy
    accuracy = (np.array(all_preds) == np.array(all_labels)).mean()
    
    return accuracy, all_preds, all_labels

def plot_confusion_matrix(true_labels, pred_labels, class_names, save_path=None, title="Confusion Matrix"):
    """Plot confusion matrix with proper labels"""
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    # Also plot normalized confusion matrix
    plt.figure(figsize=(10, 8))
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='YlOrRd', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f"{title} (Normalized)")
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    
    if save_path:
        base, ext = os.path.splitext(save_path)
        norm_path = f"{base}_normalized{ext}"
        plt.savefig(norm_path, dpi=300, bbox_inches='tight')
    
    plt.show()

def save_classification_report(true_labels, pred_labels, class_names, save_path=None):
    """Generate and save classification report"""
    report = classification_report(true_labels, pred_labels, 
                                  target_names=class_names, output_dict=True)
    
    if save_path:
        with open(save_path, 'w') as f:
            json.dump(report, f, indent=4)
    
    return report

# Evaluate each base model
base_model_accuracies = []
for i, arch in enumerate(model_architectures):
    model = best_models[arch]
    model_name = arch
    print(f"\n📊 Evaluating base model: {model_name}")
    accuracy, preds, labels = evaluate_model(model, test_loader, device)
    base_model_accuracies.append(accuracy)
    
    # Save results
    plot_confusion_matrix(
        labels, preds, class_names, 
        save_path=os.path.join(VIS_DIR, f"{model_name}_confusion_matrix.png"),
        title=f"Confusion Matrix - {model_name}"
    )
    
    report = save_classification_report(
        labels, preds, class_names,
        save_path=os.path.join(METRICS_DIR, f"{model_name}_classification_report.json")
    )
    
    print(f"Test Accuracy: {accuracy:.4f}")
    print("Classification Report:")
    print(classification_report(labels, preds, target_names=class_names))

# Evaluate the ensemble
print("\n📊 Evaluating Adaptive Weighted Ensemble")
ensemble_acc, ensemble_preds, ensemble_labels = evaluate_model(ensemble, test_loader, device)

# Save ensemble results
plot_confusion_matrix(
    ensemble_labels, ensemble_preds, class_names, 
    save_path=os.path.join(ENSEMBLE_DIR, "ensemble_confusion_matrix.png"),
    title="Confusion Matrix - Adaptive Ensemble"
)

ensemble_report = save_classification_report(
    ensemble_labels, ensemble_preds, class_names,
    save_path=os.path.join(ENSEMBLE_DIR, "ensemble_classification_report.json")
)

print(f"Ensemble Test Accuracy: {ensemble_acc:.4f}")
print("Ensemble Classification Report:")
print(classification_report(ensemble_labels, ensemble_preds, target_names=class_names))

# Compare models
all_accuracies = base_model_accuracies + [ensemble_acc]
all_model_names = model_architectures + ['Adaptive Ensemble']

plt.figure(figsize=(12, 6))
bars = plt.bar(all_model_names, all_accuracies)
plt.title('Model Comparison - Test Accuracy')
plt.xlabel('Model')
plt.ylabel('Accuracy')
plt.xticks(rotation=45)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()

# Add value annotations
for bar in bars:
    height = bar.get_height()
    plt.annotate(f'{height:.4f}',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha='center', va='bottom')

# Highlight the best model
best_idx = np.argmax(all_accuracies)
bars[best_idx].set_color('green')

plt.savefig(os.path.join(VIS_DIR, "model_comparison.png"), dpi=300)
plt.show()

# Calculate improvement percentage
best_base_acc = max(base_model_accuracies)
improvement = (ensemble_acc - best_base_acc) / best_base_acc * 100

print(f"\nBest base model accuracy: {best_base_acc:.4f}")
print(f"Ensemble accuracy: {ensemble_acc:.4f}")
print(f"Improvement: {improvement:.2f}%")

# Save comparison results
comparison_results = {
    'base_models': {model_architectures[i]: base_model_accuracies[i] for i in range(len(model_architectures))},
    'ensemble': ensemble_acc,
    'improvement': improvement
}

with open(os.path.join(METRICS_DIR, 'model_comparison.json'), 'w') as f:
    json.dump(comparison_results, f, indent=4)

In [None]:
# Step 15: Visualize Adaptive Weights

def visualize_adaptive_weights(ensemble, test_loader, device, num_samples=6, save_dir=None):
    """Visualize how the model adapts weights for different water conditions"""
    ensemble.eval()
    # Get a batch of images
    images, labels = next(iter(test_loader))
    images = images[:num_samples].to(device)
    labels = labels[:num_samples].to(device)
    
    # Get base weights
    base_weights = ensemble.base_weights.detach().cpu().numpy()
    
    # Get condition-specific weights
    with torch.no_grad():
        condition_weights = ensemble.condition_encoder(images).cpu().numpy()
    
    # Calculate final weights
    final_weights = []
    for i in range(num_samples):
        sample_weights = base_weights * condition_weights[i]
        sample_weights = sample_weights / sample_weights.sum()
        final_weights.append(sample_weights)
    
    # Plot
    fig, axes = plt.subplots(num_samples, 2, figsize=(14, 3*num_samples))
    model_names = [f"Model {i+1}" for i in range(len(ensemble.models))]
    
    for i in range(num_samples):
        # Plot the image
        img = images[i].cpu().permute(1, 2, 0).numpy()
        # Normalize for display
        img = (img - img.min()) / (img.max() - img.min())
        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f"Sample {i+1}")
        axes[i, 0].axis('off')
        
        # Plot the weights
        bars = axes[i, 1].bar(model_names, final_weights[i])
        axes[i, 1].set_ylim(0, 0.5)
        axes[i, 1].set_title(f"Model Weights for Sample {i+1}")
        
        # Add value annotations
        for bar in bars:
            height = bar.get_height()
            axes[i, 1].annotate(f'{height:.3f}',
                               xy=(bar.get_x() + bar.get_width() / 2, height),
                               xytext=(0, 3),
                               textcoords="offset points",
                               ha='center', va='bottom')
    
    plt.tight_layout()
    
    if save_dir:
        save_path = os.path.join(save_dir, "adaptive_weights_visualization.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    # Also visualize the learned base weights
    plt.figure(figsize=(10, 6))
    bars = plt.bar(model_architectures, base_weights)
    plt.title('Base Model Weights in Adaptive Ensemble')
    plt.xlabel('Model Architecture')
    plt.ylabel('Weight')
    plt.xticks(rotation=45)
    
    for bar in bars:
        height = bar.get_height()
        plt.annotate(f'{height:.3f}',
                   xy=(bar.get_x() + bar.get_width() / 2, height),
                   xytext=(0, 3),
                   textcoords="offset points",
                   ha='center', va='bottom')
    
    plt.tight_layout()
    
    if save_dir:
        save_path = os.path.join(save_dir, "base_weights.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

print("Visualizing adaptive weights for different water conditions...")
visualize_adaptive_weights(
    ensemble=ensemble,
    test_loader=test_loader,
    device=device,
    num_samples=6,
    save_dir=ENSEMBLE_DIR
)

In [None]:
# Step 16: Ablation Study for Adaptive Weighting

def run_ablation_study(models, ensemble, test_loader, device, class_names):
    """Compare different ensemble methods: majority vote, average, and adaptive weights"""
    # Get all test samples
    all_inputs = []
    all_labels = []
    
    for inputs, labels in tqdm(test_loader, desc="Collecting test data"):
        all_inputs.append(inputs)
        all_labels.append(labels)
    
    test_inputs = torch.cat(all_inputs)
    test_labels = torch.cat(all_labels)
    
    # Results containers
    results = {
        'per_model': [],
        'majority_vote': None,
        'simple_average': None,
        'weighted_average': None,
        'adaptive_weights': None
    }
    
    # Get predictions from individual models
    print("Evaluating individual models...")
    all_model_outputs = []
    
    with torch.no_grad():
        for i, model in enumerate(models):
            model_name = model_architectures[i]
            print(f"  Processing {model_name}...")
            
            # Get model predictions
            model_outputs = []
            for inputs in tqdm(all_inputs, desc=f"Processing {model_name}", leave=False):
                inputs = inputs.to(device)
                outputs = model(inputs)
                model_outputs.append(outputs)
            
            model_outputs = torch.cat(model_outputs)
            all_model_outputs.append(model_outputs)
            
            # Calculate accuracy
            preds = model_outputs.argmax(dim=1)
            accuracy = (preds == test_labels.to(device)).float().mean().item()
            results['per_model'].append(accuracy)
    
    # Stack all outputs: [n_models, n_samples, n_classes]
    all_model_outputs = torch.stack(all_model_outputs)
    
    # 1. Majority Voting
    print("Evaluating majority voting...")
    model_preds = all_model_outputs.argmax(dim=2)  # [n_models, n_samples]
    vote_counts = torch.zeros(test_labels.size(0), num_classes, device=device)
    
    for i in range(len(models)):
        for j in range(test_labels.size(0)):
            vote_counts[j, model_preds[i, j]] += 1
    
    majority_preds = vote_counts.argmax(dim=1)
    majority_acc = (majority_preds == test_labels.to(device)).float().mean().item()
    results['majority_vote'] = majority_acc
    
    # 2. Simple Average (Equal weights)
    print("Evaluating simple average...")
    avg_outputs = all_model_outputs.mean(dim=0)  # [n_samples, n_classes]
    avg_preds = avg_outputs.argmax(dim=1)
    avg_acc = (avg_preds == test_labels.to(device)).float().mean().item()
    results['simple_average'] = avg_acc
    
    # 3. Weighted Average (Based on validation accuracy)
    print("Evaluating weighted average...")
    val_weights = torch.tensor(val_accuracies, device=device)
    val_weights = val_weights / val_weights.sum()
    
    weighted_outputs = (all_model_outputs * val_weights.view(-1, 1, 1)).sum(dim=0)
    weighted_preds = weighted_outputs.argmax(dim=1)
    weighted_acc = (weighted_preds == test_labels.to(device)).float().mean().item()
    results['weighted_average'] = weighted_acc
    
    # 4. Adaptive Weighted Ensemble
    print("Evaluating adaptive weighted ensemble...")
    with torch.no_grad():
        ensemble_outputs = []
        for inputs in tqdm(all_inputs, desc="Processing ensemble", leave=False):
            inputs = inputs.to(device)
            outputs = ensemble(inputs)
            ensemble_outputs.append(outputs)
        
        ensemble_outputs = torch.cat(ensemble_outputs)
        ensemble_preds = ensemble_outputs.argmax(dim=1)
        adaptive_acc = (ensemble_preds == test_labels.to(device)).float().mean().item()
        results['adaptive_weights'] = adaptive_acc
    
    # Visualize results
    ablation_methods = ['Individual Models (Avg)', 'Majority Vote', 'Simple Average', 
                      'Weighted Average', 'Adaptive Ensemble']
    ablation_accs = [sum(results['per_model'])/len(results['per_model']), 
                   results['majority_vote'], 
                   results['simple_average'],
                   results['weighted_average'], 
                   results['adaptive_weights']]
    
    plt.figure(figsize=(12, 6))
    bars = plt.bar(ablation_methods, ablation_accs, color=['lightgray', 'lightblue', 
                                                        'lightgreen', 'orange', 'red'])
    
    # Add value annotations
    for bar in bars:
        height = bar.get_height()
        plt.annotate(f'{height:.4f}',
                   xy=(bar.get_x() + bar.get_width() / 2, height),
                   xytext=(0, 3),
                   textcoords="offset points",
                   ha='center', va='bottom')
    
    plt.title('Ablation Study: Comparing Ensemble Methods')
    plt.xlabel('Method')
    plt.ylabel('Accuracy')
    plt.ylim([min(ablation_accs) - 0.05, max(ablation_accs) + 0.05])
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    plt.savefig(os.path.join(ENSEMBLE_DIR, 'ablation_study.png'), dpi=300)
    plt.show()
    
    # Save ablation results
    ablation_df = pd.DataFrame({
        'Method': ablation_methods,
        'Accuracy': ablation_accs
    })
    ablation_df.to_csv(os.path.join(METRICS_DIR, 'ablation_results.csv'), index=False)
    
    # Also save individual model results
    model_df = pd.DataFrame({
        'Model': model_architectures,
        'Accuracy': results['per_model']
    })
    model_df.to_csv(os.path.join(METRICS_DIR, 'individual_model_results.csv'), index=False)
    
    print("✅ Ablation study completed")
    print("\nComparison of ensemble methods:")
    for method, acc in zip(ablation_methods, ablation_accs):
        print(f"{method}: {acc:.4f}")

# Run ablation study
run_ablation_study(
    models=trained_models,
    ensemble=ensemble,
    test_loader=test_loader,
    device=device,
    class_names=class_names
)

In [None]:
# Step 17: Per-Class Performance Analysis

from sklearn.metrics import precision_recall_fscore_support

def analyze_per_class_performance(models, ensemble, test_loader, device, class_names):
    """Generate and visualize per-class performance metrics"""
    # Get predictions from all models
    model_results = []
    ensemble_preds = None
    true_labels = None
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Evaluating models"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Get individual model predictions
            batch_preds = []
            for model in models:
                outputs = model(inputs)
                preds = outputs.argmax(dim=1)
                batch_preds.append(preds.cpu().numpy())
            
            # Get ensemble predictions
            ensemble_outputs = ensemble(inputs)
            ensemble_batch_preds = ensemble_outputs.argmax(dim=1).cpu().numpy()
            
            # Store results
            if ensemble_preds is None:
                ensemble_preds = ensemble_batch_preds
                true_labels = labels.cpu().numpy()
                model_results = [p for p in batch_preds]
            else:
                ensemble_preds = np.concatenate([ensemble_preds, ensemble_batch_preds])
                true_labels = np.concatenate([true_labels, labels.cpu().numpy()])
                for i, p in enumerate(batch_preds):
                    model_results[i] = np.concatenate([model_results[i], p])
    
    # Calculate per-class metrics
    all_precisions = []
    all_recalls = []
    all_f1s = []
    
    # Individual models
    for i, preds in enumerate(model_results):
        precision, recall, f1, _ = precision_recall_fscore_support(
            true_labels, preds, average=None, zero_division=0
        )
        all_precisions.append(precision)
        all_recalls.append(recall)
        all_f1s.append(f1)
    
    # Ensemble
    ens_precision, ens_recall, ens_f1, support = precision_recall_fscore_support(
        true_labels, ensemble_preds, average=None, zero_division=0
    )
    all_precisions.append(ens_precision)
    all_recalls.append(ens_recall)
    all_f1s.append(ens_f1)
    
    # Convert to arrays for easier plotting
    all_precisions = np.array(all_precisions)
    all_recalls = np.array(all_recalls)
    all_f1s = np.array(all_f1s)
    
    # Plotting
    metrics = ['Precision', 'Recall', 'F1-Score']
    all_metrics = [all_precisions, all_recalls, all_f1s]
    
    for metric_idx, (metric_name, metric_data) in enumerate(zip(metrics, all_metrics)):
        plt.figure(figsize=(12, 8))
        
        x = np.arange(len(class_names))
        width = 0.8 / (len(models) + 1)  # Bar width
        
        for i in range(len(models)):
            plt.bar(x + i*width - 0.4, metric_data[i], width, label=model_architectures[i])
        
        # Ensemble bars with distinctive color
        plt.bar(x + len(models)*width - 0.4, metric_data[-1], width, label='Adaptive Ensemble', color='red')
        
        plt.xlabel('Class')
        plt.ylabel(metric_name)
        plt.title(f'Per-Class {metric_name}')
        plt.xticks(x, class_names, rotation=45)
        plt.legend()
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        
        plt.savefig(os.path.join(METRICS_DIR, f'per_class_{metric_name.lower()}.png'), dpi=300)
        plt.show()
    
    # Additional visualization: Class support (number of samples)
    plt.figure(figsize=(10, 6))
    plt.bar(class_names, support)
    plt.xlabel('Class')
    plt.ylabel('Number of Test Samples')
    plt.title('Class Distribution in Test Set')
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(os.path.join(METRICS_DIR, 'test_class_distribution.png'), dpi=300)
    plt.show()
    
    # Save metrics as CSV
    results_df = pd.DataFrame()
    for i, model_name in enumerate(model_architectures + ['Adaptive Ensemble']):
        for cls_idx, cls_name in enumerate(class_names):
            results_df = pd.concat([results_df, pd.DataFrame({
                'Model': [model_name],
                'Class': [cls_name],
                'Precision': [all_precisions[i, cls_idx]],
                'Recall': [all_recalls[i, cls_idx]],
                'F1-Score': [all_f1s[i, cls_idx]],
                'Support': [support[cls_idx]]
            })])
    
    results_df.to_csv(os.path.join(METRICS_DIR, 'per_class_metrics.csv'), index=False)
    print("✅ Per-class performance analysis completed and saved")

# Run per-class analysis
analyze_per_class_performance(
    models=trained_models,
    ensemble=ensemble,
    test_loader=test_loader,
    device=device,
    class_names=class_names
)

In [None]:
# Step 18: Interpretability Analysis with Grad-CAM

from pytorch_grad_cam import GradCAM, ScoreCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

def get_last_conv_layer(model):
    """Find the last convolutional layer in a model"""
    target_layer = None
    # Common patterns for different architectures
    if hasattr(model, 'features') and hasattr(model.features[-1], 'weight'):
        # EfficientNet, MobileNet, etc.
        return model.features[-1]
    
    # For ResNet
    if hasattr(model, 'layer4'):
        return model.layer4[-1]
    
    # Generic search - find last conv layer
    for name, module in reversed(list(model.named_modules())):
        if isinstance(module, torch.nn.Conv2d):
            return module
    
    raise ValueError("Could not find a convolutional layer for CAM")

def visualize_gradcam_comparison(models, ensemble, test_loader, device, class_names, n_samples=5):
    """Visualize GradCAM for multiple models and the ensemble on same images"""
    # Get some test samples
    all_samples = []
    for inputs, labels in test_loader:
        batch_samples = [(img, label) for img, label in zip(inputs, labels)]
        all_samples.extend(batch_samples)
        if len(all_samples) >= n_samples:
            break
    
    samples = all_samples[:n_samples]
    
    for sample_idx, (image, true_label) in enumerate(samples):
        fig, axes = plt.subplots(2, len(models) + 1, figsize=(16, 8))
        
        # Original image
        img_np = image.permute(1, 2, 0).numpy()
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_np = img_np * std + mean
        img_np = np.clip(img_np, 0, 1)
        
        # Process each base model
        all_preds = []
        for i, model in enumerate(models):
            model.eval()
            
            # Get prediction
            input_tensor = image.unsqueeze(0).to(device)
            with torch.no_grad():
                output = model(input_tensor)
                pred = output.argmax(dim=1).item()
            all_preds.append(pred)
            
            # GradCAM
            target_layer = get_last_conv_layer(model)
            cam = GradCAM(model=model, target_layers=[target_layer])
            targets = [ClassifierOutputTarget(pred)]
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
            cam_image = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
            
            # Display original
            if sample_idx == 0:
                axes[0, i].set_title(f"{model_architectures[i]}\nPrediction: {class_names[pred]}")
            else:
                axes[0, i].set_title(f"Prediction: {class_names[pred]}")
            axes[0, i].imshow(img_np)
            axes[0, i].axis('off')
            
            # Display GradCAM
            axes[1, i].imshow(cam_image)
            axes[1, i].axis('off')
            axes[1, i].set_title("Attention Map")
        
        # Process ensemble
        ensemble.eval()
        input_tensor = image.unsqueeze(0).to(device)
        with torch.no_grad():
            output = ensemble(input_tensor)
            ensemble_pred = output.argmax(dim=1).item()
        
        # For ensemble, use ScoreCAM (doesn't need gradients)
        target_layer = ensemble.condition_encoder[0]  # First conv layer of condition encoder
        cam = ScoreCAM(model=ensemble, target_layers=[target_layer])
        targets = [ClassifierOutputTarget(ensemble_pred)]
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
        cam_image = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
        
        # Display ensemble results
        if sample_idx == 0:
            axes[0, -1].set_title(f"Adaptive Ensemble\nPrediction: {class_names[ensemble_pred]}")
        else:
            axes[0, -1].set_title(f"Prediction: {class_names[ensemble_pred]}")
        axes[0, -1].imshow(img_np)
        axes[0, -1].axis('off')
        axes[1, -1].imshow(cam_image)
        axes[1, -1].axis('off')
        axes[1, -1].set_title("Attention Map")
        
        # Global title
        correct_txt = "CORRECT" if ensemble_pred == true_label.item() else "INCORRECT"
        fig.suptitle(f"Sample {sample_idx+1} | True: {class_names[true_label.item()]} | Ensemble: {correct_txt}", 
                    fontsize=16)
        
        plt.tight_layout()
        plt.savefig(os.path.join(INTERP_DIR, f"gradcam_comparison_{sample_idx+1}.png"), dpi=300)
        plt.show()
        
    print("✅ GradCAM visualizations generated and saved to interpretability directory")

# Generate GradCAM visualizations
visualize_gradcam_comparison(
    models=trained_models,
    ensemble=ensemble,
    test_loader=test_loader,
    device=device,
    class_names=class_names,
    n_samples=5
)

In [None]:
# Step 19: Failure Analysis - Find cases where ensemble succeeds but individual models fail

def analyze_model_failures(models, ensemble, test_loader, device, class_names, n_samples=5):
    """Find and visualize cases where ensemble succeeds but individual models fail"""
    # Collect samples where ensemble is correct but some models are wrong
    success_cases = []
    
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        batch_size = inputs.size(0)
        
        # Get ensemble predictions
        with torch.no_grad():
            ensemble_outputs = ensemble(inputs)
            ensemble_preds = ensemble_outputs.argmax(dim=1)
            
            # Only consider cases where ensemble is correct
            correct_mask = (ensemble_preds == labels)
            if not correct_mask.any():
                continue
                
            # Get individual model predictions
            model_preds = []
            for model in models:
                outputs = model(inputs)
                preds = outputs.argmax(dim=1)
                model_preds.append(preds)
            
            model_preds = torch.stack(model_preds)  # [n_models, batch_size]
            
            # Find cases where at least one model is wrong
            for i in range(batch_size):
                if correct_mask[i]:
                    # Check if any model is wrong
                    any_wrong = False
                    for j in range(len(models)):
                        if model_preds[j, i] != labels[i]:
                            any_wrong = True
                            break
                    
                    if any_wrong:
                        success_cases.append((
                            inputs[i].cpu(),
                            labels[i].item(),
                            ensemble_preds[i].item(),
                            [mp[i].item() for mp in model_preds]
                        ))
                        
                        if len(success_cases) >= n_samples:
                            break
        
        if len(success_cases) >= n_samples:
            break
    
    # Visualize the success cases
    if not success_cases:
        print("❗ Could not find cases where ensemble succeeds but some models fail.")
        return
    
    for idx, (img, true_label, ensemble_pred, model_preds) in enumerate(success_cases):
        # Convert image for display
        img_np = img.permute(1, 2, 0).numpy()
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_np = img_np * std + mean
        img_np = np.clip(img_np, 0, 1)
        
        # Create figure
        fig, ax = plt.subplots(1, 1, figsize=(10, 8))
        ax.imshow(img_np)
        ax.set_title(f"True Class: {class_names[true_label]}", fontsize=14, pad=20)
        ax.axis('off')
        
        # Create table of predictions
        model_names = model_architectures + ['Adaptive Ensemble']
        all_preds = model_preds + [ensemble_pred]
        correct = [pred == true_label for pred in all_preds]
        
        cell_text = []
        for i, (name, pred, is_correct) in enumerate(zip(model_names, all_preds, correct)):
            pred_name = class_names[pred]
            status = "✓" if is_correct else "✗"
            cell_text.append([name, pred_name, status])
        
        # Add table below the image
        table = plt.table(cellText=cell_text,
                          colLabels=['Model', 'Prediction', 'Correct?'],
                          loc='bottom', bbox=[0, -0.5, 1, 0.3])
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 1.5)
        
        # Color the rows based on correctness
        for i in range(len(all_preds)):
            if correct[i]:
                table[(i+1, 2)].set_facecolor('#d8f3dc')  # light green
            else:
                table[(i+1, 2)].set_facecolor('#ffccd5')  # light red
        
        plt.tight_layout()
        
        plt.savefig(os.path.join(INTERP_DIR, f"failure_analysis_{idx+1}.png"), dpi=300, bbox_inches='tight')
        plt.show()
    
    print(f"✅ Failure analysis visualizations generated for {len(success_cases)} samples")

# Run failure analysis
analyze_model_failures(
    models=trained_models,
    ensemble=ensemble,
    test_loader=test_loader,
    device=device,
    class_names=class_names,
    n_samples=5
)

In [None]:
# Step 20: Condition-Based Analysis for Different Water Types

def analyze_water_conditions(ensemble, test_loader, device, class_names, n_samples=10):
    """Analyze how adaptive weights change across different water conditions"""
    # Collect diverse water samples
    water_samples = []
    water_labels = []
    
    for inputs, labels in test_loader:
        for img, label in zip(inputs, labels):
            # Simple diversity check - look at average color to categorize water types
            # This is a simplification - in practice you might use domain knowledge
            img_np = img.numpy().transpose(1, 2, 0)
            avg_color = img_np.mean(axis=(0, 1))
            brightness = avg_color.mean()
            
            # Collect samples with different brightness (proxy for water condition)
            if len(water_samples) < n_samples:
                water_samples.append(img)
                water_labels.append((label.item(), brightness))
            else:
                # Replace sample to maximize diversity
                brightness_values = [b for _, b in water_labels]
                if abs(brightness - np.mean(brightness_values)) > abs(brightness_values[0] - np.mean(brightness_values)):
                    water_samples[0] = img
                    water_labels[0] = (label.item(), brightness)
                
                # Sort by brightness
                indices = sorted(range(len(water_labels)), key=lambda i: water_labels[i][1])
                water_samples = [water_samples[i] for i in indices]
                water_labels = [water_labels[i] for i in indices]
    
    # Analyze adaptive weighting on diverse samples
    ensemble.eval()
    
    plt.figure(figsize=(15, 10))
    
    for i, (img, (label, brightness)) in enumerate(zip(water_samples, water_labels)):
        # Get model weights for this sample
        img_tensor = img.unsqueeze(0).to(device)
        
        with torch.no_grad():
            # Get condition-specific weights
            condition_weights = ensemble.condition_encoder(img_tensor).cpu().numpy()[0]
            
            # Get base weights
            base_weights = ensemble.base_weights.detach().cpu().numpy()
            
            # Calculate final weights
            final_weights = base_weights * condition_weights
            final_weights = final_weights / final_weights.sum()
            
            # Get prediction
            outputs = ensemble(img_tensor)
            pred = outputs.argmax(dim=1).item()
        
        # Create plot in grid
        plt.subplot(3, 4, i+1)
        
        # Show image
        img_np = img.permute(1, 2, 0).numpy()
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_np = img_np * std + mean
        img_np = np.clip(img_np, 0, 1)
        
        # Show image in top 70% of subplot
        plt.imshow(img_np)
        plt.title(f"Water Type {i+1}: {class_names[pred]}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(ENSEMBLE_DIR, 'water_types.png'), dpi=300)
    plt.show()
    
    # Create weight comparison plot as heatmap
    plt.figure(figsize=(15, 10))
    
    # Get all weights in a 2D array [n_samples, n_models]
    all_weights = np.zeros((len(water_samples), len(model_architectures)))
    
    for i, (img, _) in enumerate(zip(water_samples, water_labels)):
        img_tensor = img.unsqueeze(0).to(device)
        with torch.no_grad():
            condition_weights = ensemble.condition_encoder(img_tensor).cpu().numpy()[0]
            base_weights = ensemble.base_weights.detach().cpu().numpy()
            final_weights = base_weights * condition_weights
            final_weights = final_weights / final_weights.sum()
            all_weights[i] = final_weights
    
    # Create heatmap
    plt.figure(figsize=(12, 8))
    sns.heatmap(all_weights, annot=True, fmt=".3f", cmap="YlGnBu",
                xticklabels=model_architectures,
                yticklabels=[f"Water Type {i+1}" for i in range(len(water_samples))])
    plt.title("Adaptive Weight Distribution Across Water Types")
    plt.xlabel("Model")
    plt.ylabel("Water Sample")
    plt.tight_layout()
    
    plt.savefig(os.path.join(ENSEMBLE_DIR, 'adaptive_weights_heatmap.png'), dpi=300)
    plt.show()
    
    print("✅ Water condition analysis completed")

# Run water condition analysis
analyze_water_conditions(
    ensemble=ensemble,
    test_loader=test_loader,
    device=device,
    class_names=class_names,
    n_samples=12  # 3x4 grid
)

In [None]:
# Step 21: Generate Summary Statistics and Final Journal-Quality Visualizations

# 1. Create overall summary table
summary_dict = {
    'Metric': [],
    'Value': []
}

# Basic dataset statistics
summary_dict['Metric'].append('Total Classes')
summary_dict['Value'].append(num_classes)

summary_dict['Metric'].append('Total Samples')
test_total = len(test_ds)
train_val_total = sum(sum(cls.values()) for cls in [split_counts['train']] + [fc['train'] for fc in fold_counts]) + \
                 sum(sum(cls.values()) for cls in [split_counts['val']] + [fc['val'] for fc in fold_counts])
summary_dict['Value'].append(test_total + train_val_total)

summary_dict['Metric'].append('K-Fold Cross Validation')
summary_dict['Value'].append(f"{K_FOLDS}-fold")

summary_dict['Metric'].append('Base Models')
summary_dict['Value'].append(len(model_architectures))

# Model performance
summary_dict['Metric'].append('Best Base Model')
best_idx = np.argmax(base_model_accuracies)
summary_dict['Value'].append(f"{model_architectures[best_idx]} ({base_model_accuracies[best_idx]:.4f})")

summary_dict['Metric'].append('Ensemble Accuracy')
summary_dict['Value'].append(f"{ensemble_acc:.4f}")

summary_dict['Metric'].append('Accuracy Improvement')
summary_dict['Value'].append(f"{improvement:.2f}%")

# Ablation results if available
if 'ablation_df' in locals():
    summary_dict['Metric'].append('Simple Average Accuracy')
    simple_avg_acc = ablation_df.loc[ablation_df['Method'] == 'Simple Average', 'Accuracy'].values[0]
    summary_dict['Value'].append(f"{simple_avg_acc:.4f}")
    
    summary_dict['Metric'].append('Majority Vote Accuracy')
    majority_acc = ablation_df.loc[ablation_df['Method'] == 'Majority Vote', 'Accuracy'].values[0]
    summary_dict['Value'].append(f"{majority_acc:.4f}")

# Save summary
summary_df = pd.DataFrame(summary_dict)
summary_df.to_csv(os.path.join(METRICS_DIR, 'overall_summary.csv'), index=False)

# 2. Create publication-quality comparison chart with K-fold information
plt.figure(figsize=(14, 8))

# First show fold accuracies for each model
x_positions = []
x_labels = []
for i, arch in enumerate(model_architectures):
    fold_accs = fold_val_accuracies[arch]
    x_pos = np.arange(K_FOLDS) + i * (K_FOLDS + 2)
    plt.bar(x_pos, fold_accs, width=0.8, alpha=0.6, label=f"{arch} folds" if i==0 else "")
    x_positions.extend(x_pos)
    x_labels.extend([f"F{j+1}" for j in range(K_FOLDS)])
    
    # Add average for this model
    avg_pos = x_pos[-1] + 1
    avg_acc = np.mean(fold_accs)
    plt.bar(avg_pos, avg_acc, width=0.8, color='black', alpha=0.8, label=f"Average" if i==0 else "")
    x_positions.append(avg_pos)
    x_labels.append("Avg")

# Then show ensemble accuracies
ens_pos = x_positions[-1] + 3
plt.bar(ens_pos, ensemble_acc, width=0.8, color='red', label=f"Adaptive Ensemble")
plt.axhline(y=ensemble_acc, color='red', linestyle='--', alpha=0.5)
x_positions.append(ens_pos)
x_labels.append("Ensemble")

plt.xticks(x_positions, x_labels, rotation=45)
plt.title('K-Fold Cross-Validation Results', fontsize=16)
plt.xlabel('Models and Folds', fontsize=14)
plt.ylabel('Accuracy', fontsize=14)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()

plt.savefig(os.path.join(VIS_DIR, "k_fold_comparison.png"), dpi=300, bbox_inches='tight')
plt.show()

# 3. Create final model comparison with error bars from cross-validation
plt.figure(figsize=(10, 6))

# Get mean and std for each model
means = [np.mean(fold_val_accuracies[arch]) for arch in model_architectures]
stds = [np.std(fold_val_accuracies[arch]) for arch in model_architectures]

# Add ensemble result (use a small error bar for visual consistency)
all_models = model_architectures + ['Adaptive Ensemble']
all_means = means + [ensemble_acc]
all_stds = stds + [0.01]  # Small error for ensemble

# Plot with error bars
bars = plt.bar(range(len(all_models)), all_means, yerr=all_stds, 
         capsize=5, color=['blue']*len(model_architectures) + ['red'])

# Add value annotations
for i, bar in enumerate(bars):
    height = bar.get_height()
    plt.annotate(f'{all_means[i]:.4f}',
               xy=(bar.get_x() + bar.get_width() / 2, height),
               xytext=(0, 3),
               textcoords="offset points",
               ha='center', va='bottom')

plt.xticks(range(len(all_models)), all_models, rotation=45)
plt.title('Model Performance Comparison with K-Fold Cross-Validation', fontsize=16)
plt.xlabel('Model', fontsize=14)
plt.ylabel('Accuracy', fontsize=14)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()

plt.savefig(os.path.join(VIS_DIR, "final_model_comparison.png"), dpi=300, bbox_inches='tight')
plt.show()

# Display summary
print("\n📊 Final Model Summary")
print("=" * 80)
print(f"Dataset: {num_classes} classes, {train_val_total + test_total} total samples")
print(f"Cross-validation: {K_FOLDS}-fold with {EPOCHS_BASE_MODELS} max epochs")
print(f"Best Base Model: {model_architectures[np.argmax(base_model_accuracies)]} - {max(base_model_accuracies):.4f}")
print(f"Adaptive Ensemble: {ensemble_acc:.4f} ({improvement:.2f}% improvement)")
print("=" * 80)
print("\nSee visualization and result directories for all charts and metrics.")
print(f"  - Models: {MODELS_DIR}")
print(f"  - Metrics: {METRICS_DIR}")
print(f"  - Visualizations: {VIS_DIR}")
print(f"  - Ensemble Analysis: {ENSEMBLE_DIR}")
print(f"  - Interpretability: {INTERP_DIR}")

print("\n✅ Adaptive Weighted Ensemble implementation with K-fold cross-validation complete!")

In [None]:
import os
import time

# Define the folders to zip individually
folders = [
    "ensemble",
    "features",
    "interpretability",
    "logs",
    "metrics",
    "models",
    "splits",
    "visualizations"
]

print("🔄 Starting to zip individual folders...\n")

# Process each folder
for folder in folders:
    folder_path = f"/kaggle/working/outputs/{folder}"
    zip_path = f"/kaggle/working/{folder}.zip"
    
    # Check if folder exists
    if not os.path.exists(folder_path):
        print(f"⚠️ Warning: Folder {folder_path} not found. Skipping.")
        continue
    
    # Zip the folder
    print(f"🔄 Zipping {folder}...")
    start_time = time.time()
    !zip -r {zip_path} {folder_path}
    
    # Get zip file size and completion time
    if os.path.exists(zip_path):
        size_mb = os.path.getsize(zip_path) / (1024*1024)
        elapsed = time.time() - start_time
        print(f"✅ Created {folder}.zip: {size_mb:.2f} MB (in {elapsed:.1f} seconds)")
    else:
        print(f"❌ Failed to create {folder}.zip")

print("\n✅ All zipping completed!")
print("\nZip files created in /kaggle/working/:")
!ls -lh /kaggle/working/*.zip | sort -k5 -h

print("\nTo download these files manually:")
print("1. Look for the zip files in the file browser on the right side")
print("2. Click the three dots (⋮) next to each file")
print("3. Select 'Download' from the dropdown menu")