**V1**

In [2]:
# import os
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader, TensorDataset
# from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import RobustScaler, StandardScaler
# from sklearn.metrics import mean_squared_error, r2_score
# from scipy import signal
# import warnings
# from tqdm import tqdm
# import time
# import math
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# from scipy.signal import butter, sosfilt, sosfiltfilt, hilbert
# import networkx as nx
# from sklearn.decomposition import PCA
# from sklearn.cluster import KMeans
# from torch.optim.swa_utils import AveragedModel, update_bn
# from collections import defaultdict
# import traceback
# import json
# import seaborn as sns
# from datetime import datetime

# # Suppress warnings
# warnings.filterwarnings('ignore')

# # Check for GPU
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(f"Using device: {device}")

# # Set random seed for reproducibility
# SEED = 42
# np.random.seed(SEED)
# torch.manual_seed(SEED)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(SEED)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False

# #----------------------------------------------------------------------
# # Enhanced Configuration for Biologically Plausible Single Reservoir
# #----------------------------------------------------------------------

# class Config:
#     def __init__(self):
#         # Data parameters
#         self.NUM_ELECTRODES = 30
#         self.CONTEXT_SIZE = 100
#         self.PREDICT_SIZE = 50
#         self.WINDOW_SIZE = self.CONTEXT_SIZE + self.PREDICT_SIZE
#         self.SAMPLING_RATE = 1000
#         self.MAX_SEQUENCES = 500  # Increased for better training
#         self.BATCH_SIZE = 32
#         self.STRIDE = 25  # More overlap for better learning

#         # Enhanced gamma frequency bands (biologically relevant)
#         self.FREQ_BANDS = [
#             (30, 50),       # Low Gamma
#             (50, 80),       # Mid Gamma
#             (80, 120),      # High Gamma
#             (120, 200),     # Very High Gamma
#         ]

#         # Learning parameters
#         self.LEARNING_RATE = 1e-3  # Higher initial LR
#         self.WEIGHT_DECAY = 1e-5
#         self.EPOCHS = 100  # More epochs for better convergence
#         self.PATIENCE = 20
#         self.CLIP_GRAD_NORM = 1.0
#         self.LR_WARMUP_EPOCHS = 5
#         self.GRADIENT_ACCUMULATION_STEPS = 1

#         # Single Biologically Plausible Reservoir parameters
#         self.RESERVOIR_SIZE = 1000  # Single large reservoir
#         self.SPECTRAL_RADIUS = 0.95  # Edge of stability
#         self.LEAKY_RATE = 0.3  # Biologically plausible leak rate
#         self.CONNECTIVITY = 0.1  # Sparse connectivity (10%)
#         self.INPUT_SCALING = 0.5
#         self.NOISE_LEVEL = 0.01
#         self.ACTIVATION_FUNC = 'tanh'  # Biologically plausible activation

#         # Multi-Perspective Learning Parameters
#         self.ELECTRODE_PERSPECTIVE_DIM = 128
#         self.COMMON_SPACE_DIM = 256
#         self.PERSPECTIVE_ATTENTION_HEADS = 8
#         self.USE_CROSS_ELECTRODE_ATTENTION = True

#         # Network architecture parameters
#         self.HIDDEN_SIZE = 256  # Increased for better capacity
#         self.READOUT_HIDDEN = 512
#         self.READOUT_LAYERS = 3
#         self.DROPOUT = 0.2
#         self.USE_ATTENTION = True
#         self.ATTENTION_HEADS = 8
#         self.USE_SKIP_CONNECTIONS = True
#         self.USE_LAYER_NORM = True

#         # Enhanced Loss Weights (crucial for good reconstruction)
#         self.PRIMARY_LOSS_WEIGHT = 1.0
#         self.RECONSTRUCTION_LOSS_WEIGHT = 0.5  # Increased
#         self.CONSISTENCY_LOSS_WEIGHT = 0.2
#         self.DIVERSITY_LOSS_WEIGHT = 0.1
#         self.FREQUENCY_LOSS_WEIGHT = 0.3       # Added frequency domain loss
#         self.SMOOTHNESS_LOSS_WEIGHT = 0.1      # Added for temporal smoothness

#         # Training strategies
#         self.USE_SWA = True
#         self.SWA_START = 50
#         self.USE_CURRICULUM = True

#         # Visualization parameters
#         self.VISUALIZE_EVERY_N_EPOCHS = 5
#         self.NUM_SAMPLES_TO_VISUALIZE = 5

#         # Directory setup
#         self.OUTPUT_DIR = "enhanced_single_reservoir_results"
#         self.create_directories()

#     def create_directories(self):
#         """Create necessary directories for outputs."""
#         os.makedirs(self.OUTPUT_DIR, exist_ok=True)
#         for subdir in ['models', 'analysis', 'logs', 'reconstructions', 'electrode_perspectives']:
#             os.makedirs(os.path.join(self.OUTPUT_DIR, subdir), exist_ok=True)

# # Initialize configuration
# config = Config()

# #----------------------------------------------------------------------
# # Data Loading Functions (Keep Original)
# #----------------------------------------------------------------------

# def find_data_file(path, filename):
#     """Find the specified data file in the given path or its subdirectories."""
#     data_file = os.path.join(path, filename)
#     if os.path.exists(data_file):
#         return data_file

#     for root, dirs, files in os.walk(path):
#         if filename in files:
#             return os.path.join(root, filename)

#     raise FileNotFoundError(f"Could not find {filename} in {path}")

# def load_electrode_positions(path):
#     """Load electrode position data from CSV file."""
#     print("Loading electrode positions...")
#     positions_file = find_data_file(path, 'limbic_insular_probe_channels.csv')

#     positions = pd.read_csv(positions_file)
#     positions = positions.rename(columns={
#         'channel_index': 'channel_num',
#         'probe_horizontal_position': 'x_position',
#         'probe_vertical_position': 'y_position'
#     })

#     print(f"Loaded positions for {len(positions)} electrodes")
#     return positions

# def select_electrodes_fixed(positions, num_electrodes=None):
#     """Select electrodes using KMeans clustering for optimal spatial coverage."""
#     if num_electrodes is None:
#         num_electrodes = config.NUM_ELECTRODES

#     num_electrodes = min(num_electrodes, len(positions))
#     print(f"Selecting {num_electrodes} electrodes for analysis...")

#     if num_electrodes == len(positions):
#         selected_indices = list(range(len(positions)))
#         channel_nums = positions['channel_num'].astype(int).values
#         return channel_nums, selected_indices

#     # Normalize positions for clustering
#     x_norm = (positions['x_position'] - positions['x_position'].min()) / (positions['x_position'].max() - positions['x_position'].min())
#     y_norm = (positions['y_position'] - positions['y_position'].min()) / (positions['y_position'].max() - positions['y_position'].min())
#     coords = np.column_stack((x_norm.values, y_norm.values))

#     # Use KMeans to select representative electrodes
#     kmeans = KMeans(n_clusters=num_electrodes, random_state=SEED, n_init=10)
#     kmeans.fit(coords)

#     # Select the electrode closest to each cluster center
#     selected_indices = []
#     for cluster_idx in range(num_electrodes):
#         cluster_points = np.where(kmeans.labels_ == cluster_idx)[0]

#         if len(cluster_points) > 0:
#             center = kmeans.cluster_centers_[cluster_idx]
#             distances = np.sqrt(np.sum((coords[cluster_points] - center)**2, axis=1))
#             closest_idx = cluster_points[np.argmin(distances)]
#         else:
#             center = kmeans.cluster_centers_[cluster_idx]
#             all_distances = np.sqrt(np.sum((coords - center)**2, axis=1))
#             closest_idx = np.argmin(all_distances)

#         selected_indices.append(closest_idx)

#     channel_nums = positions.iloc[selected_indices]['channel_num'].astype(int).values
#     print(f"Selected {len(channel_nums)} electrodes with good spatial coverage")
#     return channel_nums, selected_indices

# def load_lfp_data(path, channel_nums, max_rows=500000):
#     """Load LFP data from CSV file."""
#     print(f"Loading LFP data for {len(channel_nums)} channels...")
#     data_file = find_data_file(path, 'limbic_insular_ieeg_data (7).csv')

#     cols_to_use = ['timestamp', 'presentation_id'] + [f'channel_{ch}' for ch in channel_nums]
#     print(f"Loading columns: {cols_to_use[:5]}... (and {len(cols_to_use)-5} more)")

#     data = pd.read_csv(data_file, usecols=cols_to_use, nrows=max_rows)
#     print(f"Loaded {len(data)} rows of LFP data")

#     num_presentations = data['presentation_id'].nunique()
#     print(f"Data contains {num_presentations} unique presentation_ids")

#     return data

# def extract_windows(data, context_size=None, predict_size=None, stride=None, max_windows=None):
#     """Extract time windows from continuous data for training and prediction."""
#     if context_size is None:
#         context_size = config.CONTEXT_SIZE
#     if predict_size is None:
#         predict_size = config.PREDICT_SIZE
#     if stride is None:
#         stride = config.STRIDE
#     if max_windows is None:
#         max_windows = config.MAX_SEQUENCES

#     print(f"Extracting windows with context={context_size}, predict={predict_size}, stride={stride}")

#     use_presentation_id = 'presentation_id' in data.columns
#     if use_presentation_id:
#         print(f"Using presentation_id to ensure windows come from same presentation")

#     data_cols = [col for col in data.columns if col.startswith('channel_')]
#     window_size = context_size + predict_size
#     windows = []

#     if use_presentation_id:
#         presentation_groups = data.groupby('presentation_id')

#         for presentation_id, group in presentation_groups:
#             if len(group) < window_size:
#                 continue

#             presentation_data = group[data_cols].values

#             for start in range(0, len(presentation_data) - window_size + 1, stride):
#                 if len(windows) >= max_windows:
#                     break

#                 window = presentation_data[start:start + window_size]
#                 if np.isnan(window).any():
#                     continue

#                 windows.append(window)

#             if len(windows) >= max_windows:
#                 break
#     else:
#         data_values = data[data_cols].values

#         for start in range(0, len(data_values) - window_size, stride):
#             if len(windows) >= max_windows:
#                 break

#             window = data_values[start:start + window_size]
#             if np.isnan(window).any():
#                 continue

#             windows.append(window)

#     windows = np.array(windows)
#     print(f"Extracted {len(windows)} windows with shape: {windows.shape}")
#     return windows

# #----------------------------------------------------------------------
# # Enhanced Preprocessing with Better Scaling
# #----------------------------------------------------------------------

# def create_filter_bank(fs, freq_bands):
#     """Create filter bank for frequency decomposition."""
#     filter_bank = []
#     nyquist = fs / 2.0

#     for band_idx, (low_freq, high_freq) in enumerate(freq_bands):
#         low = low_freq / nyquist
#         high = high_freq / nyquist
#         sos = signal.butter(4, [low, high], btype='bandpass', output='sos')
#         filter_bank.append((sos, band_idx, (low_freq, high_freq)))

#     return filter_bank

# def extract_frequency_features(windows, fs=None, freq_bands=None):
#     """Extract frequency band features."""
#     if fs is None:
#         fs = config.SAMPLING_RATE
#     if freq_bands is None:
#         freq_bands = config.FREQ_BANDS

#     print("Extracting frequency band features...")
#     n_samples, window_size, n_channels = windows.shape
#     n_bands = len(freq_bands)

#     filter_bank = create_filter_bank(fs, freq_bands)
#     band_powers = np.zeros((n_samples, n_bands, n_channels))

#     for (sos, band_idx, (low_freq, high_freq)) in filter_bank:
#         print(f"Processing band {band_idx+1}/{n_bands}: {low_freq}-{high_freq} Hz")

#         for i in tqdm(range(n_samples), desc=f"Band {band_idx+1}"):
#             for j in range(n_channels):
#                 signal_i = windows[i, :, j]
#                 filtered = signal.sosfiltfilt(sos, signal_i)
#                 # Compute band power
#                 band_powers[i, band_idx, j] = np.mean(filtered**2)

#     return band_powers

# def preprocess_data(windows, context_size=None, predict_size=None):
#     """Enhanced preprocessing with better normalization."""
#     if context_size is None:
#         context_size = config.CONTEXT_SIZE
#     if predict_size is None:
#         predict_size = config.PREDICT_SIZE

#     print("Preprocessing data...")
#     num_windows, window_size, num_electrodes = windows.shape

#     context_windows = windows[:, :context_size, :]
#     target_windows = windows[:, context_size:, :]

#     # Better scaling strategy
#     scaler = StandardScaler()  # Changed from RobustScaler

#     # Scale to microvolts first
#     context_flat = context_windows.reshape(-1, num_electrodes) * 1e6
#     target_flat = target_windows.reshape(-1, num_electrodes) * 1e6

#     # Fit scaler on all data
#     all_data = np.vstack([context_flat, target_flat])
#     scaler.fit(all_data)

#     # Transform data
#     context_scaled = scaler.transform(context_flat).reshape(context_windows.shape)
#     target_scaled = scaler.transform(target_flat).reshape(target_windows.shape)

#     # Extract frequency features for context
#     band_powers = extract_frequency_features(context_windows * 1e6, config.SAMPLING_RATE, config.FREQ_BANDS)

#     # Create dataset
#     X_data = {
#         'raw': context_scaled,
#         'band_powers': band_powers
#     }

#     y_data = target_scaled

#     # Split data
#     indices = np.arange(len(context_scaled))
#     train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=SEED)
#     val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=SEED)

#     # Create datasets
#     train_dataset = LFPDataset(
#         {k: v[train_idx] for k, v in X_data.items()},
#         y_data[train_idx]
#     )

#     val_dataset = LFPDataset(
#         {k: v[val_idx] for k, v in X_data.items()},
#         y_data[val_idx]
#     )

#     test_dataset = LFPDataset(
#         {k: v[test_idx] for k, v in X_data.items()},
#         y_data[test_idx]
#     )

#     # Create data loaders
#     train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, drop_last=True)
#     val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
#     test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)

#     print(f"Created data loaders with {len(train_loader)} training batches")

#     return train_loader, val_loader, test_loader, scaler

# class LFPDataset(Dataset):
#     """Dataset for LFP data."""
#     def __init__(self, X, y):
#         self.X_raw = torch.FloatTensor(X['raw'])
#         self.X_band_powers = torch.FloatTensor(X['band_powers'])
#         self.y = torch.FloatTensor(y)

#     def __len__(self):
#         return len(self.y)

#     def __getitem__(self, idx):
#         return {
#             'raw': self.X_raw[idx],
#             'band_powers': self.X_band_powers[idx]
#         }, self.y[idx]

# #----------------------------------------------------------------------
# # Biologically Plausible Single Reservoir
# #----------------------------------------------------------------------

# class BiologicallyPlausibleReservoir(nn.Module):
#     """Single biologically plausible reservoir with Dale's principle."""

#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.reservoir_size = config.RESERVOIR_SIZE
#         self.input_size = config.NUM_ELECTRODES
#         self.leaky_rate = config.LEAKY_RATE

#         # Initialize reservoir weights with biological constraints
#         self._initialize_reservoir()

#         # Input projection
#         self.input_projection = nn.Linear(self.input_size, self.reservoir_size, bias=False)
#         nn.init.uniform_(self.input_projection.weight, -config.INPUT_SCALING, config.INPUT_SCALING)

#         # Dale's principle: separate excitatory and inhibitory neurons
#         self.excitatory_mask = torch.rand(self.reservoir_size) < 0.8  # 80% excitatory
#         self.inhibitory_mask = ~self.excitatory_mask

#         # Activation function
#         if config.ACTIVATION_FUNC == 'tanh':
#             self.activation = nn.Tanh()
#         elif config.ACTIVATION_FUNC == 'relu':
#             self.activation = nn.ReLU()
#         else:
#             self.activation = nn.Tanh()

#         # Layer normalization for stability
#         self.layer_norm = nn.LayerNorm(self.reservoir_size)

#         # State buffer
#         self.register_buffer('state', None)

#     def _initialize_reservoir(self):
#         """Initialize reservoir with biological constraints."""
#         # Create sparse connectivity matrix
#         W = torch.zeros(self.reservoir_size, self.reservoir_size)
#         num_connections = int(self.config.CONNECTIVITY * self.reservoir_size * self.reservoir_size)

#         # Random connections
#         indices = torch.randperm(self.reservoir_size * self.reservoir_size)[:num_connections]
#         i_indices = indices // self.reservoir_size
#         j_indices = indices % self.reservoir_size

#         # Remove self-connections
#         mask = i_indices != j_indices
#         i_indices = i_indices[mask]
#         j_indices = j_indices[mask]

#         # Initialize weights
#         W[i_indices, j_indices] = torch.randn(len(i_indices))

#         # Apply spectral radius normalization
#         eigenvalues = torch.linalg.eigvals(W)
#         spectral_radius = torch.max(torch.abs(eigenvalues)).item()

#         if spectral_radius > 0:
#             W = W * (self.config.SPECTRAL_RADIUS / spectral_radius)

#         self.register_buffer('W', W)

#     def reset_state(self, batch_size=1):
#         """Reset reservoir state."""
#         self.state = torch.zeros(batch_size, self.reservoir_size, device=self.W.device)

#     def forward(self, x):
#         """Forward pass through reservoir."""
#         batch_size = x.shape[0]

#         if self.state is None or self.state.shape[0] != batch_size:
#             self.reset_state(batch_size)

#         # Input contribution
#         input_contribution = self.input_projection(x)

#         # Recurrent contribution
#         recurrent_contribution = torch.matmul(self.state, self.W.T)

#         # Apply Dale's principle (ensure positive/negative neurons)
#         if self.excitatory_mask.device != recurrent_contribution.device:
#             self.excitatory_mask = self.excitatory_mask.to(recurrent_contribution.device)
#             self.inhibitory_mask = self.inhibitory_mask.to(recurrent_contribution.device)

#         recurrent_contribution = torch.where(
#             self.excitatory_mask.unsqueeze(0),
#             torch.abs(recurrent_contribution),
#             -torch.abs(recurrent_contribution)
#         )

#         # Combine inputs
#         pre_activation = input_contribution + recurrent_contribution

#         # Apply activation
#         new_state = self.activation(pre_activation)

#         # Leaky integration
#         self.state = (1 - self.leaky_rate) * self.state + self.leaky_rate * new_state

#         # Add small noise for regularization
#         if self.training and self.config.NOISE_LEVEL > 0:
#             noise = torch.randn_like(self.state) * self.config.NOISE_LEVEL
#             self.state = self.state + noise

#         # Layer normalization
#         output = self.layer_norm(self.state)

#         return output

# #----------------------------------------------------------------------
# # Enhanced Model Components
# #----------------------------------------------------------------------

# class ElectrodeAttentionModule(nn.Module):
#     """Electrode-specific attention for extracting perspectives."""

#     def __init__(self, reservoir_dim, electrode_dim, num_heads=4):
#         super().__init__()
#         self.attention = nn.MultiheadAttention(
#             embed_dim=reservoir_dim,
#             num_heads=num_heads,
#             dropout=0.1,
#             batch_first=True
#         )

#         self.projection = nn.Sequential(
#             nn.Linear(reservoir_dim, electrode_dim),
#             nn.LayerNorm(electrode_dim),
#             nn.GELU(),
#             nn.Dropout(0.1)
#         )

#     def forward(self, reservoir_states):
#         """Extract electrode-specific perspective."""
#         # Self-attention on reservoir states
#         attended, weights = self.attention(
#             reservoir_states, reservoir_states, reservoir_states
#         )

#         # Project to electrode-specific space
#         perspective = self.projection(attended)

#         return perspective, weights

# class TemporalProcessor(nn.Module):
#     """Process temporal sequences with BiLSTM."""

#     def __init__(self, input_dim, hidden_dim, num_layers=2):
#         super().__init__()
#         self.lstm = nn.LSTM(
#             input_size=input_dim,
#             hidden_size=hidden_dim,
#             num_layers=num_layers,
#             batch_first=True,
#             dropout=0.2,
#             bidirectional=True
#         )

#         self.projection = nn.Linear(hidden_dim * 2, hidden_dim)

#     def forward(self, x):
#         """Process temporal sequence."""
#         output, (h_n, c_n) = self.lstm(x)

#         # Combine bidirectional outputs
#         output = self.projection(output)

#         # Get final hidden state
#         h_forward = h_n[-2, :, :]
#         h_backward = h_n[-1, :, :]
#         final_hidden = torch.cat([h_forward, h_backward], dim=1)
#         final_hidden = self.projection(final_hidden)

#         return output, final_hidden

# class SignalDecoder(nn.Module):
#     """Decode to reconstruct signals with skip connections."""

#     def __init__(self, input_dim, output_dim, hidden_dims=[512, 256, 128]):
#         super().__init__()

#         layers = []
#         prev_dim = input_dim

#         for hidden_dim in hidden_dims:
#             layers.extend([
#                 nn.Linear(prev_dim, hidden_dim),
#                 nn.LayerNorm(hidden_dim),
#                 nn.GELU(),
#                 nn.Dropout(0.1)
#             ])
#             prev_dim = hidden_dim

#         layers.append(nn.Linear(prev_dim, output_dim))

#         self.decoder = nn.Sequential(*layers)

#         # Skip connection if dimensions match
#         self.use_skip = (input_dim == output_dim)

#     def forward(self, x):
#         """Decode with skip connection."""
#         output = self.decoder(x)

#         if self.use_skip and x.shape == output.shape:
#             output = output + 0.1 * x  # Residual connection

#         return output

# #----------------------------------------------------------------------
# # Main Model with Single Reservoir
# #----------------------------------------------------------------------

# class EnhancedSingleReservoirESN(nn.Module):
#     """Enhanced ESN with single reservoir and multi-perspective learning."""

#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.num_electrodes = config.NUM_ELECTRODES
#         self.num_freq_bands = len(config.FREQ_BANDS)
#         self.predict_size = config.PREDICT_SIZE

#         # Feature extraction
#         self.feature_extractor = nn.Sequential(
#             nn.Linear(self.num_electrodes + self.num_freq_bands * self.num_electrodes,
#                      config.HIDDEN_SIZE),
#             nn.LayerNorm(config.HIDDEN_SIZE),
#             nn.GELU(),
#             nn.Dropout(config.DROPOUT)
#         )

#         # Single biologically plausible reservoir
#         self.reservoir = BiologicallyPlausibleReservoir(config)

#         # Electrode-specific attention modules
#         self.electrode_attentions = nn.ModuleList([
#             ElectrodeAttentionModule(
#                 reservoir_dim=config.RESERVOIR_SIZE,
#                 electrode_dim=config.ELECTRODE_PERSPECTIVE_DIM,
#                 num_heads=config.PERSPECTIVE_ATTENTION_HEADS
#             ) for _ in range(self.num_electrodes)
#         ])

#         # Temporal processor
#         self.temporal_processor = TemporalProcessor(
#             input_dim=config.RESERVOIR_SIZE,
#             hidden_dim=config.READOUT_HIDDEN,
#             num_layers=2
#         )

#         # Prediction decoders for each timestep
#         self.decoders = nn.ModuleList([
#             SignalDecoder(
#                 input_dim=config.READOUT_HIDDEN,
#                 output_dim=self.num_electrodes,
#                 hidden_dims=[config.READOUT_HIDDEN, config.READOUT_HIDDEN // 2]
#             ) for _ in range(config.PREDICT_SIZE)
#         ])

#         # Reconstruction decoder for auxiliary task
#         self.reconstruction_decoder = SignalDecoder(
#             input_dim=config.ELECTRODE_PERSPECTIVE_DIM * self.num_electrodes,
#             output_dim=self.num_electrodes * config.CONTEXT_SIZE,
#             hidden_dims=[512, 256]
#         )

#     def reset_states(self):
#         """Reset all states."""
#         self.reservoir.reset_state()

#     def forward(self, x_dict, return_perspectives=False):
#         """Forward pass with optional perspective extraction."""
#         x_raw = x_dict['raw']
#         x_band_powers = x_dict['band_powers']

#         batch_size, context_size, num_electrodes = x_raw.shape
#         device = x_raw.device

#         self.reset_states()

#         # Storage for analysis
#         reservoir_states = []
#         electrode_perspectives = []

#         # Process sequence through reservoir
#         for t in range(context_size):
#             # Extract features at time t
#             raw_t = x_raw[:, t, :]
#             band_powers_t = x_band_powers[:, :, :].reshape(batch_size, -1)  # Flatten band powers

#             # Combine features
#             features = torch.cat([raw_t, band_powers_t], dim=1)
#             features = self.feature_extractor(features)

#             # Process through reservoir
#             reservoir_state = self.reservoir(raw_t)  # Use raw signal for reservoir input
#             reservoir_states.append(reservoir_state)

#         # Stack reservoir states
#         reservoir_sequence = torch.stack(reservoir_states, dim=1)  # [batch, time, reservoir_size]

#         # Extract electrode perspectives if requested
#         if return_perspectives:
#             for e_idx in range(self.num_electrodes):
#                 perspective, _ = self.electrode_attentions[e_idx](reservoir_sequence)
#                 # Average over time for stable perspective
#                 perspective_avg = perspective.mean(dim=1)
#                 electrode_perspectives.append(perspective_avg)

#         # Temporal processing
#         temporal_output, final_hidden = self.temporal_processor(reservoir_sequence)

#         # Generate predictions for each future timestep
#         predictions = []
#         current_hidden = final_hidden

#         for t in range(self.predict_size):
#             # Decode current hidden state
#             pred_t = self.decoders[t](current_hidden)
#             predictions.append(pred_t)

#             # Update hidden state (autoregressive)
#             if t < self.predict_size - 1:
#                 # Simple update rule
#                 current_hidden = current_hidden * 0.9 + 0.1 * self.decoders[t].decoder[0](current_hidden)

#         # Stack predictions
#         predictions = torch.stack(predictions, dim=1)  # [batch, predict_size, num_electrodes]

#         if return_perspectives:
#             # Reconstruction task
#             all_perspectives = torch.cat(electrode_perspectives, dim=1)
#             reconstruction = self.reconstruction_decoder(all_perspectives)
#             reconstruction = reconstruction.view(batch_size, context_size, num_electrodes)

#             return predictions, {
#                 'perspectives': electrode_perspectives,
#                 'reconstruction': reconstruction,
#                 'reservoir_states': reservoir_sequence
#             }
#         else:
#             return predictions

# #----------------------------------------------------------------------
# # Enhanced Loss Functions
# #----------------------------------------------------------------------

# class ComprehensiveLoss(nn.Module):
#     """Comprehensive loss function for signal reconstruction."""

#     def __init__(self, config):
#         super().__init__()
#         self.config = config
#         self.mse = nn.MSELoss()
#         self.smooth_l1 = nn.SmoothL1Loss()

#     def frequency_loss(self, pred, target, sampling_rate=1000):
#         """Compute loss in frequency domain."""
#         # Compute FFT
#         pred_fft = torch.fft.rfft(pred, dim=1)
#         target_fft = torch.fft.rfft(target, dim=1)

#         # Magnitude spectrum
#         pred_mag = torch.abs(pred_fft)
#         target_mag = torch.abs(target_fft)

#         # Focus on gamma frequencies
#         freq_bins = pred_mag.shape[1]
#         gamma_start = int(30 * freq_bins / (sampling_rate / 2))
#         gamma_end = int(200 * freq_bins / (sampling_rate / 2))

#         # Weighted frequency loss
#         freq_loss = self.mse(pred_mag[:, gamma_start:gamma_end],
#                             target_mag[:, gamma_start:gamma_end])

#         return freq_loss

#     def smoothness_loss(self, pred):
#         """Temporal smoothness loss."""
#         # First-order differences
#         diff1 = pred[:, 1:, :] - pred[:, :-1, :]
#         smooth_loss = torch.mean(diff1**2)

#         return smooth_loss

#     def forward(self, predictions, targets, aux_outputs=None):
#         """Compute comprehensive loss."""
#         # Primary prediction loss
#         primary_loss = self.mse(predictions, targets)

#         # Frequency domain loss
#         freq_loss = 0
#         for e in range(predictions.shape[2]):  # For each electrode
#             freq_loss += self.frequency_loss(predictions[:, :, e], targets[:, :, e])
#         freq_loss /= predictions.shape[2]

#         # Smoothness loss
#         smooth_loss = self.smoothness_loss(predictions)

#         # Total loss
#         total_loss = (
#             self.config.PRIMARY_LOSS_WEIGHT * primary_loss +
#             self.config.FREQUENCY_LOSS_WEIGHT * freq_loss +
#             self.config.SMOOTHNESS_LOSS_WEIGHT * smooth_loss
#         )

#         # Add reconstruction loss if available
#         if aux_outputs is not None and 'reconstruction' in aux_outputs:
#             reconstruction = aux_outputs['reconstruction']
#             # Get corresponding input for reconstruction loss
#             batch_size = predictions.shape[0]
#             context_size = reconstruction.shape[1]

#             # Note: We need the input data for reconstruction loss
#             # This should be passed from the training loop
#             if 'input_data' in aux_outputs:
#                 input_data = aux_outputs['input_data']
#                 recon_loss = self.mse(reconstruction, input_data)
#                 total_loss += self.config.RECONSTRUCTION_LOSS_WEIGHT * recon_loss
#             else:
#                 recon_loss = torch.tensor(0.0, device=predictions.device)
#         else:
#             recon_loss = torch.tensor(0.0, device=predictions.device)

#         return {
#             'total_loss': total_loss,
#             'primary_loss': primary_loss,
#             'freq_loss': freq_loss,
#             'smooth_loss': smooth_loss,
#             'recon_loss': recon_loss
#         }

# #----------------------------------------------------------------------
# # Visualization Functions
# #----------------------------------------------------------------------

# def visualize_reconstructions(model, data_loader, scaler, epoch, save_dir, num_samples=5):
#     """Visualize signal reconstructions."""
#     model.eval()

#     with torch.no_grad():
#         # Get a batch of data
#         inputs, targets = next(iter(data_loader))
#         inputs = {k: v.to(device) for k, v in inputs.items()}
#         targets = targets.to(device)

#         # Get predictions
#         predictions, aux_outputs = model(inputs, return_perspectives=True)

#         # Convert to numpy
#         inputs_np = inputs['raw'].cpu().numpy()
#         targets_np = targets.cpu().numpy()
#         predictions_np = predictions.cpu().numpy()
#         reconstruction_np = aux_outputs['reconstruction'].cpu().numpy()

#         # Inverse transform
#         batch_size, context_size, num_electrodes = inputs_np.shape
#         _, predict_size, _ = predictions_np.shape

#         # Reshape for inverse transform
#         inputs_reshaped = inputs_np.reshape(-1, num_electrodes)
#         targets_reshaped = targets_np.reshape(-1, num_electrodes)
#         predictions_reshaped = predictions_np.reshape(-1, num_electrodes)
#         reconstruction_reshaped = reconstruction_np.reshape(-1, num_electrodes)

#         # Inverse transform
#         inputs_orig = scaler.inverse_transform(inputs_reshaped).reshape(batch_size, context_size, num_electrodes)
#         targets_orig = scaler.inverse_transform(targets_reshaped).reshape(batch_size, predict_size, num_electrodes)
#         predictions_orig = scaler.inverse_transform(predictions_reshaped).reshape(batch_size, predict_size, num_electrodes)
#         reconstruction_orig = scaler.inverse_transform(reconstruction_reshaped).reshape(batch_size, context_size, num_electrodes)

#         # Create figure
#         fig, axes = plt.subplots(num_samples, 3, figsize=(20, 4*num_samples))
#         if num_samples == 1:
#             axes = axes.reshape(1, -1)

#         for sample_idx in range(min(num_samples, batch_size)):
#             # Select electrodes to visualize
#             electrode_indices = [0, num_electrodes//2, num_electrodes-1]

#             for e_idx, electrode in enumerate(electrode_indices):
#                 ax = axes[sample_idx, e_idx]

#                 # Context (input)
#                 context_signal = inputs_orig[sample_idx, :, electrode]

#                 # Target (ground truth future)
#                 target_signal = targets_orig[sample_idx, :, electrode]

#                 # Prediction
#                 pred_signal = predictions_orig[sample_idx, :, electrode]

#                 # Reconstruction
#                 recon_signal = reconstruction_orig[sample_idx, :, electrode]

#                 # Time axes
#                 context_time = np.arange(context_size)
#                 future_time = np.arange(context_size, context_size + predict_size)

#                 # Plot
#                 ax.plot(context_time, context_signal, 'b-', label='Context', alpha=0.7)
#                 ax.plot(context_time, recon_signal, 'g--', label='Reconstruction', alpha=0.7)
#                 ax.plot(future_time, target_signal, 'k-', label='Target', linewidth=2)
#                 ax.plot(future_time, pred_signal, 'r--', label='Prediction', linewidth=2)

#                 # Compute metrics
#                 mse = np.mean((pred_signal - target_signal)**2)
#                 corr = np.corrcoef(pred_signal, target_signal)[0, 1]

#                 ax.set_title(f'Sample {sample_idx+1}, Electrode {electrode} | MSE: {mse:.4f}, Corr: {corr:.3f}')
#                 ax.set_xlabel('Time (ms)')
#                 ax.set_ylabel('Amplitude (μV)')
#                 ax.legend()
#                 ax.grid(True, alpha=0.3)

#         plt.suptitle(f'Signal Reconstructions - Epoch {epoch}', fontsize=16)
#         plt.tight_layout()

#         # Save figure
#         save_path = os.path.join(save_dir, f'reconstructions_epoch_{epoch}.png')
#         plt.savefig(save_path, dpi=150, bbox_inches='tight')
#         plt.close()

#         print(f"Saved reconstruction visualization to {save_path}")

# def plot_training_history(history, save_dir):
#     """Plot training history."""
#     fig, axes = plt.subplots(2, 2, figsize=(12, 10))

#     # Loss curves
#     axes[0, 0].plot(history['train_loss'], label='Train')
#     axes[0, 0].plot(history['val_loss'], label='Validation')
#     axes[0, 0].set_xlabel('Epoch')
#     axes[0, 0].set_ylabel('Total Loss')
#     axes[0, 0].set_title('Training Progress')
#     axes[0, 0].legend()
#     axes[0, 0].grid(True)

#     # Individual losses
#     loss_types = ['primary_loss', 'freq_loss', 'smooth_loss', 'recon_loss']
#     for loss_type in loss_types:
#         if f'train_{loss_type}' in history:
#             axes[0, 1].plot(history[f'train_{loss_type}'], label=loss_type)
#     axes[0, 1].set_xlabel('Epoch')
#     axes[0, 1].set_ylabel('Loss')
#     axes[0, 1].set_title('Loss Components')
#     axes[0, 1].legend()
#     axes[0, 1].grid(True)

#     # Metrics
#     if 'val_mse' in history:
#         axes[1, 0].plot(history['val_mse'], label='MSE')
#         axes[1, 0].set_xlabel('Epoch')
#         axes[1, 0].set_ylabel('MSE')
#         axes[1, 0].set_title('Validation MSE')
#         axes[1, 0].grid(True)

#     if 'val_corr' in history:
#         axes[1, 1].plot(history['val_corr'], label='Correlation')
#         axes[1, 1].set_xlabel('Epoch')
#         axes[1, 1].set_ylabel('Correlation')
#         axes[1, 1].set_title('Validation Correlation')
#         axes[1, 1].grid(True)

#     plt.tight_layout()
#     plt.savefig(os.path.join(save_dir, 'training_history.png'), dpi=150)
#     plt.close()

# #----------------------------------------------------------------------
# # Training Functions
# #----------------------------------------------------------------------

# def train_epoch(model, train_loader, optimizer, criterion, device, epoch):
#     """Train for one epoch."""
#     model.train()

#     losses = defaultdict(float)

#     progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')
#     for batch_data in progress_bar:
#         inputs, targets = batch_data
#         inputs = {k: v.to(device) for k, v in inputs.items()}
#         targets = targets.to(device)

#         # Forward pass
#         predictions, aux_outputs = model(inputs, return_perspectives=True)

#         # Add input data for reconstruction loss
#         aux_outputs['input_data'] = inputs['raw']

#         # Compute loss
#         loss_dict = criterion(predictions, targets, aux_outputs)
#         total_loss = loss_dict['total_loss']

#         # Backward pass
#         optimizer.zero_grad()
#         total_loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), config.CLIP_GRAD_NORM)
#         optimizer.step()

#         # Track losses
#         for key, value in loss_dict.items():
#             losses[key] += value.item()

#         # Update progress bar
#         progress_bar.set_postfix({
#             'loss': f"{total_loss.item():.4f}",
#             'primary': f"{loss_dict['primary_loss'].item():.4f}"
#         })

#     # Average losses
#     for key in losses:
#         losses[key] /= len(train_loader)

#     return losses

# def evaluate(model, val_loader, criterion, device):
#     """Evaluate model."""
#     model.eval()

#     losses = defaultdict(float)
#     all_predictions = []
#     all_targets = []

#     with torch.no_grad():
#         for batch_data in val_loader:
#             inputs, targets = batch_data
#             inputs = {k: v.to(device) for k, v in inputs.items()}
#             targets = targets.to(device)

#             # Forward pass
#             predictions, aux_outputs = model(inputs, return_perspectives=True)

#             # Add input data for reconstruction loss
#             aux_outputs['input_data'] = inputs['raw']

#             # Compute loss
#             loss_dict = criterion(predictions, targets, aux_outputs)

#             # Track losses
#             for key, value in loss_dict.items():
#                 losses[key] += value.item()

#             # Store predictions
#             all_predictions.append(predictions.cpu().numpy())
#             all_targets.append(targets.cpu().numpy())

#     # Average losses
#     for key in losses:
#         losses[key] /= len(val_loader)

#     # Compute metrics
#     all_predictions = np.concatenate(all_predictions, axis=0)
#     all_targets = np.concatenate(all_targets, axis=0)

#     # MSE and correlation
#     mse = np.mean((all_predictions - all_targets)**2)

#     # Average correlation across electrodes
#     correlations = []
#     for e in range(all_predictions.shape[2]):
#         pred_e = all_predictions[:, :, e].flatten()
#         target_e = all_targets[:, :, e].flatten()
#         if np.std(pred_e) > 0 and np.std(target_e) > 0:
#             corr = np.corrcoef(pred_e, target_e)[0, 1]
#             correlations.append(corr)

#     avg_corr = np.mean(correlations) if correlations else 0

#     losses['mse'] = mse
#     losses['corr'] = avg_corr

#     return losses

# def train_model(model, train_loader, val_loader, config, scaler, test_loader=None):
#     """Complete training pipeline."""
#     criterion = ComprehensiveLoss(config)
#     optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE,
#                                  weight_decay=config.WEIGHT_DECAY)
#     scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6)

#     history = defaultdict(list)
#     best_val_loss = float('inf')
#     patience_counter = 0

#     print(f"Starting training for {config.EPOCHS} epochs...")

#     for epoch in range(config.EPOCHS):
#         print(f"\n{'='*50}")
#         print(f"Epoch {epoch+1}/{config.EPOCHS}")
#         print(f"{'='*50}")

#         # Train
#         train_losses = train_epoch(model, train_loader, optimizer, criterion, device, epoch+1)

#         # Evaluate
#         val_losses = evaluate(model, val_loader, criterion, device)

#         # Update scheduler
#         scheduler.step(val_losses['total_loss'])

#         # Store history
#         for key, value in train_losses.items():
#             history[f'train_{key}'].append(value)
#         for key, value in val_losses.items():
#             history[f'val_{key}'].append(value)

#         # Print results
#         print(f"Train Loss: {train_losses['total_loss']:.4f} | Val Loss: {val_losses['total_loss']:.4f}")
#         print(f"Val MSE: {val_losses['mse']:.4f} | Val Corr: {val_losses['corr']:.3f}")

#         # Visualize reconstructions
#         if (epoch + 1) % config.VISUALIZE_EVERY_N_EPOCHS == 0:
#             visualize_reconstructions(
#                 model, val_loader, scaler, epoch+1,
#                 os.path.join(config.OUTPUT_DIR, 'reconstructions'),
#                 num_samples=config.NUM_SAMPLES_TO_VISUALIZE
#             )

#         # Save best model
#         if val_losses['total_loss'] < best_val_loss:
#             best_val_loss = val_losses['total_loss']
#             patience_counter = 0
#             torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'val_loss': best_val_loss,
#                 'config': config
#             }, os.path.join(config.OUTPUT_DIR, 'models', 'best_model.pt'))
#             print(f"Saved best model with val loss: {best_val_loss:.4f}")
#         else:
#             patience_counter += 1

#         # Early stopping
#         if patience_counter >= config.PATIENCE:
#             print(f"Early stopping after {epoch+1} epochs")
#             break

#     # Final evaluation on test set
#     if test_loader:
#         print("\nFinal evaluation on test set...")
#         test_losses = evaluate(model, test_loader, criterion, device)
#         print(f"Test Loss: {test_losses['total_loss']:.4f}")
#         print(f"Test MSE: {test_losses['mse']:.4f} | Test Corr: {test_losses['corr']:.3f}")

#         # Save final test visualization
#         visualize_reconstructions(
#             model, test_loader, scaler, 'test',
#             os.path.join(config.OUTPUT_DIR, 'reconstructions'),
#             num_samples=10
#         )

#     # Plot training history
#     plot_training_history(dict(history), config.OUTPUT_DIR)

#     return model, history

# #----------------------------------------------------------------------
# # Main Pipeline
# #----------------------------------------------------------------------

# def main(data_path):
#     """Main training pipeline."""
#     print("Starting Enhanced Single Reservoir Pipeline")
#     print(f"Output directory: {config.OUTPUT_DIR}")

#     try:
#         # Load electrode positions
#         positions = load_electrode_positions(data_path)

#         # Select electrodes
#         channel_nums, selected_indices = select_electrodes_fixed(positions, config.NUM_ELECTRODES)

#         # Load LFP data
#         lfp_data = load_lfp_data(data_path, channel_nums)

#         # Extract windows
#         windows = extract_windows(lfp_data)

#         if windows.shape[0] == 0:
#             print("No valid windows extracted!")
#             return None

#         # Preprocess data
#         train_loader, val_loader, test_loader, scaler = preprocess_data(windows)

#         # Create model
#         model = EnhancedSingleReservoirESN(config).to(device)

#         # Count parameters
#         total_params = sum(p.numel() for p in model.parameters())
#         trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
#         print(f"Total parameters: {total_params:,}")
#         print(f"Trainable parameters: {trainable_params:,}")

#         # Train model
#         model, history = train_model(model, train_loader, val_loader, config, scaler, test_loader)

#         # Save final model
#         torch.save(model.state_dict(),
#                   os.path.join(config.OUTPUT_DIR, 'models', 'final_model.pt'))

#         print("\nTraining completed successfully!")

#         return model, scaler

#     except Exception as e:
#         print(f"Error in main pipeline: {e}")
#         traceback.print_exc()
#         return None

# if __name__ == "__main__":
#     # Data path resolution
#     data_path = None

#     if os.path.exists("/kaggle/input/ecog-landmark-mkn"):
#         data_path = "/kaggle/input/ecog-landmark-mkn"
#     elif os.path.exists("/content/ecog-landmark-mkn"):
#         data_path = "/content/ecog-landmark-mkn"
#     else:
#         # Try to download or use local path
#         try:
#             import kagglehub
#             data_path = kagglehub.dataset_download("arunramponnambalam/ecog-landmark-mkn")
#         except:
#             data_path = input("Enter path to dataset: ")

#     if data_path and os.path.exists(data_path):
#         print(f"Using dataset at: {data_path}")
#         model, scaler = main(data_path)
#     else:
#         print(f"Dataset path not found: {data_path}")

### **V2**

In [6]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score
from scipy import signal
import warnings
from tqdm import tqdm
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
import traceback
import json
from collections import defaultdict
from datetime import datetime

warnings.filterwarnings('ignore')

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Seed setup
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

#----------------------------------------------------------------------
# Enhanced Configuration V2
#----------------------------------------------------------------------

class ConfigV2:
    def __init__(self):
        # Data parameters
        self.NUM_ELECTRODES = 30
        self.CONTEXT_SIZE = 100
        self.PREDICT_SIZE = 50
        self.WINDOW_SIZE = self.CONTEXT_SIZE + self.PREDICT_SIZE
        self.SAMPLING_RATE = 1000
        self.MAX_SEQUENCES = 800  # More data
        self.BATCH_SIZE = 64  # Larger batch
        self.STRIDE = 20  # More overlap

        # Frequency bands - more detailed
        self.FREQ_BANDS = [
            (4, 8),         # Theta
            (8, 13),        # Alpha
            (13, 30),       # Beta
            (30, 50),       # Low Gamma
            (50, 80),       # Mid Gamma
            (80, 120),      # High Gamma
            (120, 200),     # Very High Gamma
        ]

        # Learning parameters
        self.LEARNING_RATE = 5e-4  # Lower initial LR
        self.WEIGHT_DECAY = 1e-5
        self.EPOCHS = 200
        self.PATIENCE = 30
        self.CLIP_GRAD_NORM = 0.5  # More aggressive clipping

        # Enhanced Reservoir parameters
        self.RESERVOIR_SIZE = 2000  # Larger reservoir
        self.SPECTRAL_RADIUS = 0.99  # Closer to edge of chaos
        self.LEAKY_RATE = 0.15  # Faster dynamics
        self.CONNECTIVITY = 0.05  # Sparser (5%)
        self.INPUT_SCALING = 0.8  # Stronger input
        self.NOISE_LEVEL = 0.005  # Less noise

        # Multi-timescale processing
        self.NUM_TIMESCALES = 3
        self.TIMESCALE_FACTORS = [1.0, 0.5, 0.25]  # Fast, medium, slow

        # Architecture
        self.HIDDEN_SIZE = 512
        self.READOUT_HIDDEN = 1024
        self.DROPOUT = 0.15
        self.NUM_HEADS = 16  # More attention heads

        # Multi-perspective
        self.ELECTRODE_PERSPECTIVE_DIM = 256
        self.COMMON_SPACE_DIM = 512

        # Loss weights - adjusted
        self.TIME_LOSS_WEIGHT = 1.0
        self.FREQ_LOSS_WEIGHT = 0.5
        self.PHASE_LOSS_WEIGHT = 0.3
        self.RECONSTRUCTION_LOSS_WEIGHT = 0.8  # Higher weight
        self.SMOOTHNESS_PENALTY = 0.05  # Lower penalty
        self.DIVERSITY_LOSS_WEIGHT = 0.1

        # Output directory
        self.OUTPUT_DIR = "enhanced_esn_v2_results"
        self.create_directories()

    def create_directories(self):
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)
        for subdir in ['models', 'reconstructions', 'analysis']:
            os.makedirs(os.path.join(self.OUTPUT_DIR, subdir), exist_ok=True)

config = ConfigV2()

#----------------------------------------------------------------------
# Data Loading (same as before)
#----------------------------------------------------------------------

def find_data_file(path, filename):
    """Find file in directory tree."""
    data_file = os.path.join(path, filename)
    if os.path.exists(data_file):
        return data_file

    for root, dirs, files in os.walk(path):
        if filename in files:
            return os.path.join(root, filename)

    raise FileNotFoundError(f"Could not find {filename} in {path}")

def load_electrode_positions(path):
    """Load electrode positions."""
    positions_file = find_data_file(path, 'limbic_insular_probe_channels.csv')
    positions = pd.read_csv(positions_file)
    positions = positions.rename(columns={
        'channel_index': 'channel_num',
        'probe_horizontal_position': 'x_position',
        'probe_vertical_position': 'y_position'
    })
    return positions

def select_electrodes_fixed(positions, num_electrodes=None):
    """Select electrodes with spatial coverage."""
    if num_electrodes is None:
        num_electrodes = config.NUM_ELECTRODES

    num_electrodes = min(num_electrodes, len(positions))

    if num_electrodes == len(positions):
        return positions['channel_num'].astype(int).values, list(range(len(positions)))

    # K-means selection
    from sklearn.cluster import KMeans
    coords = positions[['x_position', 'y_position']].values
    coords_norm = (coords - coords.mean(axis=0)) / coords.std(axis=0)

    kmeans = KMeans(n_clusters=num_electrodes, random_state=SEED)
    kmeans.fit(coords_norm)

    selected_indices = []
    for i in range(num_electrodes):
        cluster_mask = kmeans.labels_ == i
        cluster_indices = np.where(cluster_mask)[0]
        if len(cluster_indices) > 0:
            center = kmeans.cluster_centers_[i]
            distances = np.linalg.norm(coords_norm[cluster_indices] - center, axis=1)
            selected_idx = cluster_indices[np.argmin(distances)]
            selected_indices.append(selected_idx)

    channel_nums = positions.iloc[selected_indices]['channel_num'].astype(int).values
    return channel_nums, selected_indices

def load_lfp_data(path, channel_nums, max_rows=500000):
    """Load LFP data."""
    data_file = find_data_file(path, 'limbic_insular_ieeg_data (7).csv')
    cols_to_use = ['timestamp', 'presentation_id'] + [f'channel_{ch}' for ch in channel_nums]
    data = pd.read_csv(data_file, usecols=cols_to_use, nrows=max_rows)
    return data

def extract_windows(data, context_size=None, predict_size=None, stride=None, max_windows=None):
    """Extract overlapping windows."""
    if context_size is None:
        context_size = config.CONTEXT_SIZE
    if predict_size is None:
        predict_size = config.PREDICT_SIZE
    if stride is None:
        stride = config.STRIDE
    if max_windows is None:
        max_windows = config.MAX_SEQUENCES

    data_cols = [col for col in data.columns if col.startswith('channel_')]
    window_size = context_size + predict_size
    windows = []

    # Group by presentation_id if available
    if 'presentation_id' in data.columns:
        for pid, group in data.groupby('presentation_id'):
            if len(group) < window_size:
                continue

            group_data = group[data_cols].values
            for start in range(0, len(group_data) - window_size + 1, stride):
                if len(windows) >= max_windows:
                    break
                window = group_data[start:start + window_size]
                if not np.isnan(window).any():
                    windows.append(window)

            if len(windows) >= max_windows:
                break
    else:
        data_values = data[data_cols].values
        for start in range(0, len(data_values) - window_size + 1, stride):
            if len(windows) >= max_windows:
                break
            window = data_values[start:start + window_size]
            if not np.isnan(window).any():
                windows.append(window)

    return np.array(windows)

#----------------------------------------------------------------------
# Enhanced Feature Extraction
#----------------------------------------------------------------------

def extract_multiscale_features(windows, fs=1000):
    """Extract multi-scale features including phase and envelope."""
    n_samples, window_size, n_channels = windows.shape
    n_bands = len(config.FREQ_BANDS)

    # Initialize feature arrays
    band_powers = np.zeros((n_samples, n_bands, n_channels))
    band_phases = np.zeros((n_samples, n_bands, window_size, n_channels))
    band_envelopes = np.zeros((n_samples, n_bands, window_size, n_channels))

    print("Extracting multi-scale features...")

    for band_idx, (low_freq, high_freq) in enumerate(config.FREQ_BANDS):
        print(f"Processing band {band_idx+1}/{n_bands}: {low_freq}-{high_freq} Hz")

        # Design filter
        nyquist = fs / 2
        low = low_freq / nyquist
        high = high_freq / nyquist
        sos = signal.butter(4, [low, high], btype='bandpass', output='sos')

        for i in tqdm(range(n_samples), desc=f"Band {band_idx+1}"):
            for j in range(n_channels):
                # Filter signal
                filtered = signal.sosfiltfilt(sos, windows[i, :, j])

                # Hilbert transform for envelope and phase
                analytic = signal.hilbert(filtered)
                envelope = np.abs(analytic)
                phase = np.angle(analytic)

                # Store features
                band_powers[i, band_idx, j] = np.mean(envelope**2)
                band_envelopes[i, band_idx, :, j] = envelope
                band_phases[i, band_idx, :, j] = phase

    return {
        'band_powers': band_powers,
        'band_envelopes': band_envelopes,
        'band_phases': band_phases
    }

def preprocess_data_v2(windows):
    """Enhanced preprocessing with multi-scale features."""
    num_windows, window_size, num_electrodes = windows.shape
    context_size = config.CONTEXT_SIZE
    predict_size = config.PREDICT_SIZE

    # Split context and target
    context_windows = windows[:, :context_size, :]
    target_windows = windows[:, context_size:, :]

    # Scale to microvolts
    context_uv = context_windows * 1e6
    target_uv = target_windows * 1e6

    # Use RobustScaler for outlier handling
    from sklearn.preprocessing import RobustScaler
    scaler = RobustScaler(quantile_range=(5, 95))

    # Fit on all data
    all_data_flat = windows.reshape(-1, num_electrodes) * 1e6
    scaler.fit(all_data_flat)

    # Transform
    context_scaled = scaler.transform(context_uv.reshape(-1, num_electrodes)).reshape(context_windows.shape)
    target_scaled = scaler.transform(target_uv.reshape(-1, num_electrodes)).reshape(target_windows.shape)

    # Extract features
    features = extract_multiscale_features(context_uv)

    # Create dataset
    X_data = {
        'raw': context_scaled,
        'band_powers': features['band_powers'],
        'band_envelopes': features['band_envelopes'][:, :, :context_size, :],
        'band_phases': features['band_phases'][:, :, :context_size, :]
    }

    # Train/val/test split
    indices = np.arange(len(context_scaled))
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=SEED)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=SEED)

    # Create datasets
    train_dataset = EnhancedLFPDataset(
        {k: v[train_idx] for k, v in X_data.items()},
        target_scaled[train_idx]
    )

    val_dataset = EnhancedLFPDataset(
        {k: v[val_idx] for k, v in X_data.items()},
        target_scaled[val_idx]
    )

    test_dataset = EnhancedLFPDataset(
        {k: v[test_idx] for k, v in X_data.items()},
        target_scaled[test_idx]
    )

    # Create loaders
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True,
                            num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False,
                          num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False,
                           num_workers=2, pin_memory=True)

    return train_loader, val_loader, test_loader, scaler

class EnhancedLFPDataset(Dataset):
    """Dataset with multi-scale features."""
    def __init__(self, X, y):
        self.X_raw = torch.FloatTensor(X['raw'])
        self.X_band_powers = torch.FloatTensor(X['band_powers'])
        self.X_band_envelopes = torch.FloatTensor(X['band_envelopes'])
        self.X_band_phases = torch.FloatTensor(X['band_phases'])
        self.y = torch.FloatTensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return {
            'raw': self.X_raw[idx],
            'band_powers': self.X_band_powers[idx],
            'band_envelopes': self.X_band_envelopes[idx],
            'band_phases': self.X_band_phases[idx]
        }, self.y[idx]

#----------------------------------------------------------------------
# Multi-Timescale Reservoir
#----------------------------------------------------------------------

class SubReservoir(nn.Module):
    """
    A single timescale reservoir submodule.
    Contains a fixed weight matrix W and a leak rate buffer.
    """
    def __init__(self, W: torch.Tensor, leak_rate: float):
        super().__init__()
        self.W = nn.Parameter(W, requires_grad=False)
        self.register_buffer('leak_rate', torch.tensor(leak_rate))

    def forward(self, prev_state: torch.Tensor, input_contrib: torch.Tensor, noise_level: float, training: bool):
        rec = prev_state @ self.W.T
        activation = torch.tanh(input_contrib + rec)
        new_state = (1 - self.leak_rate) * prev_state + self.leak_rate * activation
        if training and noise_level > 0:
            new_state = new_state + torch.randn_like(new_state) * noise_level
        return new_state

class MultiTimescaleReservoir(nn.Module):
    """
    Reservoir with multiple timescales. Ensures effective reservoir size matches submodules.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_timescales = config.NUM_TIMESCALES
        # Base sub-reservoir size
        self.sub_size = config.RESERVOIR_SIZE // self.num_timescales
        # Effective total size = sub_size * num_timescales (handles non-divisible cases)
        self.eff_res_size = self.sub_size * self.num_timescales

        self.sub_reservoirs = nn.ModuleList()
        self.input_projections = nn.ModuleList()
        for timescale in config.TIMESCALE_FACTORS:
            W = self._init_W(self.sub_size, config.CONNECTIVITY, config.SPECTRAL_RADIUS)
            leak_rate = config.LEAKY_RATE * timescale
            self.sub_reservoirs.append(SubReservoir(W, leak_rate))

            proj = nn.Linear(config.NUM_ELECTRODES, self.sub_size, bias=False)
            nn.init.uniform_(proj.weight, -config.INPUT_SCALING, config.INPUT_SCALING)
            self.input_projections.append(proj)

        # Cross-timescale connections now use eff_res_size
        self.cross_connections = nn.Linear(self.eff_res_size, self.eff_res_size, bias=False)
        nn.init.sparse_(self.cross_connections.weight, sparsity=0.9)

        # Output projection also uses eff_res_size
        self.output_projection = nn.Sequential(
            nn.Linear(self.eff_res_size, config.HIDDEN_SIZE),
            nn.LayerNorm(config.HIDDEN_SIZE),
            nn.GELU(),
            nn.Dropout(config.DROPOUT)
        )
        self.states = None

    def _init_W(self, size, connectivity, spectral_radius):
        W = torch.zeros(size, size)
        num_conn = int(connectivity * size * size)
        idx = torch.randperm(size*size)[:num_conn]
        i_idx = idx // size
        j_idx = idx % size
        mask = i_idx != j_idx
        i_idx, j_idx = i_idx[mask], j_idx[mask]
        W[i_idx, j_idx] = torch.randn(len(i_idx)) * 0.1
        eigs = torch.linalg.eigvals(W)
        curr_rad = torch.max(torch.abs(eigs)).item()
        if curr_rad > 0:
            W *= (spectral_radius / curr_rad)
        return W

    def reset_state(self, batch_size=1):
        self.states = [
            torch.zeros(batch_size, self.sub_size, device=self.cross_connections.weight.device)
            for _ in range(self.num_timescales)
        ]

    def forward(self, x):
        batch = x.size(0)
        if self.states is None or self.states[0].size(0) != batch:
            self.reset_state(batch)

        new_states = []
        for i, sub in enumerate(self.sub_reservoirs):
            inp = self.input_projections[i](x)
            state = sub(
                prev_state=self.states[i],
                input_contrib=inp,
                noise_level=self.config.NOISE_LEVEL,
                training=self.training
            )
            new_states.append(state)
        self.states = new_states

        combined = torch.cat(self.states, dim=1)
        # combined now shape [batch, eff_res_size]
        cc = self.cross_connections(combined)
        combined = combined + 0.1 * torch.tanh(cc)

        out = self.output_projection(combined)
        return out, combined


#----------------------------------------------------------------------
# Attention Mechanisms
#----------------------------------------------------------------------

class MultiHeadSelfAttention(nn.Module):
    """Efficient multi-head self-attention."""

    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Linear transformations and split into heads
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)

        if mask is not None:
            scores.masked_fill_(mask == 0, -1e9)

        # Attention weights
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention
        context = torch.matmul(attn_weights, V)

        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        # Output projection
        output = self.W_o(context)

        # Residual connection and layer norm
        output = self.layer_norm(x + self.dropout(output))

        return output, attn_weights

class TemporalConvEncoder(nn.Module):
    """Temporal convolutional encoder for capturing local patterns."""

    def __init__(self, input_dim, hidden_dim, kernel_sizes=[3, 5, 7]):
        super().__init__()

        self.convs = nn.ModuleList([
            nn.Conv1d(input_dim, hidden_dim // len(kernel_sizes),
                     kernel_size=k, padding=k//2)
            for k in kernel_sizes
        ])

        self.norm = nn.LayerNorm(hidden_dim)
        self.activation = nn.GELU()

    def forward(self, x):
        # x: [batch, time, features]
        x = x.transpose(1, 2)  # [batch, features, time]

        # Apply multiple convolutions
        conv_outputs = []
        for conv in self.convs:
            conv_out = self.activation(conv(x))
            conv_outputs.append(conv_out)

        # Concatenate
        output = torch.cat(conv_outputs, dim=1)  # [batch, hidden_dim, time]
        output = output.transpose(1, 2)  # [batch, time, hidden_dim]

        return self.norm(output)

#----------------------------------------------------------------------
# Enhanced Main Model
#----------------------------------------------------------------------

class EnhancedESNv2(nn.Module):
    """Enhanced ESN with multi-timescale reservoir and advanced architecture."""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_electrodes = config.NUM_ELECTRODES
        self.num_bands = len(config.FREQ_BANDS)

        # Multi-scale feature extraction
        self.band_encoder = nn.Sequential(
            nn.Linear(self.num_bands * self.num_electrodes, config.HIDDEN_SIZE),
            nn.LayerNorm(config.HIDDEN_SIZE),
            nn.GELU(),
            nn.Dropout(config.DROPOUT)
        )

        # Temporal convolutional encoder
        self.temporal_encoder = TemporalConvEncoder(
            input_dim=self.num_electrodes,
            hidden_dim=config.HIDDEN_SIZE,
            kernel_sizes=[3, 5, 7, 9]
        )

        # Multi-timescale reservoir
        self.reservoir = MultiTimescaleReservoir(config)

        # Attention mechanism
        self.self_attention = MultiHeadSelfAttention(
            d_model=config.HIDDEN_SIZE,
            n_heads=config.NUM_HEADS,
            dropout=config.DROPOUT
        )

        # Electrode-specific encoders
        self.electrode_encoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.HIDDEN_SIZE, config.ELECTRODE_PERSPECTIVE_DIM),
                nn.LayerNorm(config.ELECTRODE_PERSPECTIVE_DIM),
                nn.GELU(),
                nn.Dropout(config.DROPOUT)
            ) for _ in range(self.num_electrodes)
        ])

        # Cross-electrode attention
        self.cross_electrode_attention = nn.MultiheadAttention(
            embed_dim=config.ELECTRODE_PERSPECTIVE_DIM,
            num_heads=8,
            dropout=config.DROPOUT,
            batch_first=True
        )

        # Temporal processor (GRU instead of LSTM for faster training)
        self.temporal_processor = nn.GRU(
            input_size = config.HIDDEN_SIZE + self.reservoir.eff_res_size,
            hidden_size=config.READOUT_HIDDEN,
            num_layers=3,
            batch_first=True,
            dropout=config.DROPOUT,
            bidirectional=True
        )

        # Prediction heads with skip connections
        self.prediction_heads = nn.ModuleList()
        for t in range(config.PREDICT_SIZE):
            head = nn.Sequential(
                nn.Linear(config.READOUT_HIDDEN * 2, config.READOUT_HIDDEN),
                nn.LayerNorm(config.READOUT_HIDDEN),
                nn.GELU(),
                nn.Dropout(config.DROPOUT),
                nn.Linear(config.READOUT_HIDDEN, self.num_electrodes)
            )
            self.prediction_heads.append(head)

        # Direct pathway for residual predictions
        self.direct_predictor = nn.Linear(self.num_electrodes, self.num_electrodes)

        # Reconstruction decoder
        self.reconstruction_decoder = nn.Sequential(
            nn.Linear(config.ELECTRODE_PERSPECTIVE_DIM * self.num_electrodes,
                     config.HIDDEN_SIZE * 2),
            nn.LayerNorm(config.HIDDEN_SIZE * 2),
            nn.GELU(),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.HIDDEN_SIZE * 2, config.HIDDEN_SIZE),
            nn.LayerNorm(config.HIDDEN_SIZE),
            nn.GELU(),
            nn.Linear(config.HIDDEN_SIZE, self.num_electrodes * config.CONTEXT_SIZE)
        )

    def forward(self, x_dict, return_perspectives=False):
        """Forward pass with multi-scale processing."""
        x_raw = x_dict['raw']
        x_band_powers = x_dict['band_powers']
        x_band_envelopes = x_dict['band_envelopes']

        batch_size, context_size, num_electrodes = x_raw.shape
        device = x_raw.device

        # Reset reservoir
        self.reservoir.reset_state(batch_size)

        # Storage
        hidden_states = []
        reservoir_states = []
        electrode_perspectives = []

        # Process sequence
        for t in range(context_size):
            # Get current timestep
            x_t = x_raw[:, t, :]

            # Band power features
            band_features = x_band_powers.reshape(batch_size, -1)
            band_encoded = self.band_encoder(band_features)

            # Temporal convolution features (use small window around t)
            window_start = max(0, t - 2)
            window_end = min(context_size, t + 3)
            x_window = x_raw[:, window_start:window_end, :]

            # Pad if necessary
            if x_window.shape[1] < 5:
                pad_size = 5 - x_window.shape[1]
                x_window = F.pad(x_window, (0, 0, 0, pad_size), 'constant', 0)

            temporal_features = self.temporal_encoder(x_window)
            temporal_features = temporal_features[:, 2, :]  # Center of window

            # Reservoir processing
            reservoir_out, reservoir_state = self.reservoir(x_t)

            # Combine features
            combined = reservoir_out + 0.5 * band_encoded + 0.5 * temporal_features

            hidden_states.append(combined)
            reservoir_states.append(reservoir_state)

        # Stack sequences
        hidden_sequence = torch.stack(hidden_states, dim=1)
        reservoir_sequence = torch.stack(reservoir_states, dim=1)

        # Self-attention over time
        attended_sequence, _ = self.self_attention(hidden_sequence)

        # Extract electrode perspectives
        if return_perspectives:
            for e_idx in range(self.num_electrodes):
                # Electrode-specific encoding
                electrode_features = self.electrode_encoders[e_idx](attended_sequence)

                # Average over time
                electrode_perspective = electrode_features.mean(dim=1)
                electrode_perspectives.append(electrode_perspective)

            # Stack and apply cross-electrode attention
            electrode_perspectives_stacked = torch.stack(electrode_perspectives, dim=1)
            attended_perspectives, _ = self.cross_electrode_attention(
                electrode_perspectives_stacked,
                electrode_perspectives_stacked,
                electrode_perspectives_stacked
            )

        # Combine hidden and reservoir states
        combined_sequence = torch.cat([attended_sequence, reservoir_sequence], dim=2)

        # Temporal processing
        gru_out, _ = self.temporal_processor(combined_sequence)

        # Get last hidden state
        final_hidden = gru_out[:, -1, :]

        # Generate predictions
        predictions = []
        hidden = final_hidden

        # Use last context as base for residual prediction
        last_context = x_raw[:, -1, :]

        for t in range(config.PREDICT_SIZE):
            # Main prediction
            pred_main = self.prediction_heads[t](hidden)

            # Residual prediction
            pred_residual = self.direct_predictor(last_context)

            # Combine predictions
            pred_t = pred_main + 0.1 * pred_residual
            predictions.append(pred_t)

            # Update for next step
            if t < config.PREDICT_SIZE - 1:
                # Simple autoregressive update
                hidden = hidden * 0.95  # Decay
                last_context = pred_t.detach()  # Use prediction as context

        # Stack predictions
        predictions = torch.stack(predictions, dim=1)

        if return_perspectives:
            # Reconstruction
            all_perspectives = torch.cat(electrode_perspectives, dim=1)
            reconstruction = self.reconstruction_decoder(all_perspectives)
            reconstruction = reconstruction.view(batch_size, context_size, num_electrodes)

            return predictions, {
                'reconstruction': reconstruction,
                'electrode_perspectives': electrode_perspectives,
                'attended_sequence': attended_sequence
            }
        else:
            return predictions

#----------------------------------------------------------------------
# Enhanced Loss Functions
#----------------------------------------------------------------------

class EnhancedLoss(nn.Module):
    """Comprehensive loss with multiple components."""

    def __init__(self, config):
        super().__init__()
        self.config = config

    def spectral_loss(self, pred, target, fs=1000):
        """Loss in frequency domain with focus on gamma."""
        # Compute power spectral density
        pred_fft = torch.fft.rfft(pred, dim=1)
        target_fft = torch.fft.rfft(target, dim=1)

        pred_psd = torch.abs(pred_fft)**2
        target_psd = torch.abs(target_fft)**2

        # Frequency bins
        n_fft = pred_psd.shape[1]
        freqs = torch.fft.rfftfreq(pred.shape[1], 1/fs).to(pred.device)

        # Gamma band mask (30-200 Hz)
        gamma_mask = (freqs >= 30) & (freqs <= 200)

        # Weighted spectral loss
        spectral_loss = F.mse_loss(pred_psd, target_psd)
        gamma_loss = F.mse_loss(pred_psd[:, gamma_mask], target_psd[:, gamma_mask])

        return 0.5 * spectral_loss + 0.5 * gamma_loss

    def phase_coherence_loss(self, pred, target):
        """Phase coherence loss for maintaining oscillatory structure."""
        # Hilbert transform for phase
        pred_complex = torch.view_as_complex(
            torch.stack([pred, torch.zeros_like(pred)], dim=-1)
        )
        target_complex = torch.view_as_complex(
            torch.stack([target, torch.zeros_like(target)], dim=-1)
        )

        # Phase difference
        phase_diff = torch.angle(pred_complex) - torch.angle(target_complex)

        # Circular mean of phase differences
        coherence = torch.abs(torch.mean(torch.exp(1j * phase_diff)))

        return 1.0 - coherence.mean()

    def smoothness_regularization(self, pred):
        """Penalize excessive high-frequency noise."""
        # Second-order differences
        diff2 = pred[:, 2:, :] - 2*pred[:, 1:-1, :] + pred[:, :-2, :]
        return torch.mean(diff2**2)

    def forward(self, predictions, targets, aux_outputs=None):
        """Compute all loss components."""
        # Time domain loss (MSE + MAE for robustness)
        time_loss = 0.7 * F.mse_loss(predictions, targets) + \
                   0.3 * F.l1_loss(predictions, targets)

        # Frequency domain loss
        freq_loss = 0
        for e in range(predictions.shape[2]):
            freq_loss += self.spectral_loss(predictions[:, :, e], targets[:, :, e])
        freq_loss /= predictions.shape[2]

        # Phase coherence
        phase_loss = self.phase_coherence_loss(predictions, targets)

        # Smoothness
        smooth_loss = self.smoothness_regularization(predictions)

        # Total prediction loss
        total_loss = (
            self.config.TIME_LOSS_WEIGHT * time_loss +
            self.config.FREQ_LOSS_WEIGHT * freq_loss +
            self.config.PHASE_LOSS_WEIGHT * phase_loss +
            self.config.SMOOTHNESS_PENALTY * smooth_loss
        )

        # Reconstruction loss if available
        recon_loss = torch.tensor(0.0, device=predictions.device)
        if aux_outputs is not None and 'reconstruction' in aux_outputs and 'input_data' in aux_outputs:
            reconstruction = aux_outputs['reconstruction']
            input_data = aux_outputs['input_data']
            recon_loss = F.mse_loss(reconstruction, input_data)
            total_loss += self.config.RECONSTRUCTION_LOSS_WEIGHT * recon_loss

        return {
            'total_loss': total_loss,
            'time_loss': time_loss,
            'freq_loss': freq_loss,
            'phase_loss': phase_loss,
            'smooth_loss': smooth_loss,
            'recon_loss': recon_loss
        }

#----------------------------------------------------------------------
# Training with Improved Visualization
#----------------------------------------------------------------------

def visualize_predictions_v2(model, data_loader, scaler, epoch, save_dir, num_samples=5):
    """Enhanced visualization with spectral analysis."""
    model.eval()

    with torch.no_grad():
        inputs, targets = next(iter(data_loader))
        inputs = {k: v.to(device) for k, v in inputs.items()}
        targets = targets.to(device)

        # Get predictions
        predictions, aux_outputs = model(inputs, return_perspectives=True)

        # Convert to numpy and inverse transform
        inputs_np = inputs['raw'].cpu().numpy()
        targets_np = targets.cpu().numpy()
        predictions_np = predictions.cpu().numpy()
        reconstruction_np = aux_outputs['reconstruction'].cpu().numpy()

        # Inverse transform
        batch_size, context_size, num_electrodes = inputs_np.shape
        _, predict_size, _ = predictions_np.shape

        # Reshape and inverse transform
        inputs_orig = scaler.inverse_transform(
            inputs_np.reshape(-1, num_electrodes)
        ).reshape(batch_size, context_size, num_electrodes)

        targets_orig = scaler.inverse_transform(
            targets_np.reshape(-1, num_electrodes)
        ).reshape(batch_size, predict_size, num_electrodes)

        predictions_orig = scaler.inverse_transform(
            predictions_np.reshape(-1, num_electrodes)
        ).reshape(batch_size, predict_size, num_electrodes)

        reconstruction_orig = scaler.inverse_transform(
            reconstruction_np.reshape(-1, num_electrodes)
        ).reshape(batch_size, context_size, num_electrodes)

        # Create figure with spectral analysis
        fig = plt.figure(figsize=(24, 5*num_samples))
        gs = fig.add_gridspec(num_samples, 4, width_ratios=[3, 1, 3, 1])

        for sample_idx in range(min(num_samples, batch_size)):
            # Select diverse electrodes
            electrode_indices = [0, num_electrodes//3, 2*num_electrodes//3, num_electrodes-1]

            for e_idx, electrode in enumerate(electrode_indices[:2]):  # Show 2 electrodes per row
                # Time domain plot
                ax_time = fig.add_subplot(gs[sample_idx, e_idx*2])

                # Signals
                context_signal = inputs_orig[sample_idx, :, electrode]
                target_signal = targets_orig[sample_idx, :, electrode]
                pred_signal = predictions_orig[sample_idx, :, electrode]
                recon_signal = reconstruction_orig[sample_idx, :, electrode]

                # Time axes
                context_time = np.arange(context_size)
                future_time = np.arange(context_size, context_size + predict_size)

                # Plot
                ax_time.plot(context_time, context_signal, 'b-', label='Context', alpha=0.7, linewidth=1.5)
                ax_time.plot(context_time, recon_signal, 'g--', label='Reconstruction', alpha=0.7, linewidth=1.5)
                ax_time.plot(future_time, target_signal, 'k-', label='Target', linewidth=2)
                ax_time.plot(future_time, pred_signal, 'r--', label='Prediction', linewidth=2)

                # Metrics
                mse = np.mean((pred_signal - target_signal)**2)
                corr = np.corrcoef(pred_signal, target_signal)[0, 1]

                ax_time.set_title(f'Sample {sample_idx+1}, Electrode {electrode}\nMSE: {mse:.4f}, Corr: {corr:.3f}')
                ax_time.set_xlabel('Time (ms)')
                ax_time.set_ylabel('Amplitude (μV)')
                ax_time.legend()
                ax_time.grid(True, alpha=0.3)

                # Spectral analysis
                ax_spec = fig.add_subplot(gs[sample_idx, e_idx*2 + 1])

                # Compute spectra
                from scipy.signal import welch
                f_target, psd_target = welch(target_signal, fs=1000, nperseg=min(64, len(target_signal)))
                f_pred, psd_pred = welch(pred_signal, fs=1000, nperseg=min(64, len(pred_signal)))

                # Plot spectra
                ax_spec.semilogy(f_target, psd_target, 'k-', label='Target', linewidth=2)
                ax_spec.semilogy(f_pred, psd_pred, 'r--', label='Prediction', linewidth=2)

                # Highlight gamma band
                ax_spec.axvspan(30, 200, alpha=0.2, color='yellow', label='Gamma')

                ax_spec.set_xlabel('Frequency (Hz)')
                ax_spec.set_ylabel('PSD')
                ax_spec.set_xlim([0, 250])
                ax_spec.legend()
                ax_spec.grid(True, alpha=0.3)

        plt.suptitle(f'Enhanced Signal Reconstruction - Epoch {epoch}', fontsize=16)
        plt.tight_layout()

        # Save
        save_path = os.path.join(save_dir, f'reconstruction_v2_epoch_{epoch}.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"Saved enhanced visualization to {save_path}")

def train_epoch_v2(model, train_loader, optimizer, criterion, device, epoch):
    """Training with gradient accumulation."""
    model.train()
    losses = defaultdict(float)

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')

    for i, (inputs, targets) in enumerate(progress_bar):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        targets = targets.to(device)

        # Forward pass
        predictions, aux_outputs = model(inputs, return_perspectives=True)
        aux_outputs['input_data'] = inputs['raw']

        # Compute loss
        loss_dict = criterion(predictions, targets, aux_outputs)
        loss = loss_dict['total_loss']

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.CLIP_GRAD_NORM)

        optimizer.step()

        # Track losses
        for key, value in loss_dict.items():
            losses[key] += value.item()

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'time': f"{loss_dict['time_loss'].item():.4f}",
            'freq': f"{loss_dict['freq_loss'].item():.4f}"
        })

    # Average losses
    for key in losses:
        losses[key] /= len(train_loader)

    return losses

def evaluate_v2(model, val_loader, criterion, device):
    """Evaluation with comprehensive metrics."""
    model.eval()
    losses = defaultdict(float)
    predictions_all = []
    targets_all = []

    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc='Evaluating'):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            targets = targets.to(device)

            predictions, aux_outputs = model(inputs, return_perspectives=True)
            aux_outputs['input_data'] = inputs['raw']

            loss_dict = criterion(predictions, targets, aux_outputs)

            for key, value in loss_dict.items():
                losses[key] += value.item()

            predictions_all.append(predictions.cpu().numpy())
            targets_all.append(targets.cpu().numpy())

    # Average losses
    for key in losses:
        losses[key] /= len(val_loader)

    # Compute metrics
    predictions_all = np.concatenate(predictions_all, axis=0)
    targets_all = np.concatenate(targets_all, axis=0)

    # MSE
    mse = np.mean((predictions_all - targets_all)**2)

    # Correlation per electrode
    correlations = []
    for e in range(predictions_all.shape[2]):
        pred_e = predictions_all[:, :, e].flatten()
        target_e = targets_all[:, :, e].flatten()
        if np.std(pred_e) > 0 and np.std(target_e) > 0:
            corr = np.corrcoef(pred_e, target_e)[0, 1]
            if not np.isnan(corr):
                correlations.append(corr)

    avg_corr = np.mean(correlations) if correlations else 0

    # R-squared
    r2_scores = []
    for e in range(predictions_all.shape[2]):
        pred_e = predictions_all[:, :, e].flatten()
        target_e = targets_all[:, :, e].flatten()
        if np.var(target_e) > 0:
            r2 = 1 - np.sum((target_e - pred_e)**2) / np.sum((target_e - np.mean(target_e))**2)
            r2_scores.append(r2)

    avg_r2 = np.mean(r2_scores) if r2_scores else 0

    losses['mse'] = mse
    losses['corr'] = avg_corr
    losses['r2'] = avg_r2

    return losses

def train_model_v2(model, train_loader, val_loader, config, scaler, test_loader=None):
    """Complete training pipeline."""
    criterion = EnhancedLoss(config)

    # Optimizer with different LR for different parts
    optimizer = torch.optim.AdamW([
        {'params': model.reservoir.parameters(), 'lr': config.LEARNING_RATE * 0.1},  # Lower LR for reservoir
        {'params': model.temporal_encoder.parameters(), 'lr': config.LEARNING_RATE},
        {'params': model.self_attention.parameters(), 'lr': config.LEARNING_RATE},
        {'params': model.prediction_heads.parameters(), 'lr': config.LEARNING_RATE * 2}  # Higher LR for output
    ], weight_decay=config.WEIGHT_DECAY)

    # Cosine annealing with warm restarts
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2, eta_min=1e-6)

    history = defaultdict(list)
    best_val_loss = float('inf')
    best_val_corr = 0
    patience_counter = 0

    print(f"Starting enhanced training for {config.EPOCHS} epochs...")

    for epoch in range(config.EPOCHS):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{config.EPOCHS} | LR: {optimizer.param_groups[0]['lr']:.6f}")
        print(f"{'='*60}")

        # Train
        train_losses = train_epoch_v2(model, train_loader, optimizer, criterion, device, epoch+1)

        # Evaluate
        val_losses = evaluate_v2(model, val_loader, criterion, device)

        # Update scheduler
        scheduler.step()

        # Store history
        for key, value in train_losses.items():
            history[f'train_{key}'].append(value)
        for key, value in val_losses.items():
            history[f'val_{key}'].append(value)

        # Print results
        print(f"\nTrain - Total: {train_losses['total_loss']:.4f}, Time: {train_losses['time_loss']:.4f}, "
              f"Freq: {train_losses['freq_loss']:.4f}, Recon: {train_losses['recon_loss']:.4f}")
        print(f"Val - Total: {val_losses['total_loss']:.4f}, MSE: {val_losses['mse']:.4f}, "
              f"Corr: {val_losses['corr']:.3f}, R²: {val_losses['r2']:.3f}")

        # Visualize
        if (epoch + 1) % 5 == 0 or epoch == 0:
            visualize_predictions_v2(
                model, val_loader, scaler, epoch+1,
                os.path.join(config.OUTPUT_DIR, 'reconstructions'),
                num_samples=5
            )

        # Save best model
        if val_losses['corr'] > best_val_corr:
            best_val_corr = val_losses['corr']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_corr': best_val_corr,
                'val_losses': val_losses
            }, os.path.join(config.OUTPUT_DIR, 'models', 'best_model_corr.pt'))
            print(f"Saved best model with correlation: {best_val_corr:.3f}")

        if val_losses['total_loss'] < best_val_loss:
            best_val_loss = val_losses['total_loss']
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss
            }, os.path.join(config.OUTPUT_DIR, 'models', 'best_model_loss.pt'))
        else:
            patience_counter += 1

        # Early stopping
        if patience_counter >= config.PATIENCE:
            print(f"Early stopping after {epoch+1} epochs")
            break

    # Test evaluation
    if test_loader:
        print("\nFinal test evaluation...")
        test_losses = evaluate_v2(model, test_loader, criterion, device)
        print(f"Test - MSE: {test_losses['mse']:.4f}, Corr: {test_losses['corr']:.3f}, R²: {test_losses['r2']:.3f}")

        visualize_predictions_v2(
            model, test_loader, scaler, 'test',
            os.path.join(config.OUTPUT_DIR, 'reconstructions'),
            num_samples=10
        )

    return model, history

#----------------------------------------------------------------------
# Main Pipeline
#----------------------------------------------------------------------

def main_v2(data_path):
    """Enhanced main pipeline."""
    print("="*80)
    print("Enhanced ESN V2 - Multi-Timescale Biologically Plausible Model")
    print("="*80)

    try:
        # Load data
        positions = load_electrode_positions(data_path)
        channel_nums, selected_indices = select_electrodes_fixed(positions)
        lfp_data = load_lfp_data(data_path, channel_nums)

        # Extract windows
        windows = extract_windows(lfp_data)
        print(f"Extracted {windows.shape[0]} windows")

        # Preprocess with enhanced features
        train_loader, val_loader, test_loader, scaler = preprocess_data_v2(windows)

        # Create model
        model = EnhancedESNv2(config).to(device)

        # Print model info
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Model parameters - Total: {total_params:,}, Trainable: {trainable_params:,}")

        # Train
        model, history = train_model_v2(model, train_loader, val_loader, config, scaler, test_loader)

        print("\nTraining completed successfully!")
        return model, scaler

    except Exception as e:
        print(f"Error: {e}")
        traceback.print_exc()
        return None, None

if __name__ == "__main__":
    # Data path resolution
    data_path = None

    if os.path.exists("/kaggle/input/ecog-landmark-mkn"):
        data_path = "/kaggle/input/ecog-landmark-mkn"
    elif os.path.exists("/content/ecog-landmark-mkn"):
        data_path = "/content/ecog-landmark-mkn"
    else:
        try:
            import kagglehub
            data_path = kagglehub.dataset_download("arunramponnambalam/ecog-landmark-mkn")
        except:
            data_path = input("Enter dataset path: ")

    if data_path and os.path.exists(data_path):
        print(f"Using dataset at: {data_path}")
        model, scaler = main_v2(data_path)
    else:
        print(f"Dataset not found: {data_path}")

Using device: cuda
Using dataset at: /kaggle/input/ecog-landmark-mkn
Enhanced ESN V2 - Multi-Timescale Biologically Plausible Model
Extracted 800 windows
Extracting multi-scale features...
Processing band 1/7: 4-8 Hz


Band 1: 100%|██████████| 800/800 [00:16<00:00, 49.21it/s]


Processing band 2/7: 8-13 Hz


Band 2: 100%|██████████| 800/800 [00:16<00:00, 49.23it/s]


Processing band 3/7: 13-30 Hz


Band 3: 100%|██████████| 800/800 [00:16<00:00, 49.28it/s]


Processing band 4/7: 30-50 Hz


Band 4: 100%|██████████| 800/800 [00:16<00:00, 49.19it/s]


Processing band 5/7: 50-80 Hz


Band 5: 100%|██████████| 800/800 [00:16<00:00, 49.20it/s]


Processing band 6/7: 80-120 Hz


Band 6: 100%|██████████| 800/800 [00:16<00:00, 48.92it/s]


Processing band 7/7: 120-200 Hz


Band 7: 100%|██████████| 800/800 [00:16<00:00, 49.33it/s]


Model parameters - Total: 187,860,234, Trainable: 186,529,566
Starting enhanced training for 200 epochs...

Epoch 1/200 | LR: 0.000050


Epoch 1: 100%|██████████| 9/9 [00:08<00:00,  1.03it/s, loss=1769.6400, time=0.3954, freq=3537.4412]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  2.95it/s]


Train - Total: 15718.5664, Time: 0.6156, Freq: 31434.8233, Recon: 0.2279
Val - Total: 1873.7795, MSE: 0.3437, Corr: 0.019, R²: -2.771





Saved enhanced visualization to enhanced_esn_v2_results/reconstructions/reconstruction_v2_epoch_1.png
Saved best model with correlation: 0.019

Epoch 2/200 | LR: 0.000050


Epoch 2: 100%|██████████| 9/9 [00:06<00:00,  1.29it/s, loss=821.3833, time=0.2798, freq=1641.1815]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.09it/s]



Train - Total: 1301.6575, Time: 0.3351, Freq: 2601.6092, Recon: 0.2278
Val - Total: 1075.0524, MSE: 0.2123, Corr: 0.045, R²: -1.356
Saved best model with correlation: 0.045

Epoch 3/200 | LR: 0.000049


Epoch 3: 100%|██████████| 9/9 [00:06<00:00,  1.36it/s, loss=1012.4814, time=0.2408, freq=2023.5006]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.03it/s]



Train - Total: 1030.7451, Time: 0.2621, Freq: 2059.9723, Recon: 0.2273
Val - Total: 980.0897, MSE: 0.1909, Corr: 0.043, R²: -1.138

Epoch 4/200 | LR: 0.000047


Epoch 4: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=786.3887, time=0.2272, freq=1571.3575]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.06it/s]



Train - Total: 828.8215, Time: 0.2390, Freq: 1656.1851, Recon: 0.2268
Val - Total: 874.5592, MSE: 0.1778, Corr: 0.034, R²: -0.993

Epoch 5/200 | LR: 0.000045


Epoch 5: 100%|██████████| 9/9 [00:06<00:00,  1.36it/s, loss=765.7691, time=0.2432, freq=1530.0579]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.09it/s]


Train - Total: 712.6899, Time: 0.2334, Freq: 1423.9317, Recon: 0.2263
Val - Total: 836.7971, MSE: 0.1755, Corr: 0.022, R²: -0.981





Saved enhanced visualization to enhanced_esn_v2_results/reconstructions/reconstruction_v2_epoch_5.png

Epoch 6/200 | LR: 0.000043


Epoch 6: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=466.3839, time=0.2209, freq=931.3479]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  2.94it/s]



Train - Total: 614.4341, Time: 0.2246, Freq: 1227.4461, Recon: 0.2263
Val - Total: 607.1563, MSE: 0.1592, Corr: 0.065, R²: -0.784
Saved best model with correlation: 0.065

Epoch 7/200 | LR: 0.000040


Epoch 7: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=547.0149, time=0.2278, freq=1092.5961]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.05it/s]



Train - Total: 502.6549, Time: 0.2200, Freq: 1003.8994, Recon: 0.2265
Val - Total: 581.5910, MSE: 0.1661, Corr: 0.041, R²: -0.860

Epoch 8/200 | LR: 0.000037


Epoch 8: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=497.4702, time=0.2149, freq=993.5527]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  2.99it/s]



Train - Total: 461.8109, Time: 0.2196, Freq: 922.2107, Recon: 0.2262
Val - Total: 553.9359, MSE: 0.1604, Corr: 0.051, R²: -0.808

Epoch 9/200 | LR: 0.000033


Epoch 9: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=477.1280, time=0.2124, freq=952.8560]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.09it/s]



Train - Total: 428.0083, Time: 0.2159, Freq: 854.6212, Recon: 0.2261
Val - Total: 552.4408, MSE: 0.1524, Corr: 0.047, R²: -0.714

Epoch 10/200 | LR: 0.000029


Epoch 10: 100%|██████████| 9/9 [00:06<00:00,  1.36it/s, loss=356.1078, time=0.2076, freq=710.8458]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.06it/s]


Train - Total: 405.2417, Time: 0.2175, Freq: 809.0804, Recon: 0.2263
Val - Total: 510.8385, MSE: 0.1517, Corr: 0.059, R²: -0.707





Saved enhanced visualization to enhanced_esn_v2_results/reconstructions/reconstruction_v2_epoch_10.png

Epoch 11/200 | LR: 0.000026


Epoch 11: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=333.7187, time=0.2226, freq=666.0239]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.02it/s]



Train - Total: 351.1907, Time: 0.2130, Freq: 700.9900, Recon: 0.2267
Val - Total: 450.6117, MSE: 0.1572, Corr: 0.045, R²: -0.771

Epoch 12/200 | LR: 0.000022


Epoch 12: 100%|██████████| 9/9 [00:06<00:00,  1.34it/s, loss=243.4493, time=0.2120, freq=485.5200]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.03it/s]



Train - Total: 316.7224, Time: 0.2153, Freq: 632.0508, Recon: 0.2267
Val - Total: 419.5690, MSE: 0.1558, Corr: 0.057, R²: -0.763

Epoch 13/200 | LR: 0.000018


Epoch 13: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=286.6450, time=0.1907, freq=571.9791]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.00it/s]



Train - Total: 280.8790, Time: 0.2148, Freq: 560.3675, Recon: 0.2262
Val - Total: 436.6621, MSE: 0.1602, Corr: 0.050, R²: -0.802

Epoch 14/200 | LR: 0.000014


Epoch 14: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=274.9888, time=0.2325, freq=548.5300]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.06it/s]



Train - Total: 256.2288, Time: 0.2155, Freq: 511.0623, Recon: 0.2269
Val - Total: 393.6347, MSE: 0.1568, Corr: 0.055, R²: -0.766

Epoch 15/200 | LR: 0.000011


Epoch 15: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=213.1144, time=0.1992, freq=424.8656]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.03it/s]


Train - Total: 226.8370, Time: 0.2151, Freq: 452.2786, Recon: 0.2267
Val - Total: 381.9857, MSE: 0.1588, Corr: 0.051, R²: -0.794





Saved enhanced visualization to enhanced_esn_v2_results/reconstructions/reconstruction_v2_epoch_15.png

Epoch 16/200 | LR: 0.000008


Epoch 16: 100%|██████████| 9/9 [00:06<00:00,  1.35it/s, loss=199.5848, time=0.2067, freq=397.8058]
Evaluating: 100%|██████████| 2/2 [00:00<00:00,  3.02it/s]



Train - Total: 206.6465, Time: 0.2153, Freq: 411.8985, Recon: 0.2263
Val - Total: 365.9469, MSE: 0.1561, Corr: 0.056, R²: -0.759

Epoch 17/200 | LR: 0.000006


Epoch 17:   0%|          | 0/9 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [8]:
%%bash
git config --global user.name "Krish0909"