In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import scipy.signal as signal
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
import os
import pandas as pd
from scipy.stats import skew, kurtosis
import seaborn as sns
import gc

# STEP 1: Data Loading and Preprocessing
class EEGDataset(Dataset):
    def __init__(self, data_dict, label_mapping_file=None, transform=None):
        self.dataset = data_dict['dataset']
        self.labels_list = data_dict['labels']  # List of label IDs
        self.transform = transform

        # Load the mapping from label ID to text if provided
        self.label_id_to_text = {}
        if label_mapping_file and os.path.exists(label_mapping_file):
            with open(label_mapping_file, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 2:
                        label_id = parts[0]
                        label_text = ' '.join(parts[1:])
                        self.label_id_to_text[label_id] = label_text

        # Create mapping from label ID to index
        unique = sorted(set(self.labels_list))
        self.label_to_idx = {lbl:i for i,lbl in enumerate(unique)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        eeg_data = sample['eeg_data']  # Shape: [channels, time_points]
        label_id = sample['label']  # This is a string ID like 'n02510455'

        # Get label text if available
        label_text = self.label_id_to_text.get(label_id, label_id)

        # Convert string label to numerical index
        label_idx = self.label_to_idx.get(label_id, 0)

        if self.transform:
            eeg_data = self.transform(eeg_data)

        return eeg_data, label_idx, label_text

# EEG Signal Preprocessing
class EEGPreprocessor:
    def __init__(self, sampling_rate=1000, notch_freq=50, bandpass_low=0.5, bandpass_high=70):
        self.sampling_rate = sampling_rate
        self.notch_freq = notch_freq
        self.bandpass_low = bandpass_low
        self.bandpass_high = bandpass_high

        # Pre-compute filter coefficients to avoid recomputation
        self.sos = signal.butter(
            N=4,
            Wn=[self.bandpass_low, self.bandpass_high],
            btype='bandpass',
            fs=self.sampling_rate,
            output='sos'
        )
        self.b_notch, self.a_notch = signal.iirnotch(self.notch_freq, 30, self.sampling_rate)

    def __call__(self, eeg_data):
        # Convert to numpy if it's a tensor
        if isinstance(eeg_data, torch.Tensor):
            eeg_data = eeg_data.numpy()

        # Transpose to [time, channels] for easier processing
        eeg_data = eeg_data.T

        # Apply bandpass filter
        eeg_filtered = self._bandpass_filter(eeg_data)

        # Apply notch filter (to remove power line interference)
        eeg_filtered = self._notch_filter(eeg_filtered)

        # Re-reference to common average
        eeg_filtered = self._common_average_reference(eeg_filtered)

        # Z-score normalization
        eeg_normalized = self._normalize(eeg_filtered)

        # Transpose back to [channels, time]
        return torch.tensor(eeg_normalized.T, dtype=torch.float32)

    def _bandpass_filter(self, data):
        # Apply forward-backward filtering for zero phase distortion
        return signal.sosfiltfilt(self.sos, data, axis=0)

    def _notch_filter(self, data):
        return signal.filtfilt(self.b_notch, self.a_notch, data, axis=0)

    def _common_average_reference(self, data):
        # Subtract the mean across all channels at each time point
        return data - np.mean(data, axis=1, keepdims=True)

    def _normalize(self, data):
        # Z-score normalization for each channel
        return (data - np.mean(data, axis=0)) / (np.std(data, axis=0) + 1e-10)

# STEP 2: Feature Extraction
class FeatureExtractor:
    def __init__(self, sampling_rate=1000):
        self.sampling_rate = sampling_rate

        # Define frequency bands
        self.freq_bands = {
          'delta': (1, 4),       # Adjusted lower bound to reduce DC components
          'theta': (4, 8),       # Standard theta for cognitive processing
          'alpha_low': (8, 10),  # Lower alpha - attention/inhibition
          'alpha_high': (10, 13),# Higher alpha - semantic processing
          'beta_low': (13, 20),  # Lower beta - motor preparation
          'beta_high': (20, 30), # Higher beta - active processing/cognition
          'gamma_low': (30, 60), # Expanded gamma_low for visual processing
          'gamma_mid': (60, 90), # Added gamma_mid for binding
          'gamma_high': (90, 120)# Higher gamma for fine perceptual binding
      }

    def extract_features(self, eeg_data):
        """
        Extract time and frequency domain features from EEG data

        Args:
            eeg_data: EEG data of shape [channels, time_points]

        Returns:
            features: Dictionary of extracted features
        """
        features = {}

        # Time domain features
        features.update(self._extract_time_domain_features(eeg_data))

        # Frequency domain features
        features.update(self._extract_frequency_domain_features(eeg_data))

        # Connectivity features - can improve classification accuracy
        features.update(self._extract_connectivity_features(eeg_data))

        # Convert dictionary to vector
        feature_vector = []
        for key, value in features.items():
            if isinstance(value, np.ndarray):
                feature_vector.append(value.flatten())
            else:
                feature_vector.append(np.array([value]).flatten())

        return np.concatenate(feature_vector)

    def _extract_time_domain_features(self, eeg_data):
        features = {}

        # Statistical features
        features['mean'] = np.mean(eeg_data, axis=1)
        features['var'] = np.var(eeg_data, axis=1)
        features['skewness'] = skew(eeg_data, axis=1)
        features['kurtosis'] = kurtosis(eeg_data, axis=1)
        features['max'] = np.max(eeg_data, axis=1)
        features['min'] = np.min(eeg_data, axis=1)
        features['peak_to_peak'] = features['max'] - features['min']  # Reuse computed values
        features['rms'] = np.sqrt(np.mean(np.square(eeg_data), axis=1))
        features['zero_crossings'] = np.sum(np.diff(np.signbit(eeg_data), axis=1), axis=1)

        # Hjorth parameters
        features.update(self._compute_hjorth_parameters(eeg_data))

        return features

    def _compute_hjorth_parameters(self, eeg_data):
        """Compute Hjorth parameters: Activity, Mobility, and Complexity"""
        features = {}

        # First derivative
        diff1 = np.diff(eeg_data, axis=1)
        # Second derivative
        diff2 = np.diff(diff1, axis=1)

        # Activity: variance of the signal
        features['activity'] = np.var(eeg_data, axis=1)

        # Mobility: sqrt(variance of first derivative / variance of signal)
        var_diff1 = np.var(diff1, axis=1)
        mobility1 = np.sqrt(var_diff1 / (features['activity'] + 1e-10))
        features['mobility'] = mobility1

        # Complexity: mobility of first derivative / mobility of signal
        var_diff2 = np.var(diff2, axis=1)
        mobility2 = np.sqrt(var_diff2 / (var_diff1 + 1e-10))
        features['complexity'] = mobility2 / (mobility1 + 1e-10)

        return features

    def _extract_frequency_domain_features(self, eeg_data):
        features = {}

        # Compute power spectral density with Welch's method
        nperseg = min(256, eeg_data.shape[1] // 4)  # Adaptive window size
        freqs, psd = signal.welch(eeg_data, fs=self.sampling_rate,
                                 nperseg=nperseg,
                                 noverlap=nperseg // 2,
                                 axis=1)

        # Calculate total power once
        total_power = np.sum(psd, axis=1) + 1e-10

        # Band powers and their ratios
        for band_name, (low_freq, high_freq) in self.freq_bands.items():
            # Find frequencies in the band
            idx_band = np.logical_and(freqs >= low_freq, freqs <= high_freq)
            # Calculate band power
            band_power = np.sum(psd[:, idx_band], axis=1)
            features[f'{band_name}_power'] = band_power

            # Calculate relative band power
            features[f'{band_name}_rel_power'] = band_power / total_power

        # Spectral edge frequency (95%)
        features['sef_95'] = self._compute_spectral_edge_frequency(freqs, psd, 0.95)

        # Spectral entropy
        features['spectral_entropy'] = self._compute_spectral_entropy(psd)

        # Spectral peak frequency and power
        peak_freqs = freqs[np.argmax(psd, axis=1)]
        peak_powers = np.max(psd, axis=1)
        features['peak_freq'] = peak_freqs
        features['peak_power'] = peak_powers

        return features

    def _compute_spectral_edge_frequency(self, freqs, psd, edge=0.95):
        """Compute frequency below which edge% of power resides"""
        sef = np.zeros(psd.shape[0])
        for i in range(psd.shape[0]):
            # Cumulative sum of PSD
            cumsum = np.cumsum(psd[i]) / (np.sum(psd[i]) + 1e-10)
            # Find frequency below which edge% of power resides
            idx = np.where(cumsum >= edge)[0]
            if len(idx) > 0:
                sef[i] = freqs[idx[0]]
            else:
                sef[i] = freqs[-1]
        return sef

    def _compute_spectral_entropy(self, psd):
        """Compute spectral entropy"""
        entropy = np.zeros(psd.shape[0])
        for i in range(psd.shape[0]):
            # Normalize PSD
            psd_norm = psd[i] / (np.sum(psd[i]) + 1e-10)
            # Calculate entropy
            entropy[i] = -np.sum(psd_norm * np.log2(psd_norm + 1e-10))
        return entropy

    def _extract_connectivity_features(self, eeg_data):
        """Extract connectivity features between EEG channels"""
        features = {}

        # Number of channels
        n_channels = eeg_data.shape[0]

        # Calculate correlation matrix
        corr_matrix = np.corrcoef(eeg_data)

        # Extract upper triangle (excluding diagonal)
        upper_tri_idx = np.triu_indices(n_channels, k=1)
        correlations = corr_matrix[upper_tri_idx]

        # Basic statistics of correlations
        features['mean_corr'] = np.mean(correlations)
        features['std_corr'] = np.std(correlations)
        features['max_corr'] = np.max(correlations)
        features['min_corr'] = np.min(correlations)

        # Phase synchronization - simplified version using Hilbert transform
        analytic_signal = signal.hilbert(eeg_data, axis=1)
        instantaneous_phase = np.angle(analytic_signal)

        # Calculate phase differences between adjacent channels
        phase_diff = np.zeros((n_channels-1,) + instantaneous_phase.shape[1:])
        for i in range(n_channels-1):
            phase_diff[i] = instantaneous_phase[i+1] - instantaneous_phase[i]

        # Phase locking value (PLV)
        plv_values = np.abs(np.mean(np.exp(1j * phase_diff), axis=1))
        features['mean_plv'] = np.mean(plv_values)
        features['std_plv'] = np.std(plv_values)

        return features

# Dataset with precomputed features
class EEGFeatureDataset(Dataset):
    def __init__(self, features, labels, texts):
        self.features = features
        self.labels = labels
        self.texts = texts

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feature = self.features[idx]
        label = self.labels[idx]

        # Get text from dictionary if texts is a dictionary
        if isinstance(self.texts, dict):
            text = self.texts.get(label, f"Unknown-{label}")
        else:
            text = self.texts[idx] if idx < len(self.texts) else f"Unknown-{label}"

        # Convert to tensors if not already
        if not isinstance(feature, torch.Tensor):
            feature = torch.tensor(feature, dtype=torch.float32)
        if not isinstance(label, torch.Tensor) and not isinstance(label, int):
            label = torch.tensor(label, dtype=torch.long)

        return feature, label, text

# STEP 3: Classification Model
class EEGClassifier(nn.Module):
    def __init__(self, input_dim, n_classes, hidden_dims=[4096, 2048, 1024],
                 seq_length=None, n_channels=None, dropout_rate=0.5):
        super(EEGClassifier, self).__init__()

        self.input_dim = input_dim
        self.n_classes = n_classes
        self.seq_length = seq_length
        self.n_channels = n_channels

        # Option to reshape as temporal sequence if seq_length and n_channels are provided
        self.reshape_input = seq_length is not None and n_channels is not None

        # Input normalization layer
        self.input_norm = nn.BatchNorm1d(input_dim)

        # PART 1: CNN FEATURE EXTRACTION (if seq_length and n_channels provided)
        if self.reshape_input:
            # CNN for spatial-temporal feature extraction
            self.conv_block = nn.Sequential(
                nn.Conv2d(1, 32, kernel_size=(1, 16), stride=(1, 2), padding=(0, 7)),
                nn.BatchNorm2d(32),
                nn.ELU(),
                nn.Conv2d(32, 64, kernel_size=(n_channels, 1), stride=1, padding=0),
                nn.BatchNorm2d(64),
                nn.ELU(),
                nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)),
                nn.Dropout(dropout_rate)
            )

            # Calculate output size after convolutions
            conv_output_size = self._calculate_conv_output_size()
            lstm_input_size = conv_output_size

            # LSTM for temporal dynamics
            self.lstm = nn.LSTM(
                input_size=64,  # Number of features per timestep (output channels from CNN)
                hidden_size=128,
                num_layers=2,
                batch_first=True,
                dropout=dropout_rate,
                bidirectional=True
            )

            # Self-attention mechanism for temporal focus
            self.attention = SelfAttention(256)  # 256 = 128*2 (bidirectional)

            # Set the input dimension for dense layers
            dense_input_dim = 256
        else:
            # If no reshape, use attention on flat input
            self.attention = nn.Sequential(
                nn.Linear(input_dim, input_dim // 4),
                nn.LeakyReLU(0.2),
                nn.Linear(input_dim // 4, input_dim),
                nn.Sigmoid()
            )
            dense_input_dim = input_dim

        # PART 2: DENSE NETWORK PATHWAY
        layers = []
        prev_dim = dense_input_dim

        for i, hidden_dim in enumerate(hidden_dims):
            # Dense block with residual connection if dimensions match
            if prev_dim == hidden_dim:
                layers.append(ResidualBlock(prev_dim, hidden_dim, dropout_rate))
            else:
                layers.append(DenseBlock(prev_dim, hidden_dim, dropout_rate))

            prev_dim = hidden_dim

            # Add Squeeze-and-Excitation blocks for feature recalibration
            if i < len(hidden_dims) - 1:  # Not for the last layer
                layers.append(SEBlock(hidden_dim))

        self.feature_layers = nn.Sequential(*layers)

        # Multi-head output with ensemble averaging
        self.heads = nn.ModuleList([
            nn.Linear(prev_dim, n_classes) for _ in range(3)
        ])

        # Initialize weights
        self.apply(self._init_weights)

    def _calculate_conv_output_size(self):
        # Calculate output size after convolutions
        # This is a placeholder - actual calculation depends on your exact architecture
        length_after_conv = ((self.seq_length - 16 + 2*7) // 2) + 1
        length_after_pool = length_after_conv // 4
        return 64 * length_after_pool  # 64 channels

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Apply input normalization
        if not self.reshape_input:
            x = self.input_norm(x)

            # Apply attention mechanism
            attn = self.attention(x)
            x = x * attn

            # Pass through feature layers
            features = self.feature_layers(x)
        else:
            # Reshape input to [batch, 1, channels, time]
            batch_size = x.size(0)
            x = x.view(batch_size, 1, self.n_channels, self.seq_length)

            # Pass through CNN
            x = self.conv_block(x)  # -> [batch, 64, 1, reduced_time]

            # Reshape for LSTM: [batch, time, features]
            x = x.squeeze(2).permute(0, 2, 1)  # -> [batch, reduced_time, 64]

            # Pass through LSTM
            x, _ = self.lstm(x)  # -> [batch, reduced_time, 256]

            # Apply self-attention
            x, _ = self.attention(x)  # -> [batch, 256]

            # Pass through feature layers
            features = self.feature_layers(x)

        # Ensemble predictions from multiple heads
        logits = torch.stack([head(features) for head in self.heads])
        logits = torch.mean(logits, dim=0)

        return logits

    def predict_proba(self, x):
        logits = self.forward(x)
        return torch.softmax(logits, dim=1)


# Helper blocks for enhanced architecture

class ResidualBlock(nn.Module):
    def __init__(self, in_dim, out_dim, dropout_rate=0.4):
        super(ResidualBlock, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.norm = nn.BatchNorm1d(out_dim)
        self.activation = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        residual = x
        out = self.linear(x)
        out = self.norm(out)
        out = self.activation(out)
        out = self.dropout(out)
        out += residual
        return out


class DenseBlock(nn.Module):
    def __init__(self, in_dim, out_dim, dropout_rate=0.4):
        super(DenseBlock, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.norm = nn.BatchNorm1d(out_dim)
        self.activation = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        out = self.linear(x)
        out = self.norm(out)
        out = self.activation(out)
        out = self.dropout(out)
        return out


class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for feature recalibration"""
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c = x.size()
        y = self.avg_pool(x.unsqueeze(-1)).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y.squeeze(-1)


class SelfAttention(nn.Module):
    """Self-attention mechanism for sequential data"""
    def __init__(self, hidden_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.scale = hidden_dim ** 0.5

    def forward(self, x):
        # x shape: [batch, seq_len, hidden_dim]
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Calculate attention scores
        attn_scores = torch.bmm(q, k.transpose(1, 2)) / self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Apply attention to values
        context = torch.bmm(attn_probs, v)

        # Global feature vector (attention-weighted sum)
        global_feat = torch.sum(context, dim=1)

        return global_feat, attn_probs

# STEP 4: Data Preprocessing Helper - OPTIMIZED FOR MEMORY
# STEP 4: Data Preprocessing Helper - OPTIMIZED FOR MEMORY
def prepare_dataset_with_features(dataset, batch_size=64, device='cuda'):
    """Pre-compute features for the dataset in memory-efficient batches"""
    feature_extractor = FeatureExtractor()

    processed_features = []
    labels = []
    label_texts = {}

    # Process in batches to reduce memory usage
    num_batches = (len(dataset) + batch_size - 1) // batch_size

    for batch_idx in tqdm(range(num_batches), desc="Extracting features in batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(dataset))

        batch_features = []
        batch_labels = []

        for idx in range(start_idx, end_idx):
            eeg_data, label_idx, label_text = dataset[idx]

            # Convert to numpy if needed
            if isinstance(eeg_data, torch.Tensor):
                eeg_data_np = eeg_data.cpu().numpy()
            else:
                eeg_data_np = eeg_data

            # Extract features
            features = feature_extractor.extract_features(eeg_data_np)

            batch_features.append(torch.tensor(features, dtype=torch.float32))
            batch_labels.append(label_idx)

            # Store label text mapping
            label_texts[label_idx] = label_text

        processed_features.extend(batch_features)
        labels.extend(batch_labels)

        # Force garbage collection after each batch
        gc.collect()

        # Clear CUDA cache if using GPU
        if device == 'cuda' and torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Get label texts from dataset if available
    if not label_texts and hasattr(dataset, 'get_label_texts'):
        label_texts = dataset.get_label_texts()

    return processed_features, labels, label_texts

# STEP 5: Training Function - MEMORY OPTIMIZED
def train_model(model, train_loader, val_loader, num_epochs=500, learning_rate=0.001,
               weight_decay=1e-5, device="cpu", class_weights=None):

    model.to(device)

    # Initialize optimizer with weight decay for regularization
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Loss function with class weights if provided
    if class_weights is not None:
        class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    else:
        criterion = nn.CrossEntropyLoss()

    # Track metrics
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    # For early stopping
    best_val_loss = float('inf')
    patience = 10
    no_improve_epoch = 0
    best_model_state = None

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")

        for features, labels, _ in progress_bar:
            # Move tensors to device
            features = features.to(device)
            labels = labels.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(features)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            train_loss += loss.item()

            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{100 * correct / total:.2f}%"
            })

            # Clear GPU memory after each batch
            del features, labels, outputs, loss, predicted
            if device == 'cuda' and torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Calculate epoch metrics
        epoch_train_loss = train_loss / len(train_loader)
        epoch_train_acc = 100 * correct / total
        train_losses.append(epoch_train_loss)
        train_accs.append(epoch_train_acc)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")

            for features, labels, _ in progress_bar:
                features = features.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(features)
                loss = criterion(outputs, labels)

                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_loss += loss.item()

                # Store predictions for metrics
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

                # Update progress bar
                progress_bar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'acc': f"{100 * correct / total:.2f}%"
                })

                # Clear GPU memory
                del features, labels, outputs, loss, predicted
                if device == 'cuda' and torch.cuda.is_available():
                    torch.cuda.empty_cache()

        # Calculate epoch metrics
        epoch_val_loss = val_loss / len(val_loader)
        epoch_val_acc = 100 * correct / total
        val_losses.append(epoch_val_loss)
        val_accs.append(epoch_val_acc)

        # Learning rate scheduler step
        scheduler.step(epoch_val_loss)

        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%")
        print(f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%")
        print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")

        # Check if this is the best model
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_state = {k: v.cpu().detach() for k, v in model.state_dict().items()}
            no_improve_epoch = 0
            print("New best model saved!")

            # Print classification report
            print("\nClassification Report:")
            print(classification_report(all_labels, all_preds))
        else:
            no_improve_epoch += 1

        # Early stopping check
        if no_improve_epoch >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

        # Force garbage collection after each epoch
        gc.collect()
        if device == 'cuda' and torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Load best model weights
    model.load_state_dict(best_model_state)
    model.to(device)  # Make sure model is on the correct device

    # Plot training curves
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()

    return model

# STEP 6: Evaluation Function - MEMORY OPTIMIZED
def evaluate_model(model, test_loader, device='cpu'):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for features, labels, _ in tqdm(test_loader, desc="Evaluating"):
            features = features.to(device)

            # Get predictions
            outputs = model(features)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            # Store results (move to CPU to save GPU memory)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())

            # Clear GPU memory
            del features, outputs, probs, predicted
            if device == 'cuda' and torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)

    # Generate classification report
    report = classification_report(all_labels, all_preds)

    # Generate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()

    return accuracy, report, cm, np.array(all_probs)

def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Enable memory tracking for PyTorch
    if device.type == 'cuda':
        torch.cuda.reset_peak_memory_stats()
        print(f"Initial GPU memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

    # Create or use existing label mapping
    label_mapping_file = 'labels_txt.txt'

    # Set up paths for feature storage
    features_path = {
        'session_1': "features_session_1.pt",
        'session_2': "features_session_2.pt",
        'combined': "features_combined.pt"
    }

    # Track whether we need to recompute the combined features
    need_to_combine = False

    # Process both sessions sequentially
    session_files = [
        "/content/drive/MyDrive/session_1.pth",
        "/content/drive/MyDrive/session_2.pth"
    ]

    # Process each session separately
    for i, session_file in enumerate(session_files):
        session_name = f"session_{i+1}"
        session_feature_path = features_path[session_name]

        # Check if features for this session already exist
        if os.path.exists(session_feature_path):
            print(f"Loading pre-computed features from {session_feature_path}...")
            features_data = torch.load(session_feature_path)
            if 'features' not in features_data or 'labels' not in features_data or 'texts' not in features_data:
                print(f"Invalid feature file format for {session_feature_path}. Will recompute.")
                need_to_combine = True
                os.remove(session_feature_path)
            else:
                print(f"Loaded features with shape {len(features_data['features'])}")
        else:
            # If feature file doesn't exist, we need to process and later combine
            need_to_combine = True

            # Load data for this session
            print(f"Loading data from {session_file}...")
            try:
                # Load in CPU memory to avoid GPU memory usage during loading
                session_data = torch.load(session_file, map_location='cpu', weights_only=False)
                print(f"Loaded {len(session_data['dataset'])} EEG samples with {len(set(session_data['labels']))} unique classes")

                # Create EEG dataset with preprocessing
                preprocessor = EEGPreprocessor(sampling_rate=256)  # Adjust sampling rate to match your data
                dataset = EEGDataset(session_data, label_mapping_file, transform=preprocessor)

                # Free up memory from raw data
                del session_data
                gc.collect()

                if device.type == 'cuda':
                    torch.cuda.empty_cache()
                    print(f"GPU memory after data loading: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

                # Extract features from the dataset in batches
                print(f"Preprocessing and extracting features for {session_name} in batches...")
                features, labels, texts = prepare_dataset_with_features(dataset, batch_size=32, device=device)

                # Save features to disk
                print(f"Saving extracted features to {session_feature_path}...")
                torch.save({
                    'features': features,
                    'labels': labels,
                    'texts': texts
                }, session_feature_path)
                print("Features saved successfully!")

                # Free up memory
                del dataset, features, labels, texts
                gc.collect()

                if device.type == 'cuda':
                    torch.cuda.empty_cache()
                    print(f"GPU memory after feature extraction: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

            except Exception as e:
                print(f"Error processing {session_file}: {e}")
                continue

    # Check if we need to recombine features
    if need_to_combine or not os.path.exists(features_path['combined']):
        print("Combining features from all sessions...")
        all_features = []
        all_labels = []
        all_texts = {}

        # Load and combine features from each session
        for i in range(len(session_files)):
            session_name = f"session_{i+1}"
            session_feature_path = features_path[session_name]

            if os.path.exists(session_feature_path):
                print(f"Loading features from {session_feature_path} for combining...")
                features_data = torch.load(session_feature_path)

                # Add features and labels
                all_features.extend(features_data['features'])
                all_labels.extend(features_data['labels'])

                # Merge text dictionaries
                if isinstance(features_data['texts'], dict):
                    all_texts.update(features_data['texts'])
                else:
                    # Handle case where texts are in list format
                    for j, label in enumerate(features_data['labels']):
                        if j < len(features_data['texts']):
                            all_texts[label] = features_data['texts'][j]

                # Clear memory
                del features_data
                gc.collect()

                if device.type == 'cuda':
                    torch.cuda.empty_cache()

        # Save combined features
        print(f"Total combined features: {len(all_features)}")
        print(f"Saving combined features to {features_path['combined']}...")
        torch.save({
            'features': all_features,
            'labels': all_labels,
            'texts': all_texts
        }, features_path['combined'])
        print("Combined features saved successfully!")
    else:
        # Load combined features
        print(f"Loading pre-computed combined features from {features_path['combined']}...")
        combined_data = torch.load(features_path['combined'])
        all_features = combined_data['features']
        all_labels = combined_data['labels']
        all_texts = combined_data['texts']
        print(f"Loaded combined features with shape {len(all_features)}")

        # Clear memory
        del combined_data
        gc.collect()

        if device.type == 'cuda':
            torch.cuda.empty_cache()

    # Create feature dataset from combined data
    feature_dataset = EEGFeatureDataset(all_features, all_labels, all_texts)

    # Free up memory that's no longer needed
    del all_features, all_labels
    gc.collect()

    if device.type == 'cuda':
        torch.cuda.empty_cache()
        print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

    # Split data with stratification
    train_idx, temp_idx = train_test_split(
        range(len(feature_dataset)),
        test_size=0.3,
        random_state=42,
        stratify=feature_dataset.labels
    )

    val_idx, test_idx = train_test_split(
        temp_idx,
        test_size=0.5,  # 50% of temp_idx, resulting in 15% of original data
        random_state=42,
        stratify=[feature_dataset.labels[i] for i in temp_idx]
    )

    # Create samplers
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    # Create data loaders with appropriate batch size
    # Smaller batch size can help with memory usage
    batch_size = 16 if device.type == 'cuda' else 32
    train_loader = DataLoader(feature_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=(device.type=='cuda'))
    val_loader = DataLoader(feature_dataset, batch_size=batch_size, sampler=val_sampler, pin_memory=(device.type=='cuda'))
    test_loader = DataLoader(feature_dataset, batch_size=batch_size, sampler=test_sampler, pin_memory=(device.type=='cuda'))

    # Get feature dimension from the first sample
    feature_dim = feature_dataset[0][0].shape[0]
    n_classes = len(set(feature_dataset.labels))

    print(f"Feature dimension: {feature_dim}")
    print(f"Number of classes: {n_classes}")

    # Calculate class weights for imbalanced data
    class_counts = np.bincount(feature_dataset.labels)
    class_weights = 1.0 / class_counts
    class_weights = class_weights / np.sum(class_weights) * len(class_counts)
    print(f"Class weights: {class_weights}")

    # Create model
    model = EEGClassifier(
        input_dim=feature_dim,
        n_classes=n_classes
        # hidden_dims=[512, 256, 128]  # Adjust architecture as needed
    )

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total model parameters: {total_params:,}")

    # Train model
    print("\nTraining model...")
    trained_model = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=500,  # Adjust as needed
        learning_rate=0.0005,
        weight_decay=5e-5,
        device=device,
        class_weights=class_weights
    )

    # Evaluate model
    print("\nEvaluating model on test set...")
    accuracy, report, cm, probabilities = evaluate_model(trained_model, test_loader, device)

    print(f"\nFinal Test Accuracy: {accuracy*100:.2f}%")
    # print("\nClassification Report:")
    # print(report)

    # Save model
    torch.save({
        'model_state_dict': trained_model.state_dict(),
        'feature_dim': feature_dim,
        'n_classes': n_classes,
        'accuracy': accuracy
    }, "eeg_classifier_model.pt")

    print("\nModel saved successfully!")

    print("\nEEG classification pipeline complete!")

if __name__ == "__main__":
    main()

Using device: cuda
Initial GPU memory: 16.25 MB
Loading pre-computed features from features_session_1.pt...
Loaded features with shape 31950
Loading pre-computed features from features_session_2.pt...
Loaded features with shape 31900
Loading pre-computed combined features from features_combined.pt...
Loaded combined features with shape 63850
GPU memory before training: 16.25 MB
Feature dimension: 2114
Number of classes: 80
Class weights: [0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 1.06400665 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623



Total model parameters: 24,272,902

Training model...


Epoch 1/500 [Train]: 100%|██████████| 2794/2794 [00:47<00:00, 58.91it/s, loss=5.8834, acc=1.59%]
Epoch 1/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 269.38it/s, loss=5.6759, acc=2.20%]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1/500
Train Loss: 5.3126, Train Acc: 1.59%
Val Loss: 4.7420, Val Acc: 2.20%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       120
           1       0.00      0.00      0.00       120
           2       0.00      0.00      0.00       120
           3       0.00      0.00      0.00       120
           4       1.00      0.01      0.02       120
           5       0.00      0.00      0.00       120
           6       0.00      0.00      0.00       120
           7       0.03      0.01      0.01       120
           8       0.02      0.10      0.04       120
           9       0.00      0.00      0.00       120
          10       0.09      0.01      0.02       120
          11       0.00      0.00      0.00       120
          12       0.00      0.00      0.00       120
          13       0.06      0.02      0.03       120
          14       0.00      0.00   

Epoch 2/500 [Train]: 100%|██████████| 2794/2794 [00:47<00:00, 58.96it/s, loss=4.6732, acc=2.19%]
Epoch 2/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 298.44it/s, loss=4.5763, acc=2.95%]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 2/500
Train Loss: 4.7283, Train Acc: 2.19%
Val Loss: 4.3887, Val Acc: 2.95%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.02      0.03      0.03       120
           1       0.03      0.10      0.04       120
           2       0.00      0.00      0.00       120
           3       0.02      0.10      0.03       120
           4       0.00      0.00      0.00       120
           5       0.00      0.00      0.00       120
           6       0.00      0.00      0.00       120
           7       0.04      0.03      0.03       120
           8       0.00      0.00      0.00       120
           9       0.00      0.00      0.00       120
          10       0.01      0.01      0.01       120
          11       0.02      0.04      0.03       120
          12       0.02      0.01      0.01       120
          13       0.01      0.04      0.02       120
          14       0.07      0.05   

Epoch 3/500 [Train]: 100%|██████████| 2794/2794 [00:47<00:00, 58.42it/s, loss=4.5177, acc=2.88%]
Epoch 3/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 293.55it/s, loss=4.0045, acc=4.20%]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 3/500
Train Loss: 4.4821, Train Acc: 2.88%
Val Loss: 4.1975, Val Acc: 4.20%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.05      0.06      0.05       120
           1       0.07      0.06      0.07       120
           2       0.00      0.00      0.00       120
           3       0.04      0.09      0.05       120
           4       0.01      0.03      0.02       120
           5       0.04      0.01      0.01       120
           6       0.09      0.03      0.05       120
           7       0.05      0.02      0.02       120
           8       0.05      0.10      0.06       120
           9       0.02      0.01      0.01       120
          10       0.00      0.00      0.00       120
          11       0.03      0.10      0.05       120
          12       0.11      0.07      0.09       120
          13       0.04      0.01      0.01       120
          14       0.05      0.17   

Epoch 4/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.67it/s, loss=4.2250, acc=4.16%]
Epoch 4/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 291.77it/s, loss=3.9720, acc=6.69%]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 4/500
Train Loss: 4.2747, Train Acc: 4.16%
Val Loss: 4.0396, Val Acc: 6.69%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.12      0.04      0.06       120
           1       0.00      0.00      0.00       120
           2       0.09      0.03      0.04       120
           3       0.05      0.10      0.07       120
           4       0.05      0.30      0.08       120
           5       0.04      0.09      0.06       120
           6       0.10      0.12      0.11       120
           7       0.00      0.00      0.00       120
           8       0.11      0.01      0.02       120
           9       0.00      0.00      0.00       120
          10       0.03      0.09      0.04       120
          11       0.00      0.00      0.00       120
          12       0.04      0.30      0.08       120
          13       0.03      0.08      0.04       120
          14       0.07      0.07   

Epoch 5/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 58.21it/s, loss=3.5542, acc=6.06%]
Epoch 5/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 288.86it/s, loss=4.1259, acc=9.31%]



Epoch 5/500
Train Loss: 4.0902, Train Acc: 6.06%
Val Loss: 3.8306, Val Acc: 9.31%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.08      0.05      0.06       120
           1       0.10      0.04      0.06       120
           2       0.16      0.04      0.07       120
           3       0.06      0.09      0.07       120
           4       0.12      0.13      0.13       120
           5       0.00      0.00      0.00       120
           6       0.15      0.12      0.13       120
           7       0.13      0.10      0.11       120
           8       0.14      0.06      0.08       120
           9       0.05      0.05      0.05       120
          10       0.00      0.00      0.00       120
          11       0.11      0.07      0.09       120
          12       0.06      0.21      0.09       120
          13       0.10      0.09      0.09       120
          14       0.18      0.14   

Epoch 6/500 [Train]: 100%|██████████| 2794/2794 [00:47<00:00, 58.23it/s, loss=4.2787, acc=7.87%]
Epoch 6/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 239.92it/s, loss=3.1859, acc=11.21%]



Epoch 6/500
Train Loss: 3.9304, Train Acc: 7.87%
Val Loss: 3.6781, Val Acc: 11.21%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.19      0.03      0.06       120
           1       0.06      0.03      0.04       120
           2       0.06      0.04      0.05       120
           3       0.10      0.04      0.06       120
           4       0.07      0.12      0.09       120
           5       0.12      0.07      0.09       120
           6       0.16      0.10      0.12       120
           7       0.14      0.18      0.16       120
           8       0.20      0.17      0.19       120
           9       0.14      0.10      0.12       120
          10       0.05      0.08      0.06       120
          11       0.15      0.12      0.13       120
          12       0.13      0.20      0.16       120
          13       0.05      0.06      0.05       120
          14       0.16      0.16  

Epoch 7/500 [Train]: 100%|██████████| 2794/2794 [00:47<00:00, 58.25it/s, loss=3.4547, acc=9.35%]
Epoch 7/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 260.29it/s, loss=4.1180, acc=13.13%]



Epoch 7/500
Train Loss: 3.7924, Train Acc: 9.35%
Val Loss: 3.5249, Val Acc: 13.13%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.15      0.17      0.16       120
           1       0.08      0.06      0.07       120
           2       0.11      0.05      0.07       120
           3       0.14      0.14      0.14       120
           4       0.05      0.20      0.08       120
           5       0.09      0.04      0.06       120
           6       0.11      0.14      0.13       120
           7       0.38      0.13      0.20       120
           8       0.25      0.17      0.20       120
           9       0.15      0.09      0.11       120
          10       0.08      0.16      0.11       120
          11       0.33      0.18      0.24       120
          12       0.15      0.28      0.20       120
          13       0.06      0.08      0.07       120
          14       0.11      0.17  

Epoch 8/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.58it/s, loss=5.2696, acc=11.10%]
Epoch 8/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 281.78it/s, loss=3.2453, acc=15.14%]



Epoch 8/500
Train Loss: 3.6619, Train Acc: 11.10%
Val Loss: 3.3700, Val Acc: 15.14%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.17      0.10      0.13       120
           1       0.12      0.16      0.13       120
           2       0.18      0.17      0.18       120
           3       0.12      0.13      0.13       120
           4       0.06      0.15      0.09       120
           5       0.15      0.12      0.13       120
           6       0.16      0.16      0.16       120
           7       0.14      0.15      0.15       120
           8       0.16      0.15      0.16       120
           9       0.10      0.10      0.10       120
          10       0.11      0.17      0.13       120
          11       0.17      0.17      0.17       120
          12       0.17      0.30      0.22       120
          13       0.20      0.10      0.13       120
          14       0.13      0.07 

Epoch 9/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.69it/s, loss=3.8806, acc=13.01%]
Epoch 9/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 281.87it/s, loss=3.2084, acc=16.06%]



Epoch 9/500
Train Loss: 3.5575, Train Acc: 13.01%
Val Loss: 3.3007, Val Acc: 16.06%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.25      0.12      0.16       120
           1       0.15      0.14      0.15       120
           2       0.17      0.23      0.20       120
           3       0.12      0.14      0.13       120
           4       0.08      0.09      0.08       120
           5       0.17      0.17      0.17       120
           6       0.15      0.23      0.19       120
           7       0.30      0.11      0.16       120
           8       0.17      0.24      0.20       120
           9       0.13      0.03      0.04       120
          10       0.11      0.08      0.09       120
          11       0.19      0.21      0.20       120
          12       0.24      0.27      0.25       120
          13       0.10      0.26      0.15       120
          14       0.19      0.21 

Epoch 10/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.24it/s, loss=3.2131, acc=14.40%]
Epoch 10/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 277.07it/s, loss=3.3059, acc=18.02%]



Epoch 10/500
Train Loss: 3.4647, Train Acc: 14.40%
Val Loss: 3.1981, Val Acc: 18.02%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.14      0.18      0.16       120
           1       0.17      0.06      0.09       120
           2       0.24      0.14      0.18       120
           3       0.29      0.18      0.23       120
           4       0.15      0.14      0.14       120
           5       0.18      0.10      0.13       120
           6       0.18      0.33      0.24       120
           7       0.18      0.10      0.13       120
           8       0.21      0.20      0.21       120
           9       0.20      0.04      0.07       120
          10       0.16      0.12      0.14       120
          11       0.20      0.23      0.21       120
          12       0.18      0.31      0.23       120
          13       0.17      0.22      0.19       120
          14       0.16      0.24

Epoch 11/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.27it/s, loss=4.0672, acc=15.73%]
Epoch 11/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 275.70it/s, loss=2.7119, acc=19.07%]



Epoch 11/500
Train Loss: 3.3801, Train Acc: 15.73%
Val Loss: 3.1416, Val Acc: 19.07%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.27      0.19      0.22       120
           1       0.16      0.28      0.21       120
           2       0.15      0.25      0.19       120
           3       0.24      0.22      0.23       120
           4       0.18      0.20      0.19       120
           5       0.44      0.12      0.19       120
           6       0.22      0.17      0.19       120
           7       0.28      0.21      0.24       120
           8       0.30      0.32      0.31       120
           9       0.18      0.20      0.19       120
          10       0.28      0.06      0.10       120
          11       0.25      0.17      0.21       120
          12       0.29      0.28      0.28       120
          13       0.16      0.23      0.19       120
          14       0.22      0.28

Epoch 12/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.35it/s, loss=3.5964, acc=16.83%]
Epoch 12/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 256.11it/s, loss=3.0108, acc=20.52%]



Epoch 12/500
Train Loss: 3.3247, Train Acc: 16.83%
Val Loss: 3.0836, Val Acc: 20.52%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.22      0.21      0.21       120
           1       0.21      0.08      0.12       120
           2       0.21      0.27      0.24       120
           3       0.24      0.20      0.22       120
           4       0.14      0.15      0.14       120
           5       0.22      0.17      0.19       120
           6       0.26      0.23      0.25       120
           7       0.24      0.23      0.24       120
           8       0.24      0.33      0.28       120
           9       0.10      0.04      0.06       120
          10       0.20      0.15      0.17       120
          11       0.35      0.27      0.30       120
          12       0.20      0.26      0.23       120
          13       0.20      0.26      0.22       120
          14       0.23      0.23

Epoch 13/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.63it/s, loss=3.5330, acc=17.66%]
Epoch 13/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 245.15it/s, loss=3.1395, acc=21.65%]



Epoch 13/500
Train Loss: 3.2656, Train Acc: 17.66%
Val Loss: 2.9941, Val Acc: 21.65%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.24      0.34      0.28       120
           1       0.20      0.35      0.25       120
           2       0.26      0.29      0.27       120
           3       0.32      0.19      0.24       120
           4       0.16      0.17      0.17       120
           5       0.18      0.26      0.21       120
           6       0.17      0.31      0.22       120
           7       0.26      0.20      0.23       120
           8       0.29      0.24      0.26       120
           9       0.16      0.15      0.16       120
          10       0.21      0.16      0.18       120
          11       0.31      0.27      0.29       120
          12       0.27      0.30      0.28       120
          13       0.26      0.21      0.23       120
          14       0.30      0.17

Epoch 14/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.31it/s, loss=3.7421, acc=18.45%]
Epoch 14/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 268.07it/s, loss=3.6064, acc=21.44%]



Epoch 14/500
Train Loss: 3.2242, Train Acc: 18.45%
Val Loss: 3.0220, Val Acc: 21.44%
Learning rate: 0.000500


Epoch 15/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 57.01it/s, loss=4.0110, acc=19.56%]
Epoch 15/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 270.10it/s, loss=2.5036, acc=23.03%]



Epoch 15/500
Train Loss: 3.1698, Train Acc: 19.56%
Val Loss: 2.9289, Val Acc: 23.03%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.22      0.18      0.20       120
           1       0.17      0.25      0.21       120
           2       0.38      0.28      0.32       120
           3       0.29      0.13      0.18       120
           4       0.17      0.19      0.18       120
           5       0.24      0.16      0.19       120
           6       0.21      0.23      0.22       120
           7       0.32      0.24      0.27       120
           8       0.27      0.23      0.25       120
           9       0.25      0.12      0.16       120
          10       0.16      0.14      0.15       120
          11       0.25      0.31      0.27       120
          12       0.21      0.31      0.25       120
          13       0.20      0.28      0.23       120
          14       0.18      0.28

Epoch 16/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.82it/s, loss=3.0184, acc=19.96%]
Epoch 16/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 276.59it/s, loss=2.6250, acc=23.44%]



Epoch 16/500
Train Loss: 3.1305, Train Acc: 19.96%
Val Loss: 2.9025, Val Acc: 23.44%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.21      0.29      0.25       120
           1       0.24      0.30      0.27       120
           2       0.21      0.28      0.24       120
           3       0.32      0.20      0.25       120
           4       0.18      0.17      0.17       120
           5       0.32      0.16      0.21       120
           6       0.20      0.20      0.20       120
           7       0.21      0.23      0.22       120
           8       0.31      0.28      0.29       120
           9       0.17      0.12      0.14       120
          10       0.25      0.18      0.21       120
          11       0.43      0.28      0.34       120
          12       0.28      0.33      0.30       120
          13       0.18      0.19      0.19       120
          14       0.25      0.23

Epoch 17/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.16it/s, loss=5.0416, acc=20.93%]
Epoch 17/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 268.99it/s, loss=2.6806, acc=24.16%]



Epoch 17/500
Train Loss: 3.0987, Train Acc: 20.93%
Val Loss: 2.8518, Val Acc: 24.16%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.19      0.23      0.20       120
           1       0.26      0.24      0.25       120
           2       0.28      0.17      0.22       120
           3       0.42      0.16      0.23       120
           4       0.17      0.21      0.19       120
           5       0.27      0.27      0.27       120
           6       0.30      0.28      0.29       120
           7       0.26      0.30      0.28       120
           8       0.25      0.23      0.24       120
           9       0.16      0.12      0.14       120
          10       0.22      0.24      0.23       120
          11       0.29      0.33      0.31       120
          12       0.22      0.26      0.24       120
          13       0.25      0.34      0.29       120
          14       0.21      0.37

Epoch 18/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.34it/s, loss=2.9741, acc=21.49%]
Epoch 18/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 252.24it/s, loss=3.7819, acc=23.80%]



Epoch 18/500
Train Loss: 3.0633, Train Acc: 21.49%
Val Loss: 2.8606, Val Acc: 23.80%
Learning rate: 0.000500


Epoch 19/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.11it/s, loss=3.3231, acc=22.31%]
Epoch 19/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 261.25it/s, loss=2.2506, acc=23.90%]



Epoch 19/500
Train Loss: 3.0329, Train Acc: 22.31%
Val Loss: 2.8480, Val Acc: 23.90%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.23      0.26      0.24       120
           1       0.17      0.25      0.20       120
           2       0.21      0.28      0.24       120
           3       0.23      0.21      0.22       120
           4       0.21      0.28      0.24       120
           5       0.31      0.28      0.29       120
           6       0.21      0.27      0.23       120
           7       0.43      0.21      0.28       120
           8       0.33      0.27      0.30       120
           9       0.21      0.17      0.19       120
          10       0.14      0.22      0.17       120
          11       0.28      0.38      0.32       120
          12       0.38      0.31      0.34       120
          13       0.30      0.25      0.27       120
          14       0.27      0.26

Epoch 20/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.74it/s, loss=3.1598, acc=22.63%]
Epoch 20/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 274.64it/s, loss=3.0895, acc=24.46%]



Epoch 20/500
Train Loss: 2.9982, Train Acc: 22.63%
Val Loss: 2.8107, Val Acc: 24.46%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.25      0.23      0.24       120
           1       0.28      0.12      0.16       120
           2       0.41      0.10      0.16       120
           3       0.22      0.38      0.28       120
           4       0.21      0.20      0.21       120
           5       0.32      0.22      0.26       120
           6       0.21      0.23      0.22       120
           7       0.32      0.23      0.26       120
           8       0.35      0.31      0.33       120
           9       0.32      0.14      0.20       120
          10       0.24      0.27      0.25       120
          11       0.38      0.33      0.35       120
          12       0.26      0.34      0.29       120
          13       0.28      0.36      0.32       120
          14       0.27      0.29

Epoch 21/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.52it/s, loss=3.1633, acc=23.28%]
Epoch 21/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 270.88it/s, loss=2.8716, acc=25.48%]



Epoch 21/500
Train Loss: 2.9766, Train Acc: 23.28%
Val Loss: 2.8079, Val Acc: 25.48%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.25      0.30      0.27       120
           1       0.17      0.27      0.21       120
           2       0.28      0.38      0.32       120
           3       0.29      0.25      0.27       120
           4       0.23      0.30      0.26       120
           5       0.14      0.33      0.20       120
           6       0.18      0.17      0.18       120
           7       0.27      0.19      0.22       120
           8       0.30      0.38      0.33       120
           9       0.31      0.14      0.19       120
          10       0.26      0.23      0.25       120
          11       0.36      0.35      0.35       120
          12       0.32      0.21      0.25       120
          13       0.23      0.28      0.25       120
          14       0.27      0.27

Epoch 22/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.36it/s, loss=3.6551, acc=23.94%]
Epoch 22/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 254.66it/s, loss=3.1617, acc=25.21%]



Epoch 22/500
Train Loss: 2.9438, Train Acc: 23.94%
Val Loss: 2.7933, Val Acc: 25.21%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.27      0.20      0.23       120
           1       0.20      0.28      0.23       120
           2       0.33      0.22      0.26       120
           3       0.26      0.26      0.26       120
           4       0.27      0.18      0.22       120
           5       0.26      0.13      0.18       120
           6       0.26      0.28      0.27       120
           7       0.23      0.32      0.27       120
           8       0.37      0.27      0.31       120
           9       0.24      0.13      0.17       120
          10       0.25      0.15      0.19       120
          11       0.28      0.41      0.33       120
          12       0.29      0.31      0.30       120
          13       0.28      0.28      0.28       120
          14       0.25      0.30

Epoch 23/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.28it/s, loss=2.8625, acc=24.29%]
Epoch 23/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 263.80it/s, loss=2.1168, acc=24.98%]



Epoch 23/500
Train Loss: 2.9306, Train Acc: 24.29%
Val Loss: 2.7954, Val Acc: 24.98%
Learning rate: 0.000500


Epoch 24/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.26it/s, loss=3.0766, acc=24.73%]
Epoch 24/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 262.75it/s, loss=2.7322, acc=25.59%]



Epoch 24/500
Train Loss: 2.9080, Train Acc: 24.73%
Val Loss: 2.7790, Val Acc: 25.59%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.22      0.28      0.24       120
           1       0.35      0.17      0.23       120
           2       0.35      0.35      0.35       120
           3       0.29      0.16      0.20       120
           4       0.19      0.28      0.23       120
           5       0.25      0.30      0.27       120
           6       0.29      0.27      0.28       120
           7       0.26      0.30      0.28       120
           8       0.34      0.33      0.33       120
           9       0.27      0.22      0.24       120
          10       0.22      0.19      0.20       120
          11       0.37      0.42      0.39       120
          12       0.28      0.31      0.29       120
          13       0.24      0.31      0.27       120
          14       0.37      0.16

Epoch 25/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.63it/s, loss=2.9067, acc=24.98%]
Epoch 25/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 236.49it/s, loss=2.5435, acc=25.56%]



Epoch 25/500
Train Loss: 2.8930, Train Acc: 24.98%
Val Loss: 2.7753, Val Acc: 25.56%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.24      0.29       120
           1       0.23      0.24      0.24       120
           2       0.26      0.43      0.33       120
           3       0.25      0.27      0.26       120
           4       0.27      0.19      0.23       120
           5       0.26      0.21      0.23       120
           6       0.23      0.21      0.22       120
           7       0.29      0.20      0.24       120
           8       0.31      0.42      0.36       120
           9       0.30      0.05      0.09       120
          10       0.29      0.14      0.19       120
          11       0.35      0.38      0.36       120
          12       0.29      0.34      0.31       120
          13       0.28      0.28      0.28       120
          14       0.24      0.24

Epoch 26/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.22it/s, loss=3.1569, acc=25.32%]
Epoch 26/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 236.32it/s, loss=3.2857, acc=25.58%]



Epoch 26/500
Train Loss: 2.8715, Train Acc: 25.32%
Val Loss: 2.7640, Val Acc: 25.58%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.24      0.19      0.21       120
           1       0.31      0.25      0.28       120
           2       0.22      0.46      0.29       120
           3       0.27      0.17      0.21       120
           4       0.22      0.27      0.24       120
           5       0.25      0.23      0.24       120
           6       0.22      0.36      0.27       120
           7       0.26      0.28      0.27       120
           8       0.30      0.24      0.27       120
           9       0.34      0.14      0.20       120
          10       0.23      0.15      0.18       120
          11       0.42      0.28      0.33       120
          12       0.28      0.33      0.30       120
          13       0.24      0.32      0.27       120
          14       0.32      0.26

Epoch 27/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.69it/s, loss=4.1751, acc=25.62%]
Epoch 27/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 261.06it/s, loss=3.2078, acc=25.50%]



Epoch 27/500
Train Loss: 2.8532, Train Acc: 25.62%
Val Loss: 2.7565, Val Acc: 25.50%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.24      0.20      0.22       120
           1       0.31      0.28      0.29       120
           2       0.22      0.36      0.27       120
           3       0.41      0.23      0.29       120
           4       0.21      0.27      0.23       120
           5       0.15      0.28      0.19       120
           6       0.23      0.20      0.22       120
           7       0.28      0.22      0.24       120
           8       0.27      0.32      0.29       120
           9       0.22      0.09      0.13       120
          10       0.17      0.17      0.17       120
          11       0.47      0.32      0.38       120
          12       0.26      0.33      0.29       120
          13       0.25      0.28      0.26       120
          14       0.21      0.20

Epoch 28/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.49it/s, loss=2.9990, acc=26.12%]
Epoch 28/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 258.12it/s, loss=2.4495, acc=25.80%]



Epoch 28/500
Train Loss: 2.8414, Train Acc: 26.12%
Val Loss: 2.7387, Val Acc: 25.80%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.28      0.14      0.19       120
           1       0.28      0.28      0.28       120
           2       0.35      0.31      0.33       120
           3       0.35      0.23      0.28       120
           4       0.23      0.31      0.27       120
           5       0.32      0.28      0.30       120
           6       0.33      0.23      0.27       120
           7       0.26      0.28      0.27       120
           8       0.28      0.39      0.33       120
           9       0.20      0.23      0.21       120
          10       0.20      0.21      0.20       120
          11       0.29      0.35      0.31       120
          12       0.31      0.30      0.31       120
          13       0.24      0.14      0.18       120
          14       0.23      0.23

Epoch 29/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.26it/s, loss=3.0895, acc=26.58%]
Epoch 29/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 260.06it/s, loss=3.1282, acc=25.96%]



Epoch 29/500
Train Loss: 2.8217, Train Acc: 26.58%
Val Loss: 2.7687, Val Acc: 25.96%
Learning rate: 0.000500


Epoch 30/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.84it/s, loss=4.3535, acc=27.01%]
Epoch 30/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 241.19it/s, loss=2.8196, acc=25.99%]



Epoch 30/500
Train Loss: 2.7980, Train Acc: 27.01%
Val Loss: 2.7708, Val Acc: 25.99%
Learning rate: 0.000500


Epoch 31/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.55it/s, loss=2.9778, acc=27.46%]
Epoch 31/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 250.39it/s, loss=2.6738, acc=25.98%]



Epoch 31/500
Train Loss: 2.7814, Train Acc: 27.46%
Val Loss: 2.7645, Val Acc: 25.98%
Learning rate: 0.000500


Epoch 32/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.28it/s, loss=3.1734, acc=27.65%]
Epoch 32/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 253.76it/s, loss=2.9274, acc=26.40%]



Epoch 32/500
Train Loss: 2.7724, Train Acc: 27.65%
Val Loss: 2.7531, Val Acc: 26.40%
Learning rate: 0.000500


Epoch 33/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.48it/s, loss=4.3072, acc=28.04%]
Epoch 33/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 259.08it/s, loss=3.1009, acc=26.74%]



Epoch 33/500
Train Loss: 2.7589, Train Acc: 28.04%
Val Loss: 2.7559, Val Acc: 26.74%
Learning rate: 0.000500


Epoch 34/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.42it/s, loss=2.5498, acc=28.23%]
Epoch 34/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 262.16it/s, loss=2.5061, acc=25.77%]



Epoch 34/500
Train Loss: 2.7493, Train Acc: 28.23%
Val Loss: 2.7926, Val Acc: 25.77%
Learning rate: 0.000250


Epoch 35/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.43it/s, loss=2.5255, acc=33.86%]
Epoch 35/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 257.89it/s, loss=2.6408, acc=29.05%]



Epoch 35/500
Train Loss: 2.4887, Train Acc: 33.86%
Val Loss: 2.6285, Val Acc: 29.05%
Learning rate: 0.000250
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.28      0.32      0.29       120
           1       0.23      0.38      0.29       120
           2       0.34      0.34      0.34       120
           3       0.30      0.17      0.22       120
           4       0.35      0.25      0.29       120
           5       0.40      0.33      0.36       120
           6       0.23      0.33      0.27       120
           7       0.37      0.27      0.31       120
           8       0.28      0.46      0.35       120
           9       0.24      0.22      0.23       120
          10       0.31      0.25      0.28       120
          11       0.41      0.47      0.44       120
          12       0.25      0.37      0.30       120
          13       0.25      0.25      0.25       120
          14       0.34      0.31

Epoch 36/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 55.95it/s, loss=3.2883, acc=36.55%]
Epoch 36/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 240.96it/s, loss=2.7747, acc=29.63%]



Epoch 36/500
Train Loss: 2.3655, Train Acc: 36.55%
Val Loss: 2.6263, Val Acc: 29.63%
Learning rate: 0.000250
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.27      0.23      0.24       120
           1       0.23      0.28      0.26       120
           2       0.37      0.34      0.35       120
           3       0.45      0.18      0.26       120
           4       0.29      0.27      0.28       120
           5       0.32      0.30      0.31       120
           6       0.27      0.33      0.30       120
           7       0.31      0.33      0.32       120
           8       0.40      0.46      0.43       120
           9       0.20      0.12      0.15       120
          10       0.33      0.28      0.30       120
          11       0.34      0.53      0.42       120
          12       0.26      0.41      0.31       120
          13       0.29      0.42      0.34       120
          14       0.31      0.23

Epoch 37/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.66it/s, loss=2.1996, acc=37.64%]
Epoch 37/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 252.89it/s, loss=2.0811, acc=29.56%]



Epoch 37/500
Train Loss: 2.3193, Train Acc: 37.64%
Val Loss: 2.6388, Val Acc: 29.56%
Learning rate: 0.000250


Epoch 38/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.82it/s, loss=3.2307, acc=38.05%]
Epoch 38/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 254.13it/s, loss=2.9338, acc=29.55%]



Epoch 38/500
Train Loss: 2.2820, Train Acc: 38.05%
Val Loss: 2.6479, Val Acc: 29.55%
Learning rate: 0.000250


Epoch 39/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.74it/s, loss=3.4993, acc=38.83%]
Epoch 39/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 253.02it/s, loss=2.8404, acc=29.80%]



Epoch 39/500
Train Loss: 2.2602, Train Acc: 38.83%
Val Loss: 2.6189, Val Acc: 29.80%
Learning rate: 0.000250
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.25      0.23      0.24       120
           1       0.32      0.17      0.23       120
           2       0.35      0.38      0.36       120
           3       0.27      0.27      0.27       120
           4       0.37      0.22      0.27       120
           5       0.31      0.30      0.31       120
           6       0.30      0.26      0.28       120
           7       0.30      0.30      0.30       120
           8       0.36      0.37      0.36       120
           9       0.19      0.20      0.19       120
          10       0.32      0.26      0.28       120
          11       0.45      0.38      0.41       120
          12       0.32      0.25      0.28       120
          13       0.33      0.31      0.32       120
          14       0.31      0.28

Epoch 40/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.36it/s, loss=2.8471, acc=39.29%]
Epoch 40/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 243.30it/s, loss=2.6457, acc=29.68%]



Epoch 40/500
Train Loss: 2.2316, Train Acc: 39.29%
Val Loss: 2.6319, Val Acc: 29.68%
Learning rate: 0.000250


Epoch 41/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.52it/s, loss=2.4965, acc=39.53%]
Epoch 41/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 248.08it/s, loss=3.2374, acc=29.48%]



Epoch 41/500
Train Loss: 2.2294, Train Acc: 39.53%
Val Loss: 2.6564, Val Acc: 29.48%
Learning rate: 0.000250


Epoch 42/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 55.96it/s, loss=2.6048, acc=40.44%]
Epoch 42/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 253.97it/s, loss=2.7186, acc=29.51%]



Epoch 42/500
Train Loss: 2.1932, Train Acc: 40.44%
Val Loss: 2.6616, Val Acc: 29.51%
Learning rate: 0.000250


Epoch 43/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.39it/s, loss=3.3755, acc=40.99%]
Epoch 43/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 252.25it/s, loss=2.1829, acc=28.41%]



Epoch 43/500
Train Loss: 2.1764, Train Acc: 40.99%
Val Loss: 2.6881, Val Acc: 28.41%
Learning rate: 0.000250


Epoch 44/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.77it/s, loss=2.2184, acc=40.93%]
Epoch 44/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 252.64it/s, loss=3.2302, acc=28.69%]



Epoch 44/500
Train Loss: 2.1800, Train Acc: 40.93%
Val Loss: 2.7071, Val Acc: 28.69%
Learning rate: 0.000250


Epoch 45/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.55it/s, loss=2.6037, acc=41.37%]
Epoch 45/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 266.16it/s, loss=1.9871, acc=28.58%]



Epoch 45/500
Train Loss: 2.1540, Train Acc: 41.37%
Val Loss: 2.6980, Val Acc: 28.58%
Learning rate: 0.000125


Epoch 46/500 [Train]: 100%|██████████| 2794/2794 [00:51<00:00, 54.25it/s, loss=3.0389, acc=45.91%]
Epoch 46/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 236.11it/s, loss=2.4886, acc=30.47%]



Epoch 46/500
Train Loss: 1.9675, Train Acc: 45.91%
Val Loss: 2.6701, Val Acc: 30.47%
Learning rate: 0.000125


Epoch 47/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.70it/s, loss=1.5663, acc=47.87%]
Epoch 47/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 262.34it/s, loss=2.4349, acc=30.45%]



Epoch 47/500
Train Loss: 1.8849, Train Acc: 47.87%
Val Loss: 2.6640, Val Acc: 30.45%
Learning rate: 0.000125


Epoch 48/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.54it/s, loss=3.0441, acc=48.97%]
Epoch 48/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 255.48it/s, loss=2.2935, acc=30.20%]



Epoch 48/500
Train Loss: 1.8308, Train Acc: 48.97%
Val Loss: 2.6922, Val Acc: 30.20%
Learning rate: 0.000125


Epoch 49/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.40it/s, loss=2.5115, acc=49.89%]
Epoch 49/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 248.91it/s, loss=3.1359, acc=30.02%]



Epoch 49/500
Train Loss: 1.8006, Train Acc: 49.89%
Val Loss: 2.7162, Val Acc: 30.02%
Learning rate: 0.000125
Early stopping triggered after 49 epochs

Evaluating model on test set...


Evaluating: 100%|██████████| 599/599 [00:01<00:00, 545.95it/s]



Final Test Accuracy: 29.19%

Model saved successfully!

EEG classification pipeline complete!


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import scipy.signal as signal
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm
import os
import pandas as pd
from scipy.stats import skew, kurtosis
import seaborn as sns
import gc

# STEP 1: Data Loading and Preprocessing
class EEGDataset(Dataset):
    def __init__(self, data_dict, label_mapping_file=None, transform=None):
        self.dataset = data_dict['dataset']
        self.labels_list = data_dict['labels']  # List of label IDs
        self.transform = transform

        # Load the mapping from label ID to text if provided
        self.label_id_to_text = {}
        if label_mapping_file and os.path.exists(label_mapping_file):
            with open(label_mapping_file, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 2:
                        label_id = parts[0]
                        label_text = ' '.join(parts[1:])
                        self.label_id_to_text[label_id] = label_text

        # Create mapping from label ID to index
        unique = sorted(set(self.labels_list))
        self.label_to_idx = {lbl:i for i,lbl in enumerate(unique)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        eeg_data = sample['eeg_data']  # Shape: [channels, time_points]
        label_id = sample['label']  # This is a string ID like 'n02510455'

        # Get label text if available
        label_text = self.label_id_to_text.get(label_id, label_id)

        # Convert string label to numerical index
        label_idx = self.label_to_idx.get(label_id, 0)

        if self.transform:
            eeg_data = self.transform(eeg_data)

        return eeg_data, label_idx, label_text

# EEG Signal Preprocessing
class EEGPreprocessor:
    def __init__(self, sampling_rate=1000, notch_freq=50, bandpass_low=0.5, bandpass_high=70):
        self.sampling_rate = sampling_rate
        self.notch_freq = notch_freq
        self.bandpass_low = bandpass_low
        self.bandpass_high = bandpass_high

        # Pre-compute filter coefficients to avoid recomputation
        self.sos = signal.butter(
            N=4,
            Wn=[self.bandpass_low, self.bandpass_high],
            btype='bandpass',
            fs=self.sampling_rate,
            output='sos'
        )
        self.b_notch, self.a_notch = signal.iirnotch(self.notch_freq, 30, self.sampling_rate)

    def __call__(self, eeg_data):
        # Convert to numpy if it's a tensor
        if isinstance(eeg_data, torch.Tensor):
            eeg_data = eeg_data.numpy()

        # Transpose to [time, channels] for easier processing
        eeg_data = eeg_data.T

        # Apply bandpass filter
        eeg_filtered = self._bandpass_filter(eeg_data)

        # Apply notch filter (to remove power line interference)
        eeg_filtered = self._notch_filter(eeg_filtered)

        # Re-reference to common average
        eeg_filtered = self._common_average_reference(eeg_filtered)

        # Z-score normalization
        eeg_normalized = self._normalize(eeg_filtered)

        # Transpose back to [channels, time]
        return torch.tensor(eeg_normalized.T, dtype=torch.float32)

    def _bandpass_filter(self, data):
        # Apply forward-backward filtering for zero phase distortion
        return signal.sosfiltfilt(self.sos, data, axis=0)

    def _notch_filter(self, data):
        return signal.filtfilt(self.b_notch, self.a_notch, data, axis=0)

    def _common_average_reference(self, data):
        # Subtract the mean across all channels at each time point
        return data - np.mean(data, axis=1, keepdims=True)

    def _normalize(self, data):
        # Z-score normalization for each channel
        return (data - np.mean(data, axis=0)) / (np.std(data, axis=0) + 1e-10)

# STEP 2: Feature Extraction
class FeatureExtractor:
    def __init__(self, sampling_rate=1000):
        self.sampling_rate = sampling_rate

        # Define frequency bands
        self.freq_bands = {
          'delta': (1, 4),       # Adjusted lower bound to reduce DC components
          'theta': (4, 8),       # Standard theta for cognitive processing
          'alpha_low': (8, 10),  # Lower alpha - attention/inhibition
          'alpha_high': (10, 13),# Higher alpha - semantic processing
          'beta_low': (13, 20),  # Lower beta - motor preparation
          'beta_high': (20, 30), # Higher beta - active processing/cognition
          'gamma_low': (30, 60), # Expanded gamma_low for visual processing
          'gamma_mid': (60, 90), # Added gamma_mid for binding
          'gamma_high': (90, 120)# Higher gamma for fine perceptual binding
      }

    def extract_features(self, eeg_data):
        """
        Extract time and frequency domain features from EEG data

        Args:
            eeg_data: EEG data of shape [channels, time_points]

        Returns:
            features: Dictionary of extracted features
        """
        features = {}

        # Time domain features
        features.update(self._extract_time_domain_features(eeg_data))

        # Frequency domain features
        features.update(self._extract_frequency_domain_features(eeg_data))

        # Connectivity features - can improve classification accuracy
        features.update(self._extract_connectivity_features(eeg_data))

        # Convert dictionary to vector
        feature_vector = []
        for key, value in features.items():
            if isinstance(value, np.ndarray):
                feature_vector.append(value.flatten())
            else:
                feature_vector.append(np.array([value]).flatten())

        return np.concatenate(feature_vector)

    def _extract_time_domain_features(self, eeg_data):
        features = {}

        # Statistical features
        features['mean'] = np.mean(eeg_data, axis=1)
        features['var'] = np.var(eeg_data, axis=1)
        features['skewness'] = skew(eeg_data, axis=1)
        features['kurtosis'] = kurtosis(eeg_data, axis=1)
        features['max'] = np.max(eeg_data, axis=1)
        features['min'] = np.min(eeg_data, axis=1)
        features['peak_to_peak'] = features['max'] - features['min']  # Reuse computed values
        features['rms'] = np.sqrt(np.mean(np.square(eeg_data), axis=1))
        features['zero_crossings'] = np.sum(np.diff(np.signbit(eeg_data), axis=1), axis=1)

        # Hjorth parameters
        features.update(self._compute_hjorth_parameters(eeg_data))

        return features

    def _compute_hjorth_parameters(self, eeg_data):
        """Compute Hjorth parameters: Activity, Mobility, and Complexity"""
        features = {}

        # First derivative
        diff1 = np.diff(eeg_data, axis=1)
        # Second derivative
        diff2 = np.diff(diff1, axis=1)

        # Activity: variance of the signal
        features['activity'] = np.var(eeg_data, axis=1)

        # Mobility: sqrt(variance of first derivative / variance of signal)
        var_diff1 = np.var(diff1, axis=1)
        mobility1 = np.sqrt(var_diff1 / (features['activity'] + 1e-10))
        features['mobility'] = mobility1

        # Complexity: mobility of first derivative / mobility of signal
        var_diff2 = np.var(diff2, axis=1)
        mobility2 = np.sqrt(var_diff2 / (var_diff1 + 1e-10))
        features['complexity'] = mobility2 / (mobility1 + 1e-10)

        return features

    def _extract_frequency_domain_features(self, eeg_data):
        features = {}

        # Compute power spectral density with Welch's method
        nperseg = min(256, eeg_data.shape[1] // 4)  # Adaptive window size
        freqs, psd = signal.welch(eeg_data, fs=self.sampling_rate,
                                 nperseg=nperseg,
                                 noverlap=nperseg // 2,
                                 axis=1)

        # Calculate total power once
        total_power = np.sum(psd, axis=1) + 1e-10

        # Band powers and their ratios
        for band_name, (low_freq, high_freq) in self.freq_bands.items():
            # Find frequencies in the band
            idx_band = np.logical_and(freqs >= low_freq, freqs <= high_freq)
            # Calculate band power
            band_power = np.sum(psd[:, idx_band], axis=1)
            features[f'{band_name}_power'] = band_power

            # Calculate relative band power
            features[f'{band_name}_rel_power'] = band_power / total_power

        # Spectral edge frequency (95%)
        features['sef_95'] = self._compute_spectral_edge_frequency(freqs, psd, 0.95)

        # Spectral entropy
        features['spectral_entropy'] = self._compute_spectral_entropy(psd)

        # Spectral peak frequency and power
        peak_freqs = freqs[np.argmax(psd, axis=1)]
        peak_powers = np.max(psd, axis=1)
        features['peak_freq'] = peak_freqs
        features['peak_power'] = peak_powers

        return features

    def _compute_spectral_edge_frequency(self, freqs, psd, edge=0.95):
        """Compute frequency below which edge% of power resides"""
        sef = np.zeros(psd.shape[0])
        for i in range(psd.shape[0]):
            # Cumulative sum of PSD
            cumsum = np.cumsum(psd[i]) / (np.sum(psd[i]) + 1e-10)
            # Find frequency below which edge% of power resides
            idx = np.where(cumsum >= edge)[0]
            if len(idx) > 0:
                sef[i] = freqs[idx[0]]
            else:
                sef[i] = freqs[-1]
        return sef

    def _compute_spectral_entropy(self, psd):
        """Compute spectral entropy"""
        entropy = np.zeros(psd.shape[0])
        for i in range(psd.shape[0]):
            # Normalize PSD
            psd_norm = psd[i] / (np.sum(psd[i]) + 1e-10)
            # Calculate entropy
            entropy[i] = -np.sum(psd_norm * np.log2(psd_norm + 1e-10))
        return entropy

    def _extract_connectivity_features(self, eeg_data):
        """Extract connectivity features between EEG channels"""
        features = {}

        # Number of channels
        n_channels = eeg_data.shape[0]

        # Calculate correlation matrix
        corr_matrix = np.corrcoef(eeg_data)

        # Extract upper triangle (excluding diagonal)
        upper_tri_idx = np.triu_indices(n_channels, k=1)
        correlations = corr_matrix[upper_tri_idx]

        # Basic statistics of correlations
        features['mean_corr'] = np.mean(correlations)
        features['std_corr'] = np.std(correlations)
        features['max_corr'] = np.max(correlations)
        features['min_corr'] = np.min(correlations)

        # Phase synchronization - simplified version using Hilbert transform
        analytic_signal = signal.hilbert(eeg_data, axis=1)
        instantaneous_phase = np.angle(analytic_signal)

        # Calculate phase differences between adjacent channels
        phase_diff = np.zeros((n_channels-1,) + instantaneous_phase.shape[1:])
        for i in range(n_channels-1):
            phase_diff[i] = instantaneous_phase[i+1] - instantaneous_phase[i]

        # Phase locking value (PLV)
        plv_values = np.abs(np.mean(np.exp(1j * phase_diff), axis=1))
        features['mean_plv'] = np.mean(plv_values)
        features['std_plv'] = np.std(plv_values)

        return features

# Dataset with precomputed features
class EEGFeatureDataset(Dataset):
    def __init__(self, features, labels, texts):
        self.features = features
        self.labels = labels
        self.texts = texts

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feature = self.features[idx]
        label = self.labels[idx]

        # Get text from dictionary if texts is a dictionary
        if isinstance(self.texts, dict):
            text = self.texts.get(label, f"Unknown-{label}")
        else:
            text = self.texts[idx] if idx < len(self.texts) else f"Unknown-{label}"

        # Convert to tensors if not already
        if not isinstance(feature, torch.Tensor):
            feature = torch.tensor(feature, dtype=torch.float32)
        if not isinstance(label, torch.Tensor) and not isinstance(label, int):
            label = torch.tensor(label, dtype=torch.long)

        return feature, label, text

# STEP 3: Classification Model
class EEGClassifier(nn.Module):
    def __init__(self, input_dim, n_classes, hidden_dims=[4096, 2048, 1024],
                 seq_length=None, n_channels=None, dropout_rate=0.3):
        super(EEGClassifier, self).__init__()

        self.input_dim = input_dim
        self.n_classes = n_classes
        self.seq_length = seq_length
        self.n_channels = n_channels

        # Option to reshape as temporal sequence if seq_length and n_channels are provided
        self.reshape_input = seq_length is not None and n_channels is not None

        # Input normalization layer
        self.input_norm = nn.BatchNorm1d(input_dim)

        # PART 1: CNN FEATURE EXTRACTION (if seq_length and n_channels provided)
        if self.reshape_input:
            # CNN for spatial-temporal feature extraction
            self.conv_block = nn.Sequential(
                nn.Conv2d(1, 32, kernel_size=(1, 16), stride=(1, 2), padding=(0, 7)),
                nn.BatchNorm2d(32),
                nn.ELU(),
                nn.Conv2d(32, 64, kernel_size=(n_channels, 1), stride=1, padding=0),
                nn.BatchNorm2d(64),
                nn.ELU(),
                nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4)),
                nn.Dropout(dropout_rate)
            )

            # Calculate output size after convolutions
            conv_output_size = self._calculate_conv_output_size()
            lstm_input_size = conv_output_size

            # LSTM for temporal dynamics
            self.lstm = nn.LSTM(
                input_size=64,  # Number of features per timestep (output channels from CNN)
                hidden_size=128,
                num_layers=2,
                batch_first=True,
                dropout=dropout_rate,
                bidirectional=True
            )

            # Self-attention mechanism for temporal focus
            self.attention = SelfAttention(256)  # 256 = 128*2 (bidirectional)

            # Set the input dimension for dense layers
            dense_input_dim = 256
        else:
            # If no reshape, use attention on flat input
            self.attention = nn.Sequential(
                nn.Linear(input_dim, input_dim // 4),
                nn.LeakyReLU(0.2),
                nn.Linear(input_dim // 4, input_dim),
                nn.Sigmoid()
            )
            dense_input_dim = input_dim

        # PART 2: DENSE NETWORK PATHWAY
        layers = []
        prev_dim = dense_input_dim

        for i, hidden_dim in enumerate(hidden_dims):
            # Dense block with residual connection if dimensions match
            if prev_dim == hidden_dim:
                layers.append(ResidualBlock(prev_dim, hidden_dim, dropout_rate))
            else:
                layers.append(DenseBlock(prev_dim, hidden_dim, dropout_rate))

            prev_dim = hidden_dim

            # Add Squeeze-and-Excitation blocks for feature recalibration
            if i < len(hidden_dims) - 1:  # Not for the last layer
                layers.append(SEBlock(hidden_dim))

        self.feature_layers = nn.Sequential(*layers)

        # Multi-head output with ensemble averaging
        self.heads = nn.ModuleList([
            nn.Linear(prev_dim, n_classes) for _ in range(3)
        ])

        # Initialize weights
        self.apply(self._init_weights)

    def _calculate_conv_output_size(self):
        # Calculate output size after convolutions
        # This is a placeholder - actual calculation depends on your exact architecture
        length_after_conv = ((self.seq_length - 16 + 2*7) // 2) + 1
        length_after_pool = length_after_conv // 4
        return 64 * length_after_pool  # 64 channels

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Apply input normalization
        if not self.reshape_input:
            x = self.input_norm(x)

            # Apply attention mechanism
            attn = self.attention(x)
            x = x * attn

            # Pass through feature layers
            features = self.feature_layers(x)
        else:
            # Reshape input to [batch, 1, channels, time]
            batch_size = x.size(0)
            x = x.view(batch_size, 1, self.n_channels, self.seq_length)

            # Pass through CNN
            x = self.conv_block(x)  # -> [batch, 64, 1, reduced_time]

            # Reshape for LSTM: [batch, time, features]
            x = x.squeeze(2).permute(0, 2, 1)  # -> [batch, reduced_time, 64]

            # Pass through LSTM
            x, _ = self.lstm(x)  # -> [batch, reduced_time, 256]

            # Apply self-attention
            x, _ = self.attention(x)  # -> [batch, 256]

            # Pass through feature layers
            features = self.feature_layers(x)

        # Ensemble predictions from multiple heads
        logits = torch.stack([head(features) for head in self.heads])
        logits = torch.mean(logits, dim=0)

        return logits

    def predict_proba(self, x):
        logits = self.forward(x)
        return torch.softmax(logits, dim=1)


# Helper blocks for enhanced architecture

class ResidualBlock(nn.Module):
    def __init__(self, in_dim, out_dim, dropout_rate=0.4):
        super(ResidualBlock, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.norm = nn.BatchNorm1d(out_dim)
        self.activation = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        residual = x
        out = self.linear(x)
        out = self.norm(out)
        out = self.activation(out)
        out = self.dropout(out)
        out += residual
        return out


class DenseBlock(nn.Module):
    def __init__(self, in_dim, out_dim, dropout_rate=0.4):
        super(DenseBlock, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.norm = nn.BatchNorm1d(out_dim)
        self.activation = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        out = self.linear(x)
        out = self.norm(out)
        out = self.activation(out)
        out = self.dropout(out)
        return out


class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for feature recalibration"""
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c = x.size()
        y = self.avg_pool(x.unsqueeze(-1)).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y.squeeze(-1)


class SelfAttention(nn.Module):
    """Self-attention mechanism for sequential data"""
    def __init__(self, hidden_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.scale = hidden_dim ** 0.5

    def forward(self, x):
        # x shape: [batch, seq_len, hidden_dim]
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Calculate attention scores
        attn_scores = torch.bmm(q, k.transpose(1, 2)) / self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Apply attention to values
        context = torch.bmm(attn_probs, v)

        # Global feature vector (attention-weighted sum)
        global_feat = torch.sum(context, dim=1)

        return global_feat, attn_probs

# STEP 4: Data Preprocessing Helper - OPTIMIZED FOR MEMORY
def prepare_dataset_with_features(dataset, batch_size=64, device='cuda'):
    """Pre-compute features for the dataset in memory-efficient batches"""
    feature_extractor = FeatureExtractor()

    processed_features = []
    labels = []
    label_texts = {}

    # Process in batches to reduce memory usage
    num_batches = (len(dataset) + batch_size - 1) // batch_size

    for batch_idx in tqdm(range(num_batches), desc="Extracting features in batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(dataset))

        batch_features = []
        batch_labels = []

        for idx in range(start_idx, end_idx):
            eeg_data, label_idx, label_text = dataset[idx]

            # Convert to numpy if needed
            if isinstance(eeg_data, torch.Tensor):
                eeg_data_np = eeg_data.cpu().numpy()
            else:
                eeg_data_np = eeg_data

            # Extract features
            features = feature_extractor.extract_features(eeg_data_np)

            batch_features.append(torch.tensor(features, dtype=torch.float32))
            batch_labels.append(label_idx)

            # Store label text mapping
            label_texts[label_idx] = label_text

        processed_features.extend(batch_features)
        labels.extend(batch_labels)

        # Force garbage collection after each batch
        gc.collect()

        # Clear CUDA cache if using GPU
        if device == 'cuda' and torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Get label texts from dataset if available
    if not label_texts and hasattr(dataset, 'get_label_texts'):
        label_texts = dataset.get_label_texts()

    return processed_features, labels, label_texts

# STEP 5: Training Function - MEMORY OPTIMIZED
def train_model(model, train_loader, val_loader, num_epochs=500, learning_rate=0.001,
               weight_decay=1e-5, device="cpu", class_weights=None):

    model.to(device)

    # Initialize optimizer with weight decay for regularization
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Loss function with class weights if provided
    if class_weights is not None:
        class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    else:
        criterion = nn.CrossEntropyLoss()

    # Track metrics
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    # For early stopping
    best_val_loss = float('inf')
    patience = 10
    no_improve_epoch = 0
    best_model_state = None

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")

        for features, labels, _ in progress_bar:
            # Move tensors to device
            features = features.to(device)
            labels = labels.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(features)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            train_loss += loss.item()

            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{100 * correct / total:.2f}%"
            })

            # Clear GPU memory after each batch
            del features, labels, outputs, loss, predicted
            if device == 'cuda' and torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Calculate epoch metrics
        epoch_train_loss = train_loss / len(train_loader)
        epoch_train_acc = 100 * correct / total
        train_losses.append(epoch_train_loss)
        train_accs.append(epoch_train_acc)

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")

            for features, labels, _ in progress_bar:
                features = features.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(features)
                loss = criterion(outputs, labels)

                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_loss += loss.item()

                # Store predictions for metrics
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

                # Update progress bar
                progress_bar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'acc': f"{100 * correct / total:.2f}%"
                })

                # Clear GPU memory
                del features, labels, outputs, loss, predicted
                if device == 'cuda' and torch.cuda.is_available():
                    torch.cuda.empty_cache()

        # Calculate epoch metrics
        epoch_val_loss = val_loss / len(val_loader)
        epoch_val_acc = 100 * correct / total
        val_losses.append(epoch_val_loss)
        val_accs.append(epoch_val_acc)

        # Learning rate scheduler step
        scheduler.step(epoch_val_loss)

        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%")
        print(f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.2f}%")
        print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")

        # Check if this is the best model
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_state = {k: v.cpu().detach() for k, v in model.state_dict().items()}
            no_improve_epoch = 0
            print("New best model saved!")

            # Print classification report
            print("\nClassification Report:")
            print(classification_report(all_labels, all_preds))
        else:
            no_improve_epoch += 1

        # Early stopping check
        if no_improve_epoch >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

        # Force garbage collection after each epoch
        gc.collect()
        if device == 'cuda' and torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Load best model weights
    model.load_state_dict(best_model_state)
    model.to(device)  # Make sure model is on the correct device

    # Plot training curves
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()

    return model

# STEP 6: Evaluation Function - MEMORY OPTIMIZED
def evaluate_model(model, test_loader, device='cpu'):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for features, labels, _ in tqdm(test_loader, desc="Evaluating"):
            features = features.to(device)

            # Get predictions
            outputs = model(features)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            # Store results (move to CPU to save GPU memory)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())

            # Clear GPU memory
            del features, outputs, probs, predicted
            if device == 'cuda' and torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)

    # Generate classification report
    report = classification_report(all_labels, all_preds)

    # Generate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Plot confusion matrix
    # plt.figure(figsize=(10, 8))
    # sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    # plt.xlabel('Predicted Label')
    # plt.ylabel('True Label')
    # plt.title('Confusion Matrix')
    # plt.savefig('confusion_matrix.png')
    # plt.close()

    return accuracy, report, cm, np.array(all_probs)

def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Enable memory tracking for PyTorch
    if device.type == 'cuda':
        torch.cuda.reset_peak_memory_stats()
        print(f"Initial GPU memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

    # Create or use existing label mapping
    label_mapping_file = 'labels_txt.txt'

    # Set up paths for feature storage
    features_path = {
        'session_1': "features_session_1.pt",
        'session_2': "features_session_2.pt",
        'combined': "features_combined.pt"
    }

    # Track whether we need to recompute the combined features
    need_to_combine = False

    # Process both sessions sequentially
    session_files = [
        "/content/drive/MyDrive/session_1.pth",
        "/content/drive/MyDrive/session_2.pth"
    ]

    # Process each session separately
    for i, session_file in enumerate(session_files):
        session_name = f"session_{i+1}"
        session_feature_path = features_path[session_name]

        # Check if features for this session already exist
        if os.path.exists(session_feature_path):
            print(f"Loading pre-computed features from {session_feature_path}...")
            features_data = torch.load(session_feature_path)
            if 'features' not in features_data or 'labels' not in features_data or 'texts' not in features_data:
                print(f"Invalid feature file format for {session_feature_path}. Will recompute.")
                need_to_combine = True
                os.remove(session_feature_path)
            else:
                print(f"Loaded features with shape {len(features_data['features'])}")
        else:
            # If feature file doesn't exist, we need to process and later combine
            need_to_combine = True

            # Load data for this session
            print(f"Loading data from {session_file}...")
            try:
                # Load in CPU memory to avoid GPU memory usage during loading
                session_data = torch.load(session_file, map_location='cpu', weights_only=False)
                print(f"Loaded {len(session_data['dataset'])} EEG samples with {len(set(session_data['labels']))} unique classes")

                # Create EEG dataset with preprocessing
                preprocessor = EEGPreprocessor(sampling_rate=1000)  # Adjust sampling rate to match your data
                dataset = EEGDataset(session_data, label_mapping_file, transform=preprocessor)

                # Free up memory from raw data
                del session_data
                gc.collect()

                if device.type == 'cuda':
                    torch.cuda.empty_cache()
                    print(f"GPU memory after data loading: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

                # Extract features from the dataset in batches
                print(f"Preprocessing and extracting features for {session_name} in batches...")
                features, labels, texts = prepare_dataset_with_features(dataset, batch_size=32, device=device)

                # Save features to disk
                print(f"Saving extracted features to {session_feature_path}...")
                torch.save({
                    'features': features,
                    'labels': labels,
                    'texts': texts
                }, session_feature_path)
                print("Features saved successfully!")

                # Free up memory
                del dataset, features, labels, texts
                gc.collect()

                if device.type == 'cuda':
                    torch.cuda.empty_cache()
                    print(f"GPU memory after feature extraction: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

            except Exception as e:
                print(f"Error processing {session_file}: {e}")
                continue

    # Check if we need to recombine features
    if need_to_combine or not os.path.exists(features_path['combined']):
        print("Combining features from all sessions...")
        all_features = []
        all_labels = []
        all_texts = {}

        # Load and combine features from each session
        for i in range(len(session_files)):
            session_name = f"session_{i+1}"
            session_feature_path = features_path[session_name]

            if os.path.exists(session_feature_path):
                print(f"Loading features from {session_feature_path} for combining...")
                features_data = torch.load(session_feature_path)

                # Add features and labels
                all_features.extend(features_data['features'])
                all_labels.extend(features_data['labels'])

                # Merge text dictionaries
                if isinstance(features_data['texts'], dict):
                    all_texts.update(features_data['texts'])
                else:
                    # Handle case where texts are in list format
                    for j, label in enumerate(features_data['labels']):
                        if j < len(features_data['texts']):
                            all_texts[label] = features_data['texts'][j]

                # Clear memory
                del features_data
                gc.collect()

                if device.type == 'cuda':
                    torch.cuda.empty_cache()

        # Save combined features
        print(f"Total combined features: {len(all_features)}")
        print(f"Saving combined features to {features_path['combined']}...")
        torch.save({
            'features': all_features,
            'labels': all_labels,
            'texts': all_texts
        }, features_path['combined'])
        print("Combined features saved successfully!")
    else:
        # Load combined features
        print(f"Loading pre-computed combined features from {features_path['combined']}...")
        combined_data = torch.load(features_path['combined'])
        all_features = combined_data['features']
        all_labels = combined_data['labels']
        all_texts = combined_data['texts']
        print(f"Loaded combined features with shape {len(all_features)}")

        # Clear memory
        del combined_data
        gc.collect()

        if device.type == 'cuda':
            torch.cuda.empty_cache()

    # Create feature dataset from combined data
    feature_dataset = EEGFeatureDataset(all_features, all_labels, all_texts)

    # Free up memory that's no longer needed
    del all_features, all_labels
    gc.collect()

    if device.type == 'cuda':
        torch.cuda.empty_cache()
        print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

    # Split data with stratification
    train_idx, temp_idx = train_test_split(
        range(len(feature_dataset)),
        test_size=0.3,
        random_state=42,
        stratify=feature_dataset.labels
    )

    val_idx, test_idx = train_test_split(
        temp_idx,
        test_size=0.5,  # 50% of temp_idx, resulting in 15% of original data
        random_state=42,
        stratify=[feature_dataset.labels[i] for i in temp_idx]
    )

    # Create samplers
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    # Create data loaders with appropriate batch size
    # Smaller batch size can help with memory usage
    batch_size = 16 if device.type == 'cuda' else 32
    train_loader = DataLoader(feature_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=(device.type=='cuda'))
    val_loader = DataLoader(feature_dataset, batch_size=batch_size, sampler=val_sampler, pin_memory=(device.type=='cuda'))
    test_loader = DataLoader(feature_dataset, batch_size=batch_size, sampler=test_sampler, pin_memory=(device.type=='cuda'))

    # Get feature dimension from the first sample
    feature_dim = feature_dataset[0][0].shape[0]
    n_classes = len(set(feature_dataset.labels))

    print(f"Feature dimension: {feature_dim}")
    print(f"Number of classes: {n_classes}")

    # Calculate class weights for imbalanced data
    class_counts = np.bincount(feature_dataset.labels)
    class_weights = 1.0 / class_counts
    class_weights = class_weights / np.sum(class_weights) * len(class_counts)
    print(f"Class weights: {class_weights}")

    # Create model
    model = EEGClassifier(
        input_dim=feature_dim,
        n_classes=n_classes
        # hidden_dims=[512, 256, 128]  # Adjust architecture as needed
    )

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total model parameters: {total_params:,}")

    # Train model
    print("\nTraining model...")
    trained_model = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=500,  # Adjust as needed
        learning_rate=0.0005,
        weight_decay=1e-5,
        device=device,
        class_weights=class_weights
    )

    # Evaluate model
    print("\nEvaluating model on test set...")
    accuracy, report, cm, probabilities = evaluate_model(trained_model, test_loader, device)

    print(f"\nFinal Test Accuracy: {accuracy*100:.2f}%")
    # print("\nClassification Report:")
    # print(report)

    # Save model
    torch.save({
        'model_state_dict': trained_model.state_dict(),
        'feature_dim': feature_dim,
        'n_classes': n_classes,
        'accuracy': accuracy
    }, "eeg_classifier_model.pt")

    print("\nModel saved successfully!")

    print("\nEEG classification pipeline complete!")

if __name__ == "__main__":
    main()

Using device: cuda
Initial GPU memory: 16.25 MB
Loading pre-computed features from features_session_1.pt...
Loaded features with shape 31950
Loading pre-computed features from features_session_2.pt...
Loaded features with shape 31900
Loading pre-computed combined features from features_combined.pt...
Loaded combined features with shape 63850
GPU memory before training: 16.25 MB
Feature dimension: 2114
Number of classes: 80
Class weights: [0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 1.06400665 0.99750623 0.99750623 0.99750623 0.99750623
 0.99750623 0.99750623



Total model parameters: 24,272,902

Training model...


Epoch 1/500 [Train]: 100%|██████████| 2794/2794 [00:47<00:00, 58.45it/s, loss=6.0113, acc=2.00%]
Epoch 1/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 297.47it/s, loss=4.4330, acc=3.28%]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1/500
Train Loss: 5.0208, Train Acc: 2.00%
Val Loss: 4.6019, Val Acc: 3.28%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.40      0.02      0.03       120
           1       0.20      0.01      0.02       120
           2       0.25      0.01      0.02       120
           3       0.00      0.00      0.00       120
           4       0.16      0.03      0.04       120
           5       0.00      0.00      0.00       120
           6       0.00      0.00      0.00       120
           7       0.09      0.02      0.03       120
           8       0.05      0.02      0.02       120
           9       0.17      0.02      0.03       120
          10       0.00      0.00      0.00       120
          11       0.03      0.08      0.04       120
          12       0.04      0.03      0.04       120
          13       0.01      0.03      0.02       120
          14       0.00      0.00   

Epoch 2/500 [Train]: 100%|██████████| 2794/2794 [00:50<00:00, 55.52it/s, loss=4.1438, acc=4.21%]
Epoch 2/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 295.07it/s, loss=4.4488, acc=7.10%]



Epoch 2/500
Train Loss: 4.5233, Train Acc: 4.21%
Val Loss: 4.0736, Val Acc: 7.10%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.04      0.07      0.05       120
           1       0.09      0.05      0.06       120
           2       0.04      0.06      0.05       120
           3       0.00      0.00      0.00       120
           4       0.16      0.11      0.13       120
           5       0.00      0.00      0.00       120
           6       0.07      0.03      0.05       120
           7       0.11      0.14      0.12       120
           8       0.14      0.05      0.07       120
           9       0.03      0.01      0.01       120
          10       0.03      0.02      0.02       120
          11       0.06      0.25      0.09       120
          12       0.12      0.23      0.16       120
          13       0.27      0.03      0.06       120
          14       0.11      0.15   

Epoch 3/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 58.02it/s, loss=4.2672, acc=7.68%]
Epoch 3/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 283.67it/s, loss=3.3428, acc=12.75%]



Epoch 3/500
Train Loss: 4.0724, Train Acc: 7.68%
Val Loss: 3.5721, Val Acc: 12.75%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.14      0.13      0.14       120
           1       0.13      0.03      0.05       120
           2       0.14      0.12      0.13       120
           3       0.11      0.12      0.11       120
           4       0.15      0.09      0.11       120
           5       0.10      0.05      0.07       120
           6       0.14      0.07      0.09       120
           7       0.07      0.09      0.08       120
           8       0.14      0.20      0.17       120
           9       0.05      0.02      0.02       120
          10       0.08      0.11      0.09       120
          11       0.10      0.26      0.14       120
          12       0.21      0.15      0.17       120
          13       0.12      0.09      0.10       120
          14       0.20      0.17  

Epoch 4/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.79it/s, loss=3.8497, acc=11.93%]
Epoch 4/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 240.41it/s, loss=2.9864, acc=16.97%]



Epoch 4/500
Train Loss: 3.6777, Train Acc: 11.93%
Val Loss: 3.2689, Val Acc: 16.97%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.16      0.13      0.15       120
           1       0.21      0.04      0.07       120
           2       0.09      0.12      0.10       120
           3       0.14      0.12      0.13       120
           4       0.10      0.19      0.13       120
           5       0.16      0.17      0.16       120
           6       0.13      0.31      0.18       120
           7       0.29      0.04      0.07       120
           8       0.13      0.17      0.15       120
           9       0.23      0.12      0.16       120
          10       0.19      0.10      0.13       120
          11       0.26      0.21      0.23       120
          12       0.23      0.18      0.20       120
          13       0.11      0.24      0.15       120
          14       0.18      0.20 

Epoch 5/500 [Train]: 100%|██████████| 2794/2794 [00:47<00:00, 59.01it/s, loss=3.3552, acc=16.12%]
Epoch 5/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 281.39it/s, loss=3.3737, acc=20.66%]



Epoch 5/500
Train Loss: 3.3913, Train Acc: 16.12%
Val Loss: 3.0624, Val Acc: 20.66%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.19      0.22      0.20       120
           1       0.16      0.15      0.15       120
           2       0.23      0.22      0.22       120
           3       0.22      0.26      0.24       120
           4       0.21      0.18      0.19       120
           5       0.29      0.10      0.15       120
           6       0.33      0.16      0.21       120
           7       0.24      0.17      0.20       120
           8       0.28      0.19      0.23       120
           9       0.15      0.23      0.18       120
          10       0.13      0.18      0.15       120
          11       0.33      0.29      0.31       120
          12       0.19      0.42      0.26       120
          13       0.24      0.22      0.23       120
          14       0.17      0.24 

Epoch 6/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.45it/s, loss=4.0776, acc=19.73%]
Epoch 6/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 285.39it/s, loss=2.6424, acc=23.98%]



Epoch 6/500
Train Loss: 3.1651, Train Acc: 19.73%
Val Loss: 2.8624, Val Acc: 23.98%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.31      0.20      0.24       120
           1       0.19      0.26      0.22       120
           2       0.25      0.17      0.20       120
           3       0.32      0.20      0.24       120
           4       0.21      0.23      0.22       120
           5       0.24      0.20      0.22       120
           6       0.34      0.23      0.27       120
           7       0.23      0.18      0.21       120
           8       0.31      0.24      0.27       120
           9       0.32      0.23      0.26       120
          10       0.19      0.25      0.22       120
          11       0.35      0.31      0.33       120
          12       0.18      0.25      0.21       120
          13       0.19      0.25      0.22       120
          14       0.29      0.23 

Epoch 7/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.96it/s, loss=2.9894, acc=23.03%]
Epoch 7/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 288.53it/s, loss=3.8121, acc=25.51%]



Epoch 7/500
Train Loss: 2.9791, Train Acc: 23.03%
Val Loss: 2.7721, Val Acc: 25.51%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.19      0.23      0.21       120
           1       0.26      0.21      0.23       120
           2       0.21      0.24      0.22       120
           3       0.34      0.30      0.32       120
           4       0.23      0.23      0.23       120
           5       0.34      0.26      0.30       120
           6       0.28      0.30      0.29       120
           7       0.32      0.28      0.30       120
           8       0.21      0.15      0.18       120
           9       0.32      0.19      0.24       120
          10       0.18      0.34      0.24       120
          11       0.42      0.30      0.35       120
          12       0.27      0.38      0.31       120
          13       0.18      0.23      0.20       120
          14       0.26      0.38 

Epoch 8/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.75it/s, loss=4.6163, acc=25.68%]
Epoch 8/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 283.85it/s, loss=2.1794, acc=28.17%]



Epoch 8/500
Train Loss: 2.8254, Train Acc: 25.68%
Val Loss: 2.6410, Val Acc: 28.17%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.24      0.29       120
           1       0.27      0.28      0.28       120
           2       0.28      0.31      0.29       120
           3       0.30      0.28      0.29       120
           4       0.25      0.33      0.29       120
           5       0.31      0.30      0.30       120
           6       0.28      0.31      0.29       120
           7       0.20      0.39      0.26       120
           8       0.33      0.30      0.31       120
           9       0.26      0.24      0.25       120
          10       0.22      0.34      0.27       120
          11       0.45      0.40      0.42       120
          12       0.28      0.30      0.29       120
          13       0.35      0.27      0.30       120
          14       0.29      0.26 

Epoch 9/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.15it/s, loss=2.7948, acc=28.90%]
Epoch 9/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 281.06it/s, loss=1.9833, acc=29.39%]



Epoch 9/500
Train Loss: 2.6704, Train Acc: 28.90%
Val Loss: 2.5673, Val Acc: 29.39%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.30      0.23      0.26       120
           1       0.29      0.34      0.31       120
           2       0.34      0.40      0.37       120
           3       0.24      0.27      0.25       120
           4       0.24      0.21      0.22       120
           5       0.24      0.37      0.29       120
           6       0.27      0.31      0.29       120
           7       0.45      0.25      0.32       120
           8       0.40      0.35      0.37       120
           9       0.25      0.15      0.19       120
          10       0.25      0.23      0.24       120
          11       0.36      0.42      0.39       120
          12       0.36      0.29      0.32       120
          13       0.27      0.31      0.29       120
          14       0.28      0.34 

Epoch 10/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.66it/s, loss=3.2106, acc=31.39%]
Epoch 10/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 245.65it/s, loss=2.4121, acc=30.98%]



Epoch 10/500
Train Loss: 2.5656, Train Acc: 31.39%
Val Loss: 2.5142, Val Acc: 30.98%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.24      0.27      0.25       120
           1       0.36      0.23      0.28       120
           2       0.42      0.37      0.39       120
           3       0.36      0.26      0.30       120
           4       0.35      0.24      0.29       120
           5       0.35      0.29      0.32       120
           6       0.33      0.34      0.34       120
           7       0.29      0.28      0.29       120
           8       0.36      0.36      0.36       120
           9       0.30      0.24      0.27       120
          10       0.24      0.28      0.25       120
          11       0.38      0.42      0.40       120
          12       0.34      0.35      0.35       120
          13       0.27      0.34      0.30       120
          14       0.30      0.45

Epoch 11/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.76it/s, loss=2.2931, acc=33.51%]
Epoch 11/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 265.29it/s, loss=2.3867, acc=32.47%]



Epoch 11/500
Train Loss: 2.4489, Train Acc: 33.51%
Val Loss: 2.4589, Val Acc: 32.47%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.32      0.31      0.31       120
           1       0.27      0.39      0.32       120
           2       0.31      0.40      0.35       120
           3       0.37      0.38      0.38       120
           4       0.31      0.32      0.32       120
           5       0.40      0.29      0.34       120
           6       0.37      0.28      0.32       120
           7       0.38      0.31      0.34       120
           8       0.39      0.31      0.35       120
           9       0.28      0.32      0.30       120
          10       0.43      0.19      0.26       120
          11       0.45      0.39      0.42       120
          12       0.39      0.38      0.38       120
          13       0.29      0.35      0.32       120
          14       0.34      0.37

Epoch 12/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.43it/s, loss=3.4983, acc=35.24%]
Epoch 12/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 275.53it/s, loss=3.0021, acc=32.78%]



Epoch 12/500
Train Loss: 2.3752, Train Acc: 35.24%
Val Loss: 2.4646, Val Acc: 32.78%
Learning rate: 0.000500


Epoch 13/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.21it/s, loss=2.1625, acc=37.43%]
Epoch 13/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 277.32it/s, loss=2.0999, acc=32.08%]



Epoch 13/500
Train Loss: 2.2891, Train Acc: 37.43%
Val Loss: 2.4586, Val Acc: 32.08%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.38      0.30      0.33       120
           1       0.30      0.38      0.33       120
           2       0.36      0.42      0.39       120
           3       0.35      0.27      0.30       120
           4       0.37      0.25      0.30       120
           5       0.23      0.26      0.24       120
           6       0.33      0.45      0.38       120
           7       0.31      0.33      0.32       120
           8       0.34      0.34      0.34       120
           9       0.29      0.32      0.31       120
          10       0.28      0.33      0.30       120
          11       0.45      0.44      0.45       120
          12       0.43      0.37      0.40       120
          13       0.30      0.31      0.30       120
          14       0.38      0.30

Epoch 14/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.09it/s, loss=2.6738, acc=38.90%]
Epoch 14/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 275.10it/s, loss=2.5460, acc=33.99%]



Epoch 14/500
Train Loss: 2.2128, Train Acc: 38.90%
Val Loss: 2.4033, Val Acc: 33.99%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.26      0.38      0.31       120
           1       0.35      0.43      0.39       120
           2       0.36      0.45      0.40       120
           3       0.30      0.41      0.35       120
           4       0.32      0.33      0.33       120
           5       0.34      0.29      0.31       120
           6       0.35      0.30      0.32       120
           7       0.32      0.30      0.31       120
           8       0.34      0.33      0.33       120
           9       0.30      0.38      0.33       120
          10       0.37      0.28      0.32       120
          11       0.46      0.49      0.47       120
          12       0.32      0.45      0.37       120
          13       0.28      0.27      0.27       120
          14       0.39      0.29

Epoch 15/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.75it/s, loss=1.9956, acc=41.07%]
Epoch 15/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 271.08it/s, loss=1.7115, acc=34.43%]



Epoch 15/500
Train Loss: 2.1281, Train Acc: 41.07%
Val Loss: 2.3916, Val Acc: 34.43%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.36      0.32      0.34       120
           1       0.37      0.34      0.35       120
           2       0.50      0.43      0.46       120
           3       0.39      0.31      0.35       120
           4       0.27      0.31      0.29       120
           5       0.30      0.26      0.28       120
           6       0.36      0.38      0.37       120
           7       0.32      0.28      0.30       120
           8       0.38      0.30      0.33       120
           9       0.40      0.30      0.34       120
          10       0.29      0.31      0.30       120
          11       0.39      0.46      0.42       120
          12       0.35      0.34      0.34       120
          13       0.29      0.33      0.31       120
          14       0.31      0.30

Epoch 16/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.11it/s, loss=1.9514, acc=42.42%]
Epoch 16/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 238.99it/s, loss=1.8921, acc=33.87%]



Epoch 16/500
Train Loss: 2.0693, Train Acc: 42.42%
Val Loss: 2.3859, Val Acc: 33.87%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.31      0.28      0.29       120
           1       0.45      0.33      0.38       120
           2       0.41      0.33      0.37       120
           3       0.42      0.34      0.38       120
           4       0.38      0.29      0.33       120
           5       0.45      0.23      0.30       120
           6       0.38      0.33      0.35       120
           7       0.37      0.32      0.34       120
           8       0.34      0.42      0.38       120
           9       0.30      0.27      0.28       120
          10       0.30      0.27      0.28       120
          11       0.58      0.40      0.47       120
          12       0.36      0.39      0.38       120
          13       0.26      0.33      0.29       120
          14       0.43      0.29

Epoch 17/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.18it/s, loss=3.6720, acc=44.04%]
Epoch 17/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 254.08it/s, loss=2.3777, acc=35.50%]



Epoch 17/500
Train Loss: 2.0088, Train Acc: 44.04%
Val Loss: 2.3726, Val Acc: 35.50%
Learning rate: 0.000500
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.30      0.23      0.26       120
           1       0.42      0.41      0.41       120
           2       0.39      0.39      0.39       120
           3       0.36      0.32      0.33       120
           4       0.26      0.44      0.33       120
           5       0.37      0.37      0.37       120
           6       0.43      0.36      0.39       120
           7       0.38      0.39      0.39       120
           8       0.46      0.33      0.39       120
           9       0.30      0.32      0.31       120
          10       0.35      0.33      0.34       120
          11       0.43      0.47      0.45       120
          12       0.37      0.44      0.40       120
          13       0.39      0.38      0.38       120
          14       0.40      0.38

Epoch 18/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 57.00it/s, loss=2.3097, acc=45.77%]
Epoch 18/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 266.90it/s, loss=3.0505, acc=35.32%]



Epoch 18/500
Train Loss: 1.9418, Train Acc: 45.77%
Val Loss: 2.3976, Val Acc: 35.32%
Learning rate: 0.000500


Epoch 19/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.81it/s, loss=2.0897, acc=46.75%]
Epoch 19/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 269.18it/s, loss=1.8921, acc=35.78%]



Epoch 19/500
Train Loss: 1.8987, Train Acc: 46.75%
Val Loss: 2.3922, Val Acc: 35.78%
Learning rate: 0.000500


Epoch 20/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.95it/s, loss=2.7499, acc=48.28%]
Epoch 20/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 266.44it/s, loss=2.2913, acc=35.44%]



Epoch 20/500
Train Loss: 1.8430, Train Acc: 48.28%
Val Loss: 2.4106, Val Acc: 35.44%
Learning rate: 0.000500


Epoch 21/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.42it/s, loss=2.1185, acc=49.52%]
Epoch 21/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 231.53it/s, loss=2.4021, acc=35.13%]



Epoch 21/500
Train Loss: 1.7803, Train Acc: 49.52%
Val Loss: 2.4223, Val Acc: 35.13%
Learning rate: 0.000500


Epoch 22/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.70it/s, loss=3.0028, acc=51.09%]
Epoch 22/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 249.28it/s, loss=3.1948, acc=35.37%]



Epoch 22/500
Train Loss: 1.7302, Train Acc: 51.09%
Val Loss: 2.4152, Val Acc: 35.37%
Learning rate: 0.000500


Epoch 23/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.15it/s, loss=2.4447, acc=51.86%]
Epoch 23/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 262.68it/s, loss=2.5199, acc=35.56%]



Epoch 23/500
Train Loss: 1.7005, Train Acc: 51.86%
Val Loss: 2.4166, Val Acc: 35.56%
Learning rate: 0.000250


Epoch 24/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.80it/s, loss=0.9939, acc=60.26%]
Epoch 24/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 272.51it/s, loss=2.1231, acc=38.19%]



Epoch 24/500
Train Loss: 1.3768, Train Acc: 60.26%
Val Loss: 2.2974, Val Acc: 38.19%
Learning rate: 0.000250
New best model saved!

Classification Report:
              precision    recall  f1-score   support

           0       0.36      0.38      0.37       120
           1       0.45      0.40      0.42       120
           2       0.44      0.43      0.44       120
           3       0.48      0.35      0.40       120
           4       0.35      0.42      0.38       120
           5       0.35      0.39      0.37       120
           6       0.40      0.41      0.40       120
           7       0.43      0.35      0.39       120
           8       0.43      0.46      0.44       120
           9       0.31      0.33      0.32       120
          10       0.36      0.29      0.32       120
          11       0.51      0.50      0.51       120
          12       0.40      0.42      0.41       120
          13       0.36      0.41      0.38       120
          14       0.40      0.40

Epoch 25/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.84it/s, loss=3.0019, acc=63.50%]
Epoch 25/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 273.73it/s, loss=2.1444, acc=37.91%]



Epoch 25/500
Train Loss: 1.2408, Train Acc: 63.50%
Val Loss: 2.3682, Val Acc: 37.91%
Learning rate: 0.000250


Epoch 26/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.66it/s, loss=2.3432, acc=65.94%]
Epoch 26/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 267.14it/s, loss=2.1130, acc=37.95%]



Epoch 26/500
Train Loss: 1.1613, Train Acc: 65.94%
Val Loss: 2.4135, Val Acc: 37.95%
Learning rate: 0.000250


Epoch 27/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.55it/s, loss=1.6739, acc=67.58%]
Epoch 27/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 266.40it/s, loss=2.4837, acc=37.79%]



Epoch 27/500
Train Loss: 1.0947, Train Acc: 67.58%
Val Loss: 2.4314, Val Acc: 37.79%
Learning rate: 0.000250


Epoch 28/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.39it/s, loss=0.8826, acc=68.99%]
Epoch 28/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 251.42it/s, loss=2.3061, acc=38.44%]



Epoch 28/500
Train Loss: 1.0523, Train Acc: 68.99%
Val Loss: 2.4601, Val Acc: 38.44%
Learning rate: 0.000250


Epoch 29/500 [Train]: 100%|██████████| 2794/2794 [00:48<00:00, 57.04it/s, loss=1.3826, acc=69.97%]
Epoch 29/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 255.39it/s, loss=2.7166, acc=37.69%]



Epoch 29/500
Train Loss: 1.0113, Train Acc: 69.97%
Val Loss: 2.5185, Val Acc: 37.69%
Learning rate: 0.000250


Epoch 30/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.76it/s, loss=2.3506, acc=71.06%]
Epoch 30/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 267.35it/s, loss=2.9438, acc=38.24%]



Epoch 30/500
Train Loss: 0.9684, Train Acc: 71.06%
Val Loss: 2.5319, Val Acc: 38.24%
Learning rate: 0.000125


Epoch 31/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.98it/s, loss=2.7477, acc=75.58%]
Epoch 31/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 267.52it/s, loss=2.1529, acc=39.25%]



Epoch 31/500
Train Loss: 0.8149, Train Acc: 75.58%
Val Loss: 2.5312, Val Acc: 39.25%
Learning rate: 0.000125


Epoch 32/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.60it/s, loss=2.3420, acc=77.81%]
Epoch 32/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 264.58it/s, loss=2.7265, acc=39.24%]



Epoch 32/500
Train Loss: 0.7271, Train Acc: 77.81%
Val Loss: 2.5409, Val Acc: 39.24%
Learning rate: 0.000125


Epoch 33/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 57.00it/s, loss=2.6924, acc=78.62%]
Epoch 33/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 248.26it/s, loss=2.1401, acc=39.44%]



Epoch 33/500
Train Loss: 0.7018, Train Acc: 78.62%
Val Loss: 2.5464, Val Acc: 39.44%
Learning rate: 0.000125


Epoch 34/500 [Train]: 100%|██████████| 2794/2794 [00:49<00:00, 56.69it/s, loss=1.3944, acc=79.83%]
Epoch 34/500 [Val]: 100%|██████████| 599/599 [00:02<00:00, 256.69it/s, loss=3.9435, acc=39.24%]



Epoch 34/500
Train Loss: 0.6625, Train Acc: 79.83%
Val Loss: 2.6649, Val Acc: 39.24%
Learning rate: 0.000125
Early stopping triggered after 34 epochs

Evaluating model on test set...


Evaluating: 100%|██████████| 599/599 [00:01<00:00, 411.14it/s]



Final Test Accuracy: 38.19%

Model saved successfully!

EEG classification pipeline complete!
