In [1]:
import torch
from torch import Tensor
import torch.nn as nn
from torch.utils.data import DataLoader
from einops.layers.torch import Rearrange
import itertools
from lavis.models.clip_models.loss import ClipLoss # TODO: Clean up
import math
import numpy as np

from eegdatasets_leaveone import EEGDataset

sub = 'sub-01'
# Verify dataloader
train_dataset = EEGDataset(subjects=[sub], exclude_subject=None, split="train")

train_dataloader= DataLoader(train_dataset, batch_size=25, shuffle=True, num_workers=0, drop_last=True)

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model + 1, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term[:d_model // 2 + 1])
        pe[:, 1::2] = torch.cos(position * div_term[:d_model // 2])

        self.register_buffer('pe', pe)

    def forward(self, x):
        pe = self.pe[:x.size(0), :].unsqueeze(1).repeat(1, x.size(1), 1)
        x = x + pe
        return x

class EEGAttention(nn.Module):
    def __init__(self, channel, nhead):
        super(EEGAttention, self).__init__()
        print(channel, nhead)
        self.pos_encoder = PositionalEncoding(channel)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=channel, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
        self.channel = channel

    def forward(self, src):
        src = src.permute(2, 0, 1)  # Change shape to [time_length, batch_size, channel]
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        return output.permute(1, 2, 0)  # Change shape back to [batch_size, channel, time_length]

class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        super().__init__()
        # revised from shallownet
        self.tsconv = nn.Sequential(
            nn.Conv2d(1, 40, (1, 5), (1, 1)),
            nn.AvgPool2d((1, 17), (1, 5)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.Conv2d(40, 40, (63, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

    def forward(self, x: Tensor) -> Tensor:
        # b, _, _, _ = x.shape
        x = x.unsqueeze(1)     
        # print("x", x.shape)   
        x = self.tsconv(x)
        # print("tsconv", x.shape)   
        x = self.projection(x)
        # print("projection", x.shape)  
        return x


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FlattenHead(nn.Sequential):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        return x


class Enc_eeg(nn.Sequential):
    def __init__(self, emb_size=40, **kwargs):
        super().__init__(
            PatchEmbedding(emb_size),
            FlattenHead()
        )

        
class Proj_eeg(nn.Sequential):
    def __init__(self, embedding_dim=1840, proj_dim=1024, drop_proj=0.5):
        super().__init__(
            nn.Linear(embedding_dim, proj_dim),
            ResidualAdd(nn.Sequential(
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.Dropout(drop_proj),
            )),
            nn.LayerNorm(proj_dim),
        )


class Proj_img(nn.Sequential):
    def __init__(self, embedding_dim=1024, proj_dim=1024, drop_proj=0.3):
        super().__init__(
            nn.Linear(embedding_dim, proj_dim),
            ResidualAdd(nn.Sequential(
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.Dropout(drop_proj),
            )),
            nn.LayerNorm(proj_dim),
        )
    def forward(self, x):
        return x 

class ATM_S_reconstruction_scale_0_1000(nn.Module):    
    def __init__(self, num_channels=63, sequence_length=250, num_subjects=1, num_features=64, num_latents=1024, num_blocks=1):
        super(ATM_S_reconstruction_scale_0_1000, self).__init__()
        self.attention_model = EEGAttention(num_channels, nhead=1)   
        self.subject_wise_linear = nn.ModuleList([nn.Linear(sequence_length, sequence_length) for _ in range(num_subjects)])
        self.enc_eeg = Enc_eeg()
        self.proj_eeg = Proj_eeg()        
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.loss_func = ClipLoss()       
         
    def forward(self, x):
        # print(f"Before attention: {x.shape}")
        x = self.attention_model(x)
        # print(f'After attention shape: {x.shape}')
        x = self.subject_wise_linear[0](x)
        # print(f'After subject-specific linear transformation shape: {x.shape}')
        eeg_embedding = self.enc_eeg(x)
        print(f'After enc_eeg shape: {eeg_embedding.shape}')
        out = self.proj_eeg(eeg_embedding)
        return out  

In [3]:
device = "cuda:2" if torch.cuda.is_available() else "cpu"
eeg_model = ATM_S_reconstruction_scale_0_1000(63, 250).to(device)
img_model = Proj_img().to(device)

optimizer = torch.optim.AdamW(itertools.chain(eeg_model.parameters(), img_model.parameters()), lr=3e-4)
text_features_train_all = train_dataset.text_features
img_features_train_all = train_dataset.img_features

63 1


In [4]:
def batchwise_cosine_similarity(Z,B):
    Z = Z.flatten(1)
    B = B.flatten(1).T
    Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True)  # Size (n, 1).
    B_norm = torch.linalg.norm(B, dim=0, keepdim=True)  # Size (1, b).
    cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
    return cosine_similarity

def topk(similarities,labels,k=5):
    if k > similarities.shape[0]:
        k = similarities.shape[0]
    topsum=0
    for i in range(k):
        topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels)
    return topsum

In [6]:
eeg_data.shape

torch.Size([25, 63, 250])

In [5]:
mse_loss_fn = nn.MSELoss()
alpha  = 0.9
(eeg_data, text, text_features, img, img_features) = next(iter(train_dataloader))

eeg_data = eeg_data.to(device)
# eeg_data = eeg_data[:,1:,:250]
text_features = text_features.to(device).float()
img_features = img_features.to(device).float() # already normalized
eeg_features = eeg_model(eeg_data).float()
# eeg_features_norm = nn.functional.normalize(eeg_features.flatten(1), dim=-1)

logit_scale = eeg_model.logit_scale
img_loss = eeg_model.loss_func(eeg_features, img_features, logit_scale)
text_loss = eeg_model.loss_func(eeg_features, text_features, logit_scale)

contrastive_loss = img_loss
regress_loss =  mse_loss_fn(eeg_features, img_features)
loss = (alpha * regress_loss *10 + (1 - alpha) * contrastive_loss*10)
loss.backward()

labels = torch.arange(len(eeg_data)).to(eeg_data.device) 
fwd_percent_correct = topk(batchwise_cosine_similarity(eeg_features, img_features), labels, k=5).item()
bwd_percent_correct = topk(batchwise_cosine_similarity(img_features, eeg_features), labels, k=5).item()

After enc_eeg shape: torch.Size([25, 1840])


In [65]:
contrastive_loss

tensor(5.0042, device='cuda:1', grad_fn=<DivBackward0>)

In [66]:
regress_loss

tensor(1.0012, device='cuda:1', grad_fn=<MseLossBackward0>)

In [63]:
loss

tensor(12.2357, device='cuda:1', grad_fn=<AddBackward0>)