In [1]:
import sys
sys.path.append("/ocean/projects/cis240129p/soederha/silent_speech")

import pandas as pd
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Subset, ConcatDataset
from pathlib import Path
import numpy as np
from lib_alice import BrennanDataset
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm

base_dir = Path("/ocean/projects/cis240129p/shared/data/eeg_alice")

import torch
torch.cuda.is_available()  

subjects_used = ["S01", "S03", "S04", "S08", "S11", "S12", "S13", "S16", "S17", "S18", "S19", "S22", "S26", "S36", "S37", "S40", "S41", "S42", "S44", "S48"]
torch.cuda.empty_cache()

In [5]:
def create_datasets(subjects, base_dir, augmented_eeg_dict = None):
    train_datasets = []
    test_datasets = []
    for subject in subjects:
        dataset = BrennanDataset(
            root_dir=base_dir,
            phoneme_dir=base_dir / "phonemes",
            idx=subject,
            phoneme_dict_path=base_dir / "phoneme_dict.txt",
            augmented_eeg_dict=augmented_eeg_dict,
        )
        num_data_points = len(dataset)

        # Split indices into train and test sets
        split_index = int(num_data_points * 1)
        train_indices = list(range(split_index))
        #test_indices = list(range(split_index, num_data_points))

        # Create Subset datasets using indices
        train_dataset = Subset(dataset, train_indices)
        #test_dataset = Subset(dataset, test_indices)

        train_datasets.append(train_dataset)
        #test_datasets.append(test_dataset)
    return train_datasets#, test_datasets


#train_ds, test_ds = create_datasets(subjects_used, base_dir)
train_ds = create_datasets(subjects_used, base_dir)
train_dataset = ConcatDataset(train_ds)
#test_dataset = ConcatDataset(test_ds)
'''print(
    f"Train dataset length: {len(train_dataset)}, Test dataset length: {len(test_dataset)}"
)'''

In [6]:
#plot eeg raw vs eeg feat

eeg_raw = train_dataset[2]["eeg_raw"]
eeg_feat = train_dataset[2]["eeg_feats"]

plt.figure(figsize=(10, 5))
plt.ylim(-3, 3)
plt.plot(eeg_raw)
plt.show()

plt.plot(eeg_feat)
plt.show()

import seaborn as sns

sns.heatmap(eeg_feat)
plt.show()

sns.heatmap(eeg_raw.T)
plt.show()

In [7]:
def collate_fn(batch):
    """
    A custom collate function that handles different types of data in a batch.
    It dynamically creates batches by converting arrays or lists to tensors and
    applies padding to variable-length sequences.
    """
    batch_dict = {}
    for key in batch[0].keys():
        batch_items = [item[key] for item in batch]
        if isinstance(batch_items[0], np.ndarray) or isinstance(
            batch_items[0], torch.Tensor
        ):
            if isinstance(batch_items[0], np.ndarray):
                batch_items = [torch.tensor(b) for b in batch_items]
            if len(batch_items[0].shape) > 0:
                batch_dict[key] = torch.nn.utils.rnn.pad_sequence(
                    batch_items, batch_first=True  # pad with zeros
                )
            else:
                batch_dict[key] = torch.stack(batch_items)
        else:
            batch_dict[key] = batch_items

    return batch_dict


train_dataloder = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
)

In [8]:
i=0
for batch in train_dataloder:
    print(type(batch))
    print(batch.keys())
    print(batch['eeg_raw'].shape)
    #print(batch['label'])
    #print(batch['eeg_feats'].shape)
    
    print(batch['label'])
    i+=1
    if i>1:
        break

In [9]:
class EEGAutoencoder(nn.Module):
    def __init__(self, sequence_lenth=520, feature_dim=60,latent_dim=256):
        super(EEGAutoencoder, self).__init__()
        
        self.input_dim = sequence_lenth * feature_dim
        self.sequence_length = sequence_lenth
        self.feature_dim = feature_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, self.input_dim),
        )
        
    def encode(self, x):
        x = x.view(x.size(0), -1)
        x = self.encoder(x) 
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = self.decoder(z)
        return x.view(-1, self.sequence_length, self.feature_dim)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

def train_eeg_vae(model, train_loader, optimizer, device, epoch, data_type='eeg_raw'):
    model.train()
    total_loss = 0
    total_recon_loss = 0
    total_kld_loss = 0
    
    running_loss = 0.0
    running_recon_loss = 0.0
    running_kld_loss = 0.0
    
    batch_pbar = tqdm(train_dataloder, desc=f'Epoch {epoch}', leave=False)
    
    for batch_idx, batch in enumerate(batch_pbar):
        eeg_feats = batch[data_type].float().to(device)
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = model(eeg_feats)
    
        reconstruction_loss = F.mse_loss(recon_batch, eeg_feats, reduction='mean')/eeg_feats.size(0)
        kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())/eeg_feats.size(0)
        
        beta = 0.1 
        loss = reconstruction_loss + beta * kld_loss
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_recon_loss += reconstruction_loss.item()
        total_kld_loss += kld_loss.item()
        
        running_loss += loss.item()
        current_loss = running_loss / (batch_idx + 1)
        
        running_recon_loss += reconstruction_loss.item()
        current_recon_loss = running_recon_loss / (batch_idx + 1)
        
        running_kld_loss += kld_loss.item()
        current_kld_loss = running_kld_loss / (batch_idx + 1)
        
        batch_pbar.set_postfix({
                'loss': f'{current_loss:.4f}',
                'recon': f'{current_recon_loss:.4f}',
                'kld': f'{current_kld_loss:.4f}'
            })
        
    avg_loss = total_loss / len(train_loader.dataset)
    return avg_loss

In [12]:
torch.cuda.empty_cache()
import gc
gc.collect()

In [None]:
import torch
from torch import optim
from torchsummaryX import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

model = EEGAutoencoder(sequence_lenth=520, feature_dim=60, latent_dim=64).to(device)


optimizer = optim.Adam(model.parameters(), lr=1e-5)


test = torch.zeros(2, 520, 60).to(device) 
summary(model, test)


num_epochs = 5
train_losses = []
for epoch in range(1, num_epochs + 1):
    train_loss = train_eeg_vae(model, train_dataloder, optimizer, device, epoch, data_type='eeg_raw')
    train_losses.append(train_loss)
    print(f'====> Epoch: {epoch} Average loss: {train_loss:.4f}')


In [None]:
def get_latent_by_word(model, train_dataloder, word_target):
    """
    Get the mean latent vector for a specific word from the training data
    """
    model.eval()
    device = next(model.parameters()).device
    
    word_latents = []
    
    with torch.no_grad():
        for batch in tqdm(train_dataloder):
            labels = batch['label']
            eeg_feats = batch['eeg_raw'].float().to(device)
            
            word_indices = [i for i, label in enumerate(labels) if label == word_target]
            
            if word_indices:
                word_eeg = eeg_feats[word_indices]
                
                mu, _ = model.encode(word_eeg)
                word_latents.append(mu)
    if word_latents:
        word_latents = torch.cat(word_latents, dim=0)
        mean_latent = torch.mean(word_latents, dim=0)
        var_latent = torch.var(word_latents, dim=0)
        return mean_latent, var_latent
    else:
        return None, None

def generate_eeg_from_word(model, word, train_dataloder, num_samples=5):
    """
    Generate new EEG features for a given word
    """
    model.eval()
    device = next(model.parameters()).device
    
    mean_latent, var_latent = get_latent_by_word(model, train_dataloder, word)
    
    if mean_latent is None:
        print(f"No samples found for word: {word}")
        return None
    
    generated_samples = []
    with torch.no_grad():
        for i in range(num_samples):
            eps = torch.randn_like(var_latent)
            z = mean_latent + eps * torch.sqrt(var_latent)
            
            generated = model.decode(z.unsqueeze(0))
            generated_samples.append(generated)
    
    return torch.cat(generated_samples, dim=0)

def analyze_generated_samples(model, word, train_dataloder, num_samples=5):
    """
    Generate a number of eeg data samples for a word and returns both original and new
    """
    real_samples = []
    real_found = False
    
    for batch in train_dataloder:
        labels = batch['label']
        eeg_feats = batch['eeg_raw']
        
        
        word_indices = [i for i, label in enumerate(labels) if label == word]
        if word_indices:
            real_found = True
            real_samples.append(eeg_feats[word_indices])
    
    if not real_found:
        print(f"No real samples found for word: {word}")
        return
    
    real_samples = torch.cat(real_samples, dim=0)
    
    generated_samples = generate_eeg_from_word(model, word, train_dataloder, num_samples)
    
    if generated_samples is None:
        return
    
    real_mean = torch.mean(real_samples, dim=0)
    real_std = torch.std(real_samples, dim=0)
    gen_mean = torch.mean(generated_samples, dim=0)
    gen_std = torch.std(generated_samples, dim=0)
    
    print("\nStatistics:")
    print(f"Real mean range: [{torch.min(real_mean):.3f}, {torch.max(real_mean):.3f}]")
    print(f"Generated mean range: [{torch.min(gen_mean):.3f}, {torch.max(gen_mean):.3f}]")
    print(f"Real std range: [{torch.min(real_std):.3f}, {torch.max(real_std):.3f}]")
    print(f"Generated std range: [{torch.min(gen_std):.3f}, {torch.max(gen_std):.3f}]")
    
    return real_samples, generated_samples


In [None]:
real_alice, generated_alice = analyze_generated_samples(model, 'Alice', train_dataloder, num_samples=1)

In [None]:
def plot_eeg_comparison(real_samples, generated_samples, word):
    plt.figure(figsize=(15, 5))
    

    plt.subplot(1, 2, 1)
    for i in range(real_samples.shape[2]):
        plt.plot(real_samples[0, :, i].cpu().numpy(), alpha=0.5)
    plt.title(f'Real EEG for "{word}"')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    

    plt.subplot(1, 2, 2)
    for i in range(generated_samples.shape[2]):
        plt.plot(generated_samples[0, :, i].cpu().numpy(), alpha=0.5)
    plt.title(f'Generated EEG for "{word}"')
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    
    plt.tight_layout()
    plt.show()

In [None]:

plot_eeg_comparison(real_alice, generated_alice, 'Alice')

In [None]:
import seaborn as sns


def plot_eeg_heatmaps(real_sample, generated_sample, word):
   """
   Plot heatmap comparisons between real and generated EEG data
   """
   fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
   
   sns.heatmap(real_sample.cpu().numpy().T, 
               ax=ax1, 
               cmap='viridis',
               cbar_kws={'label': 'Amplitude'})
   ax1.set_title(f'Real EEG Pattern for "{word}"')
   ax1.set_xlabel('Time')
   ax1.set_ylabel('Features')
   
   sns.heatmap(generated_sample.cpu().numpy().T, 
               ax=ax2, 
               cmap='viridis',
               cbar_kws={'label': 'Amplitude'})
   ax2.set_title(f'Generated EEG Pattern for "{word}"')
   ax2.set_xlabel('Time')
   ax2.set_ylabel('Features')
   
   plt.tight_layout()
   plt.show()

In [None]:
plot_eeg_heatmaps(real_alice[0], generated_alice[0], 'Alice')

In [None]:
def word_latent_dict(model, train_dataloder):
    """ Returns dict of label:latent """
    model.eval()
    device = next(model.parameters()).device
    
    word_latents_dict = {}
    
    with torch.no_grad():
        for batch in tqdm(train_dataloder):
            labels = batch['label']
            eeg_feats = batch['eeg_raw'].float().to(device)
            
            for i, label in enumerate(labels):
                if label not in word_latents_dict:
                    word_latents_dict[label] = []
                
                eeg_sample = eeg_feats[i]
                mu, _ = model.encode(eeg_sample.unsqueeze(0))
                word_latents_dict[label].append(mu)
                
            
                
    for key in word_latents_dict:
        word_latents_dict[key] = torch.cat(word_latents_dict[key], dim=0)
        
    return word_latents_dict

def mean_std_latent_dict(word_latents_dict):
    ''' calculates mean and std for each word in word_latents_dict '''
    mean_latent_dict = {}
    std_latent_dict = {}
    
    for key in word_latents_dict:
        mean_latent_dict[key] = torch.mean(word_latents_dict[key], dim=0)
        std_latent_dict[key] = torch.std(word_latents_dict[key], dim=0)
        
    return mean_latent_dict, std_latent_dict

def word_augmentation_dict(mean_latent_dict, std_latent_dict, num_samples=100):
    """ Creates dict with word:generated_samples (shape #generated x time x features) """
    word_augmentation_dict = {}
    
    for key in mean_latent_dict:
        word_augmentation_dict[key] = []
        for i in range(num_samples):
            eps = torch.randn_like(std_latent_dict[key])
            z = mean_latent_dict[key] + eps * torch.sqrt(std_latent_dict[key])
            generated = model.decode(z.unsqueeze(0))
            word_augmentation_dict[key].append(generated)
            
        word_augmentation_dict[key] = torch.cat(word_augmentation_dict[key], dim=0)
    
    return word_augmentation_dict

In [None]:
test_dict = word_latent_dict(model, train_dataloder)

In [None]:
test_avaraged_dict, test_std_dict = mean_std_latent_dict(test_dict)

In [None]:
import pickle

with open("/ocean/projects/cis240129p/soederha/silent_speech/raw_model2_mean.pkl", "wb") as pickle_file:
    pickle.dump(test_avaraged_dict, pickle_file)
with open("/ocean/projects/cis240129p/soederha/silent_speech/raw_model2_std.pkl", "wb") as pickle_file:
    pickle.dump(test_std_dict, pickle_file)

In [None]:
def word_augmentation_dict(mean_latent_dict, std_latent_dict, num_samples, batch_size=1):
    word_augmentation_dict = {}
    
    for key in mean_latent_dict.keys():
        word_augmentation_dict[key] = []
        for i in range(0, num_samples, batch_size):
            current_batch_size = min(batch_size, num_samples - i)
            eps = torch.randn(current_batch_size, *std_latent_dict[key].shape, device=std_latent_dict[key].device)
            z = mean_latent_dict[key].unsqueeze(0) + eps * torch.sqrt(std_latent_dict[key].unsqueeze(0))
            generated = model.decode(z)
            word_augmentation_dict[key].append(generated)
            torch.cuda.empty_cache()
            
        word_augmentation_dict[key] = torch.cat(word_augmentation_dict[key], dim=0)
    
    return word_augmentation_dict

In [None]:
test_augmentation_dict = word_augmentation_dict(test_avaraged_dict, test_std_dict, num_samples=1)

In [None]:
print(test_augmentation_dict['you'].shape)

In [None]:
alice_chapter = pd.read_csv('/ocean/projects/cis240129p/shared/data/eeg_alice/AliceChapterOne-EEG.csv').iloc[:,0]
new_subject = np.empty((520*len(alice_chapter), 60))
for i in range(len(alice_chapter)):
    word = alice_chapter[i]
    if word not in test_augmentation_dict:
        print(f"Word '{word}' not found in augmentation dict")
    else:
        random_coice = 0
        new_subject[i*159:(i+1)*159] = test_augmentation_dict[word][random_coice].cpu().detach().numpy()
        
print(new_subject.shape)
print(new_subject)
    

In [None]:
#create .npy file from the new subject
#np.save('/ocean/projects/cis240129p/soederha/newsubject_eegfeats_final_epoch40.npy', new_subject)