In [3]:
import sys
sys.path.insert(0, '..')

In [4]:
import torch


print(torch.__version__)
print("GPU Available:", torch.cuda.is_available())

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = "cpu"

2.0.1
GPU Available: False


In [5]:
import torch.nn as nn
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
batch_size = 64

torch.manual_seed(1)
np.random.seed(1)

In [7]:
import pickle

# To load the lists from the file:
with open(r'C:\Users\gxb18167\PycharmProjects\SIGIR_EEG_GAN\Development\EEG-To-Text-GAN\Data\EEG_Text_Pairs.pkl', 'rb') as file:
    EEG_word_level_embeddings = pickle.load(file)
    EEG_word_level_labels = pickle.load(file)

In [8]:
import nltk
nltk.download('punkt')
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize

tokenized_words = [word_tokenize(word) for word in EEG_word_level_labels]

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\gxb18167\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [28]:
print(len(tokenized_words))
print(tokenized_words[5])

63015
['founded']


In [15]:
model = Word2Vec(sentences=tokenized_words, vector_size=100, window=5, min_count=1, workers=4)

In [23]:
# Get the word embeddings
word_embeddings = {word: model.wv[word] for word in model.wv.index_to_key}

# Example usage
'''
for word, embedding in word_embeddings.items():
    print(f"Word: {word}, Embedding: {embedding}")
'''
print("Number of word embeddings:", len(word_embeddings))
word, embedding = list(word_embeddings.items())[6]
print(f"Word: {word}, Embedding: {embedding}")

Number of word embeddings: 2137
Word: of, Embedding: [ 8.13227147e-03 -4.45733406e-03 -1.06835726e-03  1.00636482e-03
 -1.91113955e-04  1.14817743e-03  6.11386076e-03 -2.02715401e-05
 -3.24596534e-03 -1.51072862e-03  5.89729892e-03  1.51410222e-03
 -7.24261976e-04  9.33324732e-03 -4.92128357e-03 -8.38409644e-04
  9.17541143e-03  6.74942741e-03  1.50285603e-03 -8.88256077e-03
  1.14874600e-03 -2.28825561e-03  9.36823711e-03  1.20992784e-03
  1.49006362e-03  2.40640994e-03 -1.83600665e-03 -4.99963388e-03
  2.32429506e-04 -2.01418041e-03  6.60093315e-03  8.94012302e-03
 -6.74754381e-04  2.97701475e-03 -6.10765442e-03  1.69932481e-03
 -6.92623248e-03 -8.69402662e-03 -5.90020278e-03 -8.95647518e-03
  7.27759488e-03 -5.77203138e-03  8.27635173e-03 -7.24354526e-03
  3.42167495e-03  9.67499893e-03 -7.78544787e-03 -9.94505733e-03
 -4.32914635e-03 -2.68313056e-03 -2.71289347e-04 -8.83155130e-03
 -8.61755759e-03  2.80021061e-03 -8.20640661e-03 -9.06933658e-03
 -2.34046578e-03 -8.63180775e-03 -7.0

In [8]:
float_tensor = torch.tensor(EEG_word_level_embeddings, dtype=torch.float)
float_tensor = float_tensor.unsqueeze(1)

  float_tensor = torch.tensor(EEG_word_level_embeddings, dtype=torch.float)


In [9]:
float_tensor.shape


torch.Size([63015, 1, 105, 8])

In [10]:
import torch
train_data = []
for i in range(len(float_tensor)):
   train_data.append([float_tensor[i], EEG_word_level_labels[i]])



trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=64)

In [11]:
#sanity check:
i1, l1 = next(iter(trainloader))
print(i1.shape, l1)

torch.Size([64, 1, 105, 8]) ('campaign', 'the', '16', 'returned', 'politics', 'left', '1591', 'singers.', 'cancer', 'backing', 'continued', 'Springwells', 'accession', 'not,', 'Pentagon', 'shirt', '(born', 'Lionel', 'for', 'stood', 'Whipple', 'Ford', 'church', 'Pennsylvania.', '1986', 'Finally,', 'C.H.', 'married', 'Friends.', 'publisher', 'a', 'President', 'de', 'Wagner,', 'free', 'and', 'interrupted', 'from', 'again', "Charles's", 'experiments', 'marriage', 'has', 'Republican', '14', 'Presidential', 'trial', 'his', 'whom', 'President', 'the', 'to', 'led', 'Her', 'trained', 'York', 'Captain', 'primary', 'pedophile,', 'his', 'degree', 'erroneously', 'cgs', 'an')


In [12]:
z_size = 100
image_size = (105, 8)
n_filters = 32

In [13]:
## Loss function and optimizers:
loss_fn = nn.BCELoss()

In [14]:
def create_noise(batch_size, z_size, mode_z):
    if mode_z == 'uniform':
        input_z = torch.rand(batch_size, z_size)*2 - 1
    elif mode_z == 'normal':
        input_z = torch.randn(batch_size, z_size)
    return input_z

mode_z = 'uniform'
fixed_z = create_noise(batch_size, z_size, mode_z).to(device)

def create_samples(g_model, input_z):
    g_output = g_model(input_z)
    images = torch.reshape(g_output, (batch_size, *image_size))
    return (images+1)/2.0

In [15]:
noise = create_noise(64, 100, "uniform")

In [16]:
noise.shape

torch.Size([64, 100])

In [17]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        # Define your layers here
        self.fc = nn.Linear(latent_dim, 1 * 105 * 8)
        self.conv_transpose = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, z):
        # Reshape the input
        x = self.fc(z)
        x = x.view(-1, 1, 105, 8)

        # Apply convolutional transpose layer
        x = self.conv_transpose(x)

        return x

class DiscriminatorWGAN(nn.Module):
    def __init__(self, n_filters):
        super(DiscriminatorWGAN, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, n_filters, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters, n_filters*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(n_filters * 2),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters*2, n_filters*4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(n_filters*4),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters*4, 1, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.network(input)
        return output.view(-1, 1).squeeze(0)



In [18]:
gen_model = Generator().to(device)
disc_model = DiscriminatorWGAN(n_filters).to(device)

g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.0002)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.0002)

In [19]:
from torch.autograd import grad as torch_grad


def gradient_penalty(real_data, generated_data):
    batch_size = real_data.size(0)

    # Calculate interpolation
    alpha = torch.rand(real_data.shape[0], 1, 1, 1, requires_grad=True, device=device)
    print("Gen:", generated_data.shape)
    interpolated = alpha * real_data + (1 - alpha) * generated_data

    # Calculate probability of interpolated examples
    proba_interpolated = disc_model(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=proba_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(proba_interpolated.size(), device=device),
                           create_graph=True, retain_graph=True)[0]

    gradients = gradients.view(batch_size, -1)
    gradients_norm = gradients.norm(2, dim=1)
    return lambda_gp * ((gradients_norm - 1)**2).mean()


In [20]:
## Train the discriminator
def d_train_wgan(x):
    disc_model.zero_grad()

    batch_size = x.size(0)
    x = x.to(device)

    # Calculate probabilities on real and generated data
    d_real = disc_model(x)
    input_z = create_noise(batch_size, z_size, mode_z).to(device)
    g_output = gen_model(input_z)
    d_generated = disc_model(g_output)
    d_loss = d_generated.mean() - d_real.mean() + gradient_penalty(x.data, g_output.data)

    d_loss.backward()
    d_optimizer.step()

    return d_loss.data.item()

In [21]:
## Train the generator
def g_train_wgan(x):
    gen_model.zero_grad()

    batch_size = x.size(0)
    input_z = create_noise(batch_size, z_size, mode_z).to(device)

    g_output = gen_model(input_z)

    d_generated = disc_model(g_output)
    g_loss = -d_generated.mean()

    # gradient backprop & optimize ONLY G's parameters
    g_loss.backward()
    g_optimizer.step()

    return g_loss.data.item()

In [23]:
epoch_samples_wgan = []
lambda_gp = 10.0
num_epochs = 100
torch.manual_seed(1)
critic_iterations = 5

for epoch in range(1, num_epochs+1):
    gen_model.train()
    d_losses, g_losses = [], []
    for i, (x, t) in enumerate(trainloader):
        print("T:", t)
        for _ in range(critic_iterations):
            d_loss = d_train_wgan(x)
        d_losses.append(d_loss)
        g_losses.append(g_train_wgan(x))

    print(f'Epoch {epoch:03d} | D Loss >>'
          f' {torch.FloatTensor(d_losses).mean():.4f}')
    gen_model.eval()
    epoch_samples_wgan.append(
        create_samples(gen_model, fixed_z).detach().cpu().numpy())

T: ('campaign', 'the', '16', 'returned', 'politics', 'left', '1591', 'singers.', 'cancer', 'backing', 'continued', 'Springwells', 'accession', 'not,', 'Pentagon', 'shirt', '(born', 'Lionel', 'for', 'stood', 'Whipple', 'Ford', 'church', 'Pennsylvania.', '1986', 'Finally,', 'C.H.', 'married', 'Friends.', 'publisher', 'a', 'President', 'de', 'Wagner,', 'free', 'and', 'interrupted', 'from', 'again', "Charles's", 'experiments', 'marriage', 'has', 'Republican', '14', 'Presidential', 'trial', 'his', 'whom', 'President', 'the', 'to', 'led', 'Her', 'trained', 'York', 'Captain', 'primary', 'pedophile,', 'his', 'degree', 'erroneously', 'cgs', 'an')
Gen: torch.Size([64, 1, 105, 8])
Gen: torch.Size([64, 1, 105, 8])
Gen: torch.Size([64, 1, 105, 8])
Gen: torch.Size([64, 1, 105, 8])
Gen: torch.Size([64, 1, 105, 8])
T: ('the', 'Development', 'the', 'particularly', 'Louisiana', 'first', 'American', 'in', 'Cat', 'House', 'Shortly', 'his', 'Whiteman.', "Duffy's", 'success,', 'great', 'home', '1935,', 'Rea

KeyboardInterrupt: 