In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

from sklearn.metrics import classification_report

import numpy as np

from sentence_transformers import SentenceTransformer
from transformers import CLIPModel, CLIPProcessor, \
                            DistilBertModel, DistilBertTokenizerFast, \
                            GPT2Tokenizer, GPT2Model, \
                            RobertaTokenizer, RobertaModel, \
                            AutoTokenizer, AutoModelForSequenceClassification, \
                            pipeline
# import google.generativeai as genai
#genai.configure(api_key="")

from torcheeg.models import EEGNet

from tqdm import tqdm

from EEGDataset import EEGDataset, WordEEGDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Dataset

In [3]:
# ds = WordEEGDataset("shards", pad_upto=200)
ds = EEGDataset("shards", pad_upto=4000, crp_rng=(0,1))
ds[0][0].shape, ds[0][1], ds[0][2]

(torch.Size([1076, 4000]),
 '0',
 'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.')

In [4]:
# train_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(274)))
# val_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(274,332)))
# test_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(332,392)))

# train_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(40)))
# val_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(40,45)))
# test_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(45,49)))

In [None]:
train_ds, val_ds, test_ds = ds.split_train_valid_test(train_ratio=0.7, valid_ratio=0.15, shuffle=False)

train_dl = train_ds.getLoader(batch_size=25, num_workers=0)
val_dl = val_ds.getLoader(batch_size=25, num_workers=0)
test_dl = test_ds.getLoader(batch_size=25, num_workers=0)

len(train_ds), len(val_ds), len(test_ds)

(1400, 300, 300)

# Text Model Testing

In [3]:
class TextEncoder:
    def __init__(self):
        # self.model = SentenceTransformer("all-MiniLM-L6-v2").to(device)

        # self.tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
        # self.model = DistilBertModel.from_pretrained("distilbert-base-uncased")

        # self.tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        # self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

        # self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        # self.model = GPT2Model.from_pretrained("gpt2")

        self.model = SentenceTransformer("paraphrase-mpnet-base-v2")

        # self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        # self.model = RobertaModel.from_pretrained("roberta-base")

        # model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
        # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # self.model = AutoModelForSequenceClassification.from_pretrained(model_name)

        pass

    def encode(self, txts):
        # all-MiniLM-L6-v2
        # return torch.tensor(self.model.encode(txts))
    
        # distilbert-base-uncased
        # inputs = self.tokenizer(
        #     txts,
        #     truncation=True,
        #     padding=True,
        #     max_length=256,
        #     return_tensors="pt"
        # )
        # return self.model(**inputs).last_hidden_state[:,0]

        # openai/clip-vit-base-patch32
        # inputs = self.tokenizer(
        #     text=txts,
        #     padding=True,
        #     return_tensors="pt",
        # )
        # return self.model.get_text_features(**inputs)

        # gpt2
        # inputs = self.tokenizer(
        #     txts,
        #     truncation=True,
        #     return_tensors="pt",
        # )
        # last_hidden = self.model(**inputs).last_hidden_state
        # return last_hidden[torch.arange(last_hidden.size(0)), inputs['attention_mask'].sum(1)-1]
        # # attention_mask = inputs['attention_mask'].unsqueeze(-1)
        # # text_embeds_mean = (last_hidden * attention_mask).sum(1) / attention_mask.sum(1)
        # # return text_embeds_mean / text_embeds_mean.norm(dim=-1, keepdim=True)

        # paraphrase-mpnet-base-v2
        return torch.tensor(self.model.encode(txts))

        # Google Gemini
        # res = genai.embed_content(
        #     model="models/text-embedding-004",
        #     content=txts,
        #     task_type="retrieval_document"
        # )["embedding"]
        # return torch.tensor(res)

        # Roberta
        # inputs = self.tokenizer(txts, return_tensors="pt")
        # token_embeddings = self.model(**inputs).last_hidden_state
        # mask = inputs["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float()
        # summed = torch.sum(token_embeddings * mask, dim=1)
        # counted = torch.clamp(mask.sum(dim=1), min=1e-9)
        # return summed / counted

        # Roberta 3 Class
        # inputs = self.tokenizer(
        #     txts,
        #     padding=True,
        #     truncation=True,
        #     return_tensors="pt"
        # )
        # outputs = self.model.roberta(**inputs, output_hidden_states=True)
        # last_hidden = outputs.hidden_states[-1]
        # return last_hidden[:, 0, :]

model = TextEncoder()

In [5]:
target_embeddings = model.encode(["Negative", "Neutral", "Positive"])
target_embeddings.shape, target_embeddings.dtype, type(target_embeddings)

(torch.Size([3, 768]), torch.float32, torch.Tensor)

In [10]:
lbls = []
txts = []
txt_embeds = []
for i in tqdm(range(100)):
    _, sent, txt = ds[i]
    if _ is None:
        continue
    lbls.append(int(sent)+1)
    txts.append(txt)
    txt_embeds.append(model.encode([txt]))

preds = torch.cat(txt_embeds, dim=0)

# cat_norm = target_embeddings.clone().detach()
# text_norm = preds.clone().detach()
# diffs = torch.cdist(text_norm, cat_norm, p=2)
# pred = diffs.argmin(dim=1)

cat_norm = F.normalize(target_embeddings, p=2, dim=1)
text_norm = F.normalize(preds, p=2, dim=1)
similarities = torch.matmul(text_norm, cat_norm.T)
pred = similarities.argmax(dim=1)

print(classification_report(lbls, pred))

100%|██████████| 100/100 [00:01<00:00, 64.24it/s]

              precision    recall  f1-score   support

           0       0.57      0.64      0.61        36
           1       0.34      0.35      0.35        34
           2       0.60      0.50      0.55        30

    accuracy                           0.50       100
   macro avg       0.51      0.50      0.50       100
weighted avg       0.50      0.50      0.50       100






# Dataset Testing

In [20]:
train_ds[0][0].shape, train_ds[0][1], len(train_ds[0][2]), train_ds[0][2]

(torch.Size([1076, 10000]),
 '0',
 117,
 'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.')

In [21]:
val_ds[0][0].shape, len(val_ds[0][2])
# val_ds[0][1][0]

(torch.Size([1076, 10000]), 120)

In [None]:
# embeddings = model.encode([ds[0][1]])
embeddings = model.encode([train_ds[0][1]])
embeddings.shape, embeddings.dtype, type(embeddings)

((1, 384), dtype('float32'), numpy.ndarray)

In [5]:
ds[0][0].shape

for batch_data, batch_sent, batch_labels in train_dl:
    print(len(batch_data), batch_data.shape)
    print(len(batch_sent), batch_sent.shape)
    print(len(batch_labels))
    break

10 torch.Size([10, 1076, 10000])
10 torch.Size([10])
10


# Models

In [None]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        # self.text_encoder_model = SentenceTransformer("all-MiniLM-L6-v2")
        # self.text_encoder_model = CLIPModel.from_pretrained(MODEL_NAME)
        # self.processor = CLIPProcessor.from_pretrained(MODEL_NAME)
        self.text_encoder_model = SentenceTransformer("paraphrase-mpnet-base-v2")

    def forward(self, texts):
        # inputs = self.processor(text=texts, return_tensors="pt", padding=True)
        # inputs = {k: v.to(self.text_encoder_model.device) for k, v in inputs.items()}
        # embeddings = self.text_encoder_model.get_text_features(**inputs)
        
        embeddings = self.text_encoder_model.encode(texts, convert_to_tensor=True)

        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings
    
class EEGPatchEmbedding(nn.Module):
    def __init__(self, n_channels=1076, patch_size=20, embed_dim=384):
        super().__init__()
        self.patch_size = patch_size
        dropprob = 0.3
        self.proj = nn.Sequential(
            nn.Linear(n_channels * patch_size, (n_channels * patch_size)//2),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear((n_channels * patch_size)//2,embed_dim)
        )

    def forward(self, x):
        # x: (B, C, T)
        B, C, T = x.shape
        x = x[:, :, :T - (T % self.patch_size)]  # trim
        x = x.view(B, C, -1, self.patch_size)    # (B, C, tokens, patch)
        x = x.permute(0, 2, 1, 3)                 # (B, tokens, C, patch)
        x = x.flatten(2)                          # (B, tokens, C*patch)
        return self.proj(x)                       # (B, tokens, embed_dim)

class EEGTransformer(nn.Module):
    def __init__(
        self,
        n_channels=1076,
        patch_size=20,
        embed_dim=384,
        num_layers=4,
        num_heads=8
    ):
        super().__init__()

        dropprob = 0.3

        self.patch_embed = EEGPatchEmbedding(
            n_channels, patch_size, embed_dim
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=4 * embed_dim,
            dropout=dropprob,
            batch_first=True
        )

        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        # x: (B, C, T)
        B = x.size(0)

        x = self.patch_embed(x)   # (B, tokens, embed_dim)

        cls = self.cls_token.expand(B, -1, -1) # FIXME:
        x = torch.cat([cls, x], dim=1)

        x = self.transformer(x)
        cls_out = x[:, 0]

        return cls_out


class EEGEncoder(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384):
        super(EEGEncoder, self).__init__()

        dropprob = .3

        # self.temporal = nn.Sequential(
        #     nn.Conv1d(ch_count, 1024, 64, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Conv1d(1024, 512, 32, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Conv1d(512, 256, 32, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     # nn.AdaptiveAvgPool2d((256, 1))
        #     nn.AdaptiveAvgPool1d((1))
        # )

        # self.fc = nn.Sequential(
        #     nn.Linear(256, embedding_dim//2),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Linear(embedding_dim//2, embedding_dim)
        # )

        # self.eeg_encoder = EEGNet(
        #     chunk_size=2000,
        #     num_electrodes=ch_count,
        #     dropout=dropprob,
        #     kernel_1=64,
        #     kernel_2=16,
        #     F1=8,
        #     F2=16,
        #     D=2,
        #     num_classes=embedding_dim
        # )

        self.eeg_transformer = EEGTransformer(
            n_channels=ch_count,
            patch_size=20,
            embed_dim=embedding_dim,
            num_layers=4,
            num_heads=8,
        )

    def compute_power_bands(self, x):
        fs = 500
        
        # eeg: (N, C, T)
        freqs = torch.fft.rfftfreq(x.size(-1), 1/fs).to(x.device)  # (F,)
        fft_vals = torch.fft.rfft(x, dim=-1)                         # (N, C, F)
        psd = (fft_vals.abs()**2)

        bands = [(0.5,4), (4,8), (8,12), (12,30), (30,49)]
        feats = []
        for low, high in bands:
            idx = (freqs >= low) & (freqs < high)
            band = psd[..., idx].mean(dim=-1)   # (N, C)
            feats.append(band)

        return torch.stack(feats, dim=2)        # (N, C, P)

    def forward(self, x):
        x = torch.fft.rfft(x, dim=2)
        x = torch.log(torch.abs(x) + 1e-8)

        # x = self.compute_power_bands(x)

        # x = self.temporal(x).squeeze(-1)
        # x = self.fc(x)

        # x = x.unsqueeze(1)
        # x = self.eeg_encoder(x)

        x = self.eeg_transformer(x)

        x = F.normalize(x, p=2, dim=1)
        # x = torch.tanh(x) * 3
        return x
    
class EEGClassifier(nn.Module):
    def __init__(self, embedding_dim=384):
        super(EEGClassifier, self).__init__()

        dropprob = .3

        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear(128, 16),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear(16, 3)
        )

    def forward(self, x):
        return self.classifier(x)
    
class EEGCLIPModel(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384, freeze_text=True):
        super(EEGCLIPModel, self).__init__()
        self.text_encoder = TextEncoder()
        self.eeg_encoder = EEGEncoder(ch_count=ch_count, embedding_dim=embedding_dim)
        self.eeg_classifier = EEGClassifier(embedding_dim=embedding_dim)

        if freeze_text:
            for param in self.text_encoder.parameters():
                param.requires_grad = False

    def forward(self, eeg_data, texts):
        eeg_embeddings = self.eeg_encoder(eeg_data)
        text_embeddings = self.text_encoder(texts)
        classification = self.eeg_classifier(eeg_embeddings)
        return eeg_embeddings, text_embeddings, classification

# Training

In [7]:
# model = EEGCLIPModel().to(device)
model = EEGCLIPModel(1076, 768, freeze_text=True).to(device)
# model.load_state_dict(torch.load("last_model.pt"))

In [8]:
eeg, sent, text = ds[0]
res = model(eeg.unsqueeze(0).to(torch.float32).to("cuda"), [text])

res[0].shape, res[1].shape, res[2]

(torch.Size([1, 768]),
 torch.Size([1, 768]),
 tensor([[-0.1714,  0.2153,  0.0996]], device='cuda:0',
        grad_fn=<AddmmBackward0>))

In [None]:
# Text Embedding Normalization Space
all_embeds = model.text_encoder([lbl[1] for lbl in ds.labels])

text_mean = all_embeds.mean(dim=0)
text_std = all_embeds.std(dim=0).clamp(min=1e-6)

text_mean, text_std

In [None]:
# Classification Classes
classification_classes = ["Positive", "Neutral", "Negative"]
classification_classes_embeds = model.text_encoder(classification_classes)
classification_classes_embeds_z = (classification_classes_embeds-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
classification_classes_embeds_z.shape

In [8]:
tau, alpha, beta = 1, .7, 1.3

def train(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, epochs: int = 10):
# def train(model: nn.Module, train_dataset: WordEEGDataset, valid_dataset: WordEEGDataset, epochs: int = 10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    best_valid_loss = None

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        train_count = 0
        train_correct = 0
        train_total_count = 0
        for batch in tqdm(train_loader):
            if batch[0] is None:
                continue
        # for batch in tqdm(train_dataset):
            # if isinstance(train_dataset, EEGDataset):
            #     batch = (batch[0].unsqueeze(0), [batch[1]])
            # batch = ([batch[0]], [batch[1]])
            B = batch[0].shape[0]
            eeg_data = batch[0].to(torch.float32).clone().detach().to(device) # (B,C,T)
            sent_lbl = batch[1].to(torch.int64).clone().detach().to(device) # (B)
            texts = batch[2] # [""]

            optimizer.zero_grad()
            eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts) # (B,D)
            sent_logits: torch.Tensor

            loss = 0
            loss += 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()
            loss += loss_fn(sent_logits, sent_lbl) / 2

            sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
            dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
            positive_losses = torch.relu(dist-tau*alpha)
            negative_losses = torch.relu(tau*beta-dist)
            positive_mask = torch.tril(sent_mask == True, -1)
            negative_mask = torch.tril(sent_mask == False, -1)
            if positive_mask.sum().item() == 0:
                positive_loss = 0
            else:
                positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
            if negative_mask.sum().item() == 0:
                negative_loss = 0
            else:
                negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
            loss += positive_loss + negative_loss # clip loss

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            train_count += 1

            sent_probs = sent_logits.softmax(dim=1)
            sent_preds = sent_probs.argmax(dim=1)
            train_correct += (sent_lbl == sent_preds).sum().item()
            train_total_count += B

            # for i in range(len(batch[1])):
            #     eeg_data = batch[0][i].to(torch.float32).to(device)
            #     texts = batch[1][i]

            #     optimizer.zero_grad()
            #     eeg_embeddings, text_embeddings = model(eeg_data, texts)

            #     # text_z = (text_embeddings-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
            #     # loss = loss_fn(eeg_embeddings, text_z)
            #     loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()
            #     loss.backward()
            #     optimizer.step()

            #     total_loss += loss.item()
            #     train_count += 1

            #     # Accuracy Check
            #     # text_norm = F.normalize(text_embeddings, p=2, dim=1)
            #     # eeg_norm = F.normalize(eeg_embeddings, p=2, dim=1)
            #     # classes_norm = F.normalize(classification_classes_embeds, p=2, dim=1)
            #     sim_ground = text_embeddings @ classification_classes_embeds.T
            #     sim_pred = eeg_embeddings @ classification_classes_embeds.T

            #     # sim_ground = -torch.cdist(text_z.clone().detach(), classification_classes_embeds_z, p=2)
            #     # sim_pred = -torch.cdist(eeg_embeddings.clone().detach(), classification_classes_embeds_z, p=2)

            #     ground = sim_ground.argmax(dim=1)
            #     pred = sim_pred.argmax(dim=1)
            #     train_correct += (ground == pred).sum().item()
            #     train_total_count += ground.shape[0]


        # avg_loss = total_loss / len(train_loader)
        avg_loss = total_loss / train_count
        avg_acc = train_correct / train_total_count


        model.eval()
        total_valid_loss = 0.0
        valid_count = 0
        valid_correct = 0
        valid_total_count = 0
        with torch.inference_mode():
            for batch in tqdm(valid_loader):
                if batch[0] is None:
                    continue
            # for batch in tqdm(valid_dataset):
                # if isinstance(valid_dataset, EEGDataset):
                #     batch = (batch[0].unsqueeze(0), [batch[1]])
                # batch = ([batch[0]], [batch[1]])
                B = batch[0].shape[0]
                eeg_data = batch[0].to(torch.float32).clone().detach().to(device)
                sent_lbl = batch[1].to(torch.int64).clone().detach().to(device)
                texts = batch[2]

                eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts)
                sent_logits: torch.Tensor

                loss = 0
                loss += 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()
                loss += loss_fn(sent_logits, sent_lbl) / 2

                sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
                dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
                positive_losses = torch.relu(dist-tau*alpha)
                negative_losses = torch.relu(tau*beta-dist)
                positive_mask = torch.tril(sent_mask == True, -1)
                negative_mask = torch.tril(sent_mask == False, -1)
                positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
                negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
                loss += positive_loss + negative_loss # clip loss

                total_valid_loss += loss.item()
                valid_count += 1

                sent_probs = sent_logits.softmax(dim=1)
                sent_preds = sent_probs.argmax(dim=1)
                valid_correct += (sent_lbl == sent_preds).sum().item()
                valid_total_count += B

                # for i in range(len(batch[0])):
                #     eeg_data = batch[0][i].to(torch.float32).to(device)
                #     texts = batch[1][i]

                #     eeg_embeddings, text_embeddings = model(eeg_data, texts)

                #     # text_z = (text_embeddings-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
                #     # loss = loss_fn(eeg_embeddings, text_z)
                #     loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()

                #     total_valid_loss += loss.item()
                #     valid_count += 1

                #     # Accuracy Check
                #     # text_norm = F.normalize(text_embeddings, p=2, dim=1)
                #     # eeg_norm = F.normalize(eeg_embeddings, p=2, dim=1)
                #     # classes_norm = F.normalize(classification_classes_embeds, p=2, dim=1)
                #     sim_ground = text_embeddings @ classification_classes_embeds.T
                #     sim_pred = eeg_embeddings @ classification_classes_embeds.T

                #     # sim_ground = -torch.cdist(text_z.clone().detach(), classification_classes_embeds_z, p=2)
                #     # sim_pred = -torch.cdist(eeg_embeddings.clone().detach(), classification_classes_embeds_z, p=2)

                #     ground = sim_ground.argmax(dim=1)
                #     pred = sim_pred.argmax(dim=1)
                #     valid_correct += (ground == pred).sum().item()
                #     valid_total_count += ground.shape[0]
                    
        # avg_valid_loss = total_valid_loss / len(valid_loader)
        avg_valid_loss = total_valid_loss / valid_count
        avg_valid_acc = valid_correct / valid_total_count

        if (best_valid_loss is None) or (avg_valid_loss < best_valid_loss):
            print(f"Valid Loss: {avg_valid_loss:.10f}")
            best_valid_loss = avg_valid_loss
            torch.save(model.state_dict(), "best_model.pt")

        print(f"Epoch [{epoch+1}/{epochs}]:- Train Loss: {avg_loss:.6f} | Train Acc: {avg_acc*100:.4f}% | Valid Loss: {avg_valid_loss:.6f} | Valid Acc: {avg_valid_acc*100:.4f}%")
        torch.save(model.state_dict(), "last_model.pt")    

        torch.cuda.empty_cache()


In [9]:
classification_classes = ["Positive", "Neutral", "Negative"]

def test(model: nn.Module, test_loader: DataLoader):
# def test(model: nn.Module, test_dataset: WordEEGDataset):
    loss_fn = nn.CrossEntropyLoss()

    model.eval()
    total_loss = 0.0
    count = 0
    correct = 0
    total_count = 0

    actuals = []
    preds = []
    with torch.inference_mode():
        for batch in tqdm(test_loader):
        # for batch in tqdm(test_dataset):
            # if isinstance(test_dataset, EEGDataset):
            #     batch = (batch[0].unsqueeze(0), [batch[1]])
            # batch = ([batch[0]], [batch[1]])
            B = batch[0].shape[0]
            eeg_data = batch[0].to(torch.float32).clone().detach().to(device)
            sent_lbl = batch[1].to(torch.int64).clone().detach().to(device)
            texts = batch[2]

            eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts)
            sent_logits: torch.Tensor

            loss = 0
            loss += 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()
            loss += loss_fn(sent_logits, sent_lbl)

            sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
            dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
            positive_losses = torch.relu(dist-tau*alpha)
            negative_losses = torch.relu(tau*beta-dist)
            positive_mask = torch.tril(sent_mask == True, -1)
            negative_mask = torch.tril(sent_mask == False, -1)
            positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
            negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
            loss += positive_loss + negative_loss # clip loss

            total_loss += loss.item()
            count += 1

            sent_probs = sent_logits.softmax(dim=1)
            sent_preds = sent_probs.argmax(dim=1)
            correct += (sent_lbl == sent_preds).sum().item()
            total_count += B

            actuals.append(sent_lbl)
            preds.append(sent_preds)
            # for i in range(len(batch[0])):
            #     eeg_data = batch[0][i].to(torch.float32).to(device)
            #     texts = batch[1][i]

            #     eeg_embeddings, text_embeddings = model(eeg_data, texts)

            #     # text_z = (text_embeddings-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
            #     # loss = loss_fn(eeg_embeddings, text_z)
            #     loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()

            #     total_loss += loss.item()
            #     count += 1

            #     # Accuracy Check
            #     # text_norm = F.normalize(text_embeddings, p=2, dim=1)
            #     # eeg_norm = F.normalize(eeg_embeddings, p=2, dim=1)
            #     # classes_norm = F.normalize(classification_classes_embeds, p=2, dim=1)
            #     sim_ground = text_embeddings @ classification_classes_embeds.T
            #     sim_pred = eeg_embeddings @ classification_classes_embeds.T

            #     # sim_ground = -torch.cdist(text_z.clone().detach(), classification_classes_embeds_z, p=2)
            #     # sim_pred = -torch.cdist(eeg_embeddings.clone().detach(), classification_classes_embeds_z, p=2)

            #     ground = sim_ground.argmax(dim=1)
            #     pred = sim_pred.argmax(dim=1)
            #     correct += (ground == pred).sum().item()
            #     total_count += ground.shape[0]
            #     actuals.append(ground)
            #     preds.append(pred)

    avg_loss = total_loss / count
    avg_acc = correct / total_count
    print(f"Test Loss: {avg_loss:.6f} | Test Acc: {avg_acc*100:.4f}%")

    actuals = torch.cat(actuals, dim=0)
    preds = torch.cat(preds, dim=0)
    cr = classification_report(actuals.cpu(), preds.cpu(), labels=[0,1,2], target_names=classification_classes)
    print(cr)

In [10]:
train(model, train_dl, val_dl, epochs=5)
# train(model, train_ds, val_ds, epochs=3)
# train(model, ds, epochs=20)

  0%|          | 0/1400 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 884.00 MiB. GPU 0 has a total capacity of 3.63 GiB of which 258.69 MiB is free. Including non-PyTorch memory, this process has 3.37 GiB memory in use. Of the allocated memory 3.20 GiB is allocated by PyTorch, and 86.57 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [11]:
train(model, train_dl, val_dl, epochs=15)

100%|██████████| 124/124 [05:52<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Valid Loss: 1.0845134126
Epoch [1/15]:- Train Loss: 1.043077 | Train Acc: 33.1119% | Valid Loss: 1.084513 | Valid Acc: 38.8802%


100%|██████████| 124/124 [05:52<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [2/15]:- Train Loss: 1.034578 | Train Acc: 34.6853% | Valid Loss: 1.087936 | Valid Acc: 36.0809%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Valid Loss: 1.0762624741
Epoch [3/15]:- Train Loss: 1.028615 | Train Acc: 33.9161% | Valid Loss: 1.076262 | Valid Acc: 36.0809%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Valid Loss: 1.0753319308
Epoch [4/15]:- Train Loss: 1.022034 | Train Acc: 35.3147% | Valid Loss: 1.075332 | Valid Acc: 36.0809%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [5/15]:- Train Loss: 1.017642 | Train Acc: 36.1538% | Valid Loss: 1.080424 | Valid Acc: 36.0809%


100%|██████████| 124/124 [05:52<00:00,  2.84s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [6/15]:- Train Loss: 1.012584 | Train Acc: 35.9091% | Valid Loss: 1.079711 | Valid Acc: 36.3919%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [7/15]:- Train Loss: 1.011112 | Train Acc: 35.2797% | Valid Loss: 1.953216 | Valid Acc: 33.9036%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [8/15]:- Train Loss: 1.007620 | Train Acc: 37.3776% | Valid Loss: 1.701506 | Valid Acc: 33.7481%


100%|██████████| 124/124 [05:51<00:00,  2.84s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [9/15]:- Train Loss: 1.004196 | Train Acc: 37.4126% | Valid Loss: 3.737535 | Valid Acc: 33.7481%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:32<00:00,  1.20s/it]


Epoch [10/15]:- Train Loss: 1.006060 | Train Acc: 38.1818% | Valid Loss: 1.136793 | Valid Acc: 34.5257%


  0%|          | 0/124 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [11]:
train(model, train_dl, val_dl, epochs=50)

100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Valid Loss: 1.2769214665
Epoch [1/50]:- Train Loss: 1.002576 | Train Acc: 38.6713% | Valid Loss: 1.276921 | Valid Acc: 35.1477%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [2/50]:- Train Loss: 0.993582 | Train Acc: 39.6154% | Valid Loss: 1.307159 | Valid Acc: 33.4370%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [3/50]:- Train Loss: 0.987603 | Train Acc: 40.3147% | Valid Loss: 1.920348 | Valid Acc: 33.2815%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Valid Loss: 1.1193668268
Epoch [4/50]:- Train Loss: 0.984271 | Train Acc: 39.6853% | Valid Loss: 1.119367 | Valid Acc: 34.5257%


100%|██████████| 124/124 [05:57<00:00,  2.88s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Valid Loss: 1.0912917014
Epoch [5/50]:- Train Loss: 0.985296 | Train Acc: 39.1958% | Valid Loss: 1.091292 | Valid Acc: 34.6812%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [6/50]:- Train Loss: 0.982832 | Train Acc: 39.9650% | Valid Loss: 2.205383 | Valid Acc: 32.6594%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [7/50]:- Train Loss: 0.982558 | Train Acc: 40.9091% | Valid Loss: 1.279554 | Valid Acc: 33.9036%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [8/50]:- Train Loss: 0.976319 | Train Acc: 42.2028% | Valid Loss: 1.315772 | Valid Acc: 34.0591%


100%|██████████| 124/124 [05:51<00:00,  2.84s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [9/50]:- Train Loss: 0.973568 | Train Acc: 40.4545% | Valid Loss: 4.887094 | Valid Acc: 32.1928%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [10/50]:- Train Loss: 0.969026 | Train Acc: 42.5874% | Valid Loss: 1.128946 | Valid Acc: 34.0591%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [11/50]:- Train Loss: 0.973182 | Train Acc: 41.7133% | Valid Loss: 1.157353 | Valid Acc: 34.8367%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [12/50]:- Train Loss: 0.968639 | Train Acc: 42.3077% | Valid Loss: 2.846274 | Valid Acc: 33.4370%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [13/50]:- Train Loss: 0.964137 | Train Acc: 42.3427% | Valid Loss: 1.275047 | Valid Acc: 33.2815%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [14/50]:- Train Loss: 0.960299 | Train Acc: 42.2727% | Valid Loss: 1.323111 | Valid Acc: 35.1477%


100%|██████████| 124/124 [05:52<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [15/50]:- Train Loss: 0.963413 | Train Acc: 42.7972% | Valid Loss: 1.231377 | Valid Acc: 34.9922%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [16/50]:- Train Loss: 0.957444 | Train Acc: 43.9510% | Valid Loss: 1.116117 | Valid Acc: 35.9253%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [17/50]:- Train Loss: 0.959664 | Train Acc: 44.7552% | Valid Loss: 6.610872 | Valid Acc: 30.3266%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [18/50]:- Train Loss: 0.953172 | Train Acc: 44.1259% | Valid Loss: 1.310657 | Valid Acc: 35.1477%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [19/50]:- Train Loss: 0.950727 | Train Acc: 46.3287% | Valid Loss: 2.098838 | Valid Acc: 35.7698%


100%|██████████| 124/124 [05:52<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.19s/it]


Epoch [20/50]:- Train Loss: 0.950057 | Train Acc: 46.2937% | Valid Loss: 1.229017 | Valid Acc: 33.1260%


100%|██████████| 124/124 [05:52<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [21/50]:- Train Loss: 0.948867 | Train Acc: 47.1329% | Valid Loss: 1.370559 | Valid Acc: 32.3484%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [22/50]:- Train Loss: 0.942751 | Train Acc: 48.2517% | Valid Loss: 1.329297 | Valid Acc: 32.1928%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [23/50]:- Train Loss: 0.943595 | Train Acc: 49.3007% | Valid Loss: 1.206346 | Valid Acc: 34.9922%


100%|██████████| 124/124 [05:52<00:00,  2.84s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [24/50]:- Train Loss: 0.937742 | Train Acc: 48.0769% | Valid Loss: 1.381094 | Valid Acc: 34.9922%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [25/50]:- Train Loss: 0.942156 | Train Acc: 47.5874% | Valid Loss: 1.532761 | Valid Acc: 37.0140%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [26/50]:- Train Loss: 0.936846 | Train Acc: 48.7063% | Valid Loss: 1.189788 | Valid Acc: 31.1042%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [27/50]:- Train Loss: 0.934703 | Train Acc: 49.5455% | Valid Loss: 1.244299 | Valid Acc: 33.2815%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [28/50]:- Train Loss: 0.934970 | Train Acc: 49.5105% | Valid Loss: 1.810498 | Valid Acc: 34.8367%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [29/50]:- Train Loss: 0.920532 | Train Acc: 51.6434% | Valid Loss: 1.994235 | Valid Acc: 32.1928%


100%|██████████| 124/124 [05:55<00:00,  2.86s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [30/50]:- Train Loss: 0.928870 | Train Acc: 50.6643% | Valid Loss: 1.290923 | Valid Acc: 31.2597%


100%|██████████| 124/124 [05:51<00:00,  2.84s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [31/50]:- Train Loss: 0.916833 | Train Acc: 50.4545% | Valid Loss: 5.974039 | Valid Acc: 33.1260%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.16s/it]


Epoch [32/50]:- Train Loss: 0.920507 | Train Acc: 52.3077% | Valid Loss: 4.416480 | Valid Acc: 34.2146%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:32<00:00,  1.21s/it]


Epoch [33/50]:- Train Loss: 0.911925 | Train Acc: 52.4126% | Valid Loss: 1.400982 | Valid Acc: 33.9036%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [34/50]:- Train Loss: 0.910965 | Train Acc: 52.5175% | Valid Loss: 1.189604 | Valid Acc: 35.7698%


100%|██████████| 124/124 [05:52<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [35/50]:- Train Loss: 0.904407 | Train Acc: 52.4476% | Valid Loss: 2.702946 | Valid Acc: 34.0591%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [36/50]:- Train Loss: 0.917740 | Train Acc: 51.8531% | Valid Loss: 1.492539 | Valid Acc: 32.3484%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [37/50]:- Train Loss: 0.910485 | Train Acc: 53.2867% | Valid Loss: 3.005740 | Valid Acc: 32.0373%


100%|██████████| 124/124 [05:54<00:00,  2.86s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [38/50]:- Train Loss: 0.907218 | Train Acc: 54.8252% | Valid Loss: 1.169268 | Valid Acc: 34.3701%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [39/50]:- Train Loss: 0.905443 | Train Acc: 54.4056% | Valid Loss: 1.130118 | Valid Acc: 35.3033%


100%|██████████| 124/124 [05:56<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [40/50]:- Train Loss: 0.894992 | Train Acc: 54.8601% | Valid Loss: 1.287345 | Valid Acc: 33.1260%


100%|██████████| 124/124 [05:52<00:00,  2.84s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [41/50]:- Train Loss: 0.892747 | Train Acc: 55.2448% | Valid Loss: 1.276046 | Valid Acc: 32.1928%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [42/50]:- Train Loss: 0.889486 | Train Acc: 56.0839% | Valid Loss: 1.233976 | Valid Acc: 32.5039%


100%|██████████| 124/124 [05:52<00:00,  2.84s/it]
100%|██████████| 27/27 [00:32<00:00,  1.20s/it]


Epoch [43/50]:- Train Loss: 0.900088 | Train Acc: 55.4895% | Valid Loss: 1.396947 | Valid Acc: 33.7481%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [44/50]:- Train Loss: 0.875112 | Train Acc: 57.7622% | Valid Loss: 1.807189 | Valid Acc: 36.3919%


100%|██████████| 124/124 [05:52<00:00,  2.84s/it]
100%|██████████| 27/27 [00:31<00:00,  1.17s/it]


Epoch [45/50]:- Train Loss: 0.884293 | Train Acc: 57.2028% | Valid Loss: 1.387857 | Valid Acc: 35.4588%


100%|██████████| 124/124 [05:52<00:00,  2.84s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [46/50]:- Train Loss: 0.880772 | Train Acc: 57.2028% | Valid Loss: 1.959220 | Valid Acc: 30.0156%


100%|██████████| 124/124 [05:56<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [47/50]:- Train Loss: 0.873986 | Train Acc: 58.0769% | Valid Loss: 1.380849 | Valid Acc: 34.0591%


100%|██████████| 124/124 [05:53<00:00,  2.85s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [48/50]:- Train Loss: 0.873151 | Train Acc: 57.7273% | Valid Loss: 1.285647 | Valid Acc: 32.9705%


100%|██████████| 124/124 [05:55<00:00,  2.87s/it]
100%|██████████| 27/27 [00:31<00:00,  1.18s/it]


Epoch [49/50]:- Train Loss: 0.868697 | Train Acc: 59.1259% | Valid Loss: 3.574982 | Valid Acc: 32.3484%


100%|██████████| 124/124 [05:54<00:00,  2.85s/it]
100%|██████████| 27/27 [00:32<00:00,  1.19s/it]


Epoch [50/50]:- Train Loss: 0.878495 | Train Acc: 56.9580% | Valid Loss: 1.240683 | Valid Acc: 35.3033%


In [14]:
model.load_state_dict(torch.load("best_model.pt"))
# model.load_state_dict(torch.load("last_model.pt"))
test(model, test_dl)
# test(model, test_ds)
# test(model, ds)

100%|██████████| 27/27 [00:32<00:00,  1.22s/it]

Test Loss: 1.640595 | Test Acc: 35.5140%
              precision    recall  f1-score   support

    Positive       0.00      0.00      0.00       192
     Neutral       0.36      0.58      0.44       224
    Negative       0.36      0.43      0.39       226

    accuracy                           0.36       642
   macro avg       0.24      0.34      0.28       642
weighted avg       0.25      0.36      0.29       642




