In [1]:
import os
import matplotlib.pyplot as plt
import glob
from pathlib import Path
import torch.nn as nn
import torch.optim as optim
import torch
import torch.nn.functional as F
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [2]:
class SimpleECGDataset(torch.utils.data.Dataset):
    """Simple PyTorch Dataset for ECG CSV files"""
    
    def __init__(self, csv_files, labels):
        """
        Args:
            csv_files: List of paths to CSV files
            labels: List of corresponding labels
        """
        self.csv_files = csv_files
        self.labels = labels
        assert len(csv_files) == len(labels)
    
    def __len__(self):
        return len(self.csv_files)
    
    def __getitem__(self, idx):
        # Load CSV (5000, 12) and transpose to (12, 5000) for Conv1d
        ecg_data = np.loadtxt(self.csv_files[idx], dtype=np.float32, delimiter=',')
        ecg_data = (ecg_data - ecg_data.mean()) / (ecg_data.std() + 1e-8)  # Normalize
        ecg_tensor = torch.from_numpy(ecg_data.T)  # (12, 5000)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return ecg_tensor, label

def create_simple_dataloaders(csv_files, labels, train_ratio=0.8, batch_size=16):
    """Create basic train/val dataloaders"""
    from sklearn.model_selection import train_test_split
    
    # Split data
    train_files, val_files, train_labels, val_labels = train_test_split(
        csv_files, labels, train_size=train_ratio, random_state=42
    )
    
    # Create datasets
    train_dataset = SimpleECGDataset(train_files, train_labels)
    val_dataset = SimpleECGDataset(val_files, val_labels)
    
    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader

In [3]:
class ECGDataset(torch.utils.data.Dataset):
    """PyTorch Dataset for ECG CSV files"""
    
    def __init__(self, csv_files, labels, normalize=True, augment=False):
        """
        Args:
            csv_files: List of paths to CSV files
            labels: List of corresponding labels (same length as csv_files)
            normalize: Whether to normalize ECG signals
            augment: Whether to apply data augmentation
        """
        self.csv_files = csv_files
        self.labels = labels
        self.normalize = normalize
        self.augment = augment
        
        assert len(csv_files) == len(labels), "Number of files must match number of labels"
        
        # Compute normalization statistics if needed
        if normalize:
            self._compute_normalization_stats()
    
    def _compute_normalization_stats(self):
        """Compute mean and std across all ECG files for normalization"""
        print("Computing normalization statistics...")
        all_data = []
        
        # Sample a subset for stats if dataset is very large
        sample_indices = np.random.choice(len(self.csv_files), 
                                        min(100, len(self.csv_files)), 
                                        replace=False)
        
        for idx in sample_indices:
            try:
                data = np.loadtxt(self.csv_files[idx], dtype=float, delimiter=',')
                if data.shape == (5000, 12):
                    all_data.append(data)
            except Exception as e:
                print(f"Warning: Could not load {self.csv_files[idx]}: {e}")
        
        if all_data:
            all_data = np.concatenate(all_data, axis=0)  # Shape: (N*5000, 12)
            self.mean = np.mean(all_data, axis=0, keepdims=True)  # Shape: (1, 12)
            self.std = np.std(all_data, axis=0, keepdims=True)   # Shape: (1, 12)
            self.std = np.where(self.std == 0, 1, self.std)  # Avoid division by zero
            print(f"Normalization stats computed from {len(all_data)} samples")
        else:
            print("Warning: Could not compute normalization stats, using defaults")
            self.mean = np.zeros((1, 12))
            self.std = np.ones((1, 12))
    
    def _normalize_ecg(self, data):
        """Normalize ECG data using precomputed stats"""
        return (data - self.mean) / self.std
    
    def _augment_ecg(self, data):
        """Apply data augmentation to ECG signal"""
        # Random noise addition
        if np.random.random() < 0.3:
            noise = np.random.normal(0, 0.02, data.shape)
            data = data + noise
        
        # Random scaling
        if np.random.random() < 0.3:
            scale = np.random.uniform(0.9, 1.1)
            data = data * scale
        
        # Random baseline shift per lead
        if np.random.random() < 0.3:
            baseline_shift = np.random.normal(0, 0.05, (1, 12))
            data = data + baseline_shift
        
        return data
    
    def __len__(self):
        return len(self.csv_files)
    
    def __getitem__(self, idx):
        try:
            # Load ECG data
            ecg_data = np.loadtxt(self.csv_files[idx], dtype=np.float32, delimiter=',')
            
            # Validate shape
            if ecg_data.shape != (5000, 12):
                raise ValueError(f"Expected shape (5000, 12), got {ecg_data.shape}")
            
            # Normalize
            if self.normalize:
                ecg_data = self._normalize_ecg(ecg_data)
            
            # Augment (only during training)
            if self.augment:
                ecg_data = self._augment_ecg(ecg_data)
            
            # Convert to tensor and transpose to (12, 5000) for Conv1d
            ecg_tensor = torch.from_numpy(ecg_data.T)  # Shape: (12, 5000)
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            
            return ecg_tensor, label
            
        except Exception as e:
            print(f"Error loading {self.csv_files[idx]}: {e}")
            # Return a zero tensor and label if file is corrupted
            return torch.zeros(12, 5000, dtype=torch.float32), torch.tensor(0, dtype=torch.long)


def create_ecg_dataloaders(csv_files, labels, train_ratio=0.8, batch_size=16, 
                          num_workers=4, normalize=True, augment_train=True):
    """
    Create train and validation dataloaders from ECG CSV files
    
    Args:
        csv_files: List of CSV file paths
        labels: List of corresponding labels
        train_ratio: Fraction of data to use for training
        batch_size: Batch size for dataloaders
        num_workers: Number of worker processes for data loading
        normalize: Whether to normalize ECG signals
        augment_train: Whether to augment training data
    
    Returns:
        train_loader, val_loader, class_counts
    """
    # Convert to numpy arrays for easier handling
    csv_files = np.array(csv_files)
    labels = np.array(labels)
    
    # Stratified split to maintain class balance
    from sklearn.model_selection import train_test_split
    
    train_files, val_files, train_labels, val_labels = train_test_split(
        csv_files, labels, 
        train_size=train_ratio,
        stratify=labels,
        random_state=42
    )
    
    print(f"Dataset split: {len(train_files)} train, {len(val_files)} validation")
    
    # Count classes
    from collections import Counter
    class_counts = Counter(labels)
    train_class_counts = Counter(train_labels)
    val_class_counts = Counter(val_labels)
    
    print("Overall class distribution:", dict(class_counts))
    print("Train class distribution:", dict(train_class_counts))
    print("Val class distribution:", dict(val_class_counts))
    
    # Create datasets
    train_dataset = ECGDataset(train_files.tolist(), train_labels.tolist(), 
                              normalize=normalize, augment=augment_train)
    val_dataset = ECGDataset(val_files.tolist(), val_labels.tolist(), 
                            normalize=normalize, augment=False)
    
    # Copy normalization stats to validation dataset
    if normalize:
        val_dataset.mean = train_dataset.mean
        val_dataset.std = train_dataset.std
    
    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, 
        num_workers=num_workers, pin_memory=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size*2, shuffle=False, 
        num_workers=num_workers, pin_memory=True
    )
    
    return train_loader, val_loader, class_counts