In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

import numpy as np

from sentence_transformers import SentenceTransformer
from transformers import CLIPModel, CLIPProcessor

from torcheeg.models import EEGNet

from tqdm import tqdm

from EEGDataset import EEGDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
# clip_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
# clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)

model = SentenceTransformer("all-MiniLM-L6-v2").to(device)

In [4]:
embeddings = model.encode(["hello world", "open source embeddings"])
embeddings.shape, embeddings.dtype, type(embeddings)

((2, 384), dtype('float32'), numpy.ndarray)

In [5]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'BertModel'})
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

In [None]:
ds = EEGDataset("shards", pad_upto=6000)
# ds = EEGDataset("/home/mostafaelfaggal/Documents/BCI", pad_upto=6000)
ds[0][0].shape, ds[0][1]

(torch.Size([3438, 6000]),
 'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.')

In [None]:
train_ds, val_ds, test_ds = ds.split_train_valid_test(train_ratio=0.7, valid_ratio=0.15, shuffle=False)

train_dl = train_ds.getLoader(batch_size=25, num_workers=0)
val_dl = val_ds.getLoader(batch_size=25, num_workers=0)
test_dl = test_ds.getLoader(batch_size=25, num_workers=0)

In [8]:
embeddings = model.encode([ds[0][1]])
embeddings.shape, embeddings.dtype, type(embeddings)

((1, 384), dtype('float32'), numpy.ndarray)

In [None]:
ds[0][0].shape

for batch_data, batch_labels in train_dl:
    print(batch_data.shape)
    print(batch_labels)
    break

In [None]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        self.text_encoder_model = SentenceTransformer("all-MiniLM-L6-v2")

    def forward(self, texts):
        embeddings = self.text_encoder_model.encode(texts, convert_to_tensor=True)
        return embeddings

class LocalizedEEGEncoder(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384):
        super(LocalizedEEGEncoder, self).__init__()

        self.temporal = nn.Sequential(
            nn.Conv1d(ch_count, 1024, 11, padding=1),
            nn.LeakyReLU(),
            nn.Conv1d(1024, 512, 11, padding=1),
            nn.LeakyReLU(),
            nn.Conv1d(512, 256, 11, padding=1),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool2d((256, 1))
        )

        self.fc = nn.Linear(256, embedding_dim)

        # self.eeg_encoder = EEGNet(chunk_size=10000,
        #                     num_electrodes=ch_count,
        #                     dropout=0.3,
        #                     kernel_1=64,
        #                     kernel_2=16,
        #                     F1=8,
        #                     F2=16,
        #                     D=2,
        #                     num_classes=embedding_dim)

    def forward(self, x):
        # x = torch.fft.rfft(x, dim=2)
        # x = torch.log(torch.abs(x) + 1e-8)
        x = self.temporal(x).squeeze(-1)
        x = self.fc(x)

        # x = self.eeg_encoder(x)

        x = F.normalize(x, p=2, dim=1)
        return x
    
class EEGCLIPModel(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384, freeze_text=True):
        super(EEGCLIPModel, self).__init__()
        self.text_encoder = TextEncoder()
        self.eeg_encoder = LocalizedEEGEncoder(ch_count=ch_count, embedding_dim=embedding_dim)

        if freeze_text:
            for param in self.text_encoder.parameters():
                param.requires_grad = False

    def forward(self, eeg_data, texts):
        eeg_embeddings = self.eeg_encoder(eeg_data)
        text_embeddings = self.text_encoder(texts)
        return eeg_embeddings, text_embeddings

# Training

In [11]:
model = EEGCLIPModel(3438).to(device)
# model.load_state_dict(torch.load("best_model.pt"))

In [None]:
def train(model: nn.Module, train_dataloader: DataLoader, valid_loader: DataLoader, epochs: int = 10):
# def train(model: nn.Module, dataset: EEGDataset, epochs: int = 10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_valid_loss = None

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for batch in tqdm(train_dataloader):
        # for i in tqdm(range(40)):
            # batch = dataset[i]
            eeg_data, texts = batch
            # eeg_data = eeg_data.unsqueeze(0)
            # texts = [texts]
            eeg_data = eeg_data.to(torch.float32).to(device)
            texts = list(texts)

            optimizer.zero_grad()
            eeg_embeddings, text_embeddings = model(eeg_data, texts)

            # loss = ((eeg_embeddings - text_embeddings) ** 2).mean()
            loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_dataloader)
        # avg_loss = total_loss / 40


        model.eval()
        total_valid_loss = 0.0
        with torch.inference_mode():
            for batch in tqdm(valid_loader):
            # for i in tqdm(range(40,45)):
                # batch = dataset[i]
                eeg_data, texts = batch
                # eeg_data = eeg_data.unsqueeze(0)
                # texts = [texts]
                eeg_data = eeg_data.to(torch.float32).to(device)
                texts = list(texts)

                eeg_embeddings, text_embeddings = model(eeg_data, texts)

                # loss = ((eeg_embeddings - text_embeddings) ** 2).mean()
                loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()

                total_valid_loss += loss.item()
        avg_valid_loss = total_valid_loss / len(valid_loader)
        # avg_valid_loss = total_valid_loss / 5

        if (best_valid_loss is None) or (avg_valid_loss < best_valid_loss):
            print(f"Valid Loss: {avg_valid_loss:.10f}")
            best_valid_loss = avg_valid_loss
            torch.save(model.state_dict(), "best_model.pt")

        print(f"Epoch [{epoch+1}/{epochs}]:- Train Loss: {avg_loss:.6f} | Valid Loss: {avg_valid_loss:.6f}")
        torch.save(model.state_dict(), "last_model.pt")    

        torch.cuda.empty_cache()


In [None]:
def test(model: nn.Module, test_dataloader: DataLoader):
# def test(model: nn.Module, dataset: EEGDataset):
    model.eval()
    total_loss = 0.0
    with torch.inference_mode():
        for batch in tqdm(test_dataloader):
        # for i in tqdm(range(45, 49)):
            # batch = dataset[i]
            eeg_data, texts = batch
            # eeg_data = eeg_data.unsqueeze(0)
            # texts = [texts]
            eeg_data = eeg_data.to(torch.float32).to(device)
            texts = list(texts)

            eeg_embeddings, text_embeddings = model(eeg_data, texts)

            # loss = ((eeg_embeddings - text_embeddings) ** 2).mean()
            loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()

            total_loss += loss.item()
    avg_loss = total_loss / 4
    print(f"Test Loss: {avg_loss:.6f}")

In [None]:
train(model, train_dl, val_dl, epochs=20)
# train(model, ds, epochs=20)

100%|██████████| 40/40 [01:06<00:00,  1.66s/it]
100%|██████████| 5/5 [00:03<00:00,  1.39it/s]


Valid Loss: 0.7127529979
Epoch [1/20]:- Train Loss: 0.766092 | Valid Loss: 0.712753


100%|██████████| 40/40 [01:06<00:00,  1.66s/it]
100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Valid Loss: 0.6473903775
Epoch [2/20]:- Train Loss: 0.639115 | Valid Loss: 0.647390


100%|██████████| 40/40 [01:06<00:00,  1.65s/it]
100%|██████████| 5/5 [00:03<00:00,  1.45it/s]


Valid Loss: 0.6147175193
Epoch [3/20]:- Train Loss: 0.592701 | Valid Loss: 0.614718


100%|██████████| 40/40 [01:11<00:00,  1.79s/it]
100%|██████████| 5/5 [00:03<00:00,  1.35it/s]


Valid Loss: 0.6024095058
Epoch [4/20]:- Train Loss: 0.574083 | Valid Loss: 0.602410


100%|██████████| 40/40 [01:13<00:00,  1.84s/it]
100%|██████████| 5/5 [00:04<00:00,  1.20it/s]


Valid Loss: 0.5967626095
Epoch [5/20]:- Train Loss: 0.565884 | Valid Loss: 0.596763


100%|██████████| 40/40 [01:13<00:00,  1.84s/it]
100%|██████████| 5/5 [00:04<00:00,  1.11it/s]


Valid Loss: 0.5931995511
Epoch [6/20]:- Train Loss: 0.562092 | Valid Loss: 0.593200


100%|██████████| 40/40 [01:16<00:00,  1.91s/it]
100%|██████████| 5/5 [00:04<00:00,  1.10it/s]


Valid Loss: 0.5915505052
Epoch [7/20]:- Train Loss: 0.560234 | Valid Loss: 0.591551


100%|██████████| 40/40 [01:15<00:00,  1.88s/it]
100%|██████████| 5/5 [00:04<00:00,  1.21it/s]


Valid Loss: 0.5905889034
Epoch [8/20]:- Train Loss: 0.559192 | Valid Loss: 0.590589


100%|██████████| 40/40 [01:16<00:00,  1.92s/it]
100%|██████████| 5/5 [00:04<00:00,  1.08it/s]


Valid Loss: 0.5899730921
Epoch [9/20]:- Train Loss: 0.558509 | Valid Loss: 0.589973


100%|██████████| 40/40 [01:15<00:00,  1.88s/it]
100%|██████████| 5/5 [00:03<00:00,  1.27it/s]


Valid Loss: 0.5895436049
Epoch [10/20]:- Train Loss: 0.558011 | Valid Loss: 0.589544


100%|██████████| 40/40 [01:14<00:00,  1.86s/it]
100%|██████████| 5/5 [00:04<00:00,  1.23it/s]


Valid Loss: 0.5892235994
Epoch [11/20]:- Train Loss: 0.557617 | Valid Loss: 0.589224


100%|██████████| 40/40 [01:16<00:00,  1.92s/it]
100%|██████████| 5/5 [00:04<00:00,  1.18it/s]


Valid Loss: 0.5889723420
Epoch [12/20]:- Train Loss: 0.557289 | Valid Loss: 0.588972


100%|██████████| 40/40 [01:15<00:00,  1.88s/it]
100%|██████████| 5/5 [00:04<00:00,  1.15it/s]


Valid Loss: 0.5887678266
Epoch [13/20]:- Train Loss: 0.557005 | Valid Loss: 0.588768


100%|██████████| 40/40 [01:16<00:00,  1.90s/it]
100%|██████████| 5/5 [00:04<00:00,  1.23it/s]


Valid Loss: 0.5885965943
Epoch [14/20]:- Train Loss: 0.556755 | Valid Loss: 0.588597


100%|██████████| 40/40 [01:15<00:00,  1.90s/it]
100%|██████████| 5/5 [00:04<00:00,  1.16it/s]


Valid Loss: 0.5884498835
Epoch [15/20]:- Train Loss: 0.556531 | Valid Loss: 0.588450


100%|██████████| 40/40 [01:14<00:00,  1.86s/it]
100%|██████████| 5/5 [00:04<00:00,  1.24it/s]


Valid Loss: 0.5883222699
Epoch [16/20]:- Train Loss: 0.556329 | Valid Loss: 0.588322


100%|██████████| 40/40 [01:14<00:00,  1.86s/it]
100%|██████████| 5/5 [00:04<00:00,  1.15it/s]


Valid Loss: 0.5882100105
Epoch [17/20]:- Train Loss: 0.556145 | Valid Loss: 0.588210


100%|██████████| 40/40 [01:15<00:00,  1.88s/it]
100%|██████████| 5/5 [00:04<00:00,  1.15it/s]


Valid Loss: 0.5881092787
Epoch [18/20]:- Train Loss: 0.555975 | Valid Loss: 0.588109


100%|██████████| 40/40 [01:15<00:00,  1.88s/it]
100%|██████████| 5/5 [00:04<00:00,  1.24it/s]


Valid Loss: 0.5880192041
Epoch [19/20]:- Train Loss: 0.555820 | Valid Loss: 0.588019


100%|██████████| 40/40 [01:14<00:00,  1.86s/it]
100%|██████████| 5/5 [00:04<00:00,  1.03it/s]


Valid Loss: 0.5879366636
Epoch [20/20]:- Train Loss: 0.555675 | Valid Loss: 0.587937


In [None]:
test(model, test_dl)
# test(model, ds)

100%|██████████| 4/4 [00:03<00:00,  1.11it/s]

Test Loss: 0.585270



