In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

https://arxiv.org/abs/1611.04076v3

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
train_path = "/kaggle/input/impulse/Impulse/EEG_Data/train_data"
latent_dim = 100
input_dim = 19 * 500
num_classes = 4
batch_size = 128
epochs = 100
learning_rate = 0.0001
n_critic = 1  # Discriminator steps per generator step

In [4]:
def normalize_data(data):
    min_val = np.min(data)
    max_val = np.max(data)
    return 2 * (data - min_val) / (max_val - min_val + 1e-8) - 1  # Normalize to [-1, 1]


In [5]:
class_map = {"Normal": 0, "Complex_Partial_Seizures": 1, "Electrographic_Seizures": 2, "Video_detected_Seizures_with_no_visual_change_over_EEG": 3}
train_data, train_labels = [], []

for class_name, class_label in class_map.items():
    class_folder = os.path.join(train_path, class_name)
    for file_name in os.listdir(class_folder):
        file_path = os.path.join(class_folder, file_name)
        signal = np.load(file_path)  # Assuming .npy files
        # Ensure correct shape
        if signal.shape == (19, 500):
            normalized_signal = normalize_data(signal)
            train_data.append(normalized_signal.flatten())  # Flatten to shape (19*500,)
            train_labels.append(class_label)

In [6]:
# Convert to tensors
train_data = torch.tensor(np.array(train_data), dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.long)

In [7]:
# Dataset and Dataloader
class EEGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [16]:
dataset = EEGDataset(train_data, train_labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)


In [26]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim, num_classes):
        super(Generator, self).__init__()
        self.class_emb = nn.Sequential(
            nn.Linear(num_classes, 16),
            nn.LeakyReLU(0.2, inplace=True)
        )

        input_dim = z_dim + 16

        self.net = nn.Sequential(
            # Layer 1: Upsample from 1 -> 2
            nn.ConvTranspose1d(input_dim, 256, kernel_size=4, stride=2, padding=1),  # Output: (256, 2)
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # Layer 2: Upsample from 2 -> 8
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=4, padding=0),  # Output: (128, 8)
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # Layer 3: Upsample from 8 -> 32
            nn.ConvTranspose1d(128, 64, kernel_size=6, stride=4, padding=1),  # Output: (64, 32)
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.2, inplace=True),

            # Layer 4: Upsample from 32 -> 128
            nn.ConvTranspose1d(64, 32, kernel_size=8, stride=4, padding=2),  # Output: (32, 128)
            nn.BatchNorm1d(32),
            nn.LeakyReLU(0.2, inplace=True),

            # Layer 5: Upsample from 128 -> 511
            nn.ConvTranspose1d(32, 19, kernel_size=9, stride=4, padding=3),  # Output: (19, 511)
            nn.Tanh()
        )

    def forward(self, z, labels):
        # Concatenate latent vector with class embedding
        class_emb = self.class_emb(labels)  # (batch_size, 16)
        z = torch.cat([z, class_emb], dim=1)  # (batch_size, z_dim + 16)
        z = z.unsqueeze(2)  # Add temporal dimension: (batch_size, z_dim + 16, 1)
        return self.net(z)  # Output: (batch_size, 19, 511)


In [27]:
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.class_emb = nn.Sequential(
            nn.Linear(num_classes, 16),  # Embed class into 16 features
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Use kernel_size=3, stride=1, and padding=1 to preserve dimensions as much as possible
        self.net = nn.Sequential(
            nn.Conv1d(35, 64, kernel_size=3, stride=2, padding=1),  # Reduces sequence length by 2
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(64, 128, kernel_size=3, stride=2, padding=1),  # Further reduces sequence length
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(128, 1, kernel_size=3, stride=1, padding=1),  # Output is a single feature per time step
        )

    def forward(self, x, labels):
        """
        x:      (batch_size, 19, time_length)
        labels: (batch_size, num_classes)
        """
        # Embed labels and expand to match temporal dimension
        class_emb = self.class_emb(labels)  # (batch_size, 16)
        class_emb = class_emb.unsqueeze(2).repeat(1, 1, x.size(2))  # (batch_size, 16, time_length)

        # Concatenate class embedding with input EEG along channel dimension
        x = torch.cat([x, class_emb], dim=1)  # (batch_size, 35, time_length)

        # Pass through convolutional layers
        validity = self.net(x)  # (batch_size, 1, time_length / 4 if stride=2)
        return validity.view(x.size(0), -1)  # Flatten to (batch_size, 1)


In [28]:
################################
#      INITIALIZE MODELS       #
################################
generator = Generator(latent_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=200, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=200, gamma=0.5)


In [29]:
################################
#           TRAINING           #
################################
for epoch in range(1, epochs + 1):
    epoch_d_loss = 0.0
    epoch_g_loss = 0.0

    for i, (real_eeg, real_lbls) in enumerate(dataloader):
        # Move data to device
        real_eeg = real_eeg.to(device)  # (batch_size, 19*500) if not reshaped yet
        real_lbls = torch.nn.functional.one_hot(real_lbls, num_classes=num_classes).float().to(device)

        # Reshape real EEG to (batch_size, 19, 500) if you truly want that shape
        # WARNING: This shape must match how your conv layers handle data
        bsz = real_eeg.size(0)
        real_eeg = real_eeg.view(bsz, 19, 500)

        ##################################################
        # (1) Train Discriminator (n_critic times)       #
        ##################################################
        for _ in range(n_critic):
            optimizer_D.zero_grad()

            # Generate fake data
            z = torch.randn(bsz, latent_dim).to(device)
            fake_labels = real_lbls  # Reuse the same label distribution
            fake_eeg = generator(z, fake_labels).detach()  # (bsz, 19, ???)

            # LSGAN losses
            # Note: .detach() so these gradients don't flow back to generator
            loss_real = 0.5 * torch.mean((discriminator(real_eeg, real_lbls) - 1) ** 2)
            loss_fake = 0.5 * torch.mean(discriminator(fake_eeg, fake_labels) ** 2)
            d_loss = loss_real + loss_fake

            d_loss.backward()
            optimizer_D.step()

        ##################################################
        # (2) Train Generator                            #
        ##################################################
        optimizer_G.zero_grad()

        # Generate fake data again (w/o .detach())
        fake_eeg = generator(z, fake_labels)
        # LSGAN generator loss
        g_loss = 0.5 * torch.mean((discriminator(fake_eeg, fake_labels) - 1) ** 2)

        g_loss.backward()
        optimizer_G.step()

        # Accumulate losses
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()

    # Average the losses over the dataset
    epoch_d_loss /= len(dataloader)
    epoch_g_loss /= len(dataloader)

    # Step schedulers if you like
    scheduler_D.step()
    scheduler_G.step()

    print(f"Epoch [{epoch}/{epochs}] | D_loss: {epoch_d_loss:.4f} | G_loss: {epoch_g_loss:.4f}")

Epoch [1/100] | D_loss: 0.3875 | G_loss: 0.3507
Epoch [2/100] | D_loss: 0.2463 | G_loss: 0.1694
Epoch [3/100] | D_loss: 0.1973 | G_loss: 0.1697
Epoch [4/100] | D_loss: 0.2257 | G_loss: 0.1493
Epoch [5/100] | D_loss: 0.2498 | G_loss: 0.1348
Epoch [6/100] | D_loss: 0.2454 | G_loss: 0.1339
Epoch [7/100] | D_loss: 0.2420 | G_loss: 0.1330
Epoch [8/100] | D_loss: 0.2396 | G_loss: 0.1343
Epoch [9/100] | D_loss: 0.2322 | G_loss: 0.1376
Epoch [10/100] | D_loss: 0.2317 | G_loss: 0.1382
Epoch [11/100] | D_loss: 0.2339 | G_loss: 0.1373
Epoch [12/100] | D_loss: 0.2335 | G_loss: 0.1362
Epoch [13/100] | D_loss: 0.2411 | G_loss: 0.1322
Epoch [14/100] | D_loss: 0.2448 | G_loss: 0.1303
Epoch [15/100] | D_loss: 0.2496 | G_loss: 0.1285
Epoch [16/100] | D_loss: 0.2502 | G_loss: 0.1275
Epoch [17/100] | D_loss: 0.2497 | G_loss: 0.1272
Epoch [18/100] | D_loss: 0.2487 | G_loss: 0.1274
Epoch [19/100] | D_loss: 0.2484 | G_loss: 0.1279
Epoch [20/100] | D_loss: 0.2481 | G_loss: 0.1278
Epoch [21/100] | D_loss: 0.24

In [30]:
# Assuming generator, latent_dim, num_classes, and device are already defined
output_dir = "./output"
os.makedirs(output_dir, exist_ok=True)

# Calculate class distribution in training data
class_counts = np.bincount(train_labels.numpy())
total_samples = len(train_labels)
class_ratios = class_counts / total_samples

# Set the total number of synthetic samples to generate
total_synthetic_samples = 5608  # Adjust as needed
synthetic_samples_per_class = (class_ratios * total_synthetic_samples).astype(int)

# Generate synthetic EEG data while maintaining class ratio
generator.eval()
for class_idx, num_samples in enumerate(synthetic_samples_per_class):
    if num_samples == 0:  # Skip classes with no samples
        continue

    z = torch.randn(num_samples, latent_dim).to(device)
    class_label = torch.zeros(num_samples, num_classes).to(device)
    class_label[:, class_idx] = 1  # One-hot encode the class label

    with torch.no_grad():
        synthetic_eeg = generator(z, class_label)
        synthetic_eeg = synthetic_eeg[:, :, :500]  # Trim to (batch_size, 19, 500)
        # Save synthetic EEG data class-wise
        output_file = os.path.join(output_dir, f"synthetic_eeg_{class_idx}.npy")
        np.save(output_file, synthetic_eeg.cpu().numpy())
        print(f"Saved synthetic EEG data for class {class_idx} to {output_file}")

total_samples

synthetic_samples_per_class

Saved synthetic EEG data for class 0 to ./output/synthetic_eeg_0.npy
Saved synthetic EEG data for class 1 to ./output/synthetic_eeg_1.npy
Saved synthetic EEG data for class 2 to ./output/synthetic_eeg_2.npy
Saved synthetic EEG data for class 3 to ./output/synthetic_eeg_3.npy


array([2783, 2196,  545,   84])

In [31]:
synthetic_eeg.cpu().numpy().shape

(84, 19, 500)

In [32]:
# import numpy as np
# from scipy.linalg import sqrtm

# def calculate_fid(real_features, generated_features):
#     # Calculate mean and covariance of real features
#     mu_r = np.mean(real_features, axis=0)
#     sigma_r = np.cov(real_features, rowvar=False)
    
#     # Calculate mean and covariance of generated features
#     mu_g = np.mean(generated_features, axis=0)
#     sigma_g = np.cov(generated_features, rowvar=False)
    
#     # Calculate squared difference of means
#     diff = mu_r - mu_g
#     mean_diff = np.sum(diff**2)
    
#     # Compute square root of product of covariance matrices
#     covmean, _ = sqrtm(sigma_r @ sigma_g, disp=False)
    
#     # Handle numerical errors (non-positive semi-definite results)
#     if np.iscomplexobj(covmean):
#         covmean = covmean.real
    
#     # Calculate FID score
#     fid = mean_diff + np.trace(sigma_r + sigma_g - 2 * covmean)
#     return fid

# # Example usage
# # real_features: Extracted features from real EEG samples (numpy array)
# # generated_features: Extracted features from generated EEG samples (numpy array)

# fid_score = calculate_fid(real_features, generated_features)
# print(f"FID Score: {fid_score}")
