# EEGMamba Finetuning on PhysioNet-MI Dataset

This notebook provides a complete guide for finetuning EEGMamba on the PhysioNet Motor Imagery dataset in Google Colab.

## Dataset Overview
- **Dataset**: PhysioNet Motor Movement/Imagery Dataset (EEG-MMI)
- **Task**: Motor imagery classification (4 classes: left hand, right hand, feet, tongue)
- **Subjects**: 109 subjects (70 train, 19 val, 20 test)
- **Channels**: 64 EEG channels
- **Sampling Rate**: 200 Hz after preprocessing
- **Data Format**: LMDB database with preprocessed epochs

## Prerequisites
✅ Google Drive with EEGMamba repository  
✅ Preprocessed PhysioNet data in LMDB format  
✅ GPU-enabled Colab runtime (recommended)

## Step 1: Environment Setup and Dependencies

First, let's mount Google Drive and set up our environment.

In [1]:
# Verify all imports work correctly
print("🔍 Testing all critical imports...")

try:
    import torch
    import einops
    import numpy as np
    import scipy
    import sklearn
    import mne
    import lmdb
    print("✅ Basic packages imported successfully!")
    
    # Test the critical ones
    import mamba_ssm
    import causal_conv1d
    print("✅ Mamba-SSM ecosystem imported successfully!")
    
    # Print versions for debugging
    print(f"\n📋 Package versions:")
    print(f"   • PyTorch: {torch.__version__}")
    print(f"   • NumPy: {np.__version__}")
    print(f"   • MNE: {mne.__version__}")
    print(f"   • Mamba-SSM: {mamba_ssm.__version__}")
    
    print("\n🎉 All dependencies are working! Ready to proceed.")
    
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("\n🔧 If you see import errors:")
    print("1. Restart runtime (Runtime → Restart Runtime)")
    print("2. Re-run the installation cells")
    print("3. Check error messages for specific package issues")

🔍 Testing all critical imports...
✅ Basic packages imported successfully!


  from .autonotebook import tqdm as notebook_tqdm


✅ Mamba-SSM ecosystem imported successfully!

📋 Package versions:
   • PyTorch: 2.5.1+cu121
   • NumPy: 1.26.2
   • MNE: 1.10.2
   • Mamba-SSM: 2.2.6.post3

🎉 All dependencies are working! Ready to proceed.


## Step 2: Data Preprocessing (Optional)

⚠️ **Note**: This step is only needed if you haven't preprocessed your data yet. If you already have the processed LMDB database, skip to Step 3.

The following cell contains the preprocessing script for PhysioNet-MI data. It will:
- Load raw EEG files from PhysioNet dataset
- Apply filtering, rereferencing, and resampling
- Extract motor imagery epochs
- Save to LMDB database format

In [2]:
# PhysioNet-MI Data Preprocessing Script
# Run this only if you need to preprocess raw data

import os
import lmdb
import pickle
import numpy as np
import mne
from tqdm import tqdm

# Configuration
tasks = ['04', '06', '08', '10', '12', '14']  # Motor imagery tasks
root_dir = '/home/mahmood/HosseinDahaei/Codes/EEGMamba/'
output_db_path = '/home/mahmood/HosseinDahaei/Codes/EEGMamba/data'

print("🔧 PhysioNet-MI Preprocessing Configuration:")
print(f"📂 Raw data path: {root_dir}")
print(f"💾 Output database: {output_db_path}")
print(f"🎯 Tasks: {tasks}")

# Check if raw data exists
if not os.path.exists(root_dir):
    print("❌ Raw data directory not found!")
    print("Please ensure PhysioNet data is downloaded to the specified path.")
else:
    print("✅ Raw data directory found!")
    
    # List available subjects
    files = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))]
    files = sorted(files)
    print(f"📊 Found {len(files)} subjects: {files[:5]}...{files[-5:]}")
    
    # Split subjects
    files_dict = {
        'train': files[:70],
        'val': files[70:89], 
        'test': files[89:109],
    }
    
    print(f"📈 Data split: Train={len(files_dict['train'])}, Val={len(files_dict['val'])}, Test={len(files_dict['test'])}")

# Set this to True only if you want to run preprocessing
RUN_PREPROCESSING = False

if RUN_PREPROCESSING:
    print("\n🚀 Starting preprocessing...")
else:
    print("\n⏸️ Preprocessing skipped. Set RUN_PREPROCESSING=True to run.")

🔧 PhysioNet-MI Preprocessing Configuration:
📂 Raw data path: /home/mahmood/HosseinDahaei/Codes/EEGMamba/
💾 Output database: /home/mahmood/HosseinDahaei/Codes/EEGMamba/data
🎯 Tasks: ['04', '06', '08', '10', '12', '14']
✅ Raw data directory found!
📊 Found 9 subjects: ['.tmp', 'data', 'datasets', 'figure', 'models']...['models', 'modules', 'preprocessing', 'pretrained_weights', 'utils']
📈 Data split: Train=9, Val=0, Test=0

⏸️ Preprocessing skipped. Set RUN_PREPROCESSING=True to run.


In [3]:
# Actual preprocessing implementation (only runs if enabled above)
if RUN_PREPROCESSING and os.path.exists(root_dir):
    print("🔄 Running preprocessing...")
    
    # EEG channel selection (64 channels)
    selected_channels = [
        'Fc5.', 'Fc3.', 'Fc1.', 'Fcz.', 'Fc2.', 'Fc4.', 'Fc6.', 
        'C5..', 'C3..', 'C1..', 'Cz..', 'C2..', 'C4..', 'C6..', 
        'Cp5.', 'Cp3.', 'Cp1.', 'Cpz.', 'Cp2.', 'Cp4.', 'Cp6.', 
        'Fp1.', 'Fpz.', 'Fp2.', 'Af7.', 'Af3.', 'Afz.', 'Af4.', 'Af8.', 
        'F7..', 'F5..', 'F3..', 'F1..', 'Fz..', 'F2..', 'F4..', 'F6..', 'F8..', 
        'Ft7.', 'Ft8.', 'T7..', 'T8..', 'T9..', 'T10.', 'Tp7.', 'Tp8.', 
        'P7..', 'P5..', 'P3..', 'P1..', 'Pz..', 'P2..', 'P4..', 'P6..', 'P8..', 
        'Po7.', 'Po3.', 'Poz.', 'Po4.', 'Po8.', 'O1..', 'Oz..', 'O2..', 'Iz..'
    ]
    
    # Initialize LMDB database
    db = lmdb.open(output_db_path, map_size=4614542346)
    dataset = {'train': [], 'val': [], 'test': []}
    
    # Process each split
    for split_name, file_list in files_dict.items():
        print(f"\n📊 Processing {split_name} set ({len(file_list)} subjects)...")
        
        for file in tqdm(file_list, desc=f"{split_name}"):
            for task in tasks:
                try:
                    # Load EEG file
                    file_path = os.path.join(root_dir, file, f'{file}R{task}.edf')
                    raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
                    
                    # Preprocessing pipeline
                    raw.pick_channels(selected_channels, ordered=True)
                    if len(raw.info['bads']) > 0:
                        raw.interpolate_bads()
                    raw.set_eeg_reference(ref_channels='average')
                    raw.filter(l_freq=0.3, h_freq=None, verbose=False)
                    raw.notch_filter(60, verbose=False)
                    raw.resample(200, verbose=False)
                    
                    # Extract epochs
                    events_from_annot, event_dict = mne.events_from_annotations(raw, verbose=False)
                    epochs = mne.Epochs(raw, events_from_annot, event_dict, 
                                      tmin=0, tmax=4-1.0/raw.info['sfreq'], 
                                      baseline=None, preload=True, verbose=False)
                    
                    # Get data and reshape
                    data = epochs.get_data(units='uV')[:, :, -800:]  # Last 4 seconds at 200Hz
                    events = epochs.events[:, 2]
                    
                    # Reshape to (batch, channels, time_segments, samples_per_segment)
                    bz, ch_nums, _ = data.shape
                    data = data.reshape(bz, ch_nums, 4, 200)
                    
                    # Save to LMDB
                    for i, (sample, event) in enumerate(zip(data, events)):
                        if event != 1:  # Skip rest events
                            sample_key = f'{file}R{task}-{i}'
                            data_dict = {
                                'sample': sample,
                                'label': event - 2 if task in ['04', '08', '12'] else event
                            }
                            txn = db.begin(write=True)
                            txn.put(key=sample_key.encode(), value=pickle.dumps(data_dict))
                            txn.commit()
                            dataset[split_name].append(sample_key)
                            
                except Exception as e:
                    print(f"❌ Error processing {file}R{task}: {str(e)[:100]}...")
                    continue
    
    # Save dataset keys
    txn = db.begin(write=True)
    txn.put(key='__keys__'.encode(), value=pickle.dumps(dataset))
    txn.commit()
    db.close()
    
    print(f"\n✅ Preprocessing complete!")
    print(f"📊 Dataset saved with {sum(len(v) for v in dataset.values())} total samples")
    for split, samples in dataset.items():
        print(f"   {split}: {len(samples)} samples")
        
else:
    if not RUN_PREPROCESSING:
        print("⏸️ Preprocessing skipped")
    else:
        print("❌ Cannot run preprocessing - raw data path not found")

⏸️ Preprocessing skipped


In [4]:
# Verify preprocessed data exists
import os

# Check for processed data in Google Drive
data_path = "/home/mahmood/HosseinDahaei/Codes/EEGMamba/data/raw_motor_movement_Imagery/processed_average"
print(f"🔍 Checking for processed data at: {data_path}")

if os.path.exists(data_path):
    files = os.listdir(data_path)
    print(f"✅ Processed data directory found!")
    print(f"📁 Contents: {files}")
    
    # Check for LMDB database files
    if 'data.mdb' in files and 'lock.mdb' in files:
        print("✅ LMDB database files present - data is ready for training!")
        
        # Get database info
        import lmdb
        db = lmdb.open(data_path, readonly=True)
        with db.begin() as txn:
            try:
                keys_data = txn.get('__keys__'.encode())
                if keys_data:
                    dataset_keys = pickle.loads(keys_data)
                    print(f"📊 Dataset splits:")
                    for split, samples in dataset_keys.items():
                        print(f"   {split}: {len(samples)} samples")
                    total_samples = sum(len(v) for v in dataset_keys.values())
                    print(f"📈 Total samples: {total_samples}")
                else:
                    print("⚠️ Dataset keys not found in database")
            except:
                print("⚠️ Could not read dataset keys")
        db.close()
    else:
        print("❌ LMDB database files missing - preprocessing needed")
        print("💡 Set RUN_PREPROCESSING=True in the previous cell to create the database")
else:
    print("❌ Processed data directory not found")
    print("💡 Options:")
    print("   1. Set RUN_PREPROCESSING=True above to preprocess data")
    print("   2. Check if the path is correct")
    print("   3. Upload preprocessed data to Google Drive")

🔍 Checking for processed data at: /home/mahmood/HosseinDahaei/Codes/EEGMamba/data/raw_motor_movement_Imagery/processed_average
✅ Processed data directory found!
📁 Contents: ['lock.mdb', 'data.mdb']
✅ LMDB database files present - data is ready for training!
📊 Dataset splits:
   train: 6300 samples
   val: 1734 samples
   test: 1758 samples
📈 Total samples: 9792


## Step 4: Test Dataset Loading

Let's test the EEGMamba dataset loader to understand our data structure.

In [5]:
# Test the EEGMamba dataset loader
print("🧪 Testing EEGMamba dataset loader...")

try:
    from datasets.physio_dataset import LoadDataset
    
    # Create test parameters for dataset loading
    class TestParams:
        def __init__(self):
            self.datasets_dir = "/home/mahmood/HosseinDahaei/Codes/EEGMamba/data/raw_motor_movement_Imagery/processed_average"
            self.batch_size = 8
    
    # Initialize dataset loader
    test_params = TestParams()
    print(f"📂 Loading dataset from: {test_params.datasets_dir}")
    
    dataset_loader = LoadDataset(test_params)
    data_loaders = dataset_loader.get_data_loader()
    
    print(f"✅ Dataset loaded successfully!")
    print(f"📊 Available data splits: {list(data_loaders.keys())}")
    
    # Test data loading
    for split_name, data_loader in data_loaders.items():
        print(f"\n🔍 Examining {split_name.upper()} set:")
        print(f"   Number of batches: {len(data_loader)}")
        
        # Get first batch to examine data structure
        for batch_idx, (data, labels) in enumerate(data_loader):
            print(f"   ✅ Batch {batch_idx + 1} loaded successfully")
            print(f"   📊 Data shape: {data.shape}")
            print(f"   🎯 Labels shape: {labels.shape}")
            print(f"   📈 Data type: {data.dtype}")
            print(f"   📉 Data range: [{data.min():.3f}, {data.max():.3f}]")
            
            # Check labels
            unique_labels = torch.unique(labels).tolist()
            print(f"   🏷️ Unique labels: {unique_labels}")
            print(f"   📝 Label meanings: 0=left_hand, 1=right_hand, 2=feet, 3=tongue")
            break  # Only examine first batch
        break  # Only examine first split for testing
        
    print("\n🎉 Dataset loading test completed successfully!")
    
except Exception as e:
    print(f"❌ Error testing dataset loader: {e}")
    print("\n🔧 Troubleshooting:")
    print("1. Make sure preprocessed data exists")
    print("2. Check the datasets_dir path")
    print("3. Verify LMDB database integrity")

🧪 Testing EEGMamba dataset loader...
📂 Loading dataset from: /home/mahmood/HosseinDahaei/Codes/EEGMamba/data/raw_motor_movement_Imagery/processed_average
6300 1734 1758
9792
✅ Dataset loaded successfully!
📊 Available data splits: ['train', 'val', 'test']

🔍 Examining TRAIN set:
   Number of batches: 788
   ✅ Batch 1 loaded successfully
   📊 Data shape: torch.Size([8, 64, 4, 200])
   🎯 Labels shape: torch.Size([8])
   📈 Data type: torch.float32
   📉 Data range: [-3.316, 5.211]
   🏷️ Unique labels: [0, 1, 3]
   📝 Label meanings: 0=left_hand, 1=right_hand, 2=feet, 3=tongue

🎉 Dataset loading test completed successfully!


In [6]:
# Detailed data analysis
import torch
import matplotlib.pyplot as plt

print("📊 Detailed Data Analysis")

if 'data_loaders' in locals():
    for split_name, data_loader in data_loaders.items():
        print(f"\n--- {split_name.upper()} SET ANALYSIS ---")
        
        all_labels = []
        batch_count = 0
        sample_count = 0
        
        # Analyze several batches
        for batch_idx, (data, labels) in enumerate(data_loader):
            batch_count += 1
            sample_count += data.shape[0]
            all_labels.extend(labels.tolist())
            
            if batch_idx == 0:  # Detailed analysis of first batch
                print(f"📐 Data tensor shape: {data.shape}")
                print(f"   • Batch size: {data.shape[0]}")
                print(f"   • Channels: {data.shape[1]}")
                print(f"   • Time segments: {data.shape[2]}")
                print(f"   • Samples per segment: {data.shape[3]}")
                print(f"   • Total time points: {data.shape[2] * data.shape[3]}")
                print(f"   • Time duration: {data.shape[2] * data.shape[3] / 200:.1f} seconds (at 200 Hz)")
                
            # Don't load all data to save memory
            if batch_idx >= 5:  # Analyze first 5 batches only
                break
                
        # Label distribution
        unique_labels, counts = torch.unique(torch.tensor(all_labels), return_counts=True)
        print(f"\n🏷️ Label Distribution (first {batch_count} batches):")
        label_names = ['Left Hand', 'Right Hand', 'Feet', 'Tongue']
        for label, count in zip(unique_labels.tolist(), counts.tolist()):
            if label < len(label_names):
                print(f"   {label_names[label]} (class {label}): {count} samples")
        
        print(f"📊 Total samples analyzed: {len(all_labels)}")
        
        # Only analyze first split to save time
        break
        
    print("\n✅ Data analysis completed!")
else:
    print("❌ No data loaders available. Please run the previous cell first.")

📊 Detailed Data Analysis

--- TRAIN SET ANALYSIS ---
📐 Data tensor shape: torch.Size([8, 64, 4, 200])
   • Batch size: 8
   • Channels: 64
   • Time segments: 4
   • Samples per segment: 200
   • Total time points: 800
   • Time duration: 4.0 seconds (at 200 Hz)

🏷️ Label Distribution (first 6 batches):
   Left Hand (class 0): 13 samples
   Right Hand (class 1): 9 samples
   Feet (class 2): 16 samples
   Tongue (class 3): 10 samples
📊 Total samples analyzed: 48

✅ Data analysis completed!


## Step 5: Training Configuration

Now let's set up the training parameters for EEGMamba finetuning on PhysioNet-MI.

In [7]:
# EEGMamba Training Configuration
print("⚙️ EEGMamba Training Configuration")

# Dataset and paths
DATASET_NAME = "PhysioNet-MI"
DATASETS_DIR = "/home/mahmood/HosseinDahaei/Codes/EEGMamba/data/raw_motor_movement_Imagery/processed_average"
MODEL_DIR = "./results/physio_models"
PRETRAINED_WEIGHTS = "pretrained_weights/pretrained_weights.pth"

# Model parameters
NUM_CLASSES = 4  # Motor imagery classes: left hand, right hand, feet, tongue
CLASSIFIER_TYPE = "all_patch_reps"  # EEGMamba classifier type

# Training hyperparameters
EPOCHS = 5  # Start with quick test, increase later
BATCH_SIZE = 16  # Adjust based on GPU memory
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 5e-2
CUDA_DEVICE = 0

# Advanced options
OPTIMIZER = "AdamW"
DROPOUT = 0.1
LABEL_SMOOTHING = 0.1
USE_PRETRAINED = True

print("📋 Configuration Summary:")
print(f"   🎯 Dataset: {DATASET_NAME}")
print(f"   📂 Data path: {DATASETS_DIR}")
print(f"   🏷️ Classes: {NUM_CLASSES}")
print(f"   🏃 Epochs: {EPOCHS}")
print(f"   📦 Batch size: {BATCH_SIZE}")
print(f"   📈 Learning rate: {LEARNING_RATE}")
print(f"   💾 Model save dir: {MODEL_DIR}")
print(f"   🎭 Use pretrained: {USE_PRETRAINED}")

# Verify pretrained weights exist
if USE_PRETRAINED:
    if os.path.exists(PRETRAINED_WEIGHTS):
        print(f"   ✅ Pretrained weights found: {PRETRAINED_WEIGHTS}")
    else:
        print(f"   ⚠️ Pretrained weights not found: {PRETRAINED_WEIGHTS}")
        print("   💡 Will train from scratch or download weights")

# Create results directory
os.makedirs(MODEL_DIR, exist_ok=True)
print(f"   📁 Results directory ready: {MODEL_DIR}")

⚙️ EEGMamba Training Configuration
📋 Configuration Summary:
   🎯 Dataset: PhysioNet-MI
   📂 Data path: /home/mahmood/HosseinDahaei/Codes/EEGMamba/data/raw_motor_movement_Imagery/processed_average
   🏷️ Classes: 4
   🏃 Epochs: 5
   📦 Batch size: 16
   📈 Learning rate: 0.0001
   💾 Model save dir: ./results/physio_models
   🎭 Use pretrained: True
   ✅ Pretrained weights found: pretrained_weights/pretrained_weights.pth
   📁 Results directory ready: ./results/physio_models


In [8]:
# Quick environment check before training
print("🔍 Pre-training Environment Check")

# Check GPU availability
if torch.cuda.is_available():
    print(f"✅ CUDA available: {torch.cuda.get_device_name(0)}")
    print(f"   GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ CUDA not available - training will be slow on CPU")

# Check key files
key_files = {
    'finetune_main.py': 'Main training script',
    'models/': 'Model definitions',
    'datasets/': 'Dataset loaders',
    'pretrained_weights/': 'Pretrained weights'
}

for file_path, description in key_files.items():
    if os.path.exists(file_path):
        print(f"✅ {description}: {file_path}")
    else:
        print(f"❌ Missing {description}: {file_path}")

# Check if we can import the model
try:
    from models.model_for_physio import Model
    print("✅ PhysioNet model import successful")
except Exception as e:
    print(f"❌ Model import failed: {e}")

print("\n🚀 Ready for training!")

🔍 Pre-training Environment Check
✅ CUDA available: NVIDIA RTX A4000
   GPU memory: 16.8 GB
✅ Main training script: finetune_main.py
✅ Model definitions: models/
✅ Dataset loaders: datasets/
✅ Pretrained weights: pretrained_weights/
✅ PhysioNet model import successful

🚀 Ready for training!


## Step 4: Quick Test Run (5 epochs)

Let's first do a quick test with fewer epochs to make sure everything works:

In [9]:
# Quick test command
test_command = f"""
python finetune_main.py \
    --downstream_dataset {DATASET_NAME} \
    --datasets_dir {DATASETS_DIR} \
    --num_of_classes {NUM_CLASSES} \
    --model_dir {MODEL_DIR}_test \
    --epochs 5 \
    --batch_size 16 \
    --lr {LEARNING_RATE} \
    --weight_decay {WEIGHT_DECAY} \
    --cuda {CUDA_DEVICE} \
    --use_pretrained_weights True \
    --foundation_dir {PRETRAINED_WEIGHTS}
"""

print("Quick Test Command:")
print(test_command.strip())
print("\n📝 Copy and run this command in terminal to test the setup!")

Quick Test Command:
python finetune_main.py     --downstream_dataset PhysioNet-MI     --datasets_dir /home/mahmood/HosseinDahaei/Codes/EEGMamba/data/raw_motor_movement_Imagery/processed_average     --num_of_classes 4     --model_dir ./results/physio_models_test     --epochs 5     --batch_size 16     --lr 0.0001     --weight_decay 0.05     --cuda 0     --use_pretrained_weights True     --foundation_dir pretrained_weights/pretrained_weights.pth

📝 Copy and run this command in terminal to test the setup!


In [None]:
# Execute quick test run (5 epochs)
print("🚀 Starting Quick Test Run (5 epochs)...")
print("This will test the setup with a short training run.")

import subprocess
import sys

# Build the test command
test_args = [
    sys.executable, "finetune_main.py",
    "--downstream_dataset", DATASET_NAME,
    "--datasets_dir", DATASETS_DIR,
    "--num_of_classes", str(NUM_CLASSES),
    "--model_dir", f"{MODEL_DIR}_test",
    "--epochs", "1",
    "--batch_size", "16", 
    "--lr", str(LEARNING_RATE),
    "--weight_decay", str(WEIGHT_DECAY),
    "--cuda", str(CUDA_DEVICE),
    "--use_pretrained_weights", "True",
    "--foundation_dir", PRETRAINED_WEIGHTS
]

print(f"📋 Command: {' '.join(test_args)}")
print("\n" + "="*50)

# Execute the command
try:
    result = subprocess.run(test_args, capture_output=True, text=True, timeout=3600)  # 60 min timeout
    
    print("STDOUT:")
    print(result.stdout)
    
    if result.stderr:
        print("\nSTDERR:")
        print(result.stderr)
    
    if result.returncode == 0:
        print("\n✅ Test run completed successfully!")
    else:
        print(f"\n❌ Test run failed with return code: {result.returncode}")
        
except subprocess.TimeoutExpired:
    print("⏱️ Command timed out after 30 minutes")
except Exception as e:
    print(f"❌ Error running command: {e}")
    
print("\n" + "="*50)

🚀 Starting Quick Test Run (5 epochs)...
This will test the setup with a short training run.
📋 Command: /opt/conda/envs/eegmamba3/bin/python finetune_main.py --downstream_dataset PhysioNet-MI --datasets_dir /home/mahmood/HosseinDahaei/Codes/EEGMamba/data/raw_motor_movement_Imagery/processed_average --num_of_classes 4 --model_dir ./results/physio_models_test --epochs 1 --batch_size 16 --lr 0.0001 --weight_decay 0.05 --cuda 0 --use_pretrained_weights True --foundation_dir pretrained_weights/pretrained_weights.pth



### 🚀 Execute Test Run

**Run this cell to execute the quick test training (5 epochs):**
- Tests that everything is set up correctly
- Takes ~10-15 minutes
- Uses smaller batch size for memory safety
- Results saved to `{MODEL_DIR}_test`

In [None]:
# Install dependencies for EEGMamba (optimized for Colab)
print("🔧 Setting up EEGMamba environment...")

# First, install basic dependencies
!pip install einops lmdb torch torchvision torchaudio
!pip install scipy scikit-learn matplotlib

# Install causal-conv1d first (required for mamba-ssm)
print("📦 Installing causal-conv1d...")
!pip install causal-conv1d --no-cache-dir

# Install mamba-ssm with specific flags for faster compilation
print("📦 Installing mamba-ssm (this may take a few minutes)...")
!pip install mamba-ssm --no-cache-dir

print("✅ All dependencies installed!")
print("🔄 Please restart runtime if running in Colab for changes to take effect.")

🔧 Setting up EEGMamba environment...
📦 Installing causal-conv1d...
Collecting causal-conv1d
  Downloading causal_conv1d-1.5.2.tar.gz (23 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ninja (from causal-conv1d)
  Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.1 kB)
Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (180 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: causal-conv1d
  Building wheel for causal-conv1d (pyproject.toml) ... [?25l[?25hdone
  Created wheel for causal-conv1d: filename=causal_conv1d-1.5.2-cp312-cp312-linux_x86_64.whl size=151160839 sha256=7cc4ae241543f93acfa4d96ecbc43e532f3aeed171f88c859f873f1a617e9606
  Stored in directory: /t

## Step 5: Full Training Command

Once the test works, use this command for full training:

In [None]:
# Full training command
full_command = f"""
python finetune_main.py \
    --downstream_dataset {DATASET_NAME} \
    --datasets_dir {DATASETS_DIR} \
    --num_of_classes {NUM_CLASSES} \
    --model_dir {MODEL_DIR} \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --lr {LEARNING_RATE} \
    --weight_decay {WEIGHT_DECAY} \
    --cuda {CUDA_DEVICE} \
    --use_pretrained_weights True \
    --foundation_dir {PRETRAINED_WEIGHTS} \
    --optimizer AdamW \
    --classifier all_patch_reps \
    --dropout 0.1 \
    --label_smoothing 0.1
"""

print("Full Training Command:")
print(full_command.strip())
print("\n🚀 This will run the complete training (note: it runs 100 iterations as per the code)")

In [None]:
# Execute full training run
print("🚀 Starting Full Training Run...")
print("⚠️ WARNING: This will run the complete training (note: default runs 100 iterations)")
print("💡 Consider modifying finetune_main.py to reduce iterations if needed")

import subprocess
import sys
import time

# Build the full training command  
full_args = [
    sys.executable, "finetune_main.py",
    "--downstream_dataset", DATASET_NAME,
    "--datasets_dir", DATASETS_DIR,
    "--num_of_classes", str(NUM_CLASSES),
    "--model_dir", MODEL_DIR,
    "--epochs", str(EPOCHS),
    "--batch_size", str(BATCH_SIZE),
    "--lr", str(LEARNING_RATE),
    "--weight_decay", str(WEIGHT_DECAY),
    "--cuda", str(CUDA_DEVICE),
    "--use_pretrained_weights", "True",
    "--foundation_dir", PRETRAINED_WEIGHTS,
    "--optimizer", "AdamW",
    "--classifier", "all_patch_reps",
    "--dropout", "0.1",
    "--label_smoothing", "0.1"
]

print(f"📋 Command: {' '.join(full_args)}")

# Ask for confirmation
confirm = input("\n⚠️ This is a long training run. Continue? (y/N): ")
if confirm.lower() != 'y':
    print("❌ Training cancelled by user")
else:
    print("\n" + "="*50)
    start_time = time.time()
    
    # Execute the command
    try:
        # Use a longer timeout for full training (6 hours)
        result = subprocess.run(full_args, capture_output=False, text=True, timeout=21600)
        
        end_time = time.time()
        duration = end_time - start_time
        
        print(f"\n⏱️ Training completed in {duration/3600:.2f} hours")
        
        if result.returncode == 0:
            print("✅ Full training completed successfully!")
            print(f"📁 Check results in: {MODEL_DIR}")
        else:
            print(f"❌ Training failed with return code: {result.returncode}")
            
    except subprocess.TimeoutExpired:
        print("⏱️ Training timed out after 6 hours")
    except KeyboardInterrupt:
        print("🛑 Training interrupted by user")
    except Exception as e:
        print(f"❌ Error running training: {e}")
    
    print("\n" + "="*50)

### 🏁 Execute Full Training

**Run this cell for complete training:**
- Full hyperparameter optimization
- Runs 100 iterations by default (as per original code)
- Takes several hours to complete
- Includes confirmation prompt before starting
- Results saved to `{MODEL_DIR}`

⚠️ **Important**: The original code runs 100 training iterations. Consider modifying `finetune_main.py` line `for i in range(1, 100):` to `range(1, 2)` if you only want one training run.

## Step 6: Alternative - Single Run Training

If you want to run just one training iteration instead of 100, you can modify the code:

In [None]:
# Show the current PhysioNet-MI loop in finetune_main.py
print("Current PhysioNet-MI section in finetune_main.py:")
print("""
elif params.downstream_dataset == 'PhysioNet-MI':
    for i in range(1, 100):  # <-- This runs 100 times!
        print('The {}th fold'.format(i))
        load_dataset = physio_dataset.LoadDataset(params)
        data_loader = load_dataset.get_data_loader()
        model = model_for_physio.Model(params)
        t = Trainer(params, data_loader, model)
        t.train_for_multiclass()
""")

print("\n💡 To run just once, you can:")
print("1. Change 'range(1, 100)' to 'range(1, 2)' in finetune_main.py")
print("2. Or comment out the for loop entirely")

## Step 7: Monitor Training Progress

During training, you can monitor progress with these commands:

In [None]:
# Commands to monitor training
print("Monitoring Commands:")
print("1. Watch GPU usage:")
print("   watch -n 1 nvidia-smi")
print("\n2. Monitor log files (if any):")
print("   tail -f training.log")
print("\n3. Check model directory:")
print(f"   ls -la {MODEL_DIR}/")
print("\n4. Monitor CPU/Memory:")
print("   htop")

## Step 8: Expected Training Time and Resources

In [None]:
# Estimate resources and time
print("Training Resource Requirements:")
print("💾 GPU Memory: ~8-12 GB (depending on batch size)")
print("⏱️  Time per epoch: ~2-5 minutes (depending on GPU)")
print(f"🕐 Total time estimate: ~{EPOCHS * 3} minutes for {EPOCHS} epochs")
print("🔄 Total runs: 100 iterations (as per original code)")
print(f"📊 Total estimated time: ~{EPOCHS * 3 * 100 / 60:.1f} hours for all 100 runs")

print("\n📈 Dataset Size:")
print("• Training samples: ~1400-1600")
print("• Validation samples: ~380-420")
print("• Test samples: ~400-450")
print("• Total: ~2200-2500 samples")

## Step 9: Troubleshooting Common Issues

In [None]:
print("Common Issues and Solutions:")
print("\n1. CUDA Out of Memory:")
print("   → Reduce batch_size from 32 to 16 or 8")
print("   → Use: --batch_size 16")

print("\n2. Dataset not found:")
print("   → Check if LMDB files exist in processed_average directory")
print("   → Run preprocessing if needed")

print("\n3. Pretrained weights not found:")
print("   → Check if pretrained_weights.pth exists")
print("   → Use: --use_pretrained_weights False (to train from scratch)")

print("\n4. Too slow training:")
print("   → Increase batch_size if GPU memory allows")
print("   → Reduce num_workers if CPU limited")
print("   → Use mixed precision training (if implemented)")

## Step 10: Next Steps After Training

In [None]:
print("After Training Completes:")
print("\n1. 📊 Analyze Results:")
print("   → Check accuracy/loss curves")
print("   → Compare with baseline results")
print("   → Look at confusion matrices")

print("\n2. 💾 Save Important Files:")
print("   → Best model checkpoints")
print("   → Training logs")
print("   → Configuration files")

print("\n3. 🔬 Further Experiments:")
print("   → Try different learning rates")
print("   → Experiment with classifier types")
print("   → Adjust data augmentation")
print("   → Compare with other baselines")

print("\n4. 📝 Document Results:")
print("   → Record best accuracy achieved")
print("   → Note optimal hyperparameters")
print("   → Save example predictions")