In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

from sklearn.metrics import classification_report

import numpy as np

from sentence_transformers import SentenceTransformer
from transformers import CLIPModel, CLIPProcessor, \
                            DistilBertModel, DistilBertTokenizerFast, \
                            GPT2Tokenizer, GPT2Model, \
                            RobertaTokenizer, RobertaModel, \
                            AutoTokenizer, AutoModelForSequenceClassification, \
                            pipeline
# import google.generativeai as genai
#genai.configure(api_key="")

from torcheeg.models import EEGNet

from tqdm import tqdm

from EEGDataset import EEGDataset, WordEEGDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Dataset

In [3]:
# ds = WordEEGDataset("shards", pad_upto=200)
# ds = EEGDataset("shards", ch_count=1076, pad_upto=6000, crp_rng=(0,1))
ds = EEGDataset("shards", ch_count=105, pad_upto=4000, crp_rng=(0,1))
ds[0][0].shape, ds[0][1], ds[0][2]

(torch.Size([105, 4000]),
 1,
 'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.')

In [4]:
# train_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(274)))
# val_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(274,332)))
# test_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(332,392)))

# train_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(40)))
# val_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(40,45)))
# test_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(45,49)))

In [5]:
train_ds, val_ds, test_ds = ds.split_train_valid_test(train_ratio=0.7, valid_ratio=0.15, shuffle=False)

batch_size = 25
train_dl = train_ds.getLoader(batch_size=batch_size, num_workers=0)
val_dl = val_ds.getLoader(batch_size=batch_size, num_workers=0)
test_dl = test_ds.getLoader(batch_size=batch_size, num_workers=0)

len(train_ds), len(val_ds), len(test_ds)

(3360, 720, 720)

# Text Model Testing

In [6]:
class TextEncoder:
    def __init__(self):
        # self.model = SentenceTransformer("all-MiniLM-L6-v2").to(device)

        # self.tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
        # self.model = DistilBertModel.from_pretrained("distilbert-base-uncased")

        # self.tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        # self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

        # self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        # self.model = GPT2Model.from_pretrained("gpt2")

        self.model = SentenceTransformer("paraphrase-mpnet-base-v2")

        # self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        # self.model = RobertaModel.from_pretrained("roberta-base")

        # model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
        # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # self.model = AutoModelForSequenceClassification.from_pretrained(model_name)

        pass

    def encode(self, txts):
        # all-MiniLM-L6-v2
        # return torch.tensor(self.model.encode(txts))
    
        # distilbert-base-uncased
        # inputs = self.tokenizer(
        #     txts,
        #     truncation=True,
        #     padding=True,
        #     max_length=256,
        #     return_tensors="pt"
        # )
        # return self.model(**inputs).last_hidden_state[:,0]

        # openai/clip-vit-base-patch32
        # inputs = self.tokenizer(
        #     text=txts,
        #     padding=True,
        #     return_tensors="pt",
        # )
        # return self.model.get_text_features(**inputs)

        # gpt2
        # inputs = self.tokenizer(
        #     txts,
        #     truncation=True,
        #     return_tensors="pt",
        # )
        # last_hidden = self.model(**inputs).last_hidden_state
        # return last_hidden[torch.arange(last_hidden.size(0)), inputs['attention_mask'].sum(1)-1]
        # # attention_mask = inputs['attention_mask'].unsqueeze(-1)
        # # text_embeds_mean = (last_hidden * attention_mask).sum(1) / attention_mask.sum(1)
        # # return text_embeds_mean / text_embeds_mean.norm(dim=-1, keepdim=True)

        # paraphrase-mpnet-base-v2
        return torch.tensor(self.model.encode(txts))

        # Google Gemini
        # res = genai.embed_content(
        #     model="models/text-embedding-004",
        #     content=txts,
        #     task_type="retrieval_document"
        # )["embedding"]
        # return torch.tensor(res)

        # Roberta
        # inputs = self.tokenizer(txts, return_tensors="pt")
        # token_embeddings = self.model(**inputs).last_hidden_state
        # mask = inputs["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float()
        # summed = torch.sum(token_embeddings * mask, dim=1)
        # counted = torch.clamp(mask.sum(dim=1), min=1e-9)
        # return summed / counted

        # Roberta 3 Class
        # inputs = self.tokenizer(
        #     txts,
        #     padding=True,
        #     truncation=True,
        #     return_tensors="pt"
        # )
        # outputs = self.model.roberta(**inputs, output_hidden_states=True)
        # last_hidden = outputs.hidden_states[-1]
        # return last_hidden[:, 0, :]

model = TextEncoder()

In [7]:
target_embeddings = model.encode(["Negative", "Neutral", "Positive"])
target_embeddings.shape, target_embeddings.dtype, type(target_embeddings)

(torch.Size([3, 768]), torch.float32, torch.Tensor)

In [8]:
lbls = []
txts = []
txt_embeds = []
for i in tqdm(range(100)):
    _, sent, txt = ds[i]
    if _ is None:
        continue
    lbls.append(int(sent)+1)
    txts.append(txt)
    txt_embeds.append(model.encode([txt]))

preds = torch.cat(txt_embeds, dim=0)

# cat_norm = target_embeddings.clone().detach()
# text_norm = preds.clone().detach()
# diffs = torch.cdist(text_norm, cat_norm, p=2)
# pred = diffs.argmin(dim=1)

cat_norm = F.normalize(target_embeddings, p=2, dim=1)
text_norm = F.normalize(preds, p=2, dim=1)
similarities = torch.matmul(text_norm, cat_norm.T)
pred = similarities.argmax(dim=1)

print(classification_report(lbls, pred))

100%|██████████| 100/100 [00:00<00:00, 129.01it/s]

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       0.36      0.33      0.35        36
           2       0.36      0.28      0.32        32
           3       0.00      0.00      0.00        30

    accuracy                           0.21        98
   macro avg       0.18      0.15      0.17        98
weighted avg       0.25      0.21      0.23        98




  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


# Dataset Testing

In [6]:
train_ds[0][0].shape, train_ds[0][1], len(train_ds[0][2]), train_ds[0][2]

(torch.Size([105, 4000]),
 1,
 117,
 'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.')

In [7]:
val_ds[0][0].shape, len(val_ds[0][2])
# val_ds[0][1][0]

(torch.Size([105, 4000]), 210)

In [None]:
# embeddings = model.encode([ds[0][1]])
embeddings = model.encode([train_ds[0][1]])
embeddings.shape, embeddings.dtype, type(embeddings)

((1, 384), dtype('float32'), numpy.ndarray)

In [8]:
ds[0][0].shape

for batch_data, batch_sent, batch_labels in train_dl:
    print(len(batch_data), batch_data.shape)
    print(len(batch_sent), batch_sent.shape)
    print(len(batch_labels))
    break

25 torch.Size([25, 105, 4000])
25 torch.Size([25])
25


In [11]:
eeg, _, _ = ds[0]
eeg.shape, eeg.mean(), eeg.std(), eeg.min(), eeg.max()

(torch.Size([105, 4000]),
 tensor(0.0094, dtype=torch.float64),
 tensor(3.6004, dtype=torch.float64),
 tensor(-58.6592, dtype=torch.float64),
 tensor(69.3399, dtype=torch.float64))

In [14]:
count = 0
for eeg, _, _ in ds:
    if not(eeg is None):
        count += 1
count

4537

# Models

In [6]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        # self.text_encoder_model = SentenceTransformer("all-MiniLM-L6-v2")
        # self.text_encoder_model = CLIPModel.from_pretrained(MODEL_NAME)
        # self.processor = CLIPProcessor.from_pretrained(MODEL_NAME)
        self.text_encoder_model = SentenceTransformer("paraphrase-mpnet-base-v2")

    def forward(self, texts):
        # inputs = self.processor(text=texts, return_tensors="pt", padding=True)
        # inputs = {k: v.to(self.text_encoder_model.device) for k, v in inputs.items()}
        # embeddings = self.text_encoder_model.get_text_features(**inputs)
        
        embeddings = self.text_encoder_model.encode(texts, convert_to_tensor=True)

        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings
    
class EEGPatchEmbedding(nn.Module):
    def __init__(self, n_channels=1076, patch_size=20, embed_dim=384):
        super().__init__()
        self.patch_size = patch_size
        dropprob = 0.3
        self.proj = nn.Sequential(
            nn.Linear(n_channels * patch_size, (n_channels * patch_size)//2),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear((n_channels * patch_size)//2,embed_dim)
        )

    def forward(self, x):
        # x: (B, C, T)
        B, C, T = x.shape
        x = x[:, :, :T - (T % self.patch_size)]  # trim
        x = x.view(B, C, -1, self.patch_size)    # (B, C, tokens, patch)
        x = x.permute(0, 2, 1, 3)                 # (B, tokens, C, patch)
        x = x.flatten(2)                          # (B, tokens, C*patch)
        return self.proj(x)                       # (B, tokens, embed_dim)

class EEGTransformer(nn.Module):
    def __init__(
        self,
        n_channels=1076,
        patch_size=20,
        embed_dim=384,
        num_layers=4,
        num_heads=8
    ):
        super().__init__()

        dropprob = 0.3

        self.patch_embed = EEGPatchEmbedding(
            n_channels, patch_size, embed_dim
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=4 * embed_dim,
            dropout=dropprob,
            batch_first=True
        )

        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        # x: (B, C, T)
        B = x.size(0)

        x = self.patch_embed(x)   # (B, tokens, embed_dim)

        cls = self.cls_token.expand(B, -1, -1) # FIXME:
        x = torch.cat([cls, x], dim=1)

        x = self.transformer(x)
        cls_out = x[:, 0]

        return cls_out


class EEGEncoder(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384):
        super(EEGEncoder, self).__init__()

        dropprob = .3

        # self.temporal = nn.Sequential(
        #     nn.Conv1d(ch_count, 1024, 64, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Conv1d(1024, 512, 32, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Conv1d(512, 256, 32, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     # nn.AdaptiveAvgPool2d((256, 1))
        #     nn.AdaptiveAvgPool1d((1))
        # )

        # self.fc = nn.Sequential(
        #     nn.Linear(256, embedding_dim//2),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Linear(embedding_dim//2, embedding_dim)
        # )

        self.eeg_encoder = EEGNet(
            chunk_size=2000,
            num_electrodes=ch_count,
            dropout=dropprob,
            kernel_1=64,
            kernel_2=16,
            F1=8,
            F2=16,
            D=2,
            num_classes=embedding_dim
        )

        # self.eeg_transformer = EEGTransformer(
        #     n_channels=ch_count,
        #     patch_size=200,
        #     embed_dim=embedding_dim,
        #     num_layers=4,
        #     num_heads=8,
        # )

    def compute_power_bands(self, x):
        fs = 500
        
        # eeg: (N, C, T)
        freqs = torch.fft.rfftfreq(x.size(-1), 1/fs).to(x.device)  # (F,)
        fft_vals = torch.fft.rfft(x, dim=-1)                         # (N, C, F)
        psd = (fft_vals.abs()**2)

        bands = [(0.5,4), (4,8), (8,12), (12,30), (30,49)]
        feats = []
        for low, high in bands:
            idx = (freqs >= low) & (freqs < high)
            band = psd[..., idx].mean(dim=-1)   # (N, C)
            feats.append(band)

        return torch.stack(feats, dim=2)        # (N, C, P)

    def forward(self, x):
        # normalizing x
        # x /= 60

        x = torch.fft.rfft(x, dim=2)
        x = torch.log(torch.abs(x) + 1e-8)
        # print(x.shape, x.mean(), x.std(), x.min(), x.max())

        # x = self.compute_power_bands(x)

        # x = self.temporal(x).squeeze(-1)
        # x = self.fc(x)

        x = x.unsqueeze(1)
        x = self.eeg_encoder(x)

        # x = self.eeg_transformer(x)

        x = F.normalize(x, p=2, dim=1)
        # x = torch.tanh(x) * 3
        return x
    
class EEGClassifier(nn.Module):
    def __init__(self, embedding_dim=384):
        super(EEGClassifier, self).__init__()

        dropprob = .3

        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear(128, 16),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear(16, 3)
        )

    def forward(self, x):
        return self.classifier(x)
    
class EEGCLIPModel(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384, freeze_text=True):
        super(EEGCLIPModel, self).__init__()
        self.text_encoder = TextEncoder()
        self.eeg_encoder = EEGEncoder(ch_count=ch_count, embedding_dim=embedding_dim)
        self.eeg_classifier = EEGClassifier(embedding_dim=embedding_dim)

        if freeze_text:
            for param in self.text_encoder.parameters():
                param.requires_grad = False

    def freeze_backbone(self, is_freeze=True):
        self.train()
        for param in self.eeg_encoder.parameters():
            param.requires_grad = not is_freeze

    def forward(self, eeg_data, texts):
        eeg_embeddings = self.eeg_encoder(eeg_data)
        text_embeddings = self.text_encoder(texts)
        # classification = self.eeg_classifier(text_embeddings)
        # classification = self.eeg_classifier(eeg_embeddings)
        classification = self.eeg_classifier(eeg_embeddings.detach())
        return eeg_embeddings, text_embeddings, classification

# Training

In [7]:
# model = EEGCLIPModel().to(device)
model = EEGCLIPModel(105, 768, freeze_text=False).to(device)
# model = EEGCLIPModel(1076, 768, freeze_text=True).to(device)
# model.load_state_dict(torch.load("last_model.pt"))

In [8]:
eeg, sent, text = ds[0]
res = model(eeg.unsqueeze(0).to(torch.float32).to("cuda"), [text])

res[0].shape, res[1].shape, res[2]

(torch.Size([1, 768]),
 torch.Size([1, 768]),
 tensor([[ 0.1026, -0.1842,  0.0357]], device='cuda:0',
        grad_fn=<AddmmBackward0>))

In [None]:
# Text Embedding Normalization Space
all_embeds = model.text_encoder([lbl[1] for lbl in ds.labels])

text_mean = all_embeds.mean(dim=0)
text_std = all_embeds.std(dim=0).clamp(min=1e-6)

text_mean, text_std

In [None]:
# Classification Classes
classification_classes = ["Positive", "Neutral", "Negative"]
classification_classes_embeds = model.text_encoder(classification_classes)
classification_classes_embeds_z = (classification_classes_embeds-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
classification_classes_embeds_z.shape

In [8]:
current_msg = ""
def log(txt: str, end="\n"):
    global current_msg
    current_msg += txt + end

    if len(current_msg) > 1:
        with open("logs.txt", "a") as f:
            f.write(current_msg)
        current_msg = ""

In [9]:
tau, alpha, beta = 1, .7, 1.3
# loss_weights = (0, 0, 1.5)
loss_weights = (.5, 1, 1.5)

def train(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, epochs: int = 10, print_every = 20, loading=False):
# def train(model: nn.Module, train_dataset: WordEEGDataset, valid_dataset: WordEEGDataset, epochs: int = 10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    best_valid_loss = None

    for epoch in tqdm(range(epochs)):
        model.train()
        total_loss = 0.0
        train_count = 0
        train_correct = 0
        train_total_count = 0

        avg_losses = {
            "cosine": 0.0,
            "ce": 0.0,
            "clip": 0.0,
            "positive_clip": 0.0,
            "negative_clip": 0.0
        }

        if loading:
            train_loading = tqdm(train_loader)
        else:
            train_loading = train_loader
        for batch in train_loading:
            if batch[0] is None:
                continue
        # for batch in tqdm(train_dataset):
            # if isinstance(train_dataset, EEGDataset):
            #     batch = (batch[0].unsqueeze(0), [batch[1]])
            # batch = ([batch[0]], [batch[1]])
            B = batch[0].shape[0]
            eeg_data = batch[0].to(torch.float32).clone().detach().to(device) # (B,C,T)
            sent_lbl = batch[1].to(torch.int64).clone().detach().to(device) # (B)
            texts = batch[2] # [""]

            optimizer.zero_grad()
            eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts) # (B,D)
            sent_logits: torch.Tensor

            cosine_loss = 2*(1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean())
            ce_loss = loss_fn(sent_logits, sent_lbl)

            sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
            dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
            positive_losses = torch.relu(dist-tau*alpha)
            negative_losses = torch.relu(tau*beta-dist)
            positive_mask = torch.tril(sent_mask == True, -1)
            negative_mask = torch.tril(sent_mask == False, -1)
            if positive_mask.sum().item() == 0:
                positive_loss = 0
            else:
                positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
            if negative_mask.sum().item() == 0:
                negative_loss = 0
            else:
                negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
            clip_loss = positive_loss + negative_loss # clip loss
            
            # loss = ce_loss*loss_weights[1]
            loss = cosine_loss*loss_weights[0] + ce_loss*loss_weights[1] + clip_loss*loss_weights[2]
            avg_losses["cosine"] += cosine_loss.item()
            avg_losses["ce"] += ce_loss.item()
            avg_losses["clip"] += clip_loss
            avg_losses["positive_clip"] += positive_loss
            avg_losses["negative_clip"] += negative_loss

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            train_count += 1

            sent_probs = sent_logits.softmax(dim=1)
            sent_preds = sent_probs.argmax(dim=1)
            train_correct += (sent_lbl == sent_preds).sum().item()
            train_total_count += B

        # avg_loss = total_loss / len(train_loader)
        avg_loss = total_loss / train_count
        avg_acc = train_correct / train_total_count
        avg_losses["cosine"] /= train_count
        avg_losses["ce"] /= train_count
        avg_losses["clip"] /= train_count
        avg_losses["positive_clip"] /= train_count
        avg_losses["negative_clip"] /= train_count


        model.eval()
        total_valid_loss = 0.0
        valid_count = 0
        valid_correct = 0
        valid_total_count = 0
        avg_valid_losses = {
            "cosine": 0.0,
            "ce": 0.0,
            "clip": 0.0,
            "positive_clip": 0.0,
            "negative_clip": 0.0
        }
        with torch.no_grad():
            if loading:
                valid_loading = tqdm(valid_loader)
            else:
                valid_loading = valid_loader
            for batch in valid_loading:
                if batch[0] is None:
                    continue
                B = batch[0].shape[0]
                eeg_data = batch[0].to(torch.float32).clone().detach().to(device)
                sent_lbl = batch[1].to(torch.int64).clone().detach().to(device)
                texts = batch[2]

                eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts)
                sent_logits: torch.Tensor

                cosine_loss = 2*(1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean())
                ce_loss = loss_fn(sent_logits, sent_lbl)

                sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
                dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
                positive_losses = torch.relu(dist-tau*alpha)
                negative_losses = torch.relu(tau*beta-dist)
                positive_mask = torch.tril(sent_mask == True, -1)
                negative_mask = torch.tril(sent_mask == False, -1)
                if positive_mask.sum().item() == 0:
                    positive_loss = 0
                else:
                    positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
                if negative_mask.sum().item() == 0:
                    negative_loss = 0
                else:
                    negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
                clip_loss = positive_loss + negative_loss # clip loss

                # loss = ce_loss*loss_weights[1]
                loss = cosine_loss*loss_weights[0] + ce_loss*loss_weights[1] + clip_loss*loss_weights[2]
                avg_valid_losses["cosine"] += cosine_loss.item()
                avg_valid_losses["ce"] += ce_loss.item()
                avg_valid_losses["clip"] += clip_loss
                avg_valid_losses["positive_clip"] += positive_loss
                avg_valid_losses["negative_clip"] += negative_loss

                total_valid_loss += loss.item()
                valid_count += 1

                sent_probs = sent_logits.softmax(dim=1)
                sent_preds = sent_probs.argmax(dim=1)
                valid_correct += (sent_lbl == sent_preds).sum().item()
                valid_total_count += B

        avg_valid_loss = total_valid_loss / valid_count
        avg_valid_acc = valid_correct / valid_total_count
        avg_valid_losses["cosine"] /= valid_count
        avg_valid_losses["ce"] /= valid_count
        avg_valid_losses["clip"] /= valid_count
        avg_valid_losses["positive_clip"] /= valid_count
        avg_valid_losses["negative_clip"] /= valid_count

        if (best_valid_loss is None) or (avg_valid_loss < best_valid_loss):
            msg = (f"Valid Loss: {avg_valid_loss:.10f}")
            log(msg)
            best_valid_loss = avg_valid_loss
            torch.save(model.state_dict(), "best_model.pt")

        msg = (f"Epoch [{epoch+1}/{epochs}]:- Train Loss: {avg_loss:.6f} | Train Acc: {avg_acc*100:.4f}% | Valid Loss: {avg_valid_loss:.6f} | Valid Acc: {avg_valid_acc*100:.4f}%")
        log(msg)
        log(f"    Train Avg Losses => Cosine: {avg_losses['cosine']:.6f}, CE: {avg_losses['ce']:.6f}, CLIP: {avg_losses['clip']:.6f}, Positive CLIP: {avg_losses['positive_clip']:.6f}, Negative CLIP: {avg_losses['negative_clip']:.6f}")
        log(f"    Valid Avg Losses => Cosine: {avg_valid_losses['cosine']:.6f}, CE: {avg_valid_losses['ce']:.6f}, CLIP: {avg_valid_losses['clip']:.6f}, Positive CLIP: {avg_valid_losses['positive_clip']:.6f}, Negative CLIP: {avg_valid_losses['negative_clip']:.6f}")
        log("")
        if (epoch % print_every == 0):
            print(msg)
        torch.save(model.state_dict(), "last_model.pt")

        torch.cuda.empty_cache()


In [10]:
classification_classes = ["Positive", "Neutral", "Negative"]

def test(model: nn.Module, test_loader: DataLoader):
# def test(model: nn.Module, test_dataset: WordEEGDataset):
    loss_fn = nn.CrossEntropyLoss()

    model.eval()
    total_loss = 0.0
    count = 0
    correct = 0
    total_count = 0

    actuals = []
    preds = []
    with torch.no_grad():
        for batch in tqdm(test_loader):
        # for batch in tqdm(test_dataset):
            # if isinstance(test_dataset, EEGDataset):
            #     batch = (batch[0].unsqueeze(0), [batch[1]])
            # batch = ([batch[0]], [batch[1]])
            B = batch[0].shape[0]
            eeg_data = batch[0].to(torch.float32).clone().detach().to(device)
            sent_lbl = batch[1].to(torch.int64).clone().detach().to(device)
            texts = batch[2]

            eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts)
            sent_logits: torch.Tensor

            cosine_loss = 2*(1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean())
            ce_loss = loss_fn(sent_logits, sent_lbl)

            sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
            dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
            positive_losses = torch.relu(dist-tau*alpha)
            negative_losses = torch.relu(tau*beta-dist)
            positive_mask = torch.tril(sent_mask == True, -1)
            negative_mask = torch.tril(sent_mask == False, -1)
            if positive_mask.sum().item() == 0:
                positive_loss = 0
            else:
                positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
            if negative_mask.sum().item() == 0:
                negative_loss = 0
            else:
                negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
            clip_loss = positive_loss + negative_loss # clip loss

            loss = cosine_loss*loss_weights[0] + ce_loss*loss_weights[1] + clip_loss*loss_weights[2]

            total_loss += loss.item()
            count += 1

            sent_probs = sent_logits.softmax(dim=1)
            sent_preds = sent_probs.argmax(dim=1)
            correct += (sent_lbl == sent_preds).sum().item()
            total_count += B

            actuals.append(sent_lbl)
            preds.append(sent_preds)
            # for i in range(len(batch[0])):
            #     eeg_data = batch[0][i].to(torch.float32).to(device)
            #     texts = batch[1][i]

            #     eeg_embeddings, text_embeddings = model(eeg_data, texts)

            #     # text_z = (text_embeddings-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
            #     # loss = loss_fn(eeg_embeddings, text_z)
            #     loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()

            #     total_loss += loss.item()
            #     count += 1

            #     # Accuracy Check
            #     # text_norm = F.normalize(text_embeddings, p=2, dim=1)
            #     # eeg_norm = F.normalize(eeg_embeddings, p=2, dim=1)
            #     # classes_norm = F.normalize(classification_classes_embeds, p=2, dim=1)
            #     sim_ground = text_embeddings @ classification_classes_embeds.T
            #     sim_pred = eeg_embeddings @ classification_classes_embeds.T

            #     # sim_ground = -torch.cdist(text_z.clone().detach(), classification_classes_embeds_z, p=2)
            #     # sim_pred = -torch.cdist(eeg_embeddings.clone().detach(), classification_classes_embeds_z, p=2)

            #     ground = sim_ground.argmax(dim=1)
            #     pred = sim_pred.argmax(dim=1)
            #     correct += (ground == pred).sum().item()
            #     total_count += ground.shape[0]
            #     actuals.append(ground)
            #     preds.append(pred)

    avg_loss = total_loss / count
    avg_acc = correct / total_count
    print(f"Test Loss: {avg_loss:.6f} | Test Acc: {avg_acc*100:.4f}%")

    actuals = torch.cat(actuals, dim=0)
    preds = torch.cat(preds, dim=0)
    cr = classification_report(actuals.cpu(), preds.cpu(), labels=[0,1,2], target_names=classification_classes)
    print(cr)

In [11]:
# model.freeze_backbone(is_freeze=False)
train(model, train_dl, val_dl, epochs=5, print_every=1, loading=True)
# train(model, train_ds, val_ds, epochs=3)
# train(model, ds, epochs=20)

100%|██████████| 135/135 [00:16<00:00,  8.18it/s]
100%|██████████| 29/29 [00:02<00:00, 11.05it/s]


Epoch [1/5]:- Train Loss: 3.194104 | Train Acc: 35.8076% | Valid Loss: 3.399286 | Valid Acc: 40.2299%


100%|██████████| 135/135 [00:16<00:00,  8.13it/s]
100%|██████████| 29/29 [00:02<00:00, 11.19it/s]


Epoch [2/5]:- Train Loss: 3.245136 | Train Acc: 34.7563% | Valid Loss: 3.336663 | Valid Acc: 40.2299%


100%|██████████| 135/135 [00:16<00:00,  8.04it/s]
100%|██████████| 29/29 [00:02<00:00, 10.53it/s]


Epoch [3/5]:- Train Loss: 3.235440 | Train Acc: 34.8200% | Valid Loss: 3.434569 | Valid Acc: 40.2299%


100%|██████████| 135/135 [00:17<00:00,  7.92it/s]
100%|██████████| 29/29 [00:02<00:00,  9.97it/s]


Epoch [4/5]:- Train Loss: 3.202153 | Train Acc: 35.0430% | Valid Loss: 3.444614 | Valid Acc: 40.2299%


100%|██████████| 135/135 [00:16<00:00,  7.94it/s]
100%|██████████| 29/29 [00:02<00:00, 10.37it/s]


Epoch [5/5]:- Train Loss: 3.116497 | Train Acc: 34.3740% | Valid Loss: 3.364746 | Valid Acc: 40.2299%


100%|██████████| 5/5 [01:41<00:00, 20.30s/it]


In [12]:
# model.freeze_backbone(is_freeze=False)
train(model, train_dl, val_dl, epochs=20)

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch [1/20]:- Train Loss: 3.000705 | Train Acc: 36.5084% | Valid Loss: 3.417452 | Valid Acc: 40.2299%


100%|██████████| 20/20 [06:50<00:00, 20.55s/it]


In [None]:
# model.freeze_backbone(is_freeze=False)
train(model, train_dl, val_dl, epochs=100)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/100]:- Train Loss: 2.414544 | Train Acc: 48.9567% | Valid Loss: 2.914893 | Valid Acc: 36.9253%


  1%|          | 1/100 [00:21<36:07, 21.90s/it]

In [14]:
# model.load_state_dict(torch.load("best_model.pt"))
model.load_state_dict(torch.load("last_model.pt"))
test(model, test_dl)
# test(model, test_ds)
# test(model, ds)

100%|██████████| 29/29 [00:02<00:00, 10.17it/s]

Test Loss: 2.899513 | Test Acc: 38.3191%
              precision    recall  f1-score   support

    Positive       0.28      0.15      0.19       191
     Neutral       0.40      0.79      0.53       271
    Negative       0.38      0.11      0.17       240

    accuracy                           0.38       702
   macro avg       0.35      0.35      0.30       702
weighted avg       0.36      0.38      0.32       702




