In [None]:
# print("Go")

# !pip install gdown
# # !gdown "https://drive.google.com/uc?id=1NplEy9bpy4aJHORtN4VXapo0RQ9nPtRw"

# # !gdown "https://drive.google.com/uc?id=1JMwEHSTcRn71gBc7JjJG8GhmmqcP8Gpk"

# !gdown "https://drive.google.com/file/d/18jbEBrtPm0CCTlu_XgeNIpviOcemRsLB"

Go
Downloading...
From: https://drive.google.com/file/d/18jbEBrtPm0CCTlu_XgeNIpviOcemRsLB
To: /content/18jbEBrtPm0CCTlu_XgeNIpviOcemRsLB
92.5kB [00:00, 3.21MB/s]


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchaudio
import scipy.signal as signal
from scipy.stats import skew, kurtosis
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

class EEGPreprocessor:
    """
    Implements the comprehensive EEG preprocessing pipeline as described.
    Transforms raw EEG signals into clean, standardized data.
    """

    def __init__(self):
        self.fs = 1000  # Sampling frequency (1000 Hz)

    def remove_dc_offset(self, eeg_data):
        """
        Removes DC offset using a high-pass filter (2nd-order Butterworth, 0.1Hz cutoff)

        Args:
            eeg_data: Raw EEG data of shape (n_channels, n_samples)

        Returns:
            EEG data with DC offset removed
        """
        b, a = signal.butter(2, 0.1 / (self.fs / 2), 'highpass')
        return signal.filtfilt(b, a, eeg_data, axis=1)

    def bandpass_filter(self, eeg_data):
        """
        Applies a 4th-order Butterworth bandpass filter (0.5-70Hz)

        Args:
            eeg_data: EEG data of shape (n_channels, n_samples)

        Returns:
            Bandpass filtered EEG data
        """
        b, a = signal.butter(4, [0.5 / (self.fs / 2), 70 / (self.fs / 2)], 'bandpass')
        return signal.filtfilt(b, a, eeg_data, axis=1)

    def notch_filter(self, eeg_data):
        """
        Applies IIR notch filter centered at 50Hz (power line frequency)
        Note: Using 50Hz for power line interference since this appears to be from a European or Asian dataset.
              For US datasets, 60Hz would be used instead.

        Args:
            eeg_data: EEG data of shape (n_channels, n_samples)

        Returns:
            Notch filtered EEG data
        """
        b, a = signal.iirnotch(50, 30, self.fs)
        return signal.filtfilt(b, a, eeg_data, axis=1)

    def remove_artifacts(self, eeg_data):
        """
        Removes artifacts using robust thresholding based on median absolute deviation (MAD)

        Args:
            eeg_data: EEG data of shape (n_channels, n_samples)

        Returns:
            EEG data with artifacts removed
        """
        cleaned_data = np.copy(eeg_data)

        for ch in range(eeg_data.shape[0]):
            channel_data = eeg_data[ch, :]
            median = np.median(channel_data)
            mad = np.median(np.abs(channel_data - median))

            # Scale factor for MAD to approximate standard deviation
            threshold = 5 * 1.4826 * mad

            # Identify artifacts
            artifact_mask = np.abs(channel_data - median) > threshold

            if np.any(artifact_mask):
                # Get indices of clean samples (not artifacts)
                clean_indices = np.where(~artifact_mask)[0]

                # Get indices of artifact samples
                artifact_indices = np.where(artifact_mask)[0]

                # Interpolate clean values onto artifact positions
                for idx in artifact_indices:
                    # Find nearest clean samples before and after
                    before = clean_indices[clean_indices < idx]
                    after = clean_indices[clean_indices > idx]

                    if len(before) > 0 and len(after) > 0:
                        # Linear interpolation between nearest clean points
                        before_idx = np.max(before)
                        after_idx = np.min(after)
                        before_val = channel_data[before_idx]
                        after_val = channel_data[after_idx]

                        # Linear interpolation
                        weight = (idx - before_idx) / (after_idx - before_idx)
                        cleaned_data[ch, idx] = before_val * (1 - weight) + after_val * weight
                    elif len(before) > 0:
                        # Use last clean value
                        cleaned_data[ch, idx] = channel_data[np.max(before)]
                    elif len(after) > 0:
                        # Use next clean value
                        cleaned_data[ch, idx] = channel_data[np.min(after)]

        return cleaned_data

    def common_average_reference(self, eeg_data):
        """
        Applies common average re-referencing

        Args:
            eeg_data: EEG data of shape (n_channels, n_samples)

        Returns:
            Common average re-referenced EEG data
        """
        return eeg_data - np.mean(eeg_data, axis=0, keepdims=True)

    def normalize(self, eeg_data):
        """
        Normalizes EEG data using robust scaling (median, IQR)
        Falls back to z-score if IQR is zero

        Args:
            eeg_data: EEG data of shape (n_channels, n_samples)

        Returns:
            Normalized EEG data
        """
        normalized_data = np.zeros_like(eeg_data, dtype=np.float32)

        for ch in range(eeg_data.shape[0]):
            channel_data = eeg_data[ch, :]
            median = np.median(channel_data)
            q75, q25 = np.percentile(channel_data, [75, 25])
            iqr = q75 - q25

            if iqr > 1e-10:  # Use robust scaling if IQR is non-zero
                normalized_data[ch, :] = (channel_data - median) / iqr
            else:  # Fall back to z-score normalization
                mean = np.mean(channel_data)
                std = np.std(channel_data)
                if std > 1e-10:
                    normalized_data[ch, :] = (channel_data - mean) / std
                else:
                    normalized_data[ch, :] = channel_data - mean  # Just center if std is too small

        return normalized_data

    def process(self, eeg_data, verbose=False):
        """
        Applies the complete preprocessing pipeline

        Args:
            eeg_data: Raw EEG data of shape (n_channels, n_samples)
            verbose: Whether to print progress updates

        Returns:
            Fully preprocessed EEG data
        """
        # Convert to numpy if it's a tensor
        if isinstance(eeg_data, torch.Tensor):
            eeg_data = eeg_data.numpy()

        # Apply each preprocessing step sequentially with optional progress updates
        eeg_data = self.remove_dc_offset(eeg_data)
        eeg_data = self.bandpass_filter(eeg_data)
        eeg_data = self.notch_filter(eeg_data)
        eeg_data = self.remove_artifacts(eeg_data)
        eeg_data = self.common_average_reference(eeg_data)
        eeg_data = self.normalize(eeg_data)

        return eeg_data

    def process_batch(self, eeg_data_batch, verbose=False):
        """
        Applies the complete preprocessing pipeline to a batch of EEG data

        Args:
            eeg_data_batch: Raw EEG data of shape (batch_size, n_channels, n_samples)
            verbose: Whether to print progress updates

        Returns:
            Fully preprocessed batch of EEG data
        """
        batch_size = eeg_data_batch.shape[0]
        processed_batch = np.zeros_like(eeg_data_batch, dtype=np.float32)

        for i in range(batch_size):
            # Apply preprocessing pipeline to each sample in the batch
            processed_batch[i] = self.process(eeg_data_batch[i], verbose)

        return processed_batch

class OptimizedFeatureExtractor:
    """
    Extracts comprehensive features from preprocessed EEG signals
    """

    def __init__(self, fs=1000, feature_mode='full'):  # Changed default to 'full' for better features
        """
        Initialize feature extractor

        Args:
            fs: Sampling frequency (default: 1000 Hz)
            feature_mode: 'essential' for basic features, 'full' for advanced features
        """
        self.fs = fs
        self.feature_mode = feature_mode

        # Define frequency bands
        self.freq_bands = {
            'delta': (0.5, 4),
            'theta': (4, 8),
            'alpha': (8, 13),
            'beta': (13, 30),
            'gamma': (30, 70)
        }

    def extract_time_domain_features(self, eeg_data, verbose=False):
        """
        Extracts time domain features for each channel

        Args:
            eeg_data: Preprocessed EEG data of shape (n_channels, n_samples)
            verbose: Whether to print progress updates

        Returns:
            Time domain features
        """
        n_channels = eeg_data.shape[0]
        time_features = []

        for ch in range(n_channels):
            channel_data = eeg_data[ch, :]

            # Basic statistics
            mean = np.mean(channel_data)
            variance = np.var(channel_data)
            max_val = np.max(channel_data)
            min_val = np.min(channel_data)
            rms = np.sqrt(np.mean(np.square(channel_data)))

            # Statistical measures
            skewness = skew(channel_data)
            kurt = kurtosis(channel_data)
            peak_to_peak = max_val - min_val

            # Zero crossings (signal oscillation rate)
            zero_crossings = np.sum(np.diff(np.signbit(channel_data).astype(int)) != 0)

            # Collect features
            ch_features = [
                mean, variance, max_val, min_val, rms,
                skewness, kurt, peak_to_peak, zero_crossings
            ]

            time_features.extend(ch_features)

        return np.array(time_features, dtype=np.float32)

    def extract_frequency_domain_features(self, eeg_data, verbose=False):
        """
        Extracts frequency domain features for each channel

        Args:
            eeg_data: Preprocessed EEG data of shape (n_channels, n_samples)
            verbose: Whether to print progress updates

        Returns:
            Frequency domain features
        """
        n_channels = eeg_data.shape[0]
        freq_features = []

        for ch in range(n_channels):
            channel_data = eeg_data[ch, :]

            # Compute Power Spectral Density using Welch's method
            freqs, psd = signal.welch(channel_data, fs=self.fs, nperseg=min(256, len(channel_data)))

            # Total power
            total_power = np.sum(psd)

            # Band power
            band_powers = {}
            for band_name, (low_freq, high_freq) in self.freq_bands.items():
                # Find indices corresponding to the frequency band
                idx_band = np.logical_and(freqs >= low_freq, freqs <= high_freq)
                band_powers[band_name] = np.sum(psd[idx_band])

            # Relative band power
            rel_band_powers = {}
            for band_name in self.freq_bands.keys():
                if total_power > 0:
                    rel_band_powers[band_name] = band_powers[band_name] / total_power
                else:
                    rel_band_powers[band_name] = 0

            # Spectral entropy - measure of irregularity/complexity
            psd_norm = psd / total_power if total_power > 0 else np.zeros_like(psd)
            spec_entropy = -np.sum(psd_norm * np.log2(psd_norm + 1e-16))

            # Collect features
            ch_freq_features = [
                total_power, spec_entropy,
                band_powers['delta'], band_powers['theta'], band_powers['alpha'],
                band_powers['beta'], band_powers['gamma'],
                rel_band_powers['delta'], rel_band_powers['theta'], rel_band_powers['alpha'],
                rel_band_powers['beta'], rel_band_powers['gamma']
            ]

            freq_features.extend(ch_freq_features)

        return np.array(freq_features, dtype=np.float32)

    def extract_advanced_features(self, eeg_data, verbose=False):
        """
        Extracts advanced features (Hjorth parameters, connectivity measures)
        Only used in 'full' feature mode

        Args:
            eeg_data: Preprocessed EEG data of shape (n_channels, n_samples)
            verbose: Whether to print progress updates

        Returns:
            Advanced features
        """
        if self.feature_mode != 'full':
            return np.array([])

        n_channels = eeg_data.shape[0]
        advanced_features = []

        # Hjorth parameters for each channel
        for ch in range(n_channels):
            channel_data = eeg_data[ch, :]

            # Activity (variance)
            activity = np.var(channel_data)

            # Mobility (standard deviation of first derivative / standard deviation of signal)
            first_deriv = np.diff(channel_data, n=1)
            first_deriv = np.append(first_deriv, first_deriv[-1])  # Padding to keep dimensions
            mobility = np.std(first_deriv) / np.std(channel_data) if np.std(channel_data) > 0 else 0

            # Complexity (mobility of first derivative / mobility of signal)
            second_deriv = np.diff(first_deriv, n=1)
            second_deriv = np.append(second_deriv, second_deriv[-1])  # Padding
            complexity = (np.std(second_deriv) / np.std(first_deriv)) / mobility if mobility > 0 and np.std(first_deriv) > 0 else 0

            advanced_features.extend([activity, mobility, complexity])

        # Inter-channel correlation (connectivity measures)
        if n_channels > 1:
            correlations = []
            for i in range(n_channels):
                for j in range(i+1, n_channels):
                    corr = np.corrcoef(eeg_data[i, :], eeg_data[j, :])[0, 1]
                    correlations.append(corr)

            # Add summary connectivity measures
            advanced_features.append(np.mean(correlations))
            advanced_features.append(np.max(correlations))

        return np.array(advanced_features, dtype=np.float32)

    def extract_features(self, eeg_data, verbose=False):
        """
        Extracts all features from preprocessed EEG data

        Args:
            eeg_data: Preprocessed EEG data of shape (n_channels, n_samples)
            verbose: Whether to print progress updates

        Returns:
            Feature vector
        """
        # Extract features from different domains
        time_features = self.extract_time_domain_features(eeg_data, verbose)
        freq_features = self.extract_frequency_domain_features(eeg_data, verbose)
        advanced_features = self.extract_advanced_features(eeg_data, verbose)

        # Combine all features
        all_features = np.concatenate([time_features, freq_features, advanced_features])

        return all_features

    def extract_features_batch(self, eeg_data_batch, verbose=False):
        """
        Extract features from a batch of EEG data with improved vectorization

        Args:
            eeg_data_batch: Batch of preprocessed EEG data, shape (batch_size, n_channels, n_samples)
            verbose: Whether to print progress updates

        Returns:
            Batch of feature vectors, shape (batch_size, n_features)
        """
        batch_size = eeg_data_batch.shape[0]
        all_features = []

        # Process batches more efficiently
        for i in range(batch_size):
            eeg_data = eeg_data_batch[i]
            features = self.extract_features(eeg_data, verbose)
            all_features.append(features)

        return np.stack(all_features)


class MelSpectrogramGenerator:
    """
    Generates 1-second fixed-length mel-spectrograms from text using Tacotron2
    """

    def __init__(self, spec_length_seconds=1.0):  # Increased to 1.0 seconds from 0.5
        """
        Initialize Tacotron2 model for text-to-mel-spectrogram conversion
        """
        self.bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
        self.processor = self.bundle.get_text_processor()
        self.tacotron2 = self.bundle.get_tacotron2()

        # Standard parameters
        self.sample_rate = 22050  # Hz
        self.hop_length = 256     # Standard hop length for 22050Hz audio
        self.n_mels = 80          # Standard number of mel channels
        self.spec_length_seconds = spec_length_seconds

        # Target length calculation based on spec_length_seconds
        # frames = time * sample_rate / hop_length
        self.target_length = int(self.spec_length_seconds * self.sample_rate / self.hop_length)

        # Move model to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tacotron2 = self.tacotron2.to(self.device)

    def text_to_mel_spectrogram(self, text):
        """
        Convert text to fixed-length mel-spectrogram using Tacotron2

        Args:
            text: Input text

        Returns:
            Mel-spectrogram tensor with shape [mel_channels, target_length]
        """
        with torch.no_grad():
            # Encode text
            inputs, input_lengths = self.processor(text)
            inputs = inputs.to(self.device)
            input_lengths = input_lengths.to(self.device)

            # Generate spectrogram
            spec, spec_lengths, _ = self.tacotron2.infer(inputs, input_lengths)

        # Remove batch dimension if present
        spec = spec.squeeze(0).cpu()

        # Ensure the output is exactly target_length frames (1.0 seconds)
        current_length = spec.shape[1]

        if current_length > self.target_length:
            # Truncate if longer than target
            spec = spec[:, :self.target_length]
        elif current_length < self.target_length:
            # Pad with zeros if shorter than target
            padding = torch.zeros(self.n_mels, self.target_length - current_length)
            spec = torch.cat([spec, padding], dim=1)

        return spec


class EEGMelDataset(Dataset):
    """
    Dataset for EEG to mel-spectrogram mapping
    """

    def __init__(self, data_path, label_file, transform=True, precompute=True, spec_length_seconds=1.0):
        """
        Initialize dataset

        Args:
            data_path: Path to the .pth file containing EEG data
            label_file: Path to the label text file
            transform: Whether to preprocess and extract features from EEG
            precompute: Whether to precompute and cache all features (slower startup, faster training)
            spec_length_seconds: Length of the mel-spectrogram in seconds
        """
        self.data = torch.load(data_path, weights_only=False)
        self.label_dict = self._load_label_dict(label_file)
        self.transform = transform
        self.precompute = precompute
        self.spec_length_seconds = spec_length_seconds

        # Initialize preprocessor and feature extractor if transform is True
        if transform:
            self.preprocessor = EEGPreprocessor()
            self.feature_extractor = OptimizedFeatureExtractor(feature_mode='full')

        # Initialize mel-spectrogram generator with fixed output length
        self.mel_generator = MelSpectrogramGenerator(spec_length_seconds=spec_length_seconds)

        # Find unique label IDs actually used in the dataset
        unique_label_ids = set()
        for sample in self.data['dataset']:
            label_id = sample['label']
            unique_label_ids.add(label_id)

        # Generate all mel-spectrograms in advance for the unique labels
        self.mel_spectrograms = {}
        for label_id in unique_label_ids:
            if label_id in self.label_dict:
                text = self.label_dict[label_id]
                self.mel_spectrograms[label_id] = self.mel_generator.text_to_mel_spectrogram(text)

        # Precompute all EEG features if requested (to avoid recomputing during training)
        self.precomputed_features = {}
        if precompute and transform:
            total_samples = len(self.data['dataset'])
            print(f"\nPrecomputing EEG features for all {total_samples} samples...")

            # Process in batches for better efficiency
            batch_size = 32  # Adjust based on memory constraints
            num_batches = (total_samples + batch_size - 1) // batch_size

            with tqdm(total=total_samples, desc="Extracting features") as pbar:
                for batch_idx in range(num_batches):
                    start_idx = batch_idx * batch_size
                    end_idx = min(start_idx + batch_size, total_samples)
                    current_batch_size = end_idx - start_idx

                    # Collect batch of EEG data
                    eeg_batch = []
                    for i in range(start_idx, end_idx):
                        eeg_batch.append(self.data['dataset'][i]['eeg_data'])

                    # Convert list to numpy array for batch processing
                    eeg_batch_array = np.array(eeg_batch)

                    # Process entire batch at once - first preprocessing
                    processed_batch = self.preprocessor.process_batch(eeg_batch_array, verbose=False)

                    # Then extract features in batch mode
                    feature_batch = self.feature_extractor.extract_features_batch(processed_batch, verbose=False)

                    # Store the features
                    for i, sample_idx in enumerate(range(start_idx, end_idx)):
                        self.precomputed_features[sample_idx] = torch.tensor(feature_batch[i], dtype=torch.float32)

                    pbar.update(current_batch_size)

                    # Clear memory periodically
                    if batch_idx % 5 == 0 and torch.cuda.is_available():
                        torch.cuda.empty_cache()

            print(f"Successfully precomputed features for all {total_samples} samples")

    def _load_label_dict(self, label_file):
        """
        Load label dictionary from label file

        Args:
            label_file: Path to the label text file

        Returns:
            Dictionary mapping label_id to corresponding text
        """
        label_dict = {}
        with open(label_file, 'r') as f:
            for line in f:
                parts = line.strip().split(maxsplit=1)
                if len(parts) == 2:
                    label_id, text = parts
                    label_dict[label_id] = text

        return label_dict

    def __len__(self):
        """
        Return the number of samples in the dataset
        """
        return len(self.data['dataset'])

    def __getitem__(self, index):
        """
        Return a sample from the dataset

        Args:
            index: Sample index

        Returns:
            EEG features and corresponding mel-spectrogram
        """
        # Get EEG data and label
        eeg_data = self.data['dataset'][index]['eeg_data']
        label_id = self.data['dataset'][index]['label']

        # Use precomputed features if available, otherwise process on the fly
        if self.precompute and index in self.precomputed_features:
            eeg_features = self.precomputed_features[index]
        elif self.transform:
            # Preprocess and extract features
            eeg_data = self.preprocessor.process(eeg_data, verbose=False)
            eeg_features = self.feature_extractor.extract_features(eeg_data, verbose=False)
            eeg_features = torch.tensor(eeg_features, dtype=torch.float32)
        else:
            eeg_features = eeg_data.flatten()

        # Get corresponding mel-spectrogram
        mel_spec = self.mel_spectrograms[label_id]

        return eeg_features, mel_spec


# Implementing the requested MultiResolutionSpectralLoss
class MultiResolutionSpectralLoss(nn.Module):
    def __init__(self,
                 fft_sizes=[1024, 2048, 512],
                 hop_sizes=[120, 240, 50],
                 win_lengths=[600, 1200, 240],
                 device='cuda'):
        super(MultiResolutionSpectralLoss, self).__init__()
        self.fft_sizes = fft_sizes
        self.hop_sizes = hop_sizes
        self.win_lengths = win_lengths
        self.device = device
        self.l1_loss = nn.L1Loss()

    def forward(self, pred_spec, target_spec):
        # Basic L1 loss between predicted and target spectrograms
        l1_loss = self.l1_loss(pred_spec, target_spec)

        # Spectral convergence loss (normalized Frobenius norm of difference)
        # Prevents NaN by adding a small epsilon to the denominator
        epsilon = 1e-8
        sc_loss = torch.norm(target_spec - pred_spec, p='fro') / (torch.norm(target_spec, p='fro') + epsilon)

        # Combine losses
        total_loss = l1_loss + sc_loss

        return total_loss


# Improved neural network architecture with residual connections and deeper layers
class ImprovedEEGToMelModel(nn.Module):
    """
    Improved neural network model for predicting mel-spectrograms from EEG features
    """

    def __init__(self, input_dim, hidden_dim, output_dim, mel_channels):
        """
        Initialize model with improved architecture

        Args:
            input_dim: Dimension of input EEG features
            hidden_dim: Dimension of hidden layers
            output_dim: Dimension of output sequence length (time frames)
            mel_channels: Number of mel channels in the output spectrogram
        """
        super(ImprovedEEGToMelModel, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.mel_channels = mel_channels

        # Initial normalization
        self.input_norm = nn.BatchNorm1d(input_dim)

        # First block with residual connection
        self.block1_fc1 = nn.Linear(input_dim, hidden_dim)
        self.block1_bn1 = nn.BatchNorm1d(hidden_dim)
        self.block1_fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.block1_bn2 = nn.BatchNorm1d(hidden_dim)
        self.block1_residual = nn.Linear(input_dim, hidden_dim)

        # Second block with residual connection
        self.block2_fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.block2_bn1 = nn.BatchNorm1d(hidden_dim // 2)
        self.block2_fc2 = nn.Linear(hidden_dim // 2, hidden_dim // 2)
        self.block2_bn2 = nn.BatchNorm1d(hidden_dim // 2)
        self.block2_residual = nn.Linear(hidden_dim, hidden_dim // 2)

        # Third block
        self.block3_fc1 = nn.Linear(hidden_dim // 2, hidden_dim // 4)
        self.block3_bn1 = nn.BatchNorm1d(hidden_dim // 4)
        self.block3_fc2 = nn.Linear(hidden_dim // 4, hidden_dim // 4)
        self.block3_bn2 = nn.BatchNorm1d(hidden_dim // 4)

        # Attention mechanism
        self.attention_query = nn.Linear(hidden_dim // 4, hidden_dim // 8)
        self.attention_key = nn.Linear(hidden_dim // 4, hidden_dim // 8)
        self.attention_value = nn.Linear(hidden_dim // 4, hidden_dim // 4)

        # Final processing
        self.fc_final = nn.Linear(hidden_dim // 4, hidden_dim // 4)
        self.bn_final = nn.BatchNorm1d(hidden_dim // 4)

        # Output projection
        self.output = nn.Linear(hidden_dim // 4, output_dim * mel_channels)

        # Activation functions
        self.relu = nn.ReLU()
        self.leaky_relu = nn.LeakyReLU(0.1)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        """
        Forward pass with residual connections and attention

        Args:
            x: Input EEG features

        Returns:
            Predicted mel-spectrogram
        """
        # Input normalization
        x_norm = self.input_norm(x)

        # Block 1 with residual connection
        residual = self.block1_residual(x_norm)
        out = self.leaky_relu(self.block1_bn1(self.block1_fc1(x_norm)))
        out = self.dropout(out)
        out = self.leaky_relu(self.block1_bn2(self.block1_fc2(out)))
        out = out + residual  # Residual connection

        # Block 2 with residual connection
        residual = self.block2_residual(out)
        out = self.leaky_relu(self.block2_bn1(self.block2_fc1(out)))
        out = self.dropout(out)
        out = self.leaky_relu(self.block2_bn2(self.block2_fc2(out)))
        out = out + residual  # Residual connection

        # Block 3
        out = self.leaky_relu(self.block3_bn1(self.block3_fc1(out)))
        out = self.dropout(out)
        out = self.leaky_relu(self.block3_bn2(self.block3_fc2(out)))

        # Self-attention mechanism
        q = self.attention_query(out)  # [batch_size, hidden_dim//8]
        k = self.attention_key(out)    # [batch_size, hidden_dim//8]
        v = self.attention_value(out)  # [batch_size, hidden_dim//4]

        # Compute attention scores (simplified self-attention for 1D features)
        attn_scores = torch.matmul(q.unsqueeze(1), k.unsqueeze(2)) / (self.hidden_dim ** 0.5)  # [batch_size, 1, 1]
        attn_weights = torch.softmax(attn_scores, dim=2)  # [batch_size, 1, 1]

        # Apply attention
        context = attn_weights * v.unsqueeze(1)  # [batch_size, 1, hidden_dim//4]
        context = context.squeeze(1)  # [batch_size, hidden_dim//4]

        # Final processing
        out = out + context  # Add attention context (residual)
        out = self.leaky_relu(self.bn_final(self.fc_final(out)))

        # Output projection
        out = self.output(out)

        # Reshape to spectrogram dimensions [batch, mel_channels, time]
        out = out.view(-1, self.mel_channels, self.output_dim)

        return out


# Update your train function with mixed precision
def train(model, train_loader, val_loader, device, num_epochs=100, lr=0.001):
    """
    Train the model with improved loss function, learning strategy, and mixed precision
    """
    # Custom spectral loss for mel-spectrograms
    criterion = MultiResolutionSpectralLoss(device=device)

    # Using AdamW optimizer with weight decay for regularization
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    # Learning rate scheduler to reduce LR on plateau
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Initialize gradient scaler for mixed precision training
    scaler = GradScaler()

    # Track best validation loss
    best_val_loss = float('inf')
    best_model_state = None

    # For early stopping
    early_stop_counter = 0
    early_stop_patience = 10

    # Training history
    train_losses = []
    val_losses = []

    # Optimize CUDA operations
    torch.backends.cudnn.benchmark = True

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        # Progress bar for training
        with tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}') as pbar:
            for i, (eeg_features, target_specs) in enumerate(pbar):
                # Move data to device
                eeg_features = eeg_features.to(device, non_blocking=True)
                target_specs = target_specs.to(device, non_blocking=True)

                # Zero the parameter gradients
                optimizer.zero_grad(set_to_none=True)  # More efficient than standard zero_grad()

                # Forward pass with mixed precision
                with autocast():
                    outputs = model(eeg_features)
                    # Compute loss
                    loss = criterion(outputs, target_specs)

                # Check for NaN loss
                if torch.isnan(loss).any():
                    print(f"WARNING: NaN loss detected! Skipping batch {i}")
                    continue

                # Backward pass with scaled gradients
                scaler.scale(loss).backward()

                # Gradient clipping on scaled gradients
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Update weights with scaled gradients
                scaler.step(optimizer)
                scaler.update()

                # Update running loss (using item() to avoid memory leaks)
                running_loss += loss.item()

                # Update progress bar
                pbar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

                # Optional: Free up memory explicitly (can help on limited GPU memory)
                if (i+1) % 10 == 0:
                    torch.cuda.empty_cache()

        # Calculate average training loss
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for eeg_features, target_specs in val_loader:
                eeg_features = eeg_features.to(device, non_blocking=True)
                target_specs = target_specs.to(device, non_blocking=True)

                # Use mixed precision for validation too
                with autocast():
                    outputs = model(eeg_features)
                    loss = criterion(outputs, target_specs)

                # Skip NaN losses during validation too
                if not torch.isnan(loss).any():
                    val_loss += loss.item()

        val_loss /= len(val_loader)
        val_losses.append(val_loss)

        # Print epoch statistics
        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

        # Learning rate scheduler step
        scheduler.step(val_loss)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            early_stop_counter = 0
            print(f"New best model saved with validation loss: {best_val_loss:.4f}")
        else:
            early_stop_counter += 1
            print(f"Validation loss did not improve. Early stopping counter: {early_stop_counter}/{early_stop_patience}")

        # Early stopping
        if early_stop_counter >= early_stop_patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

        # Explicitly clear memory at the end of epoch
        torch.cuda.empty_cache()

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Plot training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_loss.png')
    plt.close()

    return model


def evaluate_model(model, test_loader, device, num_samples=5):
    """
    Evaluate model and visualize predicted spectrograms compared to ground truth

    Args:
        model: Trained EEGToMelModel
        test_loader: DataLoader for test data
        device: Device to evaluate on (cuda/cpu)
        num_samples: Number of samples to visualize
    """
    model.eval()

    test_samples = []

    # Get some test samples
    with torch.no_grad():
        for eeg_features, target_specs in test_loader:
            batch_size = eeg_features.shape[0]
            for i in range(min(batch_size, num_samples - len(test_samples))):
                if len(test_samples) >= num_samples:
                    break

                single_eeg = eeg_features[i:i+1].to(device)
                single_target = target_specs[i:i+1].to(device)

                # Generate prediction
                pred_spec = model(single_eeg)

                test_samples.append({
                    'eeg': single_eeg.cpu(),
                    'target': single_target.cpu(),
                    'pred': pred_spec.cpu()
                })

            if len(test_samples) >= num_samples:
                break

    # Create directory for results
    os.makedirs('evaluation_results', exist_ok=True)

    # Visualize spectrograms
    plt.figure(figsize=(15, num_samples*5))

    for i, sample in enumerate(test_samples):
        target_spec = sample['target'].squeeze().numpy()
        pred_spec = sample['pred'].squeeze().numpy()

        # Plot target spectrogram
        plt.subplot(num_samples, 2, i*2+1)
        plt.imshow(target_spec, aspect='auto', origin='lower')
        plt.colorbar(format='%+2.0f dB')
        plt.title(f'Sample {i+1}: Target Mel-Spectrogram')
        plt.tight_layout()

        # Plot predicted spectrogram
        plt.subplot(num_samples, 2, i*2+2)
        plt.imshow(pred_spec, aspect='auto', origin='lower')
        plt.colorbar(format='%+2.0f dB')
        plt.title(f'Sample {i+1}: Predicted Mel-Spectrogram')
        plt.tight_layout()

    plt.savefig('evaluation_results/spectrogram_comparison.png')
    plt.close()

    print(f"Evaluation complete. Visualization saved to 'evaluation_results/spectrogram_comparison.png'")


# Update the main function to optimize data loading
def main():
    """
    Main function to train and evaluate the model with optimized settings for Kaggle P100
    """
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Parameters
    data_path = '/kaggle/working/session_1.pth'  # Path to EEG data
    label_file = '/kaggle/input/labels-file/labels_txt.txt'    # Path to label file
    batch_size = 64  # Increased batch size for better GPU utilization
    num_epochs = 100
    spec_length_seconds = 1.0

    # Enable GPU memory optimization
    torch.backends.cudnn.benchmark = True
    # Use deterministic algorithms when needed
    # torch.backends.cudnn.deterministic = True

    # Select device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Set optimal pin memory and worker settings for Kaggle
    pin_memory = True if torch.cuda.is_available() else False
    num_workers = 2  # Optimal for Kaggle notebooks (don't use too many)

    # Create dataset with memory optimization
    print("Loading and preprocessing dataset...")
    dataset = EEGMelDataset(
        data_path=data_path,
        label_file=label_file,
        transform=True,
        precompute=True,  # Still precomputing features as it's beneficial despite startup cost
        spec_length_seconds=spec_length_seconds
    )

    # Get a sample to determine input and output dimensions
    sample_eeg_features, sample_mel_spec = dataset[0]

    input_dim = sample_eeg_features.shape[0]
    mel_channels = sample_mel_spec.shape[0]
    output_dim = sample_mel_spec.shape[1]

    print(f"Input feature dimension: {input_dim}")
    print(f"Output mel-spectrogram shape: {mel_channels}x{output_dim}")

    # Split dataset with reproducible generator
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    # Note: using a fixed random generator for splitting to ensure reproducibility
    generator = torch.Generator().manual_seed(42)
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size], generator=generator
    )

    # Create optimized data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        prefetch_factor=2,  # Pre-fetch 2 batches per worker
        persistent_workers=True  # Keep workers alive between epochs
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=pin_memory
    )

    # Create model
    print("Initializing model...")
    hidden_dim = 512
    model = ImprovedEEGToMelModel(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        mel_channels=mel_channels
    ).to(device)

    # Display model architecture and parameter count
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model created with {total_params:,} parameters")

    # Train model
    print("Starting training...")
    model = train(model, train_loader, val_loader, device, num_epochs=num_epochs)

    # Save model
    torch.save(model.state_dict(), 'eeg_to_mel_model.pth')
    print("Model saved to 'eeg_to_mel_model.pth'")

    # Evaluate model
    print("Evaluating model...")
    evaluate_model(model, test_loader, device)

# if __name__ == "__main__":
#     main()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import glob
from tqdm import tqdm

# Import your model classes and preprocessing modules
# Make sure these imports match your actual implementation
# from model import ImprovedEEGToMelModel
# from preprocessing import EEGPreprocessor, OptimizedFeatureExtractor

def load_model(model_path, device):
    """
    Load the saved EEG to mel-spectrogram model

    Args:
        model_path: Path to the saved model checkpoint
        device: Device to load the model on

    Returns:
        Loaded model and model parameters
    """
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)

    # If the checkpoint is just the state dict, we need to determine the model parameters
    if isinstance(checkpoint, dict) and 'input_dim' in checkpoint:
        # Structured checkpoint with parameters
        input_dim = checkpoint['input_dim']
        hidden_dim = checkpoint['hidden_dim']
        output_dim = checkpoint['output_dim']
        mel_channels = checkpoint['mel_channels']
        model_state_dict = checkpoint['model_state_dict']
    else:
        # Just the state dict - we need to extract dimensions from the keys
        model_state_dict = checkpoint

        # Try to determine dimensions from the model state dict keys
        input_dim = model_state_dict['block1_fc1.weight'].size(1)
        hidden_dim = model_state_dict['block1_fc1.weight'].size(0)
        output_shape = model_state_dict['output.weight'].size()
        mel_channels = 80  # Standard number from the MelSpectrogramGenerator class
        output_dim = output_shape[0] // mel_channels

    # Initialize model
    model = ImprovedEEGToMelModel(input_dim, hidden_dim, output_dim, mel_channels)
    model.load_state_dict(model_state_dict)
    # model=torch.load(model_path, map_location=device,weights_only=False)
    model = model.to(device)
    model.eval()

    print(f"Model loaded with:")
    print(f"- Input dimension: {input_dim}")
    print(f"- Hidden dimension: {hidden_dim}")
    print(f"- Output time dimension: {output_dim} frames")
    print(f"- Mel channels: {mel_channels}")

    return model, output_dim, mel_channels

def process_eeg_sample(eeg_data, preprocessor, feature_extractor):
    """
    Process an EEG sample to extract features

    Args:
        eeg_data: Raw EEG data
        preprocessor: EEG preprocessor
        feature_extractor: Feature extractor

    Returns:
        Extracted features as a tensor
    """
    # Preprocess and extract features
    processed_eeg = preprocessor.process(eeg_data, verbose=False)
    features = feature_extractor.extract_features(processed_eeg, verbose=False)
    features_tensor = torch.tensor(features, dtype=torch.float32)

    return features_tensor

def predict_mel_spectrogram(model, eeg_features, device):
    """
    Predict mel-spectrogram from EEG features

    Args:
        model: Trained EEG to mel-spectrogram model
        eeg_features: EEG features
        device: Device to run inference on

    Returns:
        Predicted mel-spectrogram
    """
    # Prepare input
    eeg_features = eeg_features.unsqueeze(0).to(device)

    # Generate prediction
    with torch.no_grad():
        predicted_mel = model(eeg_features).cpu().squeeze(0)

    return predicted_mel

def convert_mel_to_audio(mel_spectrogram):
    """
    Convert mel-spectrogram to audio using WaveRNN vocoder

    Args:
        mel_spectrogram: Predicted mel-spectrogram

    Returns:
        Audio waveform and sample rate (as integer)
    """
    # Get the WaveRNN vocoder from torchaudio
    bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH
    vocoder = bundle.get_vocoder()

    # Determine the device
    device = next(vocoder.parameters()).device

    # Ensure the mel spectrogram has the correct format [batch, mel_channels, time]
    # First make sure it's a 3D tensor with batch dimension
    if mel_spectrogram.dim() == 2:
        mel_spectrogram = mel_spectrogram.unsqueeze(0)

    # Move to the device
    mel_spectrogram = mel_spectrogram.to(device)

    # Set lengths tensor (assuming full length)
    lengths = torch.tensor([mel_spectrogram.shape[2]], device=device)

    # Convert spectrogram to waveform
    with torch.no_grad():
        waveforms, sample_rate = vocoder(mel_spectrogram, lengths)

    # Convert sample_rate from tensor to integer
    if torch.is_tensor(sample_rate):
        sample_rate = sample_rate.item()

    return waveforms.cpu(), sample_rate

def load_eeg_from_file(file_path):
    """
    Load EEG data from a .pth file

    Args:
        file_path: Path to the .pth file containing EEG data

    Returns:
        EEG data from the file
    """
    data = torch.load(file_path, map_location=torch.device('cpu'))

    # Handle different data formats
    if isinstance(data, dict):
        # If the data is stored in a dictionary
        if 'eeg_data' in data:
            return data['eeg_data']
        elif 'dataset' in data and len(data['dataset']) > 0:
            return data['dataset'][0]['eeg_data']
    elif isinstance(data, torch.Tensor):
        # If the data is directly a tensor
        return data

    raise ValueError(f"Could not extract EEG data from file: {file_path}")

def generate_submission(model_path, test_files_dir, output_dir='submission'):
    """
    Generate submission files for the competition

    Args:
        model_path: Path to the saved model
        test_files_dir: Directory containing the test .pth files
        output_dir: Directory to save outputs
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load model
    model, output_dim, mel_channels = load_model(model_path, device)

    # Find all test files
    test_files = glob.glob(os.path.join(test_files_dir, "*.pth"))
    print(f"Found {len(test_files)} test files")

    if len(test_files) != 12:
        print(f"Warning: Expected 12 test files, but found {len(test_files)}")

    # Initialize preprocessor and feature extractor
    preprocessor = EEGPreprocessor()
    feature_extractor = OptimizedFeatureExtractor(feature_mode='full')

    # Process each test file
    for idx, test_file in enumerate(tqdm(test_files, desc="Processing test files")):
        # Extract file name without extension
        file_name = os.path.splitext(os.path.basename(test_file))[0]
        print(f"\nProcessing test file {idx + 1}/{len(test_files)}: {file_name}")

        try:
            # Load EEG data
            eeg_data = load_eeg_from_file(test_file)
            print(f"- EEG data shape: {eeg_data.shape}")

            # Process EEG sample
            print("- Extracting features...")
            eeg_features = process_eeg_sample(eeg_data, preprocessor, feature_extractor)

            # Predict mel-spectrogram
            print("- Predicting mel-spectrogram...")
            predicted_mel = predict_mel_spectrogram(model, eeg_features, device)

            # Save the predicted mel-spectrogram as .pth file
            mel_path = os.path.join(output_dir, f"{file_name}_mel.pth")
            torch.save(predicted_mel, mel_path)
            print(f"- Mel spectrogram saved to {mel_path}")

            # Save mel-spectrogram visualization
            plt.figure(figsize=(10, 5))
            plt.imshow(predicted_mel.numpy(), aspect='auto', origin='lower')
            plt.colorbar(format='%+2.0f dB')
            plt.title(f"Test File: {file_name} - Predicted Mel Spectrogram")
            plt.xlabel("Time")
            plt.ylabel("Mel Channels")
            plt.tight_layout()
            mel_plot_path = os.path.join(output_dir, f"{file_name}_mel.png")
            plt.savefig(mel_plot_path)
            plt.close()
            print(f"- Mel spectrogram visualization saved to {mel_plot_path}")

            # Convert predicted mel to audio
            print("- Converting predicted mel to audio...")
            waveform, sample_rate = convert_mel_to_audio(predicted_mel)

            # Save audio file
            audio_path = os.path.join(output_dir, f"{file_name}_audio.wav")
            torchaudio.save(audio_path, waveform, sample_rate)
            print(f"- Audio saved to {audio_path}")

            # Plot waveform
            plt.figure(figsize=(10, 3))
            plt.plot(waveform.squeeze().numpy())
            plt.title(f"Test File: {file_name} - Predicted Waveform")
            plt.xlabel("Time (samples)")
            plt.ylabel("Amplitude")
            plt.tight_layout()
            wave_plot_path = os.path.join(output_dir, f"{file_name}_waveform.png")
            plt.savefig(wave_plot_path)
            plt.close()
            print(f"- Waveform visualization saved to {wave_plot_path}")

        except Exception as e:
            print(f"- Error processing test file {file_name}: {e}")
            import traceback
            traceback.print_exc()

    print("\nSubmission generation complete!")
    print(f"All files saved in {output_dir}")

    # Create a list of files to be included in the zip for submission
    submission_files = []
    for file_name in os.listdir(output_dir):
        if file_name.endswith("_mel.pth") or file_name.endswith("_audio.wav"):
            submission_files.append(os.path.join(output_dir, file_name))

    print(f"\nFiles to include in submission zip ({len(submission_files)}):")
    for file in submission_files:
        print(f"- {os.path.basename(file)}")

    print("\nPlease zip these files and upload them to the competition website.")

if __name__ == "__main__":
    # Update these paths according to your environment
    model_path = 'eeg_to_mel_model.pth'  # Path to your trained model
    test_files_dir = '/content/drive/MyDrive/TEST'           # Directory containing 12 test .pth files
    output_dir = './submission'                   # Output directory for submission files

    generate_submission(model_path, test_files_dir, output_dir)

Using device: cpu
Model loaded with:
- Input dimension: 1490
- Hidden dimension: 512
- Output time dimension: 86 frames
- Mel channels: 80
Found 12 test files


Processing test files:   0%|          | 0/12 [00:00<?, ?it/s]


Processing test file 1/12: Copy of sample_2
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_2_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_2_mel.png
- Converting predicted mel to audio...


Downloading: "https://download.pytorch.org/torchaudio/models/wavernn_10k_epochs_8bits_ljspeech.pth" to /root/.cache/torch/hub/checkpoints/wavernn_10k_epochs_8bits_ljspeech.pth

  0%|          | 0.00/16.7M [00:00<?, ?B/s][A
  4%|▍         | 640k/16.7M [00:00<00:02, 6.43MB/s][A
100%|██████████| 16.7M/16.7M [00:00<00:00, 66.8MB/s]


- Audio saved to ./submission/Copy of sample_2_audio.wav


Processing test files:   8%|▊         | 1/12 [01:21<14:52, 81.18s/it]

- Waveform visualization saved to ./submission/Copy of sample_2_waveform.png

Processing test file 2/12: Copy of sample_1
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_1_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_1_mel.png
- Converting predicted mel to audio...
- Audio saved to ./submission/Copy of sample_1_audio.wav


Processing test files:  17%|█▋        | 2/12 [02:36<12:58, 77.86s/it]

- Waveform visualization saved to ./submission/Copy of sample_1_waveform.png

Processing test file 3/12: Copy of sample_7
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_7_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_7_mel.png
- Converting predicted mel to audio...


Processing test files:  25%|██▌       | 3/12 [03:54<11:40, 77.87s/it]

- Audio saved to ./submission/Copy of sample_7_audio.wav
- Waveform visualization saved to ./submission/Copy of sample_7_waveform.png

Processing test file 4/12: Copy of sample_6
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_6_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_6_mel.png
- Converting predicted mel to audio...


Processing test files:  33%|███▎      | 4/12 [05:10<10:16, 77.10s/it]

- Audio saved to ./submission/Copy of sample_6_audio.wav
- Waveform visualization saved to ./submission/Copy of sample_6_waveform.png

Processing test file 5/12: Copy of sample_5
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_5_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_5_mel.png
- Converting predicted mel to audio...


Processing test files:  42%|████▏     | 5/12 [06:25<08:55, 76.51s/it]

- Audio saved to ./submission/Copy of sample_5_audio.wav
- Waveform visualization saved to ./submission/Copy of sample_5_waveform.png

Processing test file 6/12: Copy of sample_4
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_4_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_4_mel.png
- Converting predicted mel to audio...
- Audio saved to ./submission/Copy of sample_4_audio.wav


Processing test files:  50%|█████     | 6/12 [07:41<07:36, 76.01s/it]

- Waveform visualization saved to ./submission/Copy of sample_4_waveform.png

Processing test file 7/12: Copy of sample_3
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_3_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_3_mel.png
- Converting predicted mel to audio...
- Audio saved to ./submission/Copy of sample_3_audio.wav


Processing test files:  58%|█████▊    | 7/12 [09:11<06:42, 80.60s/it]

- Waveform visualization saved to ./submission/Copy of sample_3_waveform.png

Processing test file 8/12: Copy of sample_12
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_12_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_12_mel.png
- Converting predicted mel to audio...


Processing test files:  67%|██████▋   | 8/12 [10:26<05:16, 79.01s/it]

- Audio saved to ./submission/Copy of sample_12_audio.wav
- Waveform visualization saved to ./submission/Copy of sample_12_waveform.png

Processing test file 9/12: Copy of sample_11
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_11_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_11_mel.png
- Converting predicted mel to audio...


Processing test files:  75%|███████▌  | 9/12 [11:42<03:53, 77.95s/it]

- Audio saved to ./submission/Copy of sample_11_audio.wav
- Waveform visualization saved to ./submission/Copy of sample_11_waveform.png

Processing test file 10/12: Copy of sample_10
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_10_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_10_mel.png
- Converting predicted mel to audio...


Processing test files:  83%|████████▎ | 10/12 [12:57<02:34, 77.21s/it]

- Audio saved to ./submission/Copy of sample_10_audio.wav
- Waveform visualization saved to ./submission/Copy of sample_10_waveform.png

Processing test file 11/12: Copy of sample_9
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_9_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_9_mel.png
- Converting predicted mel to audio...


Processing test files:  92%|█████████▏| 11/12 [14:13<01:16, 76.81s/it]

- Audio saved to ./submission/Copy of sample_9_audio.wav
- Waveform visualization saved to ./submission/Copy of sample_9_waveform.png

Processing test file 12/12: Copy of sample_8
- EEG data shape: torch.Size([62, 501])
- Extracting features...
- Predicting mel-spectrogram...
- Mel spectrogram saved to ./submission/Copy of sample_8_mel.pth
- Mel spectrogram visualization saved to ./submission/Copy of sample_8_mel.png
- Converting predicted mel to audio...
- Audio saved to ./submission/Copy of sample_8_audio.wav


Processing test files: 100%|██████████| 12/12 [15:29<00:00, 77.46s/it]

- Waveform visualization saved to ./submission/Copy of sample_8_waveform.png

Submission generation complete!
All files saved in ./submission

Files to include in submission zip (24):
- Copy of sample_9_mel.pth
- Copy of sample_9_audio.wav
- Copy of sample_3_audio.wav
- Copy of sample_2_mel.pth
- Copy of sample_6_mel.pth
- Copy of sample_12_audio.wav
- Copy of sample_3_mel.pth
- Copy of sample_8_audio.wav
- Copy of sample_11_audio.wav
- Copy of sample_1_audio.wav
- Copy of sample_5_audio.wav
- Copy of sample_6_audio.wav
- Copy of sample_7_mel.pth
- Copy of sample_8_mel.pth
- Copy of sample_12_mel.pth
- Copy of sample_11_mel.pth
- Copy of sample_7_audio.wav
- Copy of sample_5_mel.pth
- Copy of sample_2_audio.wav
- Copy of sample_4_mel.pth
- Copy of sample_1_mel.pth
- Copy of sample_10_mel.pth
- Copy of sample_10_audio.wav
- Copy of sample_4_audio.wav

Please zip these files and upload them to the competition website.





#Inference Using Hi-FI GAN#
#####cause it was hard to accomodate as per stencil in my code , so given here#####

In [None]:
def inference(model, hifigan_vocoder, waveform, device):
    # Do not change the code

    model.eval()
    hifigan_vocoder.eval()

    with torch.no_grad():
        # Ensure proper shape and move to device
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(0)
        waveform = waveform.to(device)

        # Generate log-mel spectrogram from input waveform
        log_mel_spec = model(waveform)

        # Scale the log-mel spectrogram to match HiFi-GAN's expected input range
        if log_mel_spec.min() < 0:
            log_mel_spec = (log_mel_spec + 1) / 2

        # Generate audio from log-mel spectrogram using HiFi-GAN
        generated_audio = hifigan_vocoder(log_mel_spec)
    return generated_audio.squeeze().cpu(), log_mel_spec.cpu()

In [None]:
def inspect_model(model_path):
    loaded = torch.load(model_path, map_location='cpu')
    print(f"Type of loaded object: {type(loaded)}")

    if isinstance(loaded, dict):
        print("Keys in the dictionary:")
        for key in loaded.keys():
            print(f"- {key}: {type(loaded[key])}")

            # If it's another dictionary or has state_dict
            if isinstance(loaded[key], dict):
                print(f"  Nested keys in {key}:")
                for nested_key in loaded[key].keys():
                    print(f"  - {nested_key}")
            elif hasattr(loaded[key], 'state_dict'):
                print(f"  Has state_dict with keys:")
                for state_key in loaded[key].state_dict().keys():
                    print(f"  - {state_key}")

inspect_model("eeg_to_mel_model_final.pth")

Type of loaded object: <class 'dict'>
Keys in the dictionary:
- model_state_dict: <class 'collections.OrderedDict'>
  Nested keys in model_state_dict:
  - fc1.weight
  - fc1.bias
  - bn1.weight
  - bn1.bias
  - bn1.running_mean
  - bn1.running_var
  - bn1.num_batches_tracked
  - fc2.weight
  - fc2.bias
  - bn2.weight
  - bn2.bias
  - bn2.running_mean
  - bn2.running_var
  - bn2.num_batches_tracked
  - fc3.weight
  - fc3.bias
  - bn3.weight
  - bn3.bias
  - bn3.running_mean
  - bn3.running_var
  - bn3.num_batches_tracked
  - output.weight
  - output.bias
- input_dim: <class 'int'>
- hidden_dim: <class 'int'>
- output_dim: <class 'int'>
- mel_channels: <class 'int'>
