In [None]:
print("Har Har Mahadev")

Har Har Mahadev


In [None]:
!pip install optuna

Collecting optuna
  Downloading optuna-4.4.0-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.16.4-py3-none-any.whl.metadata (7.3 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Downloading optuna-4.4.0-py3-none-any.whl (395 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m395.9/395.9 kB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.16.4-py3-none-any.whl (247 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m247.0/247.0 kB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.9.0-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, alembic, optuna
Successfully installed alembic-1.16.4 colorlog-6.9.0 optuna-4.4.0


In [None]:
# Research-Grade Multi-Perspective Graph ESN with Memory Optimization
# Enhanced with EMA, mixed precision, curriculum learning, and comprehensive visualizations

import os
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import mean_squared_error, r2_score
from scipy import signal
from scipy.stats import pearsonr
import warnings
from tqdm import tqdm
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from transformers import AutoTokenizer, AutoModel
import kagglehub
import json
import seaborn as sns
from matplotlib.gridspec import GridSpec
import optuna
from datetime import datetime
import joblib
from joblib import Memory
import psutil
from collections import defaultdict
import logging

# Set up environment
warnings.filterwarnings('ignore')
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.backends.cudnn.benchmark = True

# Device and memory management
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    torch.cuda.empty_cache()

# Set random seed
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Memory cache for joblib
cache_dir = './ecog_cache'
os.makedirs(cache_dir, exist_ok=True)
memory = Memory(cache_dir, verbose=0)

#----------------------------------------------------------------------
# Enhanced Configuration
#----------------------------------------------------------------------

class ResearchConfig:
    def __init__(self):
        # Data parameters - optimized for memory
        self.NUM_ELECTRODES = 50
        self.CONTEXT_SIZE = 256      # Reduced for memory
        self.PREDICT_SIZE = 128      # Reduced for memory
        self.SAMPLING_RATE = 1024
        self.MAX_SEQUENCES = 500     # Reduced for initial training
        self.BATCH_SIZE = 8          # Small batch for memory
        self.GRADIENT_ACCUMULATION = 4  # Effective batch = 32
        self.STRIDE = 64

        # Word embedding parameters
        self.WORD_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
        self.WORD_EMBEDDING_DIM = 384
        self.CONTEXT_WORDS = 8
        self.WORD_NOISE_SCALE = 0.1

        # Reservoir parameters - optimized
        self.NEURONS_PER_POPULATION = 768  # Reduced from 1024
        self.SPECTRAL_RADIUS = 0.92
        self.LEAKY_RATE = 0.35
        self.RESERVOIR_DENSITY = 0.08
        self.NOISE_LEVEL = 0.02

        # Multi-perspective parameters
        self.ELECTRODE_PERSPECTIVE_DIM = 96  # Reduced
        self.COMMON_SPACE_DIM = 192         # Reduced
        self.PERSPECTIVE_ATTENTION_HEADS = 4
        self.USE_CROSS_ELECTRODE_ATTENTION = True

        # Readout architecture
        self.READOUT_TYPE = "gru"  # GRU uses less memory than transformer
        self.READOUT_HIDDEN = 384
        self.READOUT_LAYERS = 3
        self.DROPOUT = 0.3

        # Enhanced regularization
        self.USE_EMA = True
        self.EMA_DECAY = 0.999
        self.USE_MIXUP = True
        self.MIXUP_ALPHA = 0.2
        self.LABEL_SMOOTHING = 0.1
        self.USE_GRADIENT_PENALTY = True
        self.GRADIENT_PENALTY_WEIGHT = 0.1

        # Training parameters
        self.LEARNING_RATE = 5e-4
        self.MIN_LR = 1e-6
        self.WEIGHT_DECAY = 1e-4
        self.EPOCHS = 100
        self.PATIENCE = 15
        self.CLIP_GRAD_NORM = 1.0
        self.WARMUP_EPOCHS = 5

        # Mixed precision training
        self.USE_AMP = True

        # Curriculum learning
        self.USE_CURRICULUM = True
        self.CURRICULUM_STAGES = [
            {'epochs': 10, 'predict_size': 32, 'context_size': 128},
            {'epochs': 10, 'predict_size': 64, 'context_size': 192},
            {'epochs': 80, 'predict_size': 128, 'context_size': 256}
        ]

        # Loss weights
        self.PRIMARY_LOSS_WEIGHT = 1.0
        self.RECONSTRUCTION_LOSS_WEIGHT = 0.15
        self.CONSISTENCY_LOSS_WEIGHT = 0.1
        self.DIVERSITY_LOSS_WEIGHT = 0.05
        self.FREQ_LOSS_WEIGHT = 0.3
        self.PHASE_LOSS_WEIGHT = 0.1

        # Frequency bands
        self.FREQ_BANDS = [
            (30, 50),    # Low Gamma
            (50, 80),    # Mid Gamma
            (80, 120),   # High Gamma
            (120, 200),  # Very High Gamma
        ]

        # Visualization
        self.VIZ_EVERY_N_EPOCHS = 1
        self.SAVE_EVERY_N_EPOCHS = 5

        # Paths
        self.OUTPUT_DIR = "ecog_research_results"
        self.create_directories()

    def create_directories(self):
        """Create all necessary directories."""
        dirs = [
            self.OUTPUT_DIR,
            os.path.join(self.OUTPUT_DIR, 'models'),
            os.path.join(self.OUTPUT_DIR, 'visualizations'),
            os.path.join(self.OUTPUT_DIR, 'logs'),
            os.path.join(self.OUTPUT_DIR, 'checkpoints'),
            os.path.join(self.OUTPUT_DIR, 'epoch_viz')
        ]
        for d in dirs:
            os.makedirs(d, exist_ok=True)

config = ResearchConfig()

#----------------------------------------------------------------------
# Memory Management Utilities
#----------------------------------------------------------------------

def clear_gpu_memory():
    """Clear GPU memory cache."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()

def get_memory_usage():
    """Get current memory usage."""
    if torch.cuda.is_available():
        return {
            'gpu_allocated': torch.cuda.memory_allocated() / 1e9,
            'gpu_reserved': torch.cuda.memory_reserved() / 1e9,
            'ram_used': psutil.virtual_memory().used / 1e9,
            'ram_percent': psutil.virtual_memory().percent
        }
    return {
        'ram_used': psutil.virtual_memory().used / 1e9,
        'ram_percent': psutil.virtual_memory().percent
    }

#----------------------------------------------------------------------
# Data Loading with Memory Optimization
#----------------------------------------------------------------------

@memory.cache
def load_ecog_subset(data_path, num_rows=100000):
    """Load a subset of ECoG data with caching."""
    logger.info(f"Loading ECoG subset: {num_rows} rows")

    ecog_file = None
    for root, dirs, files in os.walk(data_path):
        for file in files:
            if 'processed_ecog_data' in file and file.endswith('.csv'):
                ecog_file = os.path.join(root, file)
                break

    if not ecog_file:
        raise FileNotFoundError("Could not find ECoG data file")

    # Load subset
    ecog_data = pd.read_csv(ecog_file, nrows=num_rows)
    electrode_columns = [col for col in ecog_data.columns if col not in ['timestamp', 'word']]

    return ecog_data, electrode_columns

def prepare_data_efficient(ecog_data, electrode_columns, config):
    """Prepare data with memory-efficient processing."""
    # Select electrodes
    selected_electrodes = electrode_columns[:config.NUM_ELECTRODES]

    # Extract data
    electrode_data = ecog_data[selected_electrodes].values
    word_data = ecog_data['word'].values

    # Remove NaN rows
    valid_mask = ~np.isnan(electrode_data).any(axis=1)
    electrode_data = electrode_data[valid_mask]
    word_data = word_data[valid_mask]

    # Scale data
    scaler = RobustScaler()
    electrode_data = scaler.fit_transform(electrode_data)

    # Create stratified indices for better sampling
    window_size = config.CONTEXT_SIZE + config.PREDICT_SIZE
    indices = []

    # Sample evenly across the dataset
    total_samples = len(electrode_data) - window_size + 1
    sample_rate = max(1, total_samples // config.MAX_SEQUENCES)

    for i in range(0, total_samples, sample_rate):
        indices.append(i)
        if len(indices) >= config.MAX_SEQUENCES:
            break

    logger.info(f"Created {len(indices)} sequences")

    return electrode_data, word_data, indices, scaler

#----------------------------------------------------------------------
# Enhanced Word Embedding Module
#----------------------------------------------------------------------

class OptimizedWordEmbedding(nn.Module):
    """Memory-optimized word embedding module."""

    def __init__(self, model_name=None, device='cuda', max_cache_size=1000):
        super().__init__()
        if model_name is None:
            model_name = config.WORD_MODEL_NAME

        self.device = device
        self.max_cache_size = max_cache_size

        logger.info(f"Loading language model: {model_name}")

        # Load model with reduced memory footprint
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.language_model = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if config.USE_AMP else torch.float32
        ).to(device)
        self.language_model.eval()

        # Freeze and optimize memory
        for param in self.language_model.parameters():
            param.requires_grad = False

        self.embedding_dim = self.language_model.config.hidden_size

        # Projection layer
        self.projection = nn.Sequential(
            nn.Linear(self.embedding_dim, config.WORD_EMBEDDING_DIM),
            nn.LayerNorm(config.WORD_EMBEDDING_DIM),
            nn.GELU(),
            nn.Dropout(0.1)
        )

        # LRU cache with size limit
        self.embedding_cache = {}
        self.cache_usage = []

    @torch.no_grad()
    def _get_raw_embedding(self, text):
        """Get raw embedding without gradients."""
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=50
        ).to(self.device)

        with autocast(enabled=config.USE_AMP):
            outputs = self.language_model(**inputs)

        return outputs.last_hidden_state.float()

    def _manage_cache(self, key):
        """Manage cache size with LRU policy."""
        if key in self.cache_usage:
            self.cache_usage.remove(key)
        self.cache_usage.append(key)

        # Remove oldest if cache is too large
        if len(self.embedding_cache) > self.max_cache_size:
            oldest = self.cache_usage.pop(0)
            del self.embedding_cache[oldest]

    def get_word_embedding(self, word):
        """Get embedding with cache management."""
        if pd.isna(word) or word == "":
            return torch.zeros(config.WORD_EMBEDDING_DIM, device=self.device)

        # Check cache
        if word in self.embedding_cache:
            self._manage_cache(word)
            return self.embedding_cache[word].clone()

        # Compute embedding
        raw_embedding = self._get_raw_embedding(word)
        embedding = raw_embedding.mean(dim=1).squeeze()

        # Project
        with autocast(enabled=config.USE_AMP):
            projected = self.projection(embedding)

        # Cache
        self.embedding_cache[word] = projected.detach()
        self._manage_cache(word)

        return projected

    def get_context_embedding(self, words):
        """Get context embedding."""
        valid_words = [w for w in words if not pd.isna(w) and w != ""]

        if not valid_words:
            return torch.zeros(config.WORD_EMBEDDING_DIM, device=self.device)

        sentence = " ".join(valid_words[-5:])  # Last 5 words for memory

        raw_embedding = self._get_raw_embedding(sentence)
        embedding = raw_embedding[:, 0, :].squeeze()

        with autocast(enabled=config.USE_AMP):
            return self.projection(embedding)

    def forward(self, word_sequence, context_words=None):
        """Process batch with memory efficiency."""
        current_embeddings = torch.stack([
            self.get_word_embedding(word) for word in word_sequence
        ])

        if context_words is not None:
            context_embeddings = torch.stack([
                self.get_context_embedding(context) for context in context_words
            ])
        else:
            context_embeddings = torch.zeros_like(current_embeddings)

        # Add noise during training
        if self.training and config.WORD_NOISE_SCALE > 0:
            noise = torch.randn_like(current_embeddings) * config.WORD_NOISE_SCALE
            current_embeddings = current_embeddings + noise
            context_embeddings = context_embeddings + noise * 0.5

        return current_embeddings, context_embeddings

#----------------------------------------------------------------------
# Exponential Moving Average for Model Weights
#----------------------------------------------------------------------

class EMA:
    """Exponential Moving Average for model parameters."""

    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        """Update shadow parameters."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = (
                    self.decay * self.shadow[name] +
                    (1 - self.decay) * param.data
                )

    def apply_shadow(self):
        """Apply shadow parameters to model."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self):
        """Restore original parameters."""
        for name, param in self.model.named_parameters():
            if param.requires_grad and name in self.backup:
                param.data = self.backup[name]

#----------------------------------------------------------------------
# Enhanced Model Components
#----------------------------------------------------------------------

class EfficientElectrodeAttention(nn.Module):
    """Memory-efficient electrode attention."""

    def __init__(self, shared_dim, electrode_dim, num_heads=4):
        super().__init__()
        # Use less memory with grouped query attention
        self.num_heads = num_heads
        self.head_dim = electrode_dim // num_heads

        self.q_proj = nn.Linear(shared_dim, electrode_dim)
        self.kv_proj = nn.Linear(shared_dim, electrode_dim * 2)
        self.out_proj = nn.Linear(electrode_dim, electrode_dim)

        self.norm = nn.LayerNorm(electrode_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        B, L, D = x.shape

        # Project
        q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        kv = self.kv_proj(x).view(B, L, 2, self.num_heads, self.head_dim)
        k, v = kv.unbind(2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Attention with dropout
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # Apply attention
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, L, -1)
        out = self.out_proj(out)

        return self.norm(out), attn

class OptimizedReservoir(nn.Module):
    """Memory-optimized reservoir."""

    def __init__(self, input_size, reservoir_size, config):
        super().__init__()
        self.input_size = input_size
        self.reservoir_size = reservoir_size

        # Sparse initialization for memory efficiency
        self.register_buffer('W_in', torch.randn(reservoir_size, input_size) * 0.5)

        # Sparse reservoir matrix
        W = torch.randn(reservoir_size, reservoir_size)
        mask = torch.rand(reservoir_size, reservoir_size) < config.RESERVOIR_DENSITY
        W = W * mask.float()

        # Normalize spectral radius
        with torch.no_grad():
            try:
                eigenvalues = torch.linalg.eigvals(W)
                radius = torch.max(torch.abs(eigenvalues)).item()
                if radius > 0:
                    W = W * (config.SPECTRAL_RADIUS / radius)
            except:
                W = W * config.SPECTRAL_RADIUS / (torch.norm(W) + 1e-6)

        self.register_buffer('W', W)
        self.bias = nn.Parameter(torch.zeros(reservoir_size))

        self.leaky_rate = config.LEAKY_RATE
        self.noise_level = config.NOISE_LEVEL

        self.norm = nn.LayerNorm(reservoir_size)

    def forward(self, x, hidden=None):
        batch_size = x.shape[0]

        if hidden is None:
            hidden = torch.zeros(batch_size, self.reservoir_size, device=x.device)

        # Efficient computation
        with autocast(enabled=config.USE_AMP):
            pre_activation = x @ self.W_in.T + hidden @ self.W.T + self.bias
            activation = torch.tanh(pre_activation)

            # Leaky integration
            new_hidden = (1 - self.leaky_rate) * hidden + self.leaky_rate * activation

            # Add noise during training
            if self.training and self.noise_level > 0:
                new_hidden = new_hidden + torch.randn_like(new_hidden) * self.noise_level

        return self.norm(new_hidden), new_hidden

#----------------------------------------------------------------------
# Main Model with Memory Optimization
#----------------------------------------------------------------------

class ResearchMPGESN(nn.Module):
    """Research-grade Multi-Perspective Graph ESN."""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_electrodes = config.NUM_ELECTRODES

        # Word embeddings
        self.word_embedder = OptimizedWordEmbedding(device=device)

        # Feature projection
        input_dim = config.NUM_ELECTRODES + 2 * config.WORD_EMBEDDING_DIM
        self.feature_projection = nn.Sequential(
            nn.Linear(input_dim, config.READOUT_HIDDEN),
            nn.LayerNorm(config.READOUT_HIDDEN),
            nn.GELU(),
            nn.Dropout(config.DROPOUT)
        )

        # Reservoir
        self.reservoir = OptimizedReservoir(
            config.READOUT_HIDDEN,
            config.NEURONS_PER_POPULATION,
            config
        )

        # Efficient electrode attention
        self.electrode_attentions = nn.ModuleList([
            EfficientElectrodeAttention(
                config.NEURONS_PER_POPULATION,
                config.ELECTRODE_PERSPECTIVE_DIM,
                config.PERSPECTIVE_ATTENTION_HEADS
            ) for _ in range(config.NUM_ELECTRODES)
        ])

        # Common space projectors
        self.common_projector = nn.Sequential(
            nn.Linear(config.ELECTRODE_PERSPECTIVE_DIM, config.COMMON_SPACE_DIM),
            nn.LayerNorm(config.COMMON_SPACE_DIM),
            nn.GELU()
        )

        # GRU readout (more memory efficient than transformer)
        self.readout = nn.GRU(
            input_size=config.NEURONS_PER_POPULATION,
            hidden_size=config.READOUT_HIDDEN,
            num_layers=config.READOUT_LAYERS,
            batch_first=True,
            dropout=config.DROPOUT if config.READOUT_LAYERS > 1 else 0,
            bidirectional=True
        )

        self.readout_proj = nn.Linear(config.READOUT_HIDDEN * 2, config.READOUT_HIDDEN)

        # Prediction head with dropout - use max predict size for curriculum learning
        max_predict_size = max([stage['predict_size'] for stage in config.CURRICULUM_STAGES]) if config.USE_CURRICULUM else config.PREDICT_SIZE
        self.prediction_head = nn.Sequential(
            nn.Linear(config.READOUT_HIDDEN, config.READOUT_HIDDEN // 2),
            nn.GELU(),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.READOUT_HIDDEN // 2, config.NUM_ELECTRODES * max_predict_size)
        )
        self.max_predict_size = max_predict_size

        # Auxiliary decoders
        self.electrode_decoders = nn.ModuleList([
            nn.Linear(config.COMMON_SPACE_DIM, 1)
            for _ in range(config.NUM_ELECTRODES)
        ])

        self._init_weights()

    def _init_weights(self):
        """Initialize weights with Xavier/He initialization."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.GRU):
                for name, param in m.named_parameters():
                    if 'weight' in name:
                        nn.init.orthogonal_(param)
                    elif 'bias' in name:
                        nn.init.zeros_(param)

    def forward(self, electrode_data, current_words, context_words, return_perspectives=False):
        """Forward pass with mixed precision."""
        batch_size, context_size, _ = electrode_data.shape

        # Get word embeddings
        word_embed, context_embed = self.word_embedder(current_words, context_words)

        # Process sequence
        reservoir_states = []
        electrode_perspectives = []
        common_representations = []
        hidden = None

        # Use gradient checkpointing for memory efficiency
        for t in range(context_size):
            # Combine inputs
            electrode_t = electrode_data[:, t, :]
            combined = torch.cat([electrode_t, word_embed, context_embed], dim=1)

            # Feature projection
            with autocast(enabled=self.config.USE_AMP):
                features = self.feature_projection(combined)

                # Reservoir processing
                reservoir_out, hidden = self.reservoir(features, hidden)
                reservoir_states.append(reservoir_out)

                # Extract perspectives if needed
                if return_perspectives and t % 4 == 0:  # Sample every 4th for memory
                    perspectives_t = []
                    common_t = []

                    for e_idx in range(0, self.num_electrodes, 2):  # Sample electrodes
                        persp, _ = self.electrode_attentions[e_idx](reservoir_out.unsqueeze(1))
                        persp = persp.squeeze(1)
                        common = self.common_projector(persp)

                        perspectives_t.append(persp)
                        common_t.append(common)

                    electrode_perspectives.append(perspectives_t)
                    common_representations.append(common_t)

        # Stack states
        reservoir_sequence = torch.stack(reservoir_states, dim=1)

        # Readout
        with autocast(enabled=self.config.USE_AMP):
            readout_out, _ = self.readout(reservoir_sequence)
            readout_out = self.readout_proj(readout_out)

            # Global pooling
            readout_features = readout_out.mean(dim=1)

            # Generate predictions
            predictions = self.prediction_head(readout_features)
            predictions = predictions.view(batch_size, self.max_predict_size, self.num_electrodes)

            # --- START OF CORRECTION ---
            # Slice to current predict size using the attribute set by the curriculum learning loop
            current_predict_size = self.current_predict_size if hasattr(self, 'current_predict_size') else self.config.PREDICT_SIZE
            predictions = predictions[:, :current_predict_size, :]
            # --- END OF CORRECTION ---

        if return_perspectives:
            return predictions, {
                'perspectives': electrode_perspectives,
                'common_space': common_representations
            }
        else:
            return predictions

#----------------------------------------------------------------------
# Enhanced Loss Functions
#----------------------------------------------------------------------

class ResearchLoss(nn.Module):
    """Comprehensive loss with multiple components."""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.mse = nn.MSELoss()
        self.smooth_l1 = nn.SmoothL1Loss()

    def compute_frequency_loss(self, pred, target):
        """Frequency domain loss."""
        # Use smaller FFT size for memory
        n_fft = min(pred.shape[1], 256)

        pred_fft = torch.fft.rfft(pred, n=n_fft, dim=1)
        target_fft = torch.fft.rfft(target, n=n_fft, dim=1)

        pred_mag = torch.abs(pred_fft)
        target_mag = torch.abs(target_fft)

        # Weighted frequency loss
        freqs = torch.fft.rfftfreq(n_fft, d=1.0/self.config.SAMPLING_RATE).to(pred.device)

        weights = torch.ones_like(freqs)
        for low, high in self.config.FREQ_BANDS:
            mask = (freqs >= low) & (freqs <= high)
            weights[mask] = 2.0

        freq_loss = torch.mean((pred_mag - target_mag)**2 * weights.unsqueeze(0).unsqueeze(-1))

        return freq_loss

    def compute_phase_loss(self, pred, target):
        """Phase consistency loss."""
        # Compute instantaneous phase
        pred_analytic = torch.view_as_real(torch.fft.fft(pred, dim=1))
        target_analytic = torch.view_as_real(torch.fft.fft(target, dim=1))

        pred_phase = torch.atan2(pred_analytic[..., 1], pred_analytic[..., 0])
        target_phase = torch.atan2(target_analytic[..., 1], target_analytic[..., 0])

        # Circular distance
        phase_diff = torch.remainder(pred_phase - target_phase + np.pi, 2 * np.pi) - np.pi
        phase_loss = torch.mean(torch.abs(phase_diff))

        return phase_loss

    def compute_gradient_penalty(self, pred, target):
        """Gradient penalty for smoothness."""
        pred_grad = torch.diff(pred, dim=1)
        target_grad = torch.diff(target, dim=1)

        grad_loss = self.mse(pred_grad, target_grad)

        return grad_loss

    def forward(self, predictions, targets, electrode_perspectives_data=None, electrode_decoders=None):
        """Compute total loss."""
        # Primary losses
        primary_loss = self.smooth_l1(predictions, targets)
        freq_loss = self.compute_frequency_loss(predictions, targets)
        phase_loss = self.compute_phase_loss(predictions, targets)

        # Gradient penalty
        grad_penalty = self.compute_gradient_penalty(predictions, targets)

        # Auxiliary losses
        reconstruction_loss = torch.tensor(0.0, device=predictions.device)
        consistency_loss = torch.tensor(0.0, device=predictions.device)
        diversity_loss = torch.tensor(0.0, device=predictions.device)

        if electrode_perspectives_data is not None and electrode_decoders is not None:
            common_space = electrode_perspectives_data.get('common_space', [])

            if common_space:
                # Reconstruction loss
                for t_idx, common_t in enumerate(common_space):
                    for e_idx, common_repr in enumerate(common_t):
                        if e_idx * 2 < self.config.NUM_ELECTRODES:
                            reconstructed = electrode_decoders[e_idx * 2](common_repr)
                            target_mean = targets[:, :, e_idx * 2].mean(dim=1, keepdim=True)
                            reconstruction_loss += self.mse(reconstructed, target_mean)

                reconstruction_loss /= (len(common_space) * len(common_space[0]))

                # Consistency loss
                if len(common_space) > 1:
                    for e_idx in range(len(common_space[0])):
                        repr_sequence = torch.stack([common_space[t][e_idx] for t in range(len(common_space))], dim=1)
                        temporal_diff = torch.diff(repr_sequence, dim=1)
                        consistency_loss += torch.mean(temporal_diff ** 2)

                    consistency_loss /= len(common_space[0])

                # Diversity loss
                for common_t in common_space:
                    if len(common_t) > 1:
                        reps = torch.stack(common_t, dim=1)
                        reps_norm = F.normalize(reps, dim=-1)
                        sim_matrix = torch.matmul(reps_norm, reps_norm.transpose(-2, -1))

                        mask = ~torch.eye(len(common_t), device=reps.device, dtype=torch.bool)
                        off_diag = sim_matrix[:, mask]

                        diversity_loss += torch.mean(torch.relu(off_diag - 0.3))

                diversity_loss /= len(common_space)

        # Combine losses
        total_loss = (
            self.config.PRIMARY_LOSS_WEIGHT * primary_loss +
            self.config.FREQ_LOSS_WEIGHT * freq_loss +
            self.config.PHASE_LOSS_WEIGHT * phase_loss +
            self.config.GRADIENT_PENALTY_WEIGHT * grad_penalty +
            self.config.RECONSTRUCTION_LOSS_WEIGHT * reconstruction_loss +
            self.config.CONSISTENCY_LOSS_WEIGHT * consistency_loss +
            self.config.DIVERSITY_LOSS_WEIGHT * diversity_loss
        )

        return {
            'total': total_loss,
            'primary': primary_loss,
            'freq': freq_loss,
            'phase': phase_loss,
            'grad_penalty': grad_penalty,
            'reconstruction': reconstruction_loss,
            'consistency': consistency_loss,
            'diversity': diversity_loss
        }

#----------------------------------------------------------------------
# Dataset with Augmentation
#----------------------------------------------------------------------

class AugmentedECoGDataset(Dataset):
    """Dataset with data augmentation."""

    def __init__(self, electrode_data, word_data, indices, context_size, predict_size, context_words=8):
        self.electrode_data = electrode_data
        self.word_data = word_data
        self.indices = indices
        self.context_size = context_size
        self.predict_size = predict_size
        self.context_words = context_words

    def __len__(self):
        return len(self.indices)

    def augment_signal(self, signal):
        """Apply data augmentation to signal."""
        if np.random.random() > 0.5:
            # Time shift
            shift = np.random.randint(-5, 5)
            signal = np.roll(signal, shift, axis=0)

        if np.random.random() > 0.5:
            # Amplitude scaling
            scale = np.random.uniform(0.9, 1.1)
            signal = signal * scale

        if np.random.random() > 0.5:
            # Add Gaussian noise
            noise = np.random.normal(0, 0.02, signal.shape)
            signal = signal + noise

        return signal

    def __getitem__(self, idx):
        start_idx = self.indices[idx]

        # Extract windows
        context_data = self.electrode_data[start_idx:start_idx + self.context_size]
        target_data = self.electrode_data[start_idx + self.context_size:
                                         start_idx + self.context_size + self.predict_size]

        # Apply augmentation during training
        if hasattr(self, 'training') and self.training:
            context_data = self.augment_signal(context_data.copy())

        # Get word information
        word_idx = start_idx + self.context_size // 2
        current_word = self.word_data[word_idx] if word_idx < len(self.word_data) else ""

        # Get context words
        context_word_list = []
        for i in range(self.context_words):
            ctx_idx = word_idx - (i + 1) * 25
            if ctx_idx >= 0:
                word = self.word_data[ctx_idx]
                if not pd.isna(word) and word != "":
                    context_word_list.append(word)

        return {
            'electrode_context': torch.FloatTensor(context_data),
            'electrode_target': torch.FloatTensor(target_data),
            'current_word': current_word if not pd.isna(current_word) else "",
            'context_words': context_word_list
        }

def custom_collate_fn(batch):
    """Custom collate function."""
    electrode_context = torch.stack([item['electrode_context'] for item in batch])
    electrode_target = torch.stack([item['electrode_target'] for item in batch])
    current_words = [item['current_word'] for item in batch]
    context_words = [item['context_words'] for item in batch]

    return {
        'electrode_context': electrode_context,
        'electrode_target': electrode_target,
        'current_words': current_words,
        'context_words': context_words
    }

#----------------------------------------------------------------------
# Training Functions with Visualization
#----------------------------------------------------------------------

def mixup_data(x, y, alpha=0.2):
    """Mixup augmentation."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    mixed_y = lam * y + (1 - lam) * y[index]

    return mixed_x, mixed_y, lam

def save_epoch_visualization(model, val_loader, epoch, config, device):
    """Save visualizations for the epoch."""
    model.eval()

    # Get one batch for visualization
    batch = next(iter(val_loader))

    electrode_context = batch['electrode_context'].to(device)
    electrode_target = batch['electrode_target'].to(device)
    current_words = batch['current_words']
    context_words = batch['context_words']

    with torch.no_grad():
        predictions = model(electrode_context, current_words, context_words, return_perspectives=False)

    # Convert to numpy
    context_np = electrode_context.cpu().numpy()
    target_np = electrode_target.cpu().numpy()
    pred_np = predictions.cpu().numpy()

    # Create figure with subplots
    fig = plt.figure(figsize=(20, 15))
    gs = GridSpec(3, 2, figure=fig, hspace=0.3, wspace=0.2)

    # Sample index
    sample_idx = 0

    # --- THIS IS THE CORRECTED BLOCK ---
# Get current sequence sizes directly from the data arrays
    current_context_size = context_np.shape[1]
    current_predict_size = target_np.shape[1]

    # Time axes - now created with the correct dynamic sizes
    context_time = np.arange(current_context_size) / config.SAMPLING_RATE
    predict_time = (np.arange(current_predict_size) / config.SAMPLING_RATE) + (current_context_size / config.SAMPLING_RATE)

    # Plot 1: Signal predictions for multiple electrodes
    electrodes_to_plot = [0, config.NUM_ELECTRODES//4, config.NUM_ELECTRODES//2, config.NUM_ELECTRODES-1]

    for i, e_idx in enumerate(electrodes_to_plot[:2]):
        ax = fig.add_subplot(gs[i, 0])

        # Context
        ax.plot(context_time, context_np[sample_idx, :, e_idx],
               'b-', label='Context', alpha=0.7, linewidth=1.5)

        # Actual vs Predicted
        ax.plot(predict_time, target_np[sample_idx, :, e_idx],
               'g-', label='Actual', linewidth=2)
        ax.plot(predict_time, pred_np[sample_idx, :, e_idx],
               'r--', label='Predicted', linewidth=2)

        ax.axvline(x=context_time[-1], color='k', linestyle=':', alpha=0.5)
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Amplitude')
        ax.set_title(f'Electrode {e_idx + 1}')
        ax.legend()
        ax.grid(True, alpha=0.3)

    # Plot 2: Power spectra comparison
    for i, e_idx in enumerate(electrodes_to_plot[:2]):
        ax = fig.add_subplot(gs[i, 1])

        # Compute power spectra
        freqs, psd_actual = signal.welch(
            target_np[sample_idx, :, e_idx],
            fs=config.SAMPLING_RATE,
            nperseg=min(64, current_predict_size)
        )
        _, psd_pred = signal.welch(
            pred_np[sample_idx, :, e_idx],
            fs=config.SAMPLING_RATE,
            nperseg=min(64, current_predict_size)
        )

        ax.semilogy(freqs, psd_actual, 'g-', label='Actual', linewidth=2)
        ax.semilogy(freqs, psd_pred, 'r--', label='Predicted', linewidth=2)

        # Highlight gamma bands
        for low, high in config.FREQ_BANDS:
            ax.axvspan(low, high, alpha=0.2, color='yellow')

        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('PSD')
        ax.set_title(f'Power Spectrum - Electrode {e_idx + 1}')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xlim(0, 250)

    # Plot 3: Error heatmap
    ax = fig.add_subplot(gs[2, :])

    # Compute errors
    errors = np.abs(pred_np[sample_idx] - target_np[sample_idx])

    im = ax.imshow(errors.T, aspect='auto', cmap='hot', interpolation='nearest')
    ax.set_xlabel('Time steps')
    ax.set_ylabel('Electrode')
    ax.set_title('Prediction Error Heatmap')
    plt.colorbar(im, ax=ax, label='Absolute Error')

    # Add metrics
    mse = np.mean((pred_np[sample_idx] - target_np[sample_idx])**2)
    r2 = r2_score(target_np[sample_idx].flatten(), pred_np[sample_idx].flatten())

    fig.suptitle(f'Epoch {epoch+1} - MSE: {mse:.4f}, R²: {r2:.4f}\nWord: "{current_words[sample_idx]}"',
                 fontsize=16)

    # Save figure
    save_path = os.path.join(config.OUTPUT_DIR, 'epoch_viz', f'epoch_{epoch+1:03d}.png')
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

    logger.info(f"Saved visualization for epoch {epoch+1}")

def train_epoch(model, train_loader, optimizer, criterion, scaler, device, config, ema=None):
    """Train for one epoch with gradient accumulation."""
    model.train()

    losses = defaultdict(float)
    num_batches = 0

    progress_bar = tqdm(train_loader, desc="Training")
    optimizer.zero_grad()

    for batch_idx, batch in enumerate(progress_bar):
        electrode_context = batch['electrode_context'].to(device)
        electrode_target = batch['electrode_target'].to(device)
        current_words = batch['current_words']
        context_words = batch['context_words']

        # Mixup augmentation
        if config.USE_MIXUP and np.random.random() > 0.5:
            electrode_context, electrode_target, lam = mixup_data(
                electrode_context, electrode_target, config.MIXUP_ALPHA
            )

        # Forward pass with mixed precision
        with autocast(enabled=config.USE_AMP):
            predictions, perspective_data = model(
                electrode_context, current_words, context_words, return_perspectives=True
            )

            loss_dict = criterion(
                predictions, electrode_target,
                electrode_perspectives_data=perspective_data,
                electrode_decoders=model.electrode_decoders
            )

            loss = loss_dict['total'] / config.GRADIENT_ACCUMULATION

        # Backward pass
        scaler.scale(loss).backward()

        # Gradient accumulation
        if (batch_idx + 1) % config.GRADIENT_ACCUMULATION == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.CLIP_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            # Update EMA
            if ema is not None:
                ema.update()

        # Track losses
        for k, v in loss_dict.items():
            losses[k] += v.item()
        num_batches += 1

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{loss_dict['total'].item():.4f}",
            'primary': f"{loss_dict['primary'].item():.4f}"
        })

        # Clear cache periodically
        if batch_idx % 50 == 0:
            clear_gpu_memory()

    # Average losses
    for k in losses:
        losses[k] /= num_batches

    return dict(losses)

def evaluate(model, val_loader, criterion, device, ema=None):
    """Evaluate model."""
    model.eval()

    if ema is not None:
        ema.apply_shadow()

    losses = defaultdict(float)
    all_preds = []
    all_targets = []
    num_batches = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            electrode_context = batch['electrode_context'].to(device)
            electrode_target = batch['electrode_target'].to(device)
            current_words = batch['current_words']
            context_words = batch['context_words']

            with autocast(enabled=config.USE_AMP):
                predictions = model(
                    electrode_context, current_words, context_words, return_perspectives=False
                )

                loss_dict = criterion(
                    predictions, electrode_target,
                    electrode_perspectives_data=None,
                    electrode_decoders=None
                )

            # Track losses
            for k, v in loss_dict.items():
                losses[k] += v.item()
            num_batches += 1

            # Store predictions
            all_preds.append(predictions.cpu().numpy())
            all_targets.append(electrode_target.cpu().numpy())

    if ema is not None:
        ema.restore()

    # Compute metrics
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    # Flatten for metrics
    preds_flat = all_preds.reshape(-1)
    targets_flat = all_targets.reshape(-1)

    mse = mean_squared_error(targets_flat, preds_flat)
    r2 = r2_score(targets_flat, preds_flat)
    corr, _ = pearsonr(preds_flat, targets_flat)

    # Average losses
    for k in losses:
        losses[k] /= num_batches

    losses['mse'] = mse
    losses['r2'] = r2
    losses['correlation'] = corr

    return dict(losses)

#----------------------------------------------------------------------
# Main Training Function
#----------------------------------------------------------------------

def train_research_model(config):
    """Main training function with all research features."""
    logger.info("Starting research-grade MPGESN training")

    # Download and load data
    data_path = kagglehub.dataset_download("arunramponnambalam/podcast-csv-aligned")
    ecog_data, electrode_columns = load_ecog_subset(data_path, num_rows=100000)

    # For curriculum learning, use the maximum sizes for data preparation
    original_context = config.CONTEXT_SIZE
    original_predict = config.PREDICT_SIZE

    if config.USE_CURRICULUM:
        max_context = max([stage['context_size'] for stage in config.CURRICULUM_STAGES])
        max_predict = max([stage['predict_size'] for stage in config.CURRICULUM_STAGES])
        config.CONTEXT_SIZE = max_context
        config.PREDICT_SIZE = max_predict

    # Prepare data
    electrode_data, word_data, indices, scaler = prepare_data_efficient(
        ecog_data, electrode_columns, config
    )

    # Split data
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=SEED)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=SEED)

    # Create datasets
    train_dataset = AugmentedECoGDataset(
        electrode_data, word_data, train_idx,
        config.CONTEXT_SIZE, config.PREDICT_SIZE, config.CONTEXT_WORDS
    )
    train_dataset.training = True

    val_dataset = AugmentedECoGDataset(
        electrode_data, word_data, val_idx,
        config.CONTEXT_SIZE, config.PREDICT_SIZE, config.CONTEXT_WORDS
    )

    test_dataset = AugmentedECoGDataset(
        electrode_data, word_data, test_idx,
        config.CONTEXT_SIZE, config.PREDICT_SIZE, config.CONTEXT_WORDS
    )

    # Create loaders with pin memory
    train_loader = DataLoader(
        train_dataset, batch_size=config.BATCH_SIZE,
        shuffle=True, collate_fn=custom_collate_fn,
        num_workers=2, pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=config.BATCH_SIZE * 2,
        shuffle=False, collate_fn=custom_collate_fn,
        num_workers=2, pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset, batch_size=config.BATCH_SIZE * 2,
        shuffle=False, collate_fn=custom_collate_fn,
        num_workers=2, pin_memory=True
    )

    logger.info(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    # Initialize model
    model = ResearchMPGESN(config).to(device)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Model parameters - Total: {total_params:,}, Trainable: {trainable_params:,}")

    # Initialize EMA
    ema = EMA(model, decay=config.EMA_DECAY) if config.USE_EMA else None

    # Loss and optimizer
    criterion = ResearchLoss(config)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY,
        betas=(0.9, 0.999)
    )

    # Learning rate scheduler
    scheduler = CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=config.MIN_LR
    )

    # Mixed precision scaler
    scaler = GradScaler(enabled=config.USE_AMP)

    # Training history
    history = defaultdict(list)
    best_val_loss = float('inf')
    patience_counter = 0

    # Training loop
    for epoch in range(config.EPOCHS):
        logger.info(f"\n{'='*60}\nEpoch {epoch+1}/{config.EPOCHS}\n{'='*60}")

        # Curriculum learning - update dataset sizes
        current_context = original_context
        current_predict = original_predict

        if config.USE_CURRICULUM:
            for stage in config.CURRICULUM_STAGES:
                if epoch < stage['epochs']:
                    current_context = stage['context_size']
                    current_predict = stage['predict_size']
                    logger.info(f"Curriculum: context={current_context}, predict={current_predict}")
                    break

            # Update model's current predict size
            model.current_predict_size = current_predict

            # Create new datasets with current sizes
            train_dataset = AugmentedECoGDataset(
                electrode_data, word_data, train_idx,
                current_context, current_predict, config.CONTEXT_WORDS
            )
            train_dataset.training = True

            val_dataset = AugmentedECoGDataset(
                electrode_data, word_data, val_idx,
                current_context, current_predict, config.CONTEXT_WORDS
            )

            # Recreate loaders
            train_loader = DataLoader(
                train_dataset, batch_size=config.BATCH_SIZE,
                shuffle=True, collate_fn=custom_collate_fn,
                num_workers=2, pin_memory=True
            )

            val_loader = DataLoader(
                val_dataset, batch_size=config.BATCH_SIZE * 2,
                shuffle=False, collate_fn=custom_collate_fn,
                num_workers=2, pin_memory=True
            )

        # Train
        train_losses = train_epoch(
            model, train_loader, optimizer, criterion, scaler, device, config, ema
        )

        # Validate
        val_losses = evaluate(model, val_loader, criterion, device, ema)

        # Update scheduler
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        # Log results
        logger.info(f"Train - Total: {train_losses['total']:.4f}, Primary: {train_losses['primary']:.4f}")
        logger.info(f"Val - Total: {val_losses['total']:.4f}, MSE: {val_losses['mse']:.4f}, R²: {val_losses['r2']:.4f}")
        logger.info(f"Learning Rate: {current_lr:.6f}")

        # Save history
        for k, v in train_losses.items():
            history[f'train_{k}'].append(v)
        for k, v in val_losses.items():
            history[f'val_{k}'].append(v)
        history['lr'].append(current_lr)

        # Save visualizations
        if (epoch + 1) % config.VIZ_EVERY_N_EPOCHS == 0:
            save_epoch_visualization(model, val_loader, epoch, config, device)

        # Model checkpointing
        if val_losses['total'] < best_val_loss:
            best_val_loss = val_losses['total']
            patience_counter = 0

            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'ema_state_dict': ema.shadow if ema else None,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'val_losses': val_losses,
                'config': config.__dict__
            }

            torch.save(checkpoint, os.path.join(config.OUTPUT_DIR, 'models', 'best_model.pt'))
            logger.info(f"✅ Saved best model with val_loss: {best_val_loss:.4f}")
        else:
            patience_counter += 1

        # Periodic checkpoint
        if (epoch + 1) % config.SAVE_EVERY_N_EPOCHS == 0:
            checkpoint_path = os.path.join(
                config.OUTPUT_DIR, 'checkpoints', f'checkpoint_epoch_{epoch+1}.pt'
            )
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'history': dict(history)
            }, checkpoint_path)

        # Early stopping
        if patience_counter >= config.PATIENCE:
            logger.info(f"Early stopping after {epoch+1} epochs")
            break

        # Memory management
        clear_gpu_memory()
        mem_usage = get_memory_usage()
        logger.info(f"Memory - GPU: {mem_usage.get('gpu_allocated', 0):.2f}GB, RAM: {mem_usage['ram_percent']:.1f}%")

    # Final evaluation on test set
    logger.info("\n📊 Final evaluation on test set...")

    # Load best model
    checkpoint = torch.load(os.path.join(config.OUTPUT_DIR, 'models', 'best_model.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    if ema and checkpoint['ema_state_dict']:
        ema.shadow = checkpoint['ema_state_dict']

    test_losses = evaluate(model, test_loader, criterion, device, ema)

    logger.info(f"\nTest Results:")
    logger.info(f"  Total Loss: {test_losses['total']:.4f}")
    logger.info(f"  MSE: {test_losses['mse']:.4f}")
    logger.info(f"  R²: {test_losses['r2']:.4f}")
    logger.info(f"  Correlation: {test_losses['correlation']:.4f}")

    # Save final results
    results = {
        'config': config.__dict__,
        'history': dict(history),
        'test_results': test_losses,
        'best_epoch': checkpoint['epoch']
    }

    with open(os.path.join(config.OUTPUT_DIR, 'results.json'), 'w') as f:
        json.dump(results, f, indent=2, default=lambda x: float(x) if isinstance(x, np.floating) else x)

    # Create summary plot
    plt.figure(figsize=(15, 10))

    # Loss curves
    plt.subplot(2, 2, 1)
    plt.plot(history['train_total'], label='Train')
    plt.plot(history['val_total'], label='Val')
    plt.xlabel('Epoch')
    plt.ylabel('Total Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.grid(True)

    # R² curve
    plt.subplot(2, 2, 2)
    plt.plot(history['val_r2'])
    plt.xlabel('Epoch')
    plt.ylabel('R²')
    plt.title('Validation R² Score')
    plt.grid(True)

    # Learning rate
    plt.subplot(2, 2, 3)
    plt.plot(history['lr'])
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.yscale('log')
    plt.grid(True)

    # Component losses
    plt.subplot(2, 2, 4)
    plt.plot(history['val_primary'], label='Primary')
    plt.plot(history['val_freq'], label='Frequency')
    plt.plot(history['val_phase'], label='Phase')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Component Losses')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(config.OUTPUT_DIR, 'training_summary.png'), dpi=150)
    plt.close()

    logger.info(f"\n✅ Training completed! Results saved to {config.OUTPUT_DIR}")

    return model, history

#----------------------------------------------------------------------
# Main Execution
#----------------------------------------------------------------------

def main():
    """Main execution function."""
    # Initialize configuration
    config = ResearchConfig()

    # Train model
    model, history = train_research_model(config)

    # Create animation from epoch visualizations
    create_training_animation(config)

def create_training_animation(config):
    """Create animation from epoch visualizations."""
    import glob
    from PIL import Image

    viz_files = sorted(glob.glob(os.path.join(config.OUTPUT_DIR, 'epoch_viz', '*.png')))

    if viz_files:
        images = [Image.open(f) for f in viz_files]

        # Save as GIF
        images[0].save(
            os.path.join(config.OUTPUT_DIR, 'training_progress.gif'),
            save_all=True,
            append_images=images[1:],
            duration=500,
            loop=0
        )

        logger.info("Created training animation")

if __name__ == "__main__":
    main()

Using device: cuda
GPU: NVIDIA A100-SXM4-40GB
GPU Memory: 42.47 GB


Training: 100%|██████████| 44/44 [01:49<00:00,  2.50s/it, loss=11.0406, primary=0.2708]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.36it/s]
Training: 100%|██████████| 44/44 [01:50<00:00,  2.50s/it, loss=11.6181, primary=0.2757]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.91it/s]
Training:  45%|████▌     | 20/44 [00:50<01:00,  2.50s/it, loss=9.5026, primary=0.2326]