In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import gc

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Memory-efficient implementation
class EMGEEGDataset(Dataset):
    def __init__(self, emg_data, eeg_data, transform=None):
        print("Creating EMGEEGDataset...")
        # Fixed window size for all samples
        self.window_size = 1000
        
        # Filter to have matching samples (subject, repetition, gesture)
        emg_keys = emg_data[['subject', 'repetition', 'gesture']].apply(tuple, axis=1)
        eeg_keys = eeg_data[['subject', 'repetition', 'gesture']].apply(tuple, axis=1)
        
        # Find common keys
        common_keys = set(emg_keys).intersection(set(eeg_keys))
        print(f"Found {len(common_keys)} common sample keys")
        
        # Filter data to include only common keys
        self.emg_samples = []
        self.eeg_samples = []
        self.labels = []
        
        # Track processed samples
        processed = 0
        
        for subject, repetition, gesture in common_keys:
            # Get EMG data for this key
            emg_sample = emg_data[(emg_data['subject'] == subject) & 
                                  (emg_data['repetition'] == repetition) & 
                                  (emg_data['gesture'] == gesture)]
            
            # Get EEG data for this key
            eeg_sample = eeg_data[(eeg_data['subject'] == subject) & 
                                  (eeg_data['repetition'] == repetition) & 
                                  (eeg_data['gesture'] == gesture)]
            
            # Skip if insufficient data
            if len(emg_sample) < 10 or len(eeg_sample) < 10:
                continue
            
            # Get features (first 8 columns are channel data)
            emg_features = emg_sample.iloc[:, :8].values
            eeg_features = eeg_sample.iloc[:, :8].values
            
            # Make sure features are the expected type
            emg_features = emg_features.astype(np.float32)
            eeg_features = eeg_features.astype(np.float32)
            
            # Use the fixed window size
            window_size = self.window_size
            
            # Truncate or pad as necessary
            if len(emg_features) > window_size:
                emg_features = emg_features[:window_size]
            else:
                # Pad with zeros
                pad_length = window_size - len(emg_features)
                emg_features = np.vstack([emg_features, np.zeros((pad_length, 8), dtype=np.float32)])
            
            if len(eeg_features) > window_size:
                eeg_features = eeg_features[:window_size]
            else:
                # Pad with zeros
                pad_length = window_size - len(eeg_features)
                eeg_features = np.vstack([eeg_features, np.zeros((pad_length, 8), dtype=np.float32)])
            
            # Append to lists
            self.emg_samples.append(emg_features)
            self.eeg_samples.append(eeg_features)
            self.labels.append(gesture - 1)  # Adjust to 0-indexed
            
            processed += 1
            if processed % 100 == 0:
                print(f"Processed {processed} samples")
        
        print(f"Successfully created dataset with {len(self.labels)} samples")
        
        # Keep as lists instead of converting to numpy arrays (to handle variable lengths)
        self.labels = np.array(self.labels)
        
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        emg_sample = self.emg_samples[idx]
        eeg_sample = self.eeg_samples[idx]
        label = self.labels[idx]
        
        # Apply transformations if any
        if self.transform:
            emg_sample = self.transform(emg_sample)
            eeg_sample = self.transform(eeg_sample)
        
        # Convert to torch tensors (ensure numpy array first)
        emg_tensor = torch.tensor(emg_sample, dtype=torch.float)
        eeg_tensor = torch.tensor(eeg_sample, dtype=torch.float)
        label_tensor = torch.tensor(label, dtype=torch.long)
        
        return emg_tensor, eeg_tensor, label_tensor

# Memory-efficient data loading function
def load_and_preprocess_data(emg_path, eeg_path):
    print("Loading data...")
    
    # Load data
    emg_data = pd.read_csv(emg_path)
    eeg_data = pd.read_csv(eeg_path)
    
    print("Converting data types...")
    # Convert channel columns to float64 first to avoid FutureWarning
    for col in emg_data.columns[:8]:
        emg_data[col] = emg_data[col].astype('float64')
    
    for col in eeg_data.columns[:8]:
        eeg_data[col] = eeg_data[col].astype('float64')
    
    print("Normalizing data...")
    
    # Normalize channel data (first 8 columns)
    emg_scaler = StandardScaler()
    eeg_scaler = StandardScaler()
    
    # Scale only channel columns (first 8)
    emg_data.iloc[:, :8] = emg_scaler.fit_transform(emg_data.iloc[:, :8])
    eeg_data.iloc[:, :8] = eeg_scaler.fit_transform(eeg_data.iloc[:, :8])
    
    # Print info about the data
    print("🟢 EMG Sample Info:")
    print(f"📏 Length: {len(emg_data)}, 🛠️ Shape: {emg_data.shape}")
    print(emg_data.head())
    print("📊 EMG Data Types:")
    print(emg_data.dtypes)
    
    print("\n🟣 EEG Sample Info:")
    print(f"📏 Length: {len(eeg_data)}, 🛠️ Shape: {eeg_data.shape}")
    print(eeg_data.head())
    print("📊 EEG Data Types:")
    print(eeg_data.dtypes)
    
    return emg_data, eeg_data

# CNN Model
class MultimodalCNN(nn.Module):
    def __init__(self, num_classes=7, window_size=1000):
        super(MultimodalCNN, self).__init__()
        
        # EMG branch
        self.emg_conv = nn.Sequential(
            nn.Conv1d(8, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.MaxPool1d(2),
            nn.Dropout(0.3)
        )
        
        # EEG branch
        self.eeg_conv = nn.Sequential(
            nn.Conv1d(8, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.MaxPool1d(2),
            nn.Dropout(0.3)
        )
        
        # Fusion and classification layers
        # Calculate feature size after convolutions and pooling
        # After 3 max pooling layers (each dividing by 2), the feature size is window_size / 2^3
        feature_length = window_size // 8  # window_size divided by 2^3
        self.feature_size = 128 * feature_length
        
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_size * 2, 512),  # *2 because we concatenate EMG and EEG features
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, emg, eeg):
        # Transpose to get channels as first dimension for Conv1d (batch, channels, sequence)
        emg = emg.permute(0, 2, 1)
        eeg = eeg.permute(0, 2, 1)
        
        # Print shapes for debugging
        # print(f"EMG input shape: {emg.shape}")
        # print(f"EEG input shape: {eeg.shape}")
        
        # EMG branch
        emg_features = self.emg_conv(emg)
        # print(f"EMG features shape: {emg_features.shape}")
        emg_features = emg_features.reshape(emg_features.size(0), -1)  # Flatten
        
        # EEG branch
        eeg_features = self.eeg_conv(eeg)
        # print(f"EEG features shape: {eeg_features.shape}")
        eeg_features = eeg_features.reshape(eeg_features.size(0), -1)  # Flatten
        
        # Concatenate features
        combined_features = torch.cat((emg_features, eeg_features), dim=1)
        
        # Classification
        output = self.classifier(combined_features)
        
        return output

# Training and evaluation function
def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, num_epochs=50, patience=10):
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for emg_data, eeg_data, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            emg_data, eeg_data, labels = emg_data.to(device), eeg_data.to(device), labels.to(device)
            
            # Clear gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(emg_data, eeg_data)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Free up memory
            del emg_data, eeg_data, labels, outputs
            torch.cuda.empty_cache()
        
        train_loss = running_loss / len(train_loader)
        train_acc = correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Validation
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for emg_data, eeg_data, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                emg_data, eeg_data, labels = emg_data.to(device), eeg_data.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(emg_data, eeg_data)
                loss = criterion(outputs, labels)
                
                # Update statistics
                val_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                # Free up memory
                del emg_data, eeg_data, labels, outputs
                torch.cuda.empty_cache()
        
        val_loss = val_running_loss / len(val_loader)
        val_acc = val_correct / val_total
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Print statistics
        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save the best model
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
        
        # Garbage collection
        gc.collect()
        torch.cuda.empty_cache()
    
    # Plot training and validation loss/accuracy
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()
    
    return model

# Main function
def main():
    # Paths to EMG and EEG data files
    emg_path = 'data/processed/EMG-data.csv'
    eeg_path = 'data/processed/EEG-data.csv'
    
    # Load and preprocess data
    emg_data, eeg_data = load_and_preprocess_data(emg_path, eeg_path)
    
    # No need to convert again, already done in load_and_preprocess_data
    
    # Create dataset
    dataset = EMGEEGDataset(emg_data, eeg_data)
    
    # Split dataset into train, validation, and test sets
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    print(f"Dataset splits: Train {train_size}, Validation {val_size}, Test {test_size}")
    
    # Create data loaders with memory-efficient batch sizes
    batch_size = 16  # Small batch size to save memory
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    # Count number of classes
    num_classes = len(np.unique(dataset.labels))
    print(f"Number of classes: {num_classes}")
    
    # Create model with correct window size parameter
    window_size = 1000  # Same as in the dataset class
    model = MultimodalCNN(num_classes=num_classes, window_size=window_size).to(device)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    
    # Train model
    model = train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, num_epochs=50, patience=10)
    
    # Evaluate on test set
    model.eval()
    test_correct = 0
    test_total = 0
    
    with torch.no_grad():
        for emg_data, eeg_data, labels in tqdm(test_loader, desc="Testing"):
            emg_data, eeg_data, labels = emg_data.to(device), eeg_data.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(emg_data, eeg_data)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    
    test_acc = test_correct / test_total
    print(f"Test Accuracy: {test_acc:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda
Loading data...
Converting data types...
Normalizing data...
🟢 EMG Sample Info:
📏 Length: 664666, 🛠️ Shape: (664666, 11)
   Channel_1  Channel_2  Channel_3  Channel_4  Channel_5  Channel_6  \
0  -0.144347  -0.463070  -1.254146  -0.539958   0.018249   0.174389   
1   0.192552   0.661711   0.213725   0.350381   0.099978   0.106335   
2   1.708598   1.395264   0.847578   0.680136  -1.820643  -0.846416   
3  -1.772693  -1.294429  -0.253325   1.504524   2.633563   0.991033   
4   2.270097  -1.685657  -1.621114   0.779063   0.386028   0.106335   

   Channel_7  Channel_8  subject  repetition  gesture  
0  -0.026852   0.020317        1           1        1  
1  -0.723067  -0.825154        1           1        1  
2  -1.013157   0.999284        1           1        1  
3   1.191525   0.732293        1           1        1  
4  -0.897121  -0.068680        1           1        1  
📊 EMG Data Types:
Channel_1     float64
Channel_2     float64
Channel_3     float64
Channel_4    

Epoch 1/50 - Training:   0%|          | 0/17 [00:00<?, ?it/s]