In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np

from sentence_transformers import SentenceTransformer
from transformers import CLIPModel, CLIPProcessor

from torcheeg.models import EEGNet

from tqdm import tqdm

from EEGDataset import EEGDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
# clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
# clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)

model = SentenceTransformer("all-MiniLM-L6-v2").to(device)

In [4]:
embeddings = model.encode(["hello world", "open source embeddings"])
embeddings.shape, embeddings.dtype, type(embeddings)

((2, 384), dtype('float32'), numpy.ndarray)

In [5]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'BertModel'})
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

In [6]:
ds = EEGDataset("shards")
train_ds, val_ds, test_ds = ds.split_train_valid_test(train_ratio=0.7, valid_ratio=0.15, shuffle=False)

train_dl = train_ds.getLoader(batch_size=25, num_workers=0)
val_dl = val_ds.getLoader(batch_size=25, num_workers=0)
test_dl = test_ds.getLoader(batch_size=25, num_workers=0)

In [7]:
ds[0][1]

'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.'

In [8]:
embeddings = model.encode([ds[0][1]])
embeddings.shape, embeddings.dtype, type(embeddings)

((1, 384), dtype('float32'), numpy.ndarray)

In [9]:
ds[0][0].shape

for batch_data, batch_labels in train_dl:
    print(batch_data.shape)
    print(batch_labels)
    break

torch.Size([25, 8196, 10000])
['Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.', 'Beautifully crafted, engaging filmmaking that should attract upscale audiences hungry for quality and a nostalgic, twisty yarn that will keep them guessing.', 'Bread, My Sweet has so many flaws it would be easy for critics to shred it.', 'Slow, silly and unintentionally hilarious.', 'Ultimately feels emp11111ty and unsatisfying, like swallowing a Communion wafer without the wine.', 'Exudes the fizz of a Busby Berkeley musical and the visceral excitement of a sports extravaganza.', "The film rehashes several old themes and is capped with pointless extremes -- it's insanely violent and very graphic.", 'Ryan Gosling is, in a word, brilliant as the conflicted Daniel.', "If Deuces Wild had been tweaked up a notch it would have become a camp adventure, one of those movies that's so bad it starts to become good.", "The film's stagecrafts are 

In [10]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        self.text_encoder_model = SentenceTransformer("all-MiniLM-L6-v2")

    def forward(self, texts):
        embeddings = self.text_encoder_model.encode(texts, convert_to_tensor=True)
        return embeddings

class LocalizedEEGEncoder(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384):
        super(LocalizedEEGEncoder, self).__init__()

        self.temporal = nn.Conv1d(ch_count, 512, 3, padding=1)
        self.temporal2 = nn.Conv1d(512, 256, 3, padding=1)
        self.avg_pool = nn.AdaptiveAvgPool2d((256, 1))
        self.fc = nn.Linear(256, embedding_dim)

    def forward(self, x):
        x = torch.relu(self.temporal(x))
        x = torch.relu(self.temporal2(x))
        x = self.avg_pool(x).squeeze(-1)
        x = self.fc(x)
        return x
    
class EEGCLIPModel(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384):
        super(EEGCLIPModel, self).__init__()
        self.text_encoder = TextEncoder()
        self.eeg_encoder = LocalizedEEGEncoder(ch_count=ch_count, embedding_dim=embedding_dim)

    def forward(self, eeg_data, texts):
        eeg_embeddings = self.eeg_encoder(eeg_data)
        text_embeddings = self.text_encoder(texts)
        return eeg_embeddings, text_embeddings

# Training

In [11]:
model = EEGCLIPModel().to(device)

In [13]:
def train(model: nn.Module, train_dataloader: DataLoader, valid_loader: DataLoader, epochs: int = 10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for batch in tqdm(train_dataloader):
            eeg_data, texts = batch
            eeg_data = eeg_data.to(torch.float32).to(device)
            texts = list(texts)

            optimizer.zero_grad()
            eeg_embeddings, text_embeddings = model(eeg_data, texts)

            loss = ((eeg_embeddings - text_embeddings) ** 2).mean()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_dataloader)


        model.eval()
        total_valid_loss = 0.0
        with torch.inference_mode():
            for batch in tqdm(valid_loader):
                eeg_data, texts = batch
                eeg_data = eeg_data.to(torch.float32).to(device)
                texts = list(texts)

                eeg_embeddings, text_embeddings = model(eeg_data, texts)

                loss = ((eeg_embeddings - text_embeddings) ** 2).mean()

                total_valid_loss += loss.item()
        avg_valid_loss = total_valid_loss / len(valid_loader)
        print(f"Epoch [{epoch+1}/{epochs}]:- Train Loss: {avg_loss:.6f} | Valid Loss: {avg_valid_loss:.6f}")
        

        torch.cuda.empty_cache()

In [14]:
def test(model: nn.Module, test_dataloader: DataLoader):
    model.eval()
    total_loss = 0.0
    with torch.inference_mode():
        for batch in tqdm(test_dataloader):
            eeg_data, texts = batch
            eeg_data = eeg_data.to(device)
            texts = list(texts)

            eeg_embeddings, text_embeddings = model(eeg_data, texts)

            loss = ((eeg_embeddings - text_embeddings) ** 2).mean()

            total_loss += loss.item()
    avg_loss = total_loss / len(test_dataloader)
    print(f"Test Loss: {avg_loss:.6f}")

In [None]:
train(model, train_dl, val_dl, epochs=5)

 18%|█▊        | 2/11 [01:24<06:55, 46.21s/it]