In [23]:
import torch
import torch.nn as nn

# Define the Generator
class Generator(nn.Module):
    def __init__(self, text_embedding_size, noise_dim, eeg_channels, eeg_samples, *args, **kwargs):
        super(*args, **kwargs).__init__()
        self.fc = nn.Sequential(
            nn.Linear(text_embedding_size + noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, eeg_channels * eeg_samples)
        )

    def forward(self, text_embedding, noise):
        x = torch.cat((text_embedding, noise), dim=1)
        x = self.fc(x)
        return x.view(x.size(0), eeg_channels, eeg_samples)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, eeg_channels, eeg_samples, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(eeg_channels, 3)),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=(1, 3)),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (eeg_samples - 4), 1),
            nn.Sigmoid()
        )

    def forward(self, eeg_data):
        x = eeg_data.unsqueeze(1)  # Add a channel dimension
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [24]:
import torch
from torch.utils.data import Dataset, DataLoader

# Define a custom dataset class
class EEGTextDataset(Dataset):
    def __init__(self, eeg_data, text_embeddings):
        self.eeg_data = eeg_data
        self.text_embeddings = text_embeddings

    def __len__(self):
        return len(self.eeg_data)

    def __getitem__(self, idx):
        return self.eeg_data[idx], self.text_embeddings[idx]

# Assuming you have your EEG data and text embeddings as tensors
eeg_data = torch.randn(100, 24, 5)  # Example: 100 samples of EEG data
text_embeddings = torch.randn(100, 300)  # Example: 100 samples of text embeddings

# Create a dataset instance
dataset = EEGTextDataset(eeg_data, text_embeddings)

# Define a data loader
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [25]:

# Initialize the GAN
text_embedding_size = 300  # Adjust based on your text embeddings
noise_dim = 100
eeg_channels = 24
eeg_samples = 5

generator = Generator(text_embedding_size, noise_dim, eeg_channels, eeg_samples)
discriminator = Discriminator(eeg_channels, eeg_samples)

# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

# Training loop (you'll need to have a dataset of real EEG data and text embeddings)
num_epochs = 100
for epoch in range(num_epochs):
    for real_eeg_data, text_embedding in dataloader:  # You need to define a dataloader for your dataset
        real_labels = torch.ones((real_eeg_data.size(0), 1))
        fake_labels = torch.zeros((real_eeg_data.size(0), 1))

        # Train the discriminator
        optimizer_D.zero_grad()
        real_outputs = discriminator(real_eeg_data)
        fake_eeg_data = generator(text_embedding, torch.randn(real_eeg_data.size(0), noise_dim))
        fake_outputs = discriminator(fake_eeg_data.detach())
        loss_real = criterion(real_outputs, real_labels)
        loss_fake = criterion(fake_outputs, fake_labels)
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # Train the generator
        optimizer_G.zero_grad()
        fake_outputs = discriminator(fake_eeg_data)
        loss_G = criterion(fake_outputs, real_labels)
        loss_G.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch}/{num_epochs}] Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}')

Epoch [0/100] Loss_D: 1.4011 Loss_G: 0.7160
Epoch [1/100] Loss_D: 1.3606 Loss_G: 0.6928
Epoch [2/100] Loss_D: 1.2660 Loss_G: 0.6579
Epoch [3/100] Loss_D: 1.2550 Loss_G: 0.6427
Epoch [4/100] Loss_D: 1.2582 Loss_G: 0.6460
Epoch [5/100] Loss_D: 1.2469 Loss_G: 0.6261
Epoch [6/100] Loss_D: 1.1740 Loss_G: 0.6349
Epoch [7/100] Loss_D: 1.2293 Loss_G: 0.6506
Epoch [8/100] Loss_D: 1.1482 Loss_G: 0.6851
Epoch [9/100] Loss_D: 1.1263 Loss_G: 0.6758
Epoch [10/100] Loss_D: 1.0750 Loss_G: 0.7445
Epoch [11/100] Loss_D: 1.0365 Loss_G: 0.7611
Epoch [12/100] Loss_D: 1.0769 Loss_G: 0.7677
Epoch [13/100] Loss_D: 0.9156 Loss_G: 0.7897
Epoch [14/100] Loss_D: 1.0535 Loss_G: 0.7400
Epoch [15/100] Loss_D: 1.0241 Loss_G: 0.7966
Epoch [16/100] Loss_D: 0.8636 Loss_G: 0.8825
Epoch [17/100] Loss_D: 1.0826 Loss_G: 0.6838
Epoch [18/100] Loss_D: 0.8491 Loss_G: 0.8497
Epoch [19/100] Loss_D: 1.0642 Loss_G: 0.8057
Epoch [20/100] Loss_D: 1.0452 Loss_G: 0.6950
Epoch [21/100] Loss_D: 1.1174 Loss_G: 0.7816
Epoch [22/100] Loss_

In [56]:
generator.eval()  # Set the generator in evaluation mode

# 1. Generate text embeddings for the desired text
text_embeddings = torch.randn(1, 300)
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)

# 2. Generate random noise
noise = torch.randn(1, noise_dim)  # Use a batch size of 1 for a single example
noise = torch.tensor(noise, dtype=torch.float32)

# 3. Use the generator
synthetic_eeg_data = generator(text_embeddings, noise)

# 4. Reshape and format the synthetic data
synthetic_eeg_data = synthetic_eeg_data.view(24, 5)  # Reshape to match the real EEG data shape

# Now, synthetic_eeg_data contains the generated EEG data in the desired format

synthetic_eeg_data

  text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
  noise = torch.tensor(noise, dtype=torch.float32)


tensor([[ 2.0080e-01, -2.7315e-01, -1.3623e+00,  1.3018e+00,  9.4885e-01],
        [-1.0215e+00, -3.0463e-01, -4.7836e-01,  2.1656e-01,  2.4706e-01],
        [ 7.9356e-01, -1.0140e-01,  2.4798e-01, -1.1678e+00, -1.1323e+00],
        [-6.4252e-01, -9.9927e-01, -9.1466e-01, -9.3116e-01, -1.2791e+00],
        [-1.1923e+00, -7.5123e-02,  5.1690e-01,  6.5738e-01,  8.2234e-01],
        [-6.1131e-01,  3.9240e-01, -1.1282e+00, -3.7267e-01, -8.0096e-01],
        [-2.4545e-01,  3.9074e-01, -4.7375e-01,  6.1411e-01,  1.3573e+00],
        [-4.4284e-01,  1.0560e+00,  6.4665e-01, -2.6015e-01, -3.8711e-01],
        [ 7.5872e-01, -6.0898e-01,  7.0539e-02, -1.3710e+00, -7.5046e-01],
        [-1.1724e+00,  1.1318e+00, -1.1710e+00,  3.1407e-01,  8.4863e-01],
        [ 1.5855e-01,  5.6216e-01,  6.3915e-01,  8.2472e-01, -4.4520e-01],
        [-5.3836e-01,  4.7662e-01, -8.2559e-01,  3.6270e-01, -1.4792e-01],
        [-5.4914e-01,  4.9456e-01, -5.9625e-01,  2.7490e-02,  2.9979e-01],
        [-6.0342e-01, -9.