In [1]:
# Part 1: Imports and Initial Setup
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler, Dataset
import numpy as np
import os
import torch.optim as optim
from torch.nn import functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import *
from sklearn.model_selection import KFold
from copy import deepcopy
import logging
import json
import seaborn as sns
from pathlib import Path
import warnings
from typing import Tuple, Dict, Any, List
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
import time
from tqdm import tqdm
import random
from PIL import Image

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('epaas_training.log'),
        logging.StreamHandler()
    ]
)

# Constants
SEED = 42
NUM_CLASSES = 10
BATCH_SIZE = 64
NUM_EPOCHS = 5
BASE_LR = 0.03
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
NUM_WORKERS = 2

# Set device and random seeds
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(SEED)

# Create necessary directories
def create_directories(base_path="/content/drive/MyDrive/Epaas"):
    base = Path(base_path)
    dirs = ['checkpoints', 'results', 'logs']

    # Create base directory if it doesn't exist
    base.mkdir(parents=True, exist_ok=True)

    # Create subdirectories under base path
    directories = {}
    for dir_name in dirs:
        dir_path = base / dir_name
        dir_path.mkdir(parents=True, exist_ok=True)
        directories[dir_name] = dir_path

    return directories

directories = create_directories()



class Config:
    """Configuration class for hyperparameters"""
    def __init__(self):
        self.num_classes = NUM_CLASSES
        self.batch_size = BATCH_SIZE
        self.num_epochs = NUM_EPOCHS
        self.base_lr = BASE_LR
        self.momentum = MOMENTUM
        self.weight_decay = WEIGHT_DECAY
        self.num_workers = NUM_WORKERS
        self.device = device

        # EPAAS specific parameters
        self.initial_threshold = 0.95
        self.final_threshold = 0.8
        self.temperature = 1.0
        self.alpha = 0.5  # Weight for entropy minimization
        self.beta = 1.0   # Weight for consistency regularization

        # Training specific
        self.num_folds = 5
        self.early_stopping_patience = 10
        self.gradient_clip_val = 1.0

config = Config()


In [2]:
# Part 2: Data Loading and Augmentation

class AdvancedAugmentation:
    """
    Advanced augmentation pipeline for EPAAS with weak and strong augmentations
    """
    def __init__(self, size=96, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(size, padding=4, padding_mode='reflect'),
            transforms.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.2,
                hue=0.1
            ),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(size, padding=4, padding_mode='reflect'),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
                hue=0.2
            ),
            transforms.RandomRotation(20),
            transforms.RandomAffine(
                degrees=0,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1),
                shear=5
            ),
            transforms.RandomApply([
                transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
            ], p=0.3),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        self.test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

class STL10Dataset(Dataset):
    """
    Custom STL10 Dataset with flexible transformation handling
    """
    def __init__(self, root, split='train', transform=None, download=True):
        self.dataset = torchvision.datasets.STL10(
            root=root,
            split=split,
            transform=None,
            download=download
        )
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]

        if self.transform:
            if isinstance(self.transform, dict):
                # Multiple augmentations
                augmented = {k: t(image) for k, t in self.transform.items()}
                return augmented, label
            else:
                # Single augmentation
                return self.transform(image), label
        return image, label

class DataPrefetcher:
    """
    Data prefetcher for faster data loading
    """
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_data = next(self.loader)
        except StopIteration:
            self.next_data = None
            return

        with torch.cuda.stream(self.stream):
            if isinstance(self.next_data[0], dict):
                self.next_data = (
                    {k: v.cuda(non_blocking=True)
                     for k, v in self.next_data[0].items()},
                    self.next_data[1].cuda(non_blocking=True)
                )
            else:
                self.next_data = (
                    self.next_data[0].cuda(non_blocking=True),
                    self.next_data[1].cuda(non_blocking=True)
                )

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        data = self.next_data
        self.preload()
        return data

def get_stl10_dataloaders(config):
    """
    Get STL10 dataloaders with appropriate transformations
    """
    data_dir = Path("/content/drive/MyDrive/Epaas/data/STL10")
    os.makedirs(data_dir, exist_ok=True)

    augmentation = AdvancedAugmentation()

    try:
        # Training dataset with both weak and strong augmentations
        train_dataset = STL10Dataset(
            root=data_dir,
            split='train',
            transform={
                'weak': augmentation.weak,
                'strong': augmentation.strong
            },
            download=True
        )

        # Validation dataset with only weak augmentation
        val_dataset = STL10Dataset(
            root=data_dir,
            split='train',
            transform=augmentation.weak,  # Only weak augmentation for validation
            download=True
        )

        # Unlabeled dataset with both augmentations
        unlabeled_dataset = STL10Dataset(
            root=data_dir,
            split='unlabeled',
            transform={
                'weak': augmentation.weak,
                'strong': augmentation.strong
            },
            download=True
        )

        # Test dataset with test transformation
        test_dataset = STL10Dataset(
            root=data_dir,
            split='test',
            transform=augmentation.test,
            download=True
        )

        # Create train-validation split
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size

        train_subset, val_subset = torch.utils.data.random_split(
            train_dataset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(SEED)
        )

        # Create dataloaders
        train_loader = DataLoader(
            train_subset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=config.num_workers,
            pin_memory=True,
            drop_last=True
        )

        val_loader = DataLoader(
            val_subset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=config.num_workers,
            pin_memory=True
        )

        unlabeled_loader = DataLoader(
            unlabeled_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=config.num_workers,
            pin_memory=True,
            drop_last=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=config.num_workers,
            pin_memory=True
        )

        return {
            'train': train_loader,
            'val': val_loader,
            'unlabeled': unlabeled_loader,
            'test': test_loader
        }

    except Exception as e:
        logging.error(f"Error in data loading: {str(e)}")
        raise



def verify_data_pipeline(dataloaders):
    """
    Verify the data pipeline by checking a batch of data
    """
    try:
        # Check training data
        train_batch, train_labels = next(iter(dataloaders['train']))
        logging.info(f"Training batch shapes:")
        logging.info(f"Weak augmentation: {train_batch['weak'].shape}")
        logging.info(f"Strong augmentation: {train_batch['strong'].shape}")
        logging.info(f"Labels: {train_labels.shape}")

        # Check unlabeled data
        unlabeled_batch, _ = next(iter(dataloaders['unlabeled']))
        logging.info(f"\nUnlabeled batch shapes:")
        logging.info(f"Weak augmentation: {unlabeled_batch['weak'].shape}")
        logging.info(f"Strong augmentation: {unlabeled_batch['strong'].shape}")

        return True

    except Exception as e:
        logging.error(f"Error in data pipeline verification: {str(e)}")
        return False

if __name__ == "__main__":
    # Test data pipeline
    dataloaders = get_stl10_dataloaders(config)
    if verify_data_pipeline(dataloaders):
        logging.info("Data pipeline verification successful!")
    else:
        logging.error("Data pipeline verification failed!")


In [3]:
# Part 3: Model Architecture and Loss Functions

class ResNet18(nn.Module):
    """
    Modified ResNet18 architecture for EPAAS
    """
    def __init__(self, num_classes=10, pretrained=True, dropout_rate=0.3):
        super().__init__()
        # Load pretrained ResNet18
        self.model = torchvision.models.resnet18(pretrained=pretrained)

        # Modify the first conv layer for STL-10 image size
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.maxpool = nn.Identity()  # Remove maxpool layer

        # Modify final layers
        in_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )

        # Initialize the new layers
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.model.fc.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.model(x)

class EPAASLoss(nn.Module):
    """
    EPAAS Loss implementation with entropy-based pseudo-label selection
    and adaptive thresholding
    """
    def __init__(self, num_classes, initial_threshold=0.95, temperature=1.0,
                 alpha=0.5, beta=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.threshold = initial_threshold
        self.temperature = temperature
        self.alpha = alpha  # Weight for entropy minimization
        self.beta = beta    # Weight for consistency regularization
        self.cross_entropy = nn.CrossEntropyLoss(reduction='none')

    def entropy(self, probs):
        """Calculate entropy of probability distributions"""
        return -torch.sum(probs * torch.log(probs + 1e-6), dim=1)

    def consistency_loss(self, logits_w, logits_s):
        """Calculate consistency loss between weak and strong augmentations"""
        return F.mse_loss(
            torch.softmax(logits_s, dim=1),
            torch.softmax(logits_w, dim=1),
            reduction='none'
        ).mean(1)

    def forward(self, logits_w, logits_s, current_epoch=None, total_epochs=None):
        """
        Forward pass of EPAAS loss
        Args:
            logits_w: Logits from weakly augmented images
            logits_s: Logits from strongly augmented images
            current_epoch: Current training epoch (for threshold adjustment)
            total_epochs: Total number of epochs
        """
        with torch.no_grad():
            # Calculate probabilities and entropy for weak augmentation
            probs_w = torch.softmax(logits_w / self.temperature, dim=1)
            entropy_w = self.entropy(probs_w)

            # Adaptive threshold based on entropy and training progress
            entropy_mean = entropy_w.mean()
            base_threshold = self.threshold
            if current_epoch is not None and total_epochs is not None:
                # Gradually decrease threshold
                progress = current_epoch / total_epochs
                base_threshold = self.threshold * (1 - 0.2 * progress)

            adaptive_threshold = base_threshold * (1 - entropy_mean/np.log(self.num_classes))

            # Get pseudo-labels and confidence mask
            max_probs, pseudo_labels = torch.max(probs_w, dim=1)
            mask = max_probs.ge(adaptive_threshold)

            # Calculate auto-sampling weights based on entropy
            weights = 1 - entropy_w/np.log(self.num_classes)
            weights = weights / weights.mean()  # normalize weights

        # Supervised loss with pseudo-labels
        loss_s = self.cross_entropy(logits_s, pseudo_labels)
        loss_s = (loss_s * mask.float() * weights).mean()

        # Entropy minimization loss
        probs_s = torch.softmax(logits_s / self.temperature, dim=1)
        entropy_s = self.entropy(probs_s)
        loss_ent = entropy_s.mean()

        # Consistency regularization loss
        loss_cons = self.consistency_loss(logits_w, logits_s)
        loss_cons = (loss_cons * weights).mean()

        # Combined loss with adaptive weights
        total_loss = loss_s + self.alpha * loss_ent + self.beta * loss_cons

        return {
            'total_loss': total_loss,
            'pseudo_loss': loss_s,
            'entropy_loss': loss_ent,
            'consistency_loss': loss_cons,
            'mask_mean': mask.float().mean(),
            'threshold': adaptive_threshold
        }

class SupervisedLoss(nn.Module):
    """
    Supervised loss component for labeled data
    """
    def __init__(self, num_classes=10):
        super().__init__()
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, logits, targets):
        return self.cross_entropy(logits, targets)

class LossTracker:
    """
    Utility class to track and log different loss components
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.losses = {
            'total_loss': [],
            'pseudo_loss': [],
            'entropy_loss': [],
            'consistency_loss': [],
            'supervised_loss': [],
            'mask_mean': [],
            'threshold': []
        }

    def update(self, loss_dict):
        for key, value in loss_dict.items():
            if key in self.losses:
                self.losses[key].append(value.item() if torch.is_tensor(value) else value)

    def get_means(self):
        return {key: np.mean(values) for key, values in self.losses.items() if values}

    def log_means(self, epoch):
        means = self.get_means()
        logging.info(f"\nEpoch {epoch} Loss Summary:")
        for key, value in means.items():
            logging.info(f"{key}: {value:.4f}")

def initialize_model(config):
    """
    Initialize model and loss functions
    """
    model = ResNet18(
        num_classes=config.num_classes,
        pretrained=True,
        dropout_rate=0.3
    ).to(config.device)

    epaas_loss = EPAASLoss(
        num_classes=config.num_classes,
        initial_threshold=config.initial_threshold,
        temperature=config.temperature,
        alpha=config.alpha,
        beta=config.beta
    ).to(config.device)

    supervised_loss = SupervisedLoss(
        num_classes=config.num_classes
    ).to(config.device)

    return model, epaas_loss, supervised_loss

if __name__ == "__main__":
    # Test model and loss functions
    model, epaas_loss, supervised_loss = initialize_model(config)

    # Print model summary
    logging.info("\nModel Architecture:")
    logging.info(model)

    # Calculate model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    logging.info(f"\nTotal parameters: {total_params:,}")
    logging.info(f"Trainable parameters: {trainable_params:,}")


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]


In [4]:
# Part 4: Training Components

class EarlyStopping:
    """
    Early stopping handler with model checkpoint saving
    """
    def __init__(self, patience=7, min_delta=1e-4, mode='min', verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = deepcopy(model.state_dict())
        elif (self.mode == 'min' and val_loss > self.best_loss - self.min_delta) or \
             (self.mode == 'max' and val_loss < self.best_loss + self.min_delta):
            self.counter += 1
            if self.verbose:
                logging.info(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model = deepcopy(model.state_dict())
            self.counter = 0
        return self.early_stop

class ModelCheckpoint:
    """
    Model checkpoint handler with multiple metric monitoring
    """
    def __init__(self, save_dir, metric_name='val_loss', mode='min'):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.metric_name = metric_name
        self.mode = mode
        self.best_metric = float('inf') if mode == 'min' else float('-inf')

    def __call__(self, model, current_metric, epoch, metrics_dict=None):
        improved = (self.mode == 'min' and current_metric < self.best_metric) or \
                  (self.mode == 'max' and current_metric > self.best_metric)

        if improved:
            self.best_metric = current_metric
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                f'best_{self.metric_name}': self.best_metric,
                'metrics': metrics_dict
            }
            torch.save(checkpoint, self.save_dir / f'best_model_{self.metric_name}.pth')
            logging.info(f'Saved new best model with {self.metric_name}: {self.best_metric:.4f}')

class MetricsTracker:
    """
    Tracks and computes various metrics during training
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.predictions = []
        self.targets = []
        self.losses = []

    def update(self, preds, targets, loss=None):
        self.predictions.extend(preds.cpu().numpy())
        self.targets.extend(targets.cpu().numpy())
        if loss is not None:
            self.losses.append(loss.item())

    def compute_metrics(self):
        predictions = np.array(self.predictions)
        targets = np.array(self.targets)

        metrics = {
            'accuracy': accuracy_score(targets, predictions),
            'precision': precision_score(targets, predictions, average='macro'),
            'recall': recall_score(targets, predictions, average='macro'),
            'f1': f1_score(targets, predictions, average='macro')
        }

        if self.losses:
            metrics['loss'] = np.mean(self.losses)

        return metrics

class Trainer:
    """
    Main trainer class for EPAAS
    """
    def __init__(self, model, epaas_loss, supervised_loss, config):
        self.model = model
        self.epaas_loss = epaas_loss
        self.supervised_loss = supervised_loss
        self.config = config
        self.device = config.device

        # Initialize optimizers
        self.optimizer = self._create_optimizer()
        self.scheduler = self._create_scheduler()

        # Initialize trackers
        self.loss_tracker = LossTracker()
        self.metrics_tracker = MetricsTracker()

        # Initialize early stopping and checkpointing
        self.early_stopping = EarlyStopping(
            patience=config.early_stopping_patience,
            mode='max'
        )
        self.checkpoint = ModelCheckpoint(
            save_dir='checkpoints',
            metric_name='val_accuracy',
            mode='max'
        )

    def _create_optimizer(self):
        return optim.SGD([
            {'params': self.model.model.fc.parameters(), 'lr': self.config.base_lr},
            {'params': [p for n, p in self.model.model.named_parameters()
                       if not n.startswith('fc')],
             'lr': self.config.base_lr/10}
        ], momentum=self.config.momentum, weight_decay=self.config.weight_decay)

    def _create_scheduler(self):
        return CosineAnnealingWarmRestarts(
            self.optimizer,
            T_0=10,
            T_mult=2,
            eta_min=self.config.base_lr * 0.01
        )

    @torch.no_grad()
    def validate(self, val_loader):
        """
        Validate the model

        Args:
            val_loader (DataLoader): Validation data loader

        Returns:
            dict: Validation metrics
        """
        self.model.eval()
        self.metrics_tracker.reset()

        for batch_idx, (inputs, targets) in enumerate(val_loader):
            # Handle inputs that might be dictionary or tensor
            if isinstance(inputs, dict):
                # Use 'weak' augmentation for validation
                inputs = inputs['weak']

            # Move to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            # Forward pass
            outputs = self.model(inputs)
            loss = self.supervised_loss(outputs, targets)

            # Calculate predictions
            _, predicted = outputs.max(1)

            # Update metrics
            self.metrics_tracker.update(predicted, targets, loss)

        # Compute and return metrics
        metrics = self.metrics_tracker.compute_metrics()
        return metrics

    def train_epoch(self, train_loader, unlabeled_loader, epoch):
        """
        Train for one epoch
        """
        self.model.train()
        self.loss_tracker.reset()
        self.metrics_tracker.reset()

        unlabeled_iter = iter(unlabeled_loader)
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')

        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            # Get unlabeled batch
            try:
                unlabeled_batch, _ = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = iter(unlabeled_loader)
                unlabeled_batch, _ = next(unlabeled_iter)

            # Move data to device
            inputs, targets = inputs['weak'].to(self.device), targets.to(self.device)
            u_w, u_s = unlabeled_batch['weak'].to(self.device), \
                       unlabeled_batch['strong'].to(self.device)

            # Forward pass
            labeled_logits = self.model(inputs)
            unlabeled_logits_w = self.model(u_w)
            unlabeled_logits_s = self.model(u_s)

            # Calculate losses
            sup_loss = self.supervised_loss(labeled_logits, targets)
            unsup_losses = self.epaas_loss(
                unlabeled_logits_w,
                unlabeled_logits_s,
                current_epoch=epoch,
                total_epochs=self.config.num_epochs
            )

            # Combined loss
            total_loss = sup_loss + unsup_losses['total_loss']

            # Optimization
            self.optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config.gradient_clip_val
            )
            self.optimizer.step()

            # Update metrics
            _, predicted = labeled_logits.max(1)
            self.metrics_tracker.update(predicted, targets, total_loss)

            # Update loss tracker
            self.loss_tracker.update({
                'total_loss': total_loss,
                'supervised_loss': sup_loss,
                **unsup_losses
            })

            # Update progress bar
            progress_bar.set_postfix({
                'loss': total_loss.item(),
                'sup_loss': sup_loss.item(),
                'mask_mean': unsup_losses['mask_mean']
            })

        # Step scheduler
        self.scheduler.step()

        # Compute epoch metrics
        train_metrics = self.metrics_tracker.compute_metrics()
        train_metrics.update(self.loss_tracker.get_means())

        return train_metrics

    def save_training_state(self, epoch, metrics, save_dir):
        """
        Save training state and metrics
        """
        state = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'metrics': metrics
        }
        torch.save(state, save_dir / f'training_state_epoch_{epoch}.pth')

    def load_training_state(self, state_path):
        """
        Load training state
        """
        state = torch.load(state_path)
        self.model.load_state_dict(state['model_state_dict'])
        self.optimizer.load_state_dict(state['optimizer_state_dict'])
        self.scheduler.load_state_dict(state['scheduler_state_dict'])
        return state['epoch'], state['metrics']


In [5]:
# Part 5: Main Training Loop and Results Visualization

import logging
import json
import time
import os
from pathlib import Path
from typing import Dict, List, Union
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np

class ExperimentManager:
    """
    Manages experiment execution, logging, and visualization with enhanced error handling

    Attributes:
        config (object): Experiment configuration object
        experiment_name (str): Name of the experiment
        results_dir (Path): Path to results directory
        training_history (dict): Storage for training metrics
    """

    def __init__(self, config, experiment_name: str = "EPAAS_STL10"):
        """
        Initialize ExperimentManager with proper error handling

        Args:
            config: Experiment configuration object
            experiment_name: Name of the experiment (default: "EPAAS_STL10")
        """
        self.config = config
        self.experiment_name = self._sanitize_name(experiment_name)

        # Initialize paths safely
        self.base_path = Path("/content/drive/MyDrive/Epaas/results")
        self.results_dir = self._setup_results_directory()

        # Initialize metrics storage
        self.training_history = {
            'train_metrics': [],
            'val_metrics': [],
            'best_epoch': 0,
            'best_val_accuracy': 0.0
        }

        # Setup logging
        self.logger = self._setup_logging()

    def _sanitize_name(self, name: str) -> str:
        """Sanitize experiment name for file safety"""
        return "".join(c for c in name if c.isalnum() or c in ('-', '_')).strip().rstrip()

    def _setup_results_directory(self) -> Path:
        """Create results directory with error handling"""
        try:
            results_dir = self.base_path / self.experiment_name
            results_dir.mkdir(parents=True, exist_ok=True)
            return results_dir
        except Exception as e:
            raise RuntimeError(f"Failed to create results directory: {str(e)}")

    def _setup_logging(self) -> logging.Logger:
        """Configure logging with multiple handlers"""
        logger = logging.getLogger(self.experiment_name)
        logger.setLevel(logging.INFO)

        # Prevent duplicate handlers
        if logger.handlers:
            return logger

        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

        try:
            file_handler = logging.FileHandler(self.results_dir / 'experiment.log')
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)
        except IOError as e:
            logger.error(f"Failed to create log file: {str(e)}")

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        logger.addHandler(console_handler)

        logger.info(f"Starting experiment: {self.experiment_name}")
        logger.info(f"Config: {self._safe_config_representation()}")

        return logger

    def _safe_config_representation(self) -> Dict:
        """Convert config to safe serializable format"""
        if hasattr(self.config, '__dict__'):
            config_dict = vars(self.config)
        else:
            config_dict = dict(self.config)

        # Convert non-serializable values
        for k, v in config_dict.items():
            if hasattr(v, 'device'):
                config_dict[k] = str(v)
        return config_dict

    def plot_training_curves(self) -> None:
        """Plot training and validation metrics with safety checks"""
        if not self.training_history['train_metrics']:
            self.logger.warning("No training metrics to plot")
            return

        metrics_to_plot = [
            ('accuracy', 'Accuracy'),
            ('loss', 'Loss'),
            ('mask_mean', 'Pseudo-Label Mask Mean'),
            ('threshold', 'Confidence Threshold')
        ]

        plt.figure(figsize=(20, 15))
        for idx, (metric, title) in enumerate(metrics_to_plot, 1):
            plt.subplot(2, 2, idx)

            # Plot training metrics if available
            try:
                train_values = [m.get(metric, np.nan) for m in self.training_history['train_metrics']]
                plt.plot(train_values, label=f'Train {title}')
            except KeyError:
                pass

            # Plot validation metrics if available
            try:
                val_values = [m.get(metric, np.nan) for m in self.training_history['val_metrics']]
                plt.plot(val_values, label=f'Val {title}')
            except KeyError:
                pass

            plt.title(title)
            plt.xlabel('Epoch')
            plt.ylabel(title)
            plt.legend()
            plt.grid(True)

        plt.tight_layout()
        try:
            plt.savefig(self.results_dir / 'training_curves.png')
            plt.close()
        except Exception as e:
            self.logger.error(f"Failed to save training curves: {str(e)}")

    def plot_confusion_matrix(self, true_labels: List, predictions: List, classes: range = range(10)) -> None:
        """Generate confusion matrix with input validation"""
        if len(true_labels) != len(predictions):
            self.logger.error("Mismatch in true labels and predictions length")
            return

        try:
            cm = confusion_matrix(true_labels, predictions)
            plt.figure(figsize=(12, 10))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                        xticklabels=classes, yticklabels=classes)
            plt.title('Confusion Matrix')
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.savefig(self.results_dir / 'confusion_matrix.png')
            plt.close()
        except Exception as e:
            self.logger.error(f"Failed to plot confusion matrix: {str(e)}")

    def save_results(self, final_metrics: Dict) -> None:
        """Save experiment results with serialization safety"""
        results = {
            'config': self._safe_config_representation(),
            'training_history': self.training_history,
            'final_metrics': self._sanitize_metrics(final_metrics),
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
        }

        try:
            with open(self.results_dir / 'results_epaas_stl-10.json', 'w') as f:
                json.dump(results, f, indent=4, cls=NumpyEncoder)
        except Exception as e:
            self.logger.error(f"Failed to save results: {str(e)}")

    def _sanitize_metrics(self, metrics: Dict) -> Dict:
        """Convert metrics to JSON-serializable format"""
        return {k: float(v) if isinstance(v, (np.generic, np.ndarray)) else v
                for k, v in metrics.items()}

class NumpyEncoder(json.JSONEncoder):
    """Custom encoder for numpy data types"""
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)

def train_fold(trainer, train_loader, val_loader, unlabeled_loader, config, fold_num):
    """
    Train a single fold in k-fold cross validation

    Args:
        trainer (Trainer): Trainer instance
        train_loader (DataLoader): Training data loader
        val_loader (DataLoader): Validation data loader
        unlabeled_loader (DataLoader): Unlabeled data loader
        config (Config): Configuration object
        fold_num (int): Current fold number

    Returns:
        dict: Dictionary containing fold training results
    """
    logging.info(f"\nTraining Fold {fold_num + 1}")

    # Initialize fold metrics
    fold_metrics = {
        'train_metrics': [],
        'val_metrics': [],
        'best_val_accuracy': 0.0,
        'best_epoch': 0,
        'test_metrics': None
    }

    # Initialize early stopping and checkpoint for this fold
    early_stopping = EarlyStopping(
        patience=config.early_stopping_patience,
        mode='max',
        verbose=True
    )

    checkpoint = ModelCheckpoint(
        save_dir=f'checkpoints/fold_{fold_num}',
        metric_name='val_accuracy',
        mode='max'
    )

    try:
        # Training loop
        for epoch in range(config.num_epochs):
            # Train epoch
            train_metrics = trainer.train_epoch(
                train_loader,
                unlabeled_loader,
                epoch
            )

            # Validate
            val_metrics = trainer.validate(val_loader)

            # Update fold metrics
            fold_metrics['train_metrics'].append(train_metrics)
            fold_metrics['val_metrics'].append(val_metrics)

            # Log metrics
            logging.info(
                f"\nFold {fold_num + 1}, Epoch {epoch + 1}/{config.num_epochs}"
            )
            logging.info(f"Train Metrics: {train_metrics}")
            logging.info(f"Val Metrics: {val_metrics}")

            # Check for best model
            if val_metrics['accuracy'] > fold_metrics['best_val_accuracy']:
                fold_metrics['best_val_accuracy'] = val_metrics['accuracy']
                fold_metrics['best_epoch'] = epoch

                # Save checkpoint
                checkpoint(
                    trainer.model,
                    val_metrics['accuracy'],
                    epoch,
                    {
                        'train_metrics': train_metrics,
                        'val_metrics': val_metrics
                    }
                )

            # Early stopping check
            if early_stopping(val_metrics['accuracy'], trainer.model):
                logging.info(
                    f"Early stopping triggered in fold {fold_num + 1} "
                    f"at epoch {epoch + 1}"
                )
                break

        # Load best model for this fold
        best_model_path = Path(f'checkpoints/fold_{fold_num}/best_model_val_accuracy.pth')
        if best_model_path.exists():
            checkpoint_data = torch.load(best_model_path, map_location=config.device, weights_only=False)
            trainer.model.load_state_dict(checkpoint_data['model_state_dict'])
            logging.info(
                f"Loaded best model from epoch {fold_metrics['best_epoch']} "
                f"with validation accuracy {fold_metrics['best_val_accuracy']:.4f}"
            )

        # Final validation on test set
        test_metrics = trainer.validate(val_loader)
        fold_metrics['test_metrics'] = test_metrics

        # Calculate and log fold summary
        fold_summary = calculate_fold_summary(fold_metrics)
        log_fold_summary(fold_num, fold_summary)

        # Plot fold-specific learning curves
        plot_fold_learning_curves(
            fold_metrics,
            fold_num,
            save_dir=f'results/fold_{fold_num}'
        )

        return fold_metrics

    except Exception as e:
        logging.error(f"Error in fold {fold_num + 1} training: {str(e)}")
        raise

def calculate_fold_summary(fold_metrics):
    """
    Calculate summary statistics for a fold
    """
    return {
        'best_val_accuracy': fold_metrics['best_val_accuracy'],
        'best_epoch': fold_metrics['best_epoch'],
        'final_test_metrics': fold_metrics['test_metrics'],
        'training_time_epochs': len(fold_metrics['train_metrics']),
        'convergence_metrics': {
            'train_loss_final': fold_metrics['train_metrics'][-1]['loss'],
            'val_loss_final': fold_metrics['val_metrics'][-1]['loss'],
            'train_acc_final': fold_metrics['train_metrics'][-1]['accuracy'],
            'val_acc_final': fold_metrics['val_metrics'][-1]['accuracy']
        }
    }

def log_fold_summary(fold_num, fold_summary):
    """
    Log summary of fold training
    """
    logging.info(f"\nFold {fold_num + 1} Summary:")
    logging.info(f"Best Validation Accuracy: {fold_summary['best_val_accuracy']:.4f}")
    logging.info(f"Best Epoch: {fold_summary['best_epoch']}")
    logging.info(f"Training Duration: {fold_summary['training_time_epochs']} epochs")
    logging.info("\nFinal Test Metrics:")
    for metric, value in fold_summary['final_test_metrics'].items():
        logging.info(f"{metric}: {value:.4f}")
    logging.info("\nConvergence Metrics:")
    for metric, value in fold_summary['convergence_metrics'].items():
        logging.info(f"{metric}: {value:.4f}")

def plot_fold_learning_curves(fold_metrics, fold_num, save_dir):
    """
    Plot learning curves for a specific fold

    Args:
        fold_metrics (dict): Metrics for the fold
        fold_num (int): Fold number
        save_dir (str): Directory to save plots
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Metrics to plot
    metrics = ['loss', 'accuracy']

    for metric in metrics:
        plt.figure(figsize=(10, 6))

        # Get metric values
        train_values = [epoch[metric] for epoch in fold_metrics['train_metrics']]
        val_values = [epoch[metric] for epoch in fold_metrics['val_metrics']]

        # Plot training and validation curves
        plt.plot(train_values, label=f'Train {metric}')
        plt.plot(val_values, label=f'Validation {metric}')

        # Mark best epoch
        if metric == 'accuracy':
            best_epoch = fold_metrics['best_epoch']
            plt.axvline(x=best_epoch, color='r', linestyle='--',
                       label=f'Best Epoch ({best_epoch})')

        plt.title(f'Fold {fold_num + 1} - {metric.capitalize()} Learning Curve')
        plt.xlabel('Epoch')
        plt.ylabel(metric.capitalize())
        plt.legend()
        plt.grid(True)

        # Save plot
        plt.savefig(save_dir / f'{metric}_curve_fold_{fold_num + 1}.png')
        plt.close()

    # Plot additional metrics if available
    additional_metrics = ['mask_mean', 'threshold']
    for metric in additional_metrics:
        if metric in fold_metrics['train_metrics'][0]:
            plt.figure(figsize=(10, 6))
            values = [epoch[metric] for epoch in fold_metrics['train_metrics']]
            plt.plot(values, label=metric)
            plt.title(f'Fold {fold_num + 1} - {metric.capitalize()} Evolution')
            plt.xlabel('Epoch')
            plt.ylabel(metric.capitalize())
            plt.legend()
            plt.grid(True)
            plt.savefig(save_dir / f'{metric}_evolution_fold_{fold_num + 1}.png')
            plt.close()
def analyze_kfold_results(fold_results, experiment):
    """
    Analyze and visualize results from k-fold cross validation

    Args:
        fold_results (list): List of metrics from each fold
        experiment (ExperimentManager): Experiment manager instance

    Returns:
        dict: Statistical summary of k-fold results
    """
    # Create directory for k-fold results
    kfold_dir = experiment.results_dir / 'kfold_analysis'
    kfold_dir.mkdir(exist_ok=True)

    # Initialize containers for metrics
    metrics_summary = {
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': [],
        'loss': []
    }

    # Collect metrics from all folds
    for fold_idx, fold_data in enumerate(fold_results):
        logging.info(f"\nFold {fold_idx + 1} Results:")
        for metric, value in fold_data['test_metrics'].items():
            if metric in metrics_summary:
                metrics_summary[metric].append(value)
                logging.info(f"{metric}: {value:.4f}")

    # Calculate statistical summary
    stats_summary = calculate_statistical_summary(metrics_summary)

    # Plot fold comparisons
    plot_fold_comparisons(metrics_summary, kfold_dir)

    # Plot fold variations
    plot_fold_variations(metrics_summary, kfold_dir)

    # Save detailed results
    save_kfold_results(metrics_summary, stats_summary, kfold_dir)

    # Log summary statistics
    log_statistical_summary(stats_summary)

    return stats_summary

def calculate_statistical_summary(metrics_summary):
    """
    Calculate statistical measures for each metric
    """
    stats_summary = {}

    for metric, values in metrics_summary.items():
        values = np.array(values)
        stats_summary[metric] = {
            'mean': np.mean(values),
            'std': np.std(values),
            'min': np.min(values),
            'max': np.max(values),
            'median': np.median(values),
            'ci_95': np.percentile(values, [2.5, 97.5]).tolist()
        }

    return stats_summary

def plot_fold_comparisons(metrics_summary, save_dir):
    """
    Plot comparison of metrics across folds using box plots
    """
    plt.figure(figsize=(15, 10))

    for idx, (metric, values) in enumerate(metrics_summary.items(), 1):
        plt.subplot(2, 3, idx)
        sns.boxplot(data=values)
        plt.title(f'{metric.capitalize()} Distribution')
        plt.ylabel(metric.capitalize())

    plt.tight_layout()
    plt.savefig(save_dir / 'fold_comparisons.png')
    plt.close()

def plot_fold_variations(metrics_summary, save_dir):
    """
    Plot variations of metrics across folds using line plots
    """
    num_folds = len(next(iter(metrics_summary.values())))
    fold_indices = range(num_folds)

    plt.figure(figsize=(15, 10))

    for metric, values in metrics_summary.items():
        plt.plot(fold_indices, values, 'o-', label=metric.capitalize())

    plt.xlabel('Fold')
    plt.ylabel('Metric Value')
    plt.title('Metric Variations Across Folds')
    plt.legend()
    plt.grid(True)

    plt.savefig(save_dir / 'fold_variations.png')
    plt.close()

def save_kfold_results(metrics_summary, stats_summary, save_dir):
    """
    Save detailed k-fold results to JSON
    """
    results = {
        'metrics_by_fold': {
            metric: values for metric, values in metrics_summary.items()
        },
        'statistical_summary': stats_summary
    }

    with open(save_dir / 'kfold_results.json', 'w') as f:
        json.dump(results, f, indent=4)

def log_statistical_summary(stats_summary):
    """
    Log statistical summary of k-fold results
    """
    logging.info("\nK-Fold Cross Validation Summary:")

    for metric, stats in stats_summary.items():
        logging.info(f"\n{metric.capitalize()}:")
        logging.info(f"Mean ± Std: {stats['mean']:.4f} ± {stats['std']:.4f}")
        logging.info(f"Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
        logging.info(f"Median: {stats['median']:.4f}")
        logging.info(f"95% CI: [{stats['ci_95'][0]:.4f}, {stats['ci_95'][1]:.4f}]")

def generate_kfold_report(fold_results, save_dir):
    """
    Generate a comprehensive report of k-fold results
    """
    save_dir = Path(save_dir)
    report = ["# K-Fold Cross Validation Report\n"]

    # Calculate overall metrics
    metrics_summary = {
        metric: [fold['test_metrics'][metric] for fold in fold_results]
        for metric in fold_results[0]['test_metrics'].keys()
    }
    stats_summary = calculate_statistical_summary(metrics_summary)

    # Add overall summary
    report.append("## Overall Performance\n")
    for metric, stats in stats_summary.items():
        report.append(f"### {metric.capitalize()}\n")
        report.append(f"- Mean ± Std: {stats['mean']:.4f} ± {stats['std']:.4f}")
        report.append(f"- Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
        report.append(f"- Median: {stats['median']:.4f}")
        report.append(f"- 95% CI: [{stats['ci_95'][0]:.4f}, {stats['ci_95'][1]:.4f}]\n")

    # Add per-fold details
    report.append("## Per-fold Performance\n")
    for fold_idx, fold_data in enumerate(fold_results):
        report.append(f"### Fold {fold_idx + 1}\n")
        report.append("#### Test Metrics")
        for metric, value in fold_data['test_metrics'].items():
            report.append(f"- {metric}: {value:.4f}")
        report.append("\n#### Training Summary")
        report.append(f"- Best Epoch: {fold_data['best_epoch']}")
        report.append(f"- Best Validation Accuracy: {fold_data['best_val_accuracy']:.4f}\n")

    # Save report
    with open(save_dir / 'kfold_report.md', 'w') as f:
        f.write('\n'.join(report))

def main():
    """
    Main training loop
    """
    # Initialize experiment
    config = Config()
    experiment = ExperimentManager(config)

    try:
        # Load data
        dataloaders = get_stl10_dataloaders(config)
        logging.info("Data loaded successfully")

        # Initialize model and trainer
        model, epaas_loss, supervised_loss = initialize_model(config)
        trainer = Trainer(model, epaas_loss, supervised_loss, config)

        # Training loop
        for epoch in range(config.num_epochs):
            # Train epoch
            train_metrics = trainer.train_epoch(
                dataloaders['train'],
                dataloaders['unlabeled'],
                epoch
            )

            # Validate
            val_metrics = trainer.validate(dataloaders['val'])

            # Update training history
            experiment.training_history['train_metrics'].append(train_metrics)
            experiment.training_history['val_metrics'].append(val_metrics)

            # Log metrics
            logging.info(f"\nEpoch {epoch+1}/{config.num_epochs}")
            logging.info(f"Train Metrics: {train_metrics}")
            logging.info(f"Val Metrics: {val_metrics}")

            # Check for best model
            if val_metrics['accuracy'] > experiment.training_history['best_val_accuracy']:
                experiment.training_history['best_val_accuracy'] = val_metrics['accuracy']
                experiment.training_history['best_epoch'] = epoch
                trainer.checkpoint(model, val_metrics['accuracy'], epoch, val_metrics)

            # Early stopping check
            if trainer.early_stopping(val_metrics['accuracy'], model):
                logging.info("Early stopping triggered!")
                break

        # Load best model for final evaluation
        best_model_path = Path('checkpoints') / 'best_model_val_accuracy.pth'
        try:
          checkpoint = torch.load(best_model_path, map_location=config.device, weights_only=False)  # Modified lin
          model.load_state_dict(checkpoint['model_state_dict'])
        except Exception as e:
          logging.error(f"Error loading best model: {str(e)}")
          raise

        # Final evaluation on test set
        test_metrics = trainer.validate(dataloaders['test'])
        logging.info(f"\nFinal Test Metrics: {test_metrics}")

        # Get predictions for confusion matrix
        true_labels = []
        predictions = []
        model.eval()
        with torch.no_grad():
            for inputs, targets in dataloaders['test']:
                inputs = inputs.to(config.device)
                outputs = model(inputs)
                _, preds = outputs.max(1)
                true_labels.extend(targets.numpy())
                predictions.extend(preds.cpu().numpy())

        # Plot results
        experiment.plot_training_curves()
        experiment.plot_confusion_matrix(true_labels, predictions)

        # Save results
        experiment.save_results(test_metrics)

        return model, test_metrics

    except Exception as e:
        logging.error(f"Error in training: {str(e)}")
        raise

def run_k_fold_training():
    """
    Run k-fold cross validation
    """
    config = Config()
    experiment = ExperimentManager(config, "EPAAS_STL10_KFold")

    try:
        # Load data
        dataloaders = get_stl10_dataloaders(config)
        dataset = dataloaders['train'].dataset

        kfold = KFold(n_splits=config.num_folds, shuffle=True, random_state=SEED)
        fold_results = []

        for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
            logging.info(f"\nStarting Fold {fold+1}/{config.num_folds}")

            # Create fold-specific dataloaders
            train_subsampler = SubsetRandomSampler(train_idx)
            val_subsampler = SubsetRandomSampler(val_idx)

            train_loader = DataLoader(
                dataset,
                batch_size=config.batch_size,
                sampler=train_subsampler,
                num_workers=config.num_workers
            )
            val_loader = DataLoader(
                dataset,
                batch_size=config.batch_size,
                sampler=val_subsampler,
                num_workers=config.num_workers
            )

            # Initialize model and trainer for this fold
            model, epaas_loss, supervised_loss = initialize_model(config)
            trainer = Trainer(model, epaas_loss, supervised_loss, config)

            # Train fold
            fold_metrics = train_fold(
                trainer,
                train_loader,
                val_loader,
                dataloaders['unlabeled'],
                config,
                fold
            )

            fold_results.append(fold_metrics)

        # Analyze k-fold results
        analyze_kfold_results(fold_results, experiment)

    except Exception as e:
        logging.error(f"Error in k-fold training: {str(e)}")
        raise

if __name__ == "__main__":
    # Run main training
    model, metrics = main()

    # Optionally run k-fold validation
    run_k_fold_training()


2025-04-15 02:47:00,106 - INFO - Starting experiment: EPAAS_STL10
INFO:EPAAS_STL10:Starting experiment: EPAAS_STL10
2025-04-15 02:47:00,527 - INFO - Config: {'num_classes': 10, 'batch_size': 64, 'num_epochs': 5, 'base_lr': 0.03, 'momentum': 0.9, 'weight_decay': 0.0005, 'num_workers': 2, 'device': device(type='cuda'), 'initial_threshold': 0.95, 'final_threshold': 0.8, 'temperature': 1.0, 'alpha': 0.5, 'beta': 1.0, 'num_folds': 5, 'early_stopping_patience': 10, 'gradient_clip_val': 1.0}
INFO:EPAAS_STL10:Config: {'num_classes': 10, 'batch_size': 64, 'num_epochs': 5, 'base_lr': 0.03, 'momentum': 0.9, 'weight_decay': 0.0005, 'num_workers': 2, 'device': device(type='cuda'), 'initial_threshold': 0.95, 'final_threshold': 0.8, 'temperature': 1.0, 'alpha': 0.5, 'beta': 1.0, 'num_folds': 5, 'early_stopping_patience': 10, 'gradient_clip_val': 1.0}
Epoch 0: 100%|██████████| 62/62 [01:36<00:00,  1.56s/it, loss=16.3, sup_loss=8.66, mask_mean=tensor(0.7344, device='cuda:0')]
Epoch 1: 100%|██████████| 