In [1]:
import torch
import torch.nn as nn
import numpy as np

# Define the Generator
class Generator(nn.Module):
    def __init__(self, text_embedding_size, noise_dim, eeg_channels, eeg_samples, *args, **kwargs):
        super(*args, **kwargs).__init__()
        self.fc = nn.Sequential(
            nn.Linear(text_embedding_size + noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, eeg_channels * eeg_samples)
        )

    def forward(self, text_embedding, noise):
        x = torch.cat((text_embedding, noise), dim=1)
        x = self.fc(x)
        return x.view(x.size(0), eeg_channels, eeg_samples)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, eeg_channels, eeg_samples, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(eeg_channels, 3)),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=(1, 3)),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (eeg_samples - 4), 1),
            nn.Sigmoid()
        )

    def forward(self, eeg_data):
        x = eeg_data.unsqueeze(1)  # Add a channel dimension
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [2]:
batch_size = 64

torch.manual_seed(1)
np.random.seed(1)
import pickle

# To load the lists from the file:
with open(r'C:\Users\gxb18167\PycharmProjects\EEG-To-Text\SIGIR_Development\EEG-GAN\EEG_Text_Pairs.pkl', 'rb') as file:
    EEG_word_level_embeddings = pickle.load(file)
    EEG_word_level_labels = pickle.load(file)
float_tensor = torch.tensor(EEG_word_level_embeddings, dtype=torch.float)

import torch
train_data = []
for i in range(len(float_tensor)):
   train_data.append([float_tensor[i], EEG_word_level_labels[i]])



trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=64)

  float_tensor = torch.tensor(EEG_word_level_embeddings, dtype=torch.float)


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader

# Define a custom dataset class
class EEGTextDataset(Dataset):
    def __init__(self, eeg_data, text_embeddings):
        self.eeg_data = eeg_data
        self.text_embeddings = text_embeddings

    def __len__(self):
        return len(self.eeg_data)

    def __getitem__(self, idx):
        return self.eeg_data[idx], self.text_embeddings[idx]

# Assuming you have your EEG data and text embeddings as tensors
eeg_data = torch.randn(100, 105, 8)  # Example: 100 samples of EEG data
text_embeddings = torch.randn(len(EEG_word_level_embeddings), 300)  # Example: 100 samples of text embeddings

# Create a dataset instance
dataset = EEGTextDataset(float_tensor, text_embeddings)

# Define a data loader
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [5]:
for i, (x, l) in enumerate(dataloader):
    print(x.shape)
    print(l.shape)
    if i == 9*16:
        break

torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([16, 105, 8])
torch.Size([16, 300])
torch.Size([1

In [6]:
# Initialize the GAN
text_embedding_size = 300  # Adjust based on your text embeddings
noise_dim = 100
eeg_channels = 105
eeg_samples = 8
generator = Generator(text_embedding_size, noise_dim, eeg_channels, eeg_samples)
discriminator = Discriminator(eeg_channels, eeg_samples)

In [7]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x2890db1b610>

In [8]:
for real_eeg_data, text_embedding in dataloader:
    print(real_eeg_data.shape)
    output = discriminator(real_eeg_data)
    print(output[0].shape)

torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 

In [None]:





# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

# Training loop (you'll need to have a dataset of real EEG data and text embeddings)
num_epochs = 100
for epoch in range(num_epochs):
    for real_eeg_data, text_embedding in dataloader:  # You need to define a dataloader for your dataset
        real_labels = torch.ones((real_eeg_data.size(0), 1))
        fake_labels = torch.zeros((real_eeg_data.size(0), 1))

        # Train the discriminator
        optimizer_D.zero_grad()
        real_outputs = discriminator(real_eeg_data)
        fake_eeg_data = generator(text_embedding, torch.randn(real_eeg_data.size(0), noise_dim))
        fake_outputs = discriminator(fake_eeg_data.detach())


        #print(real_outputs)
        loss_real = criterion(real_outputs, real_labels)
        loss_fake = criterion(fake_outputs, fake_labels)
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # Train the generator
        optimizer_G.zero_grad()
        fake_outputs = discriminator(fake_eeg_data)
        loss_G = criterion(fake_outputs, real_labels)
        loss_G.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch}/{num_epochs}] Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}')

In [None]:
generator.eval()  # Set the generator in evaluation mode

# 1. Generate text embeddings for the desired text
text_embeddings = torch.randn(1, 300)
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)

# 2. Generate random noise
noise = torch.randn(1, noise_dim)  # Use a batch size of 1 for a single example
noise = torch.tensor(noise, dtype=torch.float32)

# 3. Use the generator
synthetic_eeg_data = generator(text_embeddings, noise)

# 4. Reshape and format the synthetic data
synthetic_eeg_data = synthetic_eeg_data.view(105, 8)  # Reshape to match the real EEG data shape

# Now, synthetic_eeg_data contains the generated EEG data in the desired format

synthetic_eeg_data