In [18]:
import torch
import torch.nn as nn
import numpy as np

# Define the Generator
class Generator(nn.Module):
    def __init__(self, text_embedding_size, noise_dim, eeg_channels, eeg_samples, *args, **kwargs):
        super(*args, **kwargs).__init__()
        self.fc = nn.Sequential(
            nn.Linear(text_embedding_size + noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, eeg_channels * eeg_samples)
        )

    def forward(self, text_embedding, noise):
        x = torch.cat((text_embedding, noise), dim=1)
        x = self.fc(x)
        return x.view(x.size(0), eeg_channels, eeg_samples)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, eeg_channels, eeg_samples, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(eeg_channels, 3)),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=(1, 3)),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (eeg_samples - 4), 1),
            nn.Sigmoid()
        )

    def forward(self, eeg_data):
        x = eeg_data.unsqueeze(1)  # Add a channel dimension
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
batch_size = 64

torch.manual_seed(1)
np.random.seed(1)
import pickle

# To load the lists from the file:
with open(r'C:\Users\gxb18167\PycharmProjects\SIGIR_EEG_GAN\Development\EEG-To-Text-GAN\Data\EEG_Text_Pairs.pkl', 'rb') as file:
    EEG_word_level_embeddings = pickle.load(file)
    EEG_word_level_labels = pickle.load(file)
float_tensor = torch.tensor(EEG_word_level_embeddings, dtype=torch.float)
import torch
train_data = []
for i in range(len(float_tensor)):
   train_data.append([float_tensor[i], EEG_word_level_labels[i]])



trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=64)

In [65]:
import torch
from torch.utils.data import Dataset, DataLoader

# Define a custom dataset class
class EEGTextDataset(Dataset):
    def __init__(self, eeg_data, text_embeddings):
        self.eeg_data = eeg_data
        self.text_embeddings = text_embeddings

    def __len__(self):
        return len(self.eeg_data)

    def __getitem__(self, idx):
        return self.eeg_data[idx], self.text_embeddings[idx]

# Assuming you have your EEG data and text embeddings as tensors
eeg_data = torch.randn(100, 105, 8)  # Example: 100 samples of EEG data
text_embeddings = torch.randn(len(EEG_word_level_embeddings), 300)  # Example: 100 samples of text embeddings

# Create a dataset instance
dataset = EEGTextDataset(float_tensor, text_embeddings)

# Define a data loader
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [66]:
for i, (x, l) in enumerate(dataloader):
    print(x.shape)
    print(l)
    if i == 9*16:
        break

torch.Size([16, 105, 8])
tensor([[ 0.4557,  0.1721,  0.3676,  ...,  0.1545,  0.2941, -0.1409],
        [-0.6960, -0.6299, -1.0204,  ...,  0.6492, -1.5500,  0.5709],
        [ 0.0835,  0.4916, -0.1390,  ...,  1.4403,  0.7457,  1.3498],
        ...,
        [ 0.9666,  0.4812,  0.4878,  ...,  0.3259, -0.5959,  0.5550],
        [-0.7638, -1.1014, -0.0072,  ..., -0.8527, -0.4685,  0.7366],
        [-0.8116,  0.6763,  0.1723,  ..., -0.0917,  1.5721, -1.6708]])
torch.Size([16, 105, 8])
tensor([[ 1.1716, -2.0784, -1.2101,  ..., -0.0373,  1.0883,  0.4488],
        [ 1.0599,  0.8664,  0.8710,  ...,  1.3458,  0.4059, -0.1646],
        [ 3.0783,  0.2411,  1.0241,  ..., -0.1772, -1.6040,  1.2422],
        ...,
        [-0.7341, -1.4959,  0.8287,  ..., -2.0429,  1.4246, -2.4196],
        [ 1.1781, -0.1128, -0.0787,  ..., -0.3380, -2.0790,  1.4038],
        [ 1.1242, -1.0328, -0.6829,  ...,  0.2600,  0.4575,  0.8353]])
torch.Size([16, 105, 8])
tensor([[ 1.1471, -1.2647,  0.1803,  ...,  0.5991, -1.311

In [67]:
# Initialize the GAN
text_embedding_size = 300  # Adjust based on your text embeddings
noise_dim = 100
eeg_channels = 105
eeg_samples = 8
generator = Generator(text_embedding_size, noise_dim, eeg_channels, eeg_samples)
discriminator = Discriminator(eeg_channels, eeg_samples)

In [68]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x2898ac181c0>

In [69]:
for real_eeg_data, text_embedding in dataloader:
    print(real_eeg_data.shape)
    output = discriminator(real_eeg_data)
    print(output[0].shape)

torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 105, 8])
torch.Size([1])
torch.Size([16, 

In [70]:





# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

# Training loop (you'll need to have a dataset of real EEG data and text embeddings)
num_epochs = 100
for epoch in range(num_epochs):
    for real_eeg_data, text_embedding in dataloader:  # You need to define a dataloader for your dataset
        real_labels = torch.ones((real_eeg_data.size(0), 1))
        fake_labels = torch.zeros((real_eeg_data.size(0), 1))

        # Train the discriminator
        optimizer_D.zero_grad()
        real_outputs = discriminator(real_eeg_data)
        fake_eeg_data = generator(text_embedding, torch.randn(real_eeg_data.size(0), noise_dim))
        fake_outputs = discriminator(fake_eeg_data.detach())


        print(real_outputs)
        loss_real = criterion(real_outputs, real_labels)
        loss_fake = criterion(fake_outputs, fake_labels)
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # Train the generator
        optimizer_G.zero_grad()
        fake_outputs = discriminator(fake_eeg_data)
        loss_G = criterion(fake_outputs, real_labels)
        loss_G.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch}/{num_epochs}] Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}')

tensor([[0.4886],
        [0.5313],
        [0.5051],
        [0.4899],
        [0.5123],
        [0.4764],
        [0.5001],
        [0.4965],
        [0.5076],
        [0.4764],
        [0.4822],
        [0.4782],
        [0.5211],
        [0.5125],
        [0.5168],
        [0.4882]], grad_fn=<SigmoidBackward0>)
tensor([[0.5528],
        [0.5875],
        [0.5823],
        [0.5466],
        [0.6020],
        [0.5316],
        [0.5627],
        [0.5329],
        [0.5144],
        [0.5919],
        [0.5409],
        [0.5532],
        [0.5699],
        [0.6002],
        [0.5562],
        [0.5277]], grad_fn=<SigmoidBackward0>)
tensor([[0.6568],
        [0.5565],
        [0.5338],
        [0.6314],
        [0.6316],
        [0.5759],
        [0.6063],
        [0.5657],
        [0.5774],
        [0.6536],
        [0.5789],
        [0.5853],
        [0.5900],
        [0.5450],
        [0.5423],
        [0.5307]], grad_fn=<SigmoidBackward0>)
tensor([[0.6346],
        [0.6344],
        [0.62

RuntimeError: all elements of input should be between 0 and 1

In [6]:
generator.eval()  # Set the generator in evaluation mode

# 1. Generate text embeddings for the desired text
text_embeddings = torch.randn(1, 300)
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)

# 2. Generate random noise
noise = torch.randn(1, noise_dim)  # Use a batch size of 1 for a single example
noise = torch.tensor(noise, dtype=torch.float32)

# 3. Use the generator
synthetic_eeg_data = generator(text_embeddings, noise)

# 4. Reshape and format the synthetic data
synthetic_eeg_data = synthetic_eeg_data.view(105, 8)  # Reshape to match the real EEG data shape

# Now, synthetic_eeg_data contains the generated EEG data in the desired format

synthetic_eeg_data

  text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
  noise = torch.tensor(noise, dtype=torch.float32)


tensor([[ 2.6769e-01, -7.8161e-01, -1.7181e-02, -2.2148e-01, -2.2762e-02,
         -4.8832e-01, -7.6291e-02, -7.9746e-01],
        [-8.6639e-01, -1.8707e-01, -2.7908e-01,  4.2647e-01, -2.3034e-01,
          4.7006e-01, -2.3886e-01,  1.0481e-01],
        [ 3.7513e-02,  6.3084e-01,  5.9623e-02,  7.0541e-02,  3.7241e-01,
         -1.5710e-01,  1.0378e-01,  6.7872e-01],
        [-1.1555e-01,  2.8136e-01, -2.8982e-01,  4.5657e-01,  4.2915e-01,
          1.1774e-01, -3.9722e-01,  4.2859e-01],
        [ 4.4346e-01, -5.6903e-01, -1.6506e-01,  1.9398e-01,  6.8459e-01,
         -1.0816e-01,  4.3000e-01, -1.1264e+00],
        [ 6.5096e-01, -2.7118e-02,  4.8729e-01, -2.3146e-01, -4.6962e-02,
          2.0757e-01, -2.7969e-01, -3.2577e-01],
        [ 6.3891e-01, -7.5560e-03, -3.0644e-02, -1.6031e-01, -5.2701e-01,
         -9.1640e-01, -2.4542e-01, -3.1885e-01],
        [ 3.2299e-01,  5.6358e-01, -3.4815e-01,  2.0938e-01,  6.7931e-01,
          2.2821e-01,  7.9064e-02,  1.0442e+00],
        [ 1.4040