In [1]:
import sys
import nltk
import torch.nn.functional as F
import torch
import torch.nn as nn
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import torch
nltk.download('punkt')
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
sys.path.insert(0, '..')
import pickle
from torch.autograd import grad as torch_grad

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\gxb18167\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
print(torch.__version__)
print("GPU Available:", torch.cuda.is_available())

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = "cpu"

2.0.1
GPU Available: False


In [3]:
batch_size = 64
word_embedding_dim = 50
output_shape = (1, 105, 8)
torch.manual_seed(1)
np.random.seed(1)

In [4]:

# To load the lists from the file:
with open(r"C:\Users\gxb18167\PycharmProjects\EEG-To-Text\SIGIR_Development\EEG-GAN\EEG_Text_Pairs_Sentence.pkl", 'rb') as file:
    EEG_word_level_embeddings = pickle.load(file)
    EEG_word_level_labels = pickle.load(file)

In [5]:
EEG_word_level_embeddings[0] # 105 x 8 EEG signals relating to words

array([[0.19429664, 0.17923741, 0.36307213, 0.50747943, 0.62206709,
        0.51871073, 0.42868289, 0.06850591],
       [0.28140983, 0.57861644, 1.08024967, 0.36859021, 0.36467823,
        0.41722685, 1.14768839, 0.55932534],
       [0.4871715 , 0.79681689, 0.51210332, 0.40454227, 0.55180544,
        0.43953407, 0.75510144, 0.38622764],
       [0.43004686, 0.29673383, 0.3575944 , 0.36961138, 0.3506383 ,
        0.85081506, 1.26278675, 0.41491231],
       [0.3718234 , 0.45663175, 0.42179471, 1.03561676, 0.85461956,
        0.43418464, 0.46566522, 0.33808869],
       [0.49680769, 0.46951395, 0.50899029, 0.15184164, 1.6930244 ,
        0.76735342, 0.57302761, 0.94475245],
       [0.7187323 , 1.07659662, 0.72896671, 1.66536224, 1.39709103,
        1.28687668, 1.10441923, 1.43177247],
       [1.32517946, 1.66541123, 1.60669196, 1.8083328 , 1.39522731,
        1.99193358, 1.38413417, 1.54831302],
       [1.34435618, 1.0585711 , 0.49228552, 1.85182977, 1.85807848,
        1.41647136, 1.353929

In [6]:
def create_word_label_embeddings(Word_Labels_List):
    tokenized_words = []
    for i in range(len(Word_Labels_List)):
        tokenized_words.append([Word_Labels_List[i]])
    model = Word2Vec(sentences=tokenized_words, vector_size=word_embedding_dim, window=5, min_count=1, workers=4)
    word_embeddings = {word: model.wv[word] for word in model.wv.index_to_key}
    print("Number of word embeddings:", len(word_embeddings))
    #word, embedding = list(word_embeddings.items())[10]
    #print(f"Word: {word}, Embedding: {embedding}")

    Embedded_Word_labels = []
    for word in EEG_word_level_labels:
        Embedded_Word_labels.append(word_embeddings[word])

    return Embedded_Word_labels, word_embeddings


def create_word_label_embeddings_contextual(Word_Labels_List):
    tokenized_words = []
    for i in range(len(Word_Labels_List)):
        tokenized_words.append([Word_Labels_List[i]])
    model = Word2Vec(sentences=tokenized_words, vector_size=word_embedding_dim, window=5, min_count=1, workers=4)
    word_embeddings = {word: model.wv[word] for word in model.wv.index_to_key}
    Embedded_Word_labels = []

    for words in range (0, len(EEG_word_level_labels)):
        current_word = EEG_word_level_labels[words]
        if current_word != "SOS" and words != len(EEG_word_level_labels)-1:
            prior_word = EEG_word_level_labels[words-1]

            current_word = EEG_word_level_labels[words]

            next_word = EEG_word_level_labels[words+1]

            contextual_embedding = np.concatenate((word_embeddings[prior_word], word_embeddings[current_word], word_embeddings[next_word]), axis=-1)
            Embedded_Word_labels.append(contextual_embedding)
        elif words == len(EEG_word_level_labels)-1:
            prior_word = EEG_word_level_labels[words-1]
            next_word = "SOS"
            contextual_embedding = np.concatenate((word_embeddings[prior_word], word_embeddings[current_word], word_embeddings[next_word]), axis=-1)
            Embedded_Word_labels.append(contextual_embedding)

    return Embedded_Word_labels, word_embeddings


def create_dataloader(EEG_word_level_embeddings, Embedded_Word_labels):
    #EEG_word_level_embeddings_normalize = (EEG_word_level_embeddings - np.mean(EEG_word_level_embeddings)) / np.std(EEG_word_level_embeddings)
    #Assuming EEG_synthetic is the generated synthetic EEG data
    #EEG_synthetic_denormalized = (EEG_synthetic * np.max(np.abs(EEG_word_level_embeddings))) + np.mean(EEG_word_level_embeddings)

    EEG_word_level_embeddings_normalize = (EEG_word_level_embeddings - np.mean(EEG_word_level_embeddings)) / np.max(np.abs(EEG_word_level_embeddings))


    float_tensor = torch.tensor(EEG_word_level_embeddings_normalize, dtype=torch.float)
    float_tensor = float_tensor.unsqueeze(1)

    #print(EEG_word_level_embeddings_normalize)
    # Calculate mean and standard deviation
    print(torch.isnan(float_tensor).any())

    train_data = []
    for i in range(len(float_tensor)):
       train_data.append([float_tensor[i], Embedded_Word_labels[i]])
    trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=64)
    return trainloader

In [7]:
Embedded_Word_labels, word_embeddings = create_word_label_embeddings_contextual(EEG_word_level_labels)

In [8]:
len(Embedded_Word_labels[0])

150

In [26]:
Embedded_Word_labels[0]

array([-1.07245450e-03,  4.72862710e-04,  1.02066994e-02,  1.80185456e-02,
       -1.86058991e-02, -1.42336180e-02,  1.29177449e-02,  1.79459769e-02,
       -1.00308564e-02, -7.52674323e-03,  1.47610093e-02, -3.06694279e-03,
       -9.07322671e-03,  1.31081035e-02, -9.72032081e-03, -3.63203534e-03,
        5.75315952e-03,  1.98374758e-03, -1.65704302e-02, -1.88976359e-02,
        1.46235321e-02,  1.01405242e-02,  1.35153867e-02,  1.52573106e-03,
        1.27017805e-02, -6.81073172e-03, -1.89280277e-03,  1.15371468e-02,
       -1.50432754e-02, -7.87220709e-03, -1.50231645e-02, -1.86008448e-03,
        1.90762375e-02, -1.46383336e-02, -4.66753729e-03, -3.87548213e-03,
        1.61548741e-02, -1.18617918e-02,  9.03248801e-05, -9.50746797e-03,
       -1.92071013e-02,  1.00145862e-02, -1.75191704e-02, -8.78365058e-03,
       -7.01999670e-05, -5.92362892e-04, -1.53224804e-02,  1.92294866e-02,
        9.96411592e-03,  1.84662864e-02, -9.59467900e-04, -9.75699443e-03,
       -1.17798420e-02, -

In [9]:
trainloader = create_dataloader(EEG_word_level_embeddings, Embedded_Word_labels)

tensor(False)


In [28]:
#sanity check:
i1, l1 = next(iter(trainloader))
print(i1.dtype, l1.dtype)

torch.float32 torch.float32


In [11]:
z_size = 100
image_size = (105, 8)

n_filters = 32

In [12]:
## Loss function and optimizers:
loss_fn = nn.BCELoss()

In [13]:
def create_noise(batch_size, z_size, mode_z):
    if mode_z == 'uniform':
        input_z = torch.rand(batch_size, z_size)*2 - 1
    elif mode_z == 'normal':
        input_z = torch.randn(batch_size, z_size)
    return input_z

mode_z = 'uniform'
fixed_z = create_noise(batch_size, z_size, mode_z).to(device)

def create_samples(g_model, input_z, input_t):
    g_output = g_model(input_z, input_t)
    images = torch.reshape(g_output, (batch_size, *image_size))
    return (images+1)/2.0

In [14]:
noise = create_noise(64, 100, "uniform")

In [15]:
noise.shape

torch.Size([64, 100])

In [16]:
class Generator(nn.Module):
    def __init__(self, noise_dim, word_embedding_dim, output_shape):
        super(Generator, self).__init__()

        self.noise_dim = noise_dim
        self.word_embedding_dim = word_embedding_dim

        # Define the layers of your generator
        self.fc_noise = nn.Linear(noise_dim, 105*8)  # Increase the size for more complexity
        self.fc_word_embedding = nn.Linear(word_embedding_dim, 105*8)  # Increase the size for more complexity
        self.conv1 = nn.Conv2d(2, 128, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.relu = nn.LeakyReLU(0.2)

        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, noise, word_embedding):
        # Process noise
        noise = self.fc_noise(noise)
        noise = noise.view(noise.size(0), 1, 105,8)  # Adjust the size to match conv1

        # Process word embedding
        word_embedding = self.fc_word_embedding(word_embedding.to(device))
        word_embedding = word_embedding.view(word_embedding.size(0), 1, 105, 8)  # Adjust the size to match conv1

        # Concatenate noise and word embedding
        combined_input = torch.cat([noise, word_embedding], dim=1)

        # Upsample and generate the output
        z = self.conv1(combined_input)
        z = self.bn1(z)
        z = self.relu(z)

        z = self.conv2(z)
        z = self.bn2(z)
        z = self.relu(z)

        z = self.conv3(z)
        z = self.tanh(z)

        return z

class DiscriminatorWGAN(nn.Module):
    def __init__(self, n_filters):
        super(DiscriminatorWGAN, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, n_filters, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters, n_filters*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(n_filters * 2),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters*2, n_filters*4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(n_filters*4),
            nn.LeakyReLU(0.2),

            nn.Flatten(),  # Flatten spatial dimensions

            # Fully connected layer to reduce to a single value per sample
            nn.Linear(n_filters*4 * (105 // 8) * (8 // 8), 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.network(input)
        return output


In [17]:
gen_model = Generator(z_size, 150, output_shape).to(device)
disc_model = DiscriminatorWGAN(n_filters).to(device)

g_optimizer = torch.optim.Adam(gen_model.parameters(), 0.00002)
d_optimizer = torch.optim.Adam(disc_model.parameters(), 0.00002)

In [18]:

def gradient_penalty(real_data, generated_data):
    batch_size = real_data.size(0)

    # Calculate interpolation
    alpha = torch.rand(real_data.shape[0], 1, 1, 1, requires_grad=True, device=device)
    #print("Gen:", generated_data.shape)
    interpolated = alpha * real_data + (1 - alpha) * generated_data

    # Calculate probability of interpolated examples
    proba_interpolated = disc_model(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=proba_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(proba_interpolated.size(), device=device),
                           create_graph=True, retain_graph=True)[0]

    gradients = gradients.view(batch_size, -1)
    gradients_norm = gradients.norm(2, dim=1)
    return lambda_gp * ((gradients_norm - 1)**2).mean()


In [19]:
## Train the discriminator
def d_train_wgan(x, input_t):
    disc_model.zero_grad()

    batch_size = x.size(0)
    x = x.to(device)
    #print("X:", x.shape)
    # Calculate probabilities on real and generated data
    d_real = disc_model(x)
    input_z = create_noise(batch_size, z_size, mode_z).to(device)
    g_output = gen_model(input_z, input_t)
    #print("D Real:", d_real.shape)

    d_generated = disc_model(g_output)
    #print("G output:", g_output.shape)

    d_loss = d_generated.mean() - d_real.mean() + gradient_penalty(x.data, g_output.data)

    d_loss.backward()
    d_optimizer.step()

    return d_loss.data.item()

In [20]:
## Train the generator
def g_train_wgan(x, input_t):
    gen_model.zero_grad()

    batch_size = x.size(0)
    input_z = create_noise(batch_size, z_size, mode_z).to(device)

    g_output = gen_model(input_z, input_t)

    d_generated = disc_model(g_output)
    g_loss = -d_generated.mean()
    print("G Loss:", g_loss)

    # gradient backprop & optimize ONLY G's parameters
    g_loss.backward()
    g_optimizer.step()

    return g_loss.data.item()

In [21]:
epoch_samples_wgan = []
lambda_gp = 10.0
num_epochs = 100
torch.manual_seed(1)
critic_iterations = 5
save_interval = 5
checkpoint_path = 'Textual_WGAN_GP_checkpoint_epoch_{}.pth'
final_model_path = 'Textual_WGAN_GP_model_final.pth'

for epoch in range(1, num_epochs+1):
    gen_model.train()
    d_losses, g_losses = [], []
    for i, (x, t) in enumerate(trainloader):
        print("T:", x)
        for _ in range(critic_iterations):
            d_loss = d_train_wgan(x, t)
            print("D Loss:", d_loss)
        d_losses.append(d_loss)
        g_losses.append(g_train_wgan(x, t))

    print(f'Epoch {epoch:03d} | D Loss >>'
          f' {torch.FloatTensor(d_losses).mean():.4f}')
    print(f'Epoch {epoch:03d} | G Loss >>'
          f' {torch.FloatTensor(g_losses).mean():.4f}')

        # Save checkpoints at regular intervals
    if epoch % save_interval == 0:
        torch.save({
            'epoch': epoch,
            'gen_model_state_dict': gen_model.state_dict(),
            'optimizer_state_dict': g_optimizer.state_dict(),
            'd_losses': d_losses,
            'g_losses': g_losses,
        }, checkpoint_path.format(epoch))


    '''
    gen_model.eval()
    epoch_samples_wgan.append(
        create_samples(gen_model, fixed_z, t).detach().cpu().numpy())
    '''

T: tensor([[[[-8.4443e-03, -1.9176e-02, -2.1459e-02,  ..., -2.5626e-02,
           -2.0094e-02, -1.6487e-02],
          [-2.4960e-02, -2.0529e-02, -2.0863e-02,  ..., -2.0392e-02,
           -1.9405e-02, -1.8536e-02],
          [-8.1042e-03, -2.4881e-03, -1.4266e-02,  ..., -1.1138e-02,
           -1.3892e-02, -2.1972e-02],
          ...,
          [-1.4785e-02, -1.7782e-02, -1.5470e-02,  ..., -1.6162e-02,
           -2.4063e-02, -2.7194e-02],
          [-5.3006e-03, -8.4901e-03, -4.5342e-03,  ..., -1.6185e-02,
           -1.9304e-02, -1.2332e-02],
          [-1.0411e-02, -1.8819e-02, -2.4181e-02,  ..., -1.3150e-02,
           -1.6905e-02, -3.1616e-02]]],


        [[[-1.3685e-02, -9.7599e-03, -7.0305e-03,  ..., -1.7932e-02,
           -1.9210e-02, -1.5928e-02],
          [-1.5158e-02, -8.4017e-03, -1.0504e-02,  ..., -1.4120e-02,
           -8.1077e-03, -6.9838e-03],
          [-1.6963e-02, -9.2921e-03, -7.0514e-03,  ..., -4.5759e-03,
           -7.6884e-03, -1.2391e-02],
          ...,


KeyboardInterrupt: 

In [None]:
# Save the final model after training is complete
torch.save({
    'epoch': num_epochs,
    'gen_model_state_dict': gen_model.state_dict(),
    'optimizer_state_dict': g_optimizer.state_dict(),
    'd_losses': d_losses,
    'g_losses': g_losses,
}, final_model_path)