# Train Neural Speech Decoding on Google Colab

**Requirements:** Colab Pro (for 24hr sessions + better GPU)

**Total Time:** ~16 hours (6hrs Stage 1 + 10hrs Stage 2 on A100)

## Step 1: Check GPU

In [None]:
!nvidia-smi

import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

## Step 2: Clone Repository

In [None]:
# Clone the repo into /content
%cd /content
!git clone https://github.com/flinkerlab/neural_speech_decoding.git

# Enter the repo - this is our workspace
%cd neural_speech_decoding

!pwd

## Step 3: Install Dependencies

In [None]:
!pip install -r requirements.txt -q

## Step 4: Mount Google Drive & Setup Folders

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create persistent storage in Google Drive
!mkdir -p /content/drive/MyDrive/nsd_data
!mkdir -p /content/drive/MyDrive/nsd_outputs

# Link them to the repo workspace
!mkdir -p example_data
!ln -s /content/drive/MyDrive/nsd_data example_data/data
!ln -s /content/drive/MyDrive/nsd_outputs output

print("✓ Google Drive mounted and linked")
print(f"  Data: example_data/data -> Google Drive")
print(f"  Output: output -> Google Drive")

## Step 5: Upload Data to Google Drive

**Before running training, you need to:**

1. Download HB02 dataset from: https://data.mendeley.com/datasets/fp4bv9gtwk/2
2. Upload the files to: **MyDrive/nsd_data/** in your Google Drive
3. Verify they're there by running the cell below

In [None]:
# Check if data is present
!ls -lh /content/drive/MyDrive/nsd_data/

# Should see HB02 data files (*.hdf5 or *.h5)
print("\nIf empty, please upload HB02 data to: MyDrive/nsd_data/ in Google Drive")

## Step 6: Update Config

In [None]:
import json

# Update data path in config
with open('configs/AllSubjectInfo.json', 'r') as f:
    config = json.load(f)

config['Shared']['RootPath'] = './example_data/data/'

with open('configs/AllSubjectInfo.json', 'w') as f:
    json.dump(config, f, indent=4)

print(f"✓ Config updated: RootPath = {config['Shared']['RootPath']}")

## Step 7: Stage 1 - Audio-to-Audio Training (a2a)

**Time:** ~6 hours on A100

In [None]:
!python train_a2a.py \
  --OUTPUT_DIR output/a2a/HB02 \
  --trainsubject HB02 \
  --testsubject HB02 \
  --param_file configs/a2a_production.yaml \
  --batch_size 16 \
  --reshape 1 \
  --DENSITY "HB" \
  --wavebased 1 \
  --n_filter_samples 80 \
  --n_fft 256 \
  --formant_supervision 1 \
  --intensity_thres -1 \
  --epoch_num 60

In [None]:
# Check Stage 1 completed
!ls output/a2a/HB02/*.pth | wc -l
print("Expected: 60 checkpoint files (model_epoch0.pth to model_epoch59.pth)")

## Step 8: Stage 2 - ECoG-to-Audio Training (e2a)

**Time:** ~10 hours on A100

**This produces the weights you need for phoneme classification!**

In [None]:
!python train_e2a.py \
  --OUTPUT_DIR output/e2a/resnet_HB02 \
  --trainsubject HB02 \
  --testsubject HB02 \
  --param_file configs/e2a_production.yaml \
  --batch_size 16 \
  --MAPPING_FROM_ECOG ECoGMapping_ResNet \
  --reshape 1 \
  --DENSITY "HB" \
  --wavebased 1 \
  --dynamicfiltershape 0 \
  --n_filter_samples 80 \
  --n_fft 256 \
  --formant_supervision 1 \
  --intensity_thres -1 \
  --epoch_num 60 \
  --pretrained_model_dir output/a2a/HB02 \
  --causal 0

In [None]:
# Check Stage 2 completed
!ls output/e2a/resnet_HB02/*.pth | wc -l
!ls -lh output/e2a/resnet_HB02/model_epoch59.pth

print("\n✓✓✓ TRAINING COMPLETE ✓✓✓")
print("\nYour pretrained weights:")
print("  output/e2a/resnet_HB02/model_epoch59.pth")
print("\nAlso saved to Google Drive:")
print("  /content/drive/MyDrive/nsd_outputs/e2a/resnet_HB02/model_epoch59.pth")

## Step 9: Download Weights (Optional)

In [None]:
from google.colab import files

# Uncomment to download the final checkpoint to your computer:
# files.download('output/e2a/resnet_HB02/model_epoch59.pth')

print("Weights are in Google Drive at: MyDrive/nsd_outputs/e2a/resnet_HB02/")

## Next Steps: Use for Phoneme Classification

Update your `ecog_decoder_finetune.ipynb` with:

```python
checkpoint_path = "output/e2a/resnet_HB02/model_epoch59.pth"
# Or from Google Drive:
# checkpoint_path = "/content/drive/MyDrive/nsd_outputs/e2a/resnet_HB02/model_epoch59.pth"
```