In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/5955.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/422.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/970.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/6179.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/100.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/1876.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/1041.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/4542.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/6052.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/4248.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/3191.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/63.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/1364.npy
/kaggle/input/train-data1/train_data/Complex_Partial_Seizures/1740.npy
/kaggle/inp

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import os
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.signal import resample

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Linear(in_channels, out_channels)
        self.conv2 = nn.Linear(out_channels, out_channels)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.dropout = nn.Dropout(0.3)
        
        # Adjust input dimension if needed
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Linear(in_channels, out_channels)
            
    def forward(self, x):
        residual = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.relu(out)
        return out


In [6]:
# Enhanced VAE architecture
class ImprovedVAE(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim, n_classes):
        super(ImprovedVAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder_layers = nn.ModuleList()
        current_dim = input_dim
        for hidden_dim in hidden_dims:
            self.encoder_layers.append(ResidualBlock(current_dim, hidden_dim))
            current_dim = hidden_dim
            
        self.mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.log_var = nn.Linear(hidden_dims[-1], latent_dim)
        
        # Decoder
        self.decoder_layers = nn.ModuleList()
        decoder_dims = hidden_dims[::-1]
        current_dim = latent_dim + n_classes  # Concatenate class labels
        
        for hidden_dim in decoder_dims:
            self.decoder_layers.append(ResidualBlock(current_dim, hidden_dim))
            current_dim = hidden_dim
            
        self.final_layer = nn.Linear(decoder_dims[-1], input_dim)
        
        # Classifier
        classifier_dims = [latent_dim] + [128, 64]
        self.classifier = nn.ModuleList()
        for i in range(len(classifier_dims)-1):
            self.classifier.append(ResidualBlock(classifier_dims[i], classifier_dims[i+1]))
        self.classifier_output = nn.Linear(classifier_dims[-1], n_classes)


In [9]:
def encode(self, x):
        for layer in self.encoder_layers:
            x = layer(x)
        return self.mu(x), self.log_var(x)
    
def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
def decode(self, z, labels):
        # One-hot encode labels
        label_onehot = F.one_hot(labels, num_classes=4).float()
        x = torch.cat([z, label_onehot], dim=1)
        
        for layer in self.decoder_layers:
            x = layer(x)
        return torch.tanh(self.final_layer(x))
    
def classify(self, x):
        mu, _ = self.encode(x)
        x = mu
        for layer in self.classifier:
            x = layer(x)
        return self.classifier_output(x)
    
def forward(self, x, labels):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        classification = self.classify(x)
        recon_x = self.decode(z, labels)
        return recon_x, classification, mu, log_var


In [12]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if alpha is not None:
            self.alpha = torch.tensor(alpha)

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            alpha_t = self.alpha[targets].to(inputs.device)
            focal_loss = alpha_t * focal_loss
            
        return focal_loss.mean()


In [13]:
class EEGDataAugmentation:
    @staticmethod
    def add_gaussian_noise(signal, noise_factor=0.05):
        noise = np.random.normal(0, noise_factor, signal.shape)
        return signal + noise
    
    @staticmethod
    def time_shift(signal, shift_max=0.2):
        shift = int(len(signal) * np.random.uniform(-shift_max, shift_max))
        return np.roll(signal, shift)
    
    @staticmethod
    def scaling(signal, scale_factor=0.1):
        factor = np.random.uniform(1 - scale_factor, 1 + scale_factor)
        return signal * factor
    
    @staticmethod
    def random_resample(signal, resample_factor_range=(0.8, 1.2)):
        factor = np.random.uniform(*resample_factor_range)
        new_length = int(len(signal) * factor)
        return resample(signal, new_length)
    
    @staticmethod
    def mixup(signal1, signal2, alpha=0.2):
        """Perform mixup augmentation between two signals"""
        lambda_param = np.random.beta(alpha, alpha)
        mixed_signal = lambda_param * signal1 + (1 - lambda_param) * signal2
        return mixed_signal
    
    @staticmethod
    def spectral_augment(signal, max_mask_size=0.1):
        """Apply spectral augmentation"""
        fft = np.fft.fft(signal)
        mask_size = int(len(fft) * max_mask_size)
        mask_start = np.random.randint(0, len(fft) - mask_size)
        fft[mask_start:mask_start + mask_size] = 0
        return np.real(np.fft.ifft(fft))


In [14]:
class EEGDataset(Dataset):
    def __init__(self, data_dir, augment=False, class_weights=None):
        self.data = []
        self.labels = []
        self.augment = augment
        self.augmenter = EEGDataAugmentation()
        self.class_weights = class_weights
        self.label_map = {
            'Normal': 0,
            'Complex_Partial_Seizures': 1,
            'Electrographic_Seizures': 2,
            'Video_detected_Seizures_with_no_visual_change_over_EEG': 3
        }
        
        # Load and process data
        for class_name, label in self.label_map.items():
            class_path = os.path.join(data_dir, class_name)
            if os.path.exists(class_path):
                class_data = []
                for file_name in os.listdir(class_path):
                    if file_name.endswith('.npy'):
                        try:
                            data = np.load(os.path.join(class_path, file_name))
                            if len(data.shape) == 1:
                                data = data.reshape(1, -1)
                            elif len(data.shape) > 2:
                                data = data.reshape(data.shape[0], -1)
                            class_data.append(data)
                        except Exception as e:
                            print(f"Error loading file {file_name}: {str(e)}")
                
                if class_data:
                    class_data = np.vstack(class_data)
                    # Apply extra augmentation for minority classes (2 and 3)
                    if label in [2, 3] and self.augment:
                        aug_factor = 3 if label == 2 else 2
                        aug_data = []
                        for _ in range(aug_factor - 1):
                            aug_data.append(self.augment_batch(class_data))
                        class_data = np.vstack([class_data] + aug_data)
                    
                    self.data.append(class_data)
                    self.labels.extend([label] * len(class_data))
        
        self.data = np.vstack(self.data)
        self.labels = np.array(self.labels)
        
        # Standardize the data
        self.scaler = StandardScaler()
        self.data = self.scaler.fit_transform(self.data)

    def augment_batch(self, batch):
        augmented = []
        for signal in batch:
            aug_signal = self.augment_signal(signal)
            augmented.append(aug_signal)
        return np.array(augmented)

    def augment_signal(self, signal):
        # Enhanced augmentation strategy
        augmentations = [
            (self.augmenter.add_gaussian_noise, 0.8),
            (self.augmenter.time_shift, 0.7),
            (self.augmenter.scaling, 0.7),
            (self.augmenter.spectral_augment, 0.5)
        ]
        
        augmented_signal = signal.copy()
        for aug_func, prob in augmentations:
            if np.random.random() < prob:
                augmented_signal = aug_func(augmented_signal)
        
        return augmented_signal

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        signal = self.data[idx]
        label = self.labels[idx]
        
        if self.augment:
            # Apply more aggressive augmentation for class 2
            if label == 2:
                signal = self.augment_signal(signal)
            elif np.random.random() < 0.5:
                signal = self.augment_signal(signal)
        
        return torch.FloatTensor(signal), label


In [15]:
def plot_class_distributions(train_dataset, synthetic_data, synthetic_labels):
    plt.figure(figsize=(15, 5))
    
    # Original data distribution
    plt.subplot(1, 2, 1)
    class_counts = np.bincount(train_dataset.labels)
    plt.bar(range(len(class_counts)), class_counts)
    plt.title('Original Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Count')
    
    # Synthetic data distribution
    plt.subplot(1, 2, 2)
    synthetic_counts = np.bincount(synthetic_labels)
    plt.bar(range(len(synthetic_counts)), synthetic_counts)
    plt.title('Synthetic Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Count')
    
    plt.tight_layout()
    plt.show()

In [16]:
def plot_class_samples(train_dataset, vae, device, num_samples=5):
    """Plot original and reconstructed samples for each class."""
    vae.eval()
    plt.figure(figsize=(20, 12))
    
    for class_idx in range(4):
        # Get samples from the specific class
        class_samples = train_dataset.data[train_dataset.labels == class_idx]
        class_labels = torch.full((num_samples,), class_idx)
        
        # Select random samples
        random_indices = np.random.choice(len(class_samples), num_samples, replace=False)
        samples = torch.FloatTensor(class_samples[random_indices]).to(device)
        
        with torch.no_grad():
            # Get reconstructions
            recon_samples, _, _, _ = vae(samples, class_labels.to(device))
            
            # Plot original samples
            plt.subplot(4, 2, 2*class_idx + 1)
            for i in range(num_samples):
                plt.plot(samples[i].cpu().numpy() + i*2)
            plt.title(f'Original Samples - Class {class_idx}')
            plt.xlabel('Time')
            plt.ylabel('Amplitude')
            
            # Plot reconstructed samples
            plt.subplot(4, 2, 2*class_idx + 2)
            for i in range(num_samples):
                plt.plot(recon_samples[i].cpu().numpy() + i*2)
            plt.title(f'Reconstructed Samples - Class {class_idx}')
            plt.xlabel('Time')
            plt.ylabel('Amplitude')
    
    plt.tight_layout()
    plt.show()

In [17]:

def plot_latent_space(vae, train_dataset, device):
    """Visualize the latent space using PCA projection."""
    from sklearn.decomposition import PCA
    
    vae.eval()
    all_data = torch.FloatTensor(train_dataset.data).to(device)
    all_labels = train_dataset.labels
    
    with torch.no_grad():
        # Get latent representations
        mu, _ = vae.encode(all_data)
        latent_space = mu.cpu().numpy()
        
        # Apply PCA
        pca = PCA(n_components=2)
        latent_2d = pca.fit_transform(latent_space)
        
        # Plot
        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=all_labels, 
                            cmap='viridis', alpha=0.6)
        plt.colorbar(scatter)
        plt.title('Latent Space Visualization (PCA)')
        plt.xlabel('First Principal Component')
        plt.ylabel('Second Principal Component')
        plt.show()


In [18]:
def plot_reconstruction_loss_by_class(vae, train_loader, device):
    """Plot reconstruction loss distribution for each class."""
    vae.eval()
    losses_by_class = {i: [] for i in range(4)}
    
    with torch.no_grad():
        for data, labels in train_loader:
            data = data.to(device)
            labels = labels.to(device)
            recon_batch, _, _, _ = vae(data, labels)
            
            # Calculate reconstruction loss for each sample
            recon_loss = F.mse_loss(recon_batch, data, reduction='none').mean(dim=1)
            
            # Store losses by class
            for label, loss in zip(labels.cpu().numpy(), recon_loss.cpu().numpy()):
                losses_by_class[label].append(loss)
    
    plt.figure(figsize=(12, 6))
    plt.boxplot([losses_by_class[i] for i in range(4)], labels=[f'Class {i}' for i in range(4)])
    plt.title('Reconstruction Loss Distribution by Class')
    plt.xlabel('Class')
    plt.ylabel('Reconstruction Loss')
    plt.show()

In [19]:
def train_improved_vae(vae, train_loader, n_epochs, device, lr=1e-3):
    optimizer = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    
    # Initialize Focal Loss with adjusted class weights
    class_weights = torch.tensor([1.0, 1.2, 2.5, 1.5]).to(device)  # Increased weight for class 2
    criterion = FocalLoss(alpha=class_weights, gamma=2.5)  # Increased gamma for harder examples
    
    losses = []
    accuracies = []
    best_loss = float('inf')
    patience_counter = 0
    patience_limit = 20
    
    for epoch in range(n_epochs):
        total_loss = 0
        correct = 0
        total = 0
        vae.train()
        
        for batch_idx, (data, labels) in enumerate(train_loader):
            data = data.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            recon_batch, classification, mu, log_var = vae(data, labels)
            
            # Losses with dynamic weighting
            recon_loss = F.mse_loss(recon_batch, data)
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            class_loss = criterion(classification, labels)
            
            # Adaptive loss weighting
            beta = min(epoch / (n_epochs * 0.1), 1.0)
            classification_weight = 2.0 if epoch > n_epochs * 0.2 else 1.0  # Increase classification weight later
            loss = recon_loss + beta * 0.1 * kl_loss + classification_weight * class_loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = classification.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        avg_loss = total_loss / len(train_loader.dataset)
        accuracy = 100. * correct / total
        
        losses.append(avg_loss)
        accuracies.append(accuracy)
        
        scheduler.step(avg_loss)
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience_limit:
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        if (epoch + 1) % 5 == 0:
            print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
    
    return losses, accuracies



def generate_synthetic_data(vae, n_samples_per_class, n_classes, device, noise_scale=1.0):
    vae.eval()
    synthetic_data = []
    synthetic_labels = []
    
    # Adjust samples per class based on class distribution
    class_multipliers = {
        0: 1,
        1: 1,
        2: 3,  # Generate more samples for class 2
        3: 2
    }
    
    with torch.no_grad():
        for class_idx in range(n_classes):
            adjusted_samples = n_samples_per_class * class_multipliers[class_idx]
            n_batches = 5
            samples_per_batch = adjusted_samples // n_batches
            
            for _ in range(n_batches):
                z = torch.randn(samples_per_batch, vae.latent_dim).to(device) * noise_scale
                labels = torch.full((samples_per_batch,), class_idx, dtype=torch.long).to(device)
                synthetic_samples = vae.decode(z, labels)
                
                # Add noise to synthetic samples for better generalization
                if class_idx == 2:  # For class 2
                    noise = torch.randn_like(synthetic_samples) * 0.05
                    synthetic_samples = synthetic_samples + noise
                
                synthetic_data.append(synthetic_samples.cpu().numpy())
                synthetic_labels.extend([class_idx] * samples_per_batch)
    
    return np.vstack(synthetic_data), np.array(synthetic_labels)


In [21]:
def main():
    # Hyperparameters
    hidden_dims = [512, 256, 128]
    latent_dim = 64
    n_classes = 4
    batch_size = 32
    n_epochs = 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load and prepare data
    train_dataset = EEGDataset('/kaggle/input/train-data1/train_data', augment=True)
    
    # Calculate class weights for balanced sampling
    class_counts = np.bincount(train_dataset.labels)
    class_weights = 1.0 / class_counts
    sample_weights = class_weights[train_dataset.labels]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize and train model
    input_dim = train_dataset.data.shape[1]
    vae = ImprovedVAE(input_dim, hidden_dims, latent_dim, n_classes).to(device)
    losses, accuracies = train_improved_vae(vae, train_loader, n_epochs, device)

    # Plot training metrics
    plot_training_metrics(losses, accuracies)
    
    # Generate synthetic data
    n_samples_per_class = 500
    synthetic_data, synthetic_labels = generate_synthetic_data(
        vae, n_samples_per_class, n_classes, device, noise_scale=0.8
    )

    # Plot class distributions
    plot_class_distributions(train_dataset, synthetic_data, synthetic_labels)


    print("Generating visualizations...")
    plot_class_samples(train_dataset, vae, device)
    plot_latent_space(vae, train_dataset, device)
    plot_reconstruction_loss_by_class(vae, train_loader, device)


    # Evaluate on validation set
    val_dataset = EEGDataset('/kaggle/input/val-data/validation_data', augment=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    vae.eval()
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for data, labels in val_loader:
            data = data.to(device)
            classification = vae.classify(data)
            pred = classification.argmax(dim=1).cpu().numpy()
            predictions.extend(pred)
            true_labels.extend(labels.numpy())
    
    print("\nClassification Report:")
    print(classification_report(true_labels, predictions))


In [None]:
if __name__ == "__main__":
    main()
