In [2]:
import torch

# Check if CUDA (GPU) is available
print(f"CUDA is available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    # Get the name of the GPU device
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")
    
    # Get the number of available GPUs
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    
    # Get current GPU device index
    print(f"Current GPU device index: {torch.cuda.current_device()}")

CUDA is available: True
GPU Device: NVIDIA GeForce RTX 4070 SUPER
Number of GPUs: 1
Current GPU device index: 0


In [4]:
import mne
import numpy as np
import os

# Set the input and output directories
gdf_dir = './gdf_data'
npy_dir = './npy_data'

# Create output directory if it doesn't exist
os.makedirs(npy_dir, exist_ok=True)

# Process all .gdf files
for filename in os.listdir(gdf_dir):
    if filename.endswith('.gdf'):
        # Construct full file paths
        gdf_path = os.path.join(gdf_dir, filename)
        npy_path = os.path.join(npy_dir, filename.replace('.gdf', '.npy'))
        
        # Read the GDF file using MNE
        raw = mne.io.read_raw_gdf(gdf_path, preload=True)
        
        # Get the data as a numpy array
        data = raw.get_data()
        
        # Save as .npy file
        np.save(npy_path, data)
        print(f"Converted {filename} to NPY format")

Extracting EDF parameters from /home/sahil/work/sargam/gdf_data/k6b.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
#  1, #  2, #  3, #  4, #  5, #  6, #  7, #  8, #  9, # 10, # 11, # 12, # 13, # 14, # 15, # 16, # 17, # 18, # 19, # 20, # 21, # 22, # 23, # 24, # 25, # 26, # 27, # 28, # 29, # 30, # 31, # 32, # 33, # 34, # 35, # 36, # 37, # 38, # 39, # 40, # 41, # 42, # 43, # 44, # 45, # 46, # 47, # 48, # 49, # 50, # 51, # 52, # 53, # 54, # 55, # 56, # 57, # 58, # 59, # 60
Creating raw.info structure...
Reading 0 ... 631199  =      0.000 ...  2524.796 secs...
Converted k6b.gdf to NPY format
Extracting EDF parameters from /home/sahil/work/sargam/gdf_data/k3b.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
#  1, #  2, #  3, #  4, #  5, #  6, #  7, #  8, #  9, # 10, # 11, # 12, # 13, # 14, # 15, # 16, #

In [12]:
# Create directories for each class if they don't exist
class_dirs = []
for i in range(1, 5):
    class_dir = os.path.join(npy_dir, f'class_{i}')
    class_dirs.append(class_dir)
    os.makedirs(class_dir, exist_ok=True)

# Process each GDF file
for filename in os.listdir(gdf_dir):
    if filename.endswith('.gdf'):
        gdf_path = os.path.join(gdf_dir, filename)
        raw = mne.io.read_raw_gdf(gdf_path, preload=True)
        
        # Get sampling frequency and calculate number of samples for 2.5s
        sfreq = raw.info['sfreq']
        n_samples = int(2.5 * sfreq)
        
        # Get the data and events
        data = raw.get_data()
        events, event_dict = mne.events_from_annotations(raw)
        
        # Debug print
        print(f"\nProcessing file: {filename}")
        print(f"Event dictionary: {event_dict}")
        print(f"Number of events: {len(events)}")
        
        # Process each event
        for event_idx, event in enumerate(events):
            event_time = event[0]  # Sample index of the event
            event_id = event[2]    # Get event ID (not the original event code)
            
            # Map event ID back to class number (3->1, 4->2, 5->3, 6->4)
            if event_id in range(3, 7):  # IDs 3,4,5,6 correspond to classes 1,2,3,4
                class_num = event_id - 2  # Convert 3-6 to 1-4
                
                # Extract 2.5s after the event
                start_idx = event_time
                end_idx = start_idx + n_samples
                
                if end_idx <= data.shape[1]:  # Ensure we don't go past the end of the data
                    epoch = data[:, start_idx:end_idx]
                    
                    # Save to corresponding class directory
                    output_filename = f"{os.path.splitext(filename)[0]}_event_{event_idx}.npy"
                    output_path = os.path.join(class_dirs[class_num-1], output_filename)
                    np.save(output_path, epoch)
                    print(f"Saved epoch from {filename} to {output_path}")

# Debug print - check number of files in each class directory
for i, class_dir in enumerate(class_dirs, 1):
    num_files = len([f for f in os.listdir(class_dir) if f.endswith('.npy')])
    print(f"\nClass {i} directory ({class_dir}) contains {num_files} files")

Extracting EDF parameters from /home/sahil/work/sargam/gdf_data/k6b.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
#  1, #  2, #  3, #  4, #  5, #  6, #  7, #  8, #  9, # 10, # 11, # 12, # 13, # 14, # 15, # 16, # 17, # 18, # 19, # 20, # 21, # 22, # 23, # 24, # 25, # 26, # 27, # 28, # 29, # 30, # 31, # 32, # 33, # 34, # 35, # 36, # 37, # 38, # 39, # 40, # 41, # 42, # 43, # 44, # 45, # 46, # 47, # 48, # 49, # 50, # 51, # 52, # 53, # 54, # 55, # 56, # 57, # 58, # 59, # 60
Creating raw.info structure...
Reading 0 ... 631199  =      0.000 ...  2524.796 secs...
Used Annotations descriptions: [np.str_('1023'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772'), np.str_('783'), np.str_('785'), np.str_('786')]

Processing file: k6b.gdf
Event dictionary: {np.str_('1023'): 1, np.str_('768'): 2, np.str_('769'): 3, np.str_('770'): 4, np.str_('771'): 5, np.str_('772'): 6, np.s

In [15]:
# Load all the libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import numpy as np
import os
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

In [23]:
# 1. Custom Dataset class
class EEGDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.samples = []
        self.labels = []
        self.transform = transform
        
        # Load all data from class directories
        for class_idx in range(1, 5):
            class_dir = os.path.join(data_dir, f'class_{class_idx}')
            for file in os.listdir(class_dir):
                if file.endswith('.npy'):
                    data = np.load(os.path.join(class_dir, file))
                    self.samples.append(data)
                    self.labels.append(class_idx - 1)  # 0-based indexing
        
        self.samples = np.array(self.samples)
        self.labels = np.array(self.labels)
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample, label

In [24]:
# 2. Custom Transform
class ToTensor(object):
    def __call__(self, sample):
        # Add channel dimension and convert to torch tensor
        sample = torch.FloatTensor(sample)
        return sample.unsqueeze(0)  # Shape: [1, channels, time]

class Normalize(object):
    def __call__(self, sample):
        # Normalize each channel independently
        mean = torch.mean(sample, dim=1, keepdim=True)
        std = torch.std(sample, dim=1, keepdim=True)
        return (sample - mean) / (std + 1e-8)

In [25]:
class EEGNet(nn.Module):
    def __init__(self, num_channels=60, num_samples=625):  # Add num_samples parameter
        super(EEGNet, self).__init__()
        
        # Calculate output sizes
        # After conv1: [batch, 16, num_channels, num_samples//4]
        # After conv2: [batch, 32, 1, num_samples//32]
        self.final_length = num_samples // 32
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, (1, 51), padding=(0, 25)),
            nn.BatchNorm2d(16),
            nn.ELU(),
            nn.AvgPool2d((1, 4))
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, (num_channels, 1)),
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.AvgPool2d((1, 8))
        )
        
        # Calculate the input features for the classifier
        self.num_features = 32 * self.final_length
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.num_features, 4)
        )
        
    def forward(self, x):
        # Add debug prints
        print(f"Input shape: {x.shape}")
        
        x = self.conv1(x)
        print(f"After conv1: {x.shape}")
        
        x = self.conv2(x)
        print(f"After conv2: {x.shape}")
        
        x = x.view(x.size(0), -1)
        print(f"After flatten: {x.shape}")
        
        x = self.classifier(x)
        print(f"Output shape: {x.shape}")
        return x

In [26]:
# 4. Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
    return running_loss / len(train_loader), 100. * correct / total

In [27]:
# 5. Evaluation function
def evaluate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    return running_loss / len(test_loader), 100. * correct / total

In [28]:
# 6. Main training loop

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create dataset
transform = transforms.Compose([
    ToTensor(),
    Normalize()
])

dataset = EEGDataset('./npy_data', transform=transform)

# Split dataset
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Initialize model, criterion, and optimizer
model = EEGNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 50
best_acc = 0.0

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    print(f'Epoch: {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    print('-' * 50)
    
    # Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_model.pth')

print(f'Best Test Accuracy: {best_acc:.2f}%')

Using device: cuda
Input shape: torch.Size([32, 1, 60, 625])
After conv1: torch.Size([32, 16, 60, 156])
After conv2: torch.Size([32, 32, 1, 19])
After flatten: torch.Size([32, 608])
Output shape: torch.Size([32, 4])
Input shape: torch.Size([32, 1, 60, 625])
After conv1: torch.Size([32, 16, 60, 156])
After conv2: torch.Size([32, 32, 1, 19])
After flatten: torch.Size([32, 608])
Output shape: torch.Size([32, 4])
Input shape: torch.Size([32, 1, 60, 625])
After conv1: torch.Size([32, 16, 60, 156])
After conv2: torch.Size([32, 32, 1, 19])
After flatten: torch.Size([32, 608])
Output shape: torch.Size([32, 4])
Input shape: torch.Size([32, 1, 60, 625])
After conv1: torch.Size([32, 16, 60, 156])
After conv2: torch.Size([32, 32, 1, 19])
After flatten: torch.Size([32, 608])
Output shape: torch.Size([32, 4])
Input shape: torch.Size([32, 1, 60, 625])
After conv1: torch.Size([32, 16, 60, 156])
After conv2: torch.Size([32, 32, 1, 19])
After flatten: torch.Size([32, 608])
Output shape: torch.Size([32, 