# Simple Phoneme Classification with Frozen ECoG Encoder

1. Load pretrained ECoG decoder
2. Freeze it and extract x_common
3. Train a simple classifier head

## 1. Setup

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm

from networks import ECOG_DECODER
from simple_classifier import SimplePhonemeClassifier

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

## 2. Load Pretrained ECoG Decoder

In [None]:
# Initialize ecog decoder (same config as training)
ecog_decoder = ECOG_DECODER["ECoGMapping_RNN"](
    n_mels=128,
    n_formants=6,
    n_formants_noise=1,
    network_db=False,
    causal=False,
    anticausal=False,
)

# Load pretrained weights
checkpoint_path = "output/e2a/YOUR_MODEL/model_epoch99.pth"  # UPDATE THIS
checkpoint = torch.load(checkpoint_path, map_location='cpu')

# This contains ALL weights before x_common: base_model.* + motion_projection.*
ecog_decoder_weights = checkpoint['models']['ecog_decoder']  # <-- Weights for ecog -> x_common

# Load the weights
state_dict = {k.replace('module.', ''): v for k, v in ecog_decoder_weights.items()}
ecog_decoder.load_state_dict(state_dict, strict=False)
ecog_decoder.to(device)

# FREEZE everything before x_common (base_model + motion_projection)
for param in ecog_decoder.parameters():
    param.requires_grad = False

ecog_decoder.eval()
print("✓ Loaded and froze pretrained ECoG decoder")

## 3. Create Classifier Head

In [None]:
NUM_PHONEMES = 9  # Change to your number of classes

classifier = SimplePhonemeClassifier(
    input_channels=32,  # x_common has 32 channels
    num_classes=NUM_PHONEMES,
    hidden_dim=128,
    dropout=0.3
).to(device)

print(f"✓ Created classifier: {sum(p.numel() for p in classifier.parameters())} params")

## 4. Prepare Your Data

Replace this with your actual data loading

In [None]:
# TODO: Load your data here
# Your data format: [trials, electrode_x, electrode_y, time]
# Model expects: [batch, time, electrode_x*electrode_y]

class PhonemeDataset(Dataset):
    def __init__(self, ecog_list, labels):
        """
        Args:
            ecog_list: list/array of ECoG data, each [electrode_x, electrode_y, time]
            labels: list/array of integers, one per sample
        """
        self.ecog_list = ecog_list
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        ecog = self.ecog_list[idx]  # [electrode_x, electrode_y, time]
        
        # Reshape from [electrode_x, electrode_y, time] to [time, electrode_x*electrode_y]
        elec_x, elec_y, time_steps = ecog.shape
        ecog_flat = ecog.reshape(elec_x * elec_y, time_steps).T  # [time, channels]
        
        return {
            'ecog': torch.FloatTensor(ecog_flat),
            'label': torch.LongTensor([self.labels[idx]])
        }

# Example: If your data is [trials, electrode_x, electrode_y, time]
# train_ecog_raw = np.load('your_data.npy')  # Shape: [n_trials, 8, 8, 128]
# train_labels = np.load('your_labels.npy')  # Shape: [n_trials]

# Convert to list format for dataset
# train_ecog = [train_ecog_raw[i] for i in range(len(train_ecog_raw))]
# val_ecog = [val_ecog_raw[i] for i in range(len(val_ecog_raw))]

# train_dataset = PhonemeDataset(train_ecog, train_labels)
# val_dataset = PhonemeDataset(val_ecog, val_labels)

# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

## 5. Training Setup

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

NUM_EPOCHS = 15

## 6. Training Loop

In [None]:
def train_epoch(ecog_decoder, classifier, loader, optimizer, criterion, device):
    classifier.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(loader, desc="Training"):
        ecog = batch['ecog'].to(device)  # [B, T, C]
        labels = batch['label'].squeeze().to(device)  # [B]
        
        # Get x_common from frozen encoder
        with torch.no_grad():
            x_common = ecog_decoder(ecog, return_latent=True)  # [B, 32, T]
        
        # Classify
        logits = classifier(x_common)  # [B, num_classes]
        loss = criterion(logits, labels)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Stats
        total_loss += loss.item()
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), 100 * correct / total


def validate(ecog_decoder, classifier, loader, criterion, device):
    classifier.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            ecog = batch['ecog'].to(device)
            labels = batch['label'].squeeze().to(device)
            
            # Get x_common
            x_common = ecog_decoder(ecog, return_latent=True)
            
            # Classify
            logits = classifier(x_common)
            loss = criterion(logits, labels)
            
            # Stats
            total_loss += loss.item()
            _, predicted = torch.max(logits, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
    return total_loss / len(loader), 100 * correct / total


# Training loop
best_val_acc = 0

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    train_loss, train_acc = train_epoch(ecog_decoder, classifier, train_loader, optimizer, criterion, device)
    val_loss, val_acc = validate(ecog_decoder, classifier, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(classifier.state_dict(), 'best_classifier.pth')
        print(f"✓ Saved best model (acc: {val_acc:.2f}%)")

print(f"\nBest Val Acc: {best_val_acc:.2f}%")

## 7. Test Inference

In [None]:
# Load best model
classifier.load_state_dict(torch.load('best_classifier.pth'))
classifier.eval()

# Get a test sample
sample = val_dataset[0]
ecog = sample['X'].unsqueeze(0).to(device)  # Add batch dim
true_label = sample['label'].item()

with torch.no_grad():
    x_common = ecog_decoder(ecog, return_latent=True)
    logits = classifier(x_common)
    pred = torch.argmax(logits, dim=1).item()

print(f"Input shape: {ecog.shape}")
print(f"x_common shape: {x_common.shape}")
print(f"True label: {true_label}")
print(f"Predicted: {pred}")
print(f"Correct: {pred == true_label}")