In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

from sklearn.metrics import classification_report

import numpy as np

from sentence_transformers import SentenceTransformer
from transformers import CLIPModel, CLIPProcessor, \
                            DistilBertModel, DistilBertTokenizerFast, \
                            GPT2Tokenizer, GPT2Model, \
                            RobertaTokenizer, RobertaModel, \
                            AutoTokenizer, AutoModelForSequenceClassification, \
                            pipeline
# import google.generativeai as genai
#genai.configure(api_key="")

from torcheeg.models import EEGNet

from tqdm import tqdm

from EEGDataset import EEGDataset, WordEEGDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Dataset

In [3]:
# ds = WordEEGDataset("shards", pad_upto=200)
# ds = EEGDataset("shards", ch_count=1076, pad_upto=6000, crp_rng=(0,1))
ds = EEGDataset("shards", ch_count=105, pad_upto=4000, crp_rng=(0,1))
ds[0][0].shape, ds[0][1], ds[0][2]

(torch.Size([105, 4000]),
 1,
 'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.')

In [4]:
# train_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(274)))
# val_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(274,332)))
# test_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(332,392)))

# train_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(40)))
# val_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(40,45)))
# test_ds = WordEEGDataset("shards", pad_upto=200, selective_indexing=list(range(45,49)))

In [None]:
train_ds, val_ds, test_ds = ds.split_train_valid_test(train_ratio=0.7, valid_ratio=0.15, shuffle=False)

batch_size = 25
train_dl = train_ds.getLoader(batch_size=batch_size, num_workers=0)
val_dl = val_ds.getLoader(batch_size=batch_size, num_workers=0)
test_dl = test_ds.getLoader(batch_size=batch_size, num_workers=0)

len(train_ds), len(val_ds), len(test_ds)

(3360, 720, 720)

# Text Model Testing

In [6]:
class TextEncoder:
    def __init__(self):
        # self.model = SentenceTransformer("all-MiniLM-L6-v2").to(device)

        # self.tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
        # self.model = DistilBertModel.from_pretrained("distilbert-base-uncased")

        # self.tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        # self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

        # self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        # self.model = GPT2Model.from_pretrained("gpt2")

        self.model = SentenceTransformer("paraphrase-mpnet-base-v2")

        # self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        # self.model = RobertaModel.from_pretrained("roberta-base")

        # model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
        # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # self.model = AutoModelForSequenceClassification.from_pretrained(model_name)

        pass

    def encode(self, txts):
        # all-MiniLM-L6-v2
        # return torch.tensor(self.model.encode(txts))
    
        # distilbert-base-uncased
        # inputs = self.tokenizer(
        #     txts,
        #     truncation=True,
        #     padding=True,
        #     max_length=256,
        #     return_tensors="pt"
        # )
        # return self.model(**inputs).last_hidden_state[:,0]

        # openai/clip-vit-base-patch32
        # inputs = self.tokenizer(
        #     text=txts,
        #     padding=True,
        #     return_tensors="pt",
        # )
        # return self.model.get_text_features(**inputs)

        # gpt2
        # inputs = self.tokenizer(
        #     txts,
        #     truncation=True,
        #     return_tensors="pt",
        # )
        # last_hidden = self.model(**inputs).last_hidden_state
        # return last_hidden[torch.arange(last_hidden.size(0)), inputs['attention_mask'].sum(1)-1]
        # # attention_mask = inputs['attention_mask'].unsqueeze(-1)
        # # text_embeds_mean = (last_hidden * attention_mask).sum(1) / attention_mask.sum(1)
        # # return text_embeds_mean / text_embeds_mean.norm(dim=-1, keepdim=True)

        # paraphrase-mpnet-base-v2
        return torch.tensor(self.model.encode(txts))

        # Google Gemini
        # res = genai.embed_content(
        #     model="models/text-embedding-004",
        #     content=txts,
        #     task_type="retrieval_document"
        # )["embedding"]
        # return torch.tensor(res)

        # Roberta
        # inputs = self.tokenizer(txts, return_tensors="pt")
        # token_embeddings = self.model(**inputs).last_hidden_state
        # mask = inputs["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float()
        # summed = torch.sum(token_embeddings * mask, dim=1)
        # counted = torch.clamp(mask.sum(dim=1), min=1e-9)
        # return summed / counted

        # Roberta 3 Class
        # inputs = self.tokenizer(
        #     txts,
        #     padding=True,
        #     truncation=True,
        #     return_tensors="pt"
        # )
        # outputs = self.model.roberta(**inputs, output_hidden_states=True)
        # last_hidden = outputs.hidden_states[-1]
        # return last_hidden[:, 0, :]

model = TextEncoder()

In [7]:
target_embeddings = model.encode(["Negative", "Neutral", "Positive"])
target_embeddings.shape, target_embeddings.dtype, type(target_embeddings)

(torch.Size([3, 768]), torch.float32, torch.Tensor)

In [None]:
lbls = []
txts = []
txt_embeds = []
for i in tqdm(range(100)):
    _, sent, txt = ds[i]
    if _ is None:
        continue
    lbls.append(int(sent)+1)
    txts.append(txt)
    txt_embeds.append(model.encode([txt]))

preds = torch.cat(txt_embeds, dim=0)

# cat_norm = target_embeddings.clone().detach()
# text_norm = preds.clone().detach()
# diffs = torch.cdist(text_norm, cat_norm, p=2)
# pred = diffs.argmin(dim=1)

cat_norm = F.normalize(target_embeddings, p=2, dim=1)
text_norm = F.normalize(preds, p=2, dim=1)
similarities = torch.matmul(text_norm, cat_norm.T)
pred = similarities.argmax(dim=1)

print(classification_report(lbls, pred))

 66%|██████▌   | 66/100 [00:01<00:00, 38.12it/s]

# Dataset Testing

In [None]:
train_ds[0][0].shape, train_ds[0][1], len(train_ds[0][2]), train_ds[0][2]

(torch.Size([105, 4000]),
 1,
 117,
 'Presents a good case while failing to provide a reason for us to care beyond the very basic dictums of human decency.')

In [None]:
val_ds[0][0].shape, len(val_ds[0][2])
# val_ds[0][1][0]

(torch.Size([105, 4000]), 210)

In [None]:
# embeddings = model.encode([ds[0][1]])
embeddings = model.encode([train_ds[0][1]])
embeddings.shape, embeddings.dtype, type(embeddings)

((1, 384), dtype('float32'), numpy.ndarray)

In [None]:
ds[0][0].shape

for batch_data, batch_sent, batch_labels in train_dl:
    print(len(batch_data), batch_data.shape)
    print(len(batch_sent), batch_sent.shape)
    print(len(batch_labels))
    break

5 torch.Size([5, 105, 4000])
5 torch.Size([5])
5


# Models

In [6]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        # self.text_encoder_model = SentenceTransformer("all-MiniLM-L6-v2")
        # self.text_encoder_model = CLIPModel.from_pretrained(MODEL_NAME)
        # self.processor = CLIPProcessor.from_pretrained(MODEL_NAME)
        self.text_encoder_model = SentenceTransformer("paraphrase-mpnet-base-v2")

    def forward(self, texts):
        # inputs = self.processor(text=texts, return_tensors="pt", padding=True)
        # inputs = {k: v.to(self.text_encoder_model.device) for k, v in inputs.items()}
        # embeddings = self.text_encoder_model.get_text_features(**inputs)
        
        embeddings = self.text_encoder_model.encode(texts, convert_to_tensor=True)

        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings
    
class EEGPatchEmbedding(nn.Module):
    def __init__(self, n_channels=1076, patch_size=20, embed_dim=384):
        super().__init__()
        self.patch_size = patch_size
        dropprob = 0.3
        self.proj = nn.Sequential(
            nn.Linear(n_channels * patch_size, (n_channels * patch_size)//2),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear((n_channels * patch_size)//2,embed_dim)
        )

    def forward(self, x):
        # x: (B, C, T)
        B, C, T = x.shape
        x = x[:, :, :T - (T % self.patch_size)]  # trim
        x = x.view(B, C, -1, self.patch_size)    # (B, C, tokens, patch)
        x = x.permute(0, 2, 1, 3)                 # (B, tokens, C, patch)
        x = x.flatten(2)                          # (B, tokens, C*patch)
        return self.proj(x)                       # (B, tokens, embed_dim)

class EEGTransformer(nn.Module):
    def __init__(
        self,
        n_channels=1076,
        patch_size=20,
        embed_dim=384,
        num_layers=4,
        num_heads=8
    ):
        super().__init__()

        dropprob = 0.3

        self.patch_embed = EEGPatchEmbedding(
            n_channels, patch_size, embed_dim
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=4 * embed_dim,
            dropout=dropprob,
            batch_first=True
        )

        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        # x: (B, C, T)
        B = x.size(0)

        x = self.patch_embed(x)   # (B, tokens, embed_dim)

        cls = self.cls_token.expand(B, -1, -1) # FIXME:
        x = torch.cat([cls, x], dim=1)

        x = self.transformer(x)
        cls_out = x[:, 0]

        return cls_out


class EEGEncoder(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384):
        super(EEGEncoder, self).__init__()

        dropprob = .3

        # self.temporal = nn.Sequential(
        #     nn.Conv1d(ch_count, 1024, 64, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Conv1d(1024, 512, 32, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Conv1d(512, 256, 32, padding=1),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     # nn.AdaptiveAvgPool2d((256, 1))
        #     nn.AdaptiveAvgPool1d((1))
        # )

        # self.fc = nn.Sequential(
        #     nn.Linear(256, embedding_dim//2),
        #     nn.LeakyReLU(),
        #     nn.Dropout(dropprob),
        #     nn.Linear(embedding_dim//2, embedding_dim)
        # )

        # self.eeg_encoder = EEGNet(
        #     chunk_size=3000,
        #     num_electrodes=ch_count,
        #     dropout=dropprob,
        #     kernel_1=64,
        #     kernel_2=16,
        #     F1=8,
        #     F2=16,
        #     D=2,
        #     num_classes=embedding_dim
        # )

        self.eeg_transformer = EEGTransformer(
            n_channels=ch_count,
            patch_size=200,
            embed_dim=embedding_dim,
            num_layers=4,
            num_heads=8,
        )

    def compute_power_bands(self, x):
        fs = 500
        
        # eeg: (N, C, T)
        freqs = torch.fft.rfftfreq(x.size(-1), 1/fs).to(x.device)  # (F,)
        fft_vals = torch.fft.rfft(x, dim=-1)                         # (N, C, F)
        psd = (fft_vals.abs()**2)

        bands = [(0.5,4), (4,8), (8,12), (12,30), (30,49)]
        feats = []
        for low, high in bands:
            idx = (freqs >= low) & (freqs < high)
            band = psd[..., idx].mean(dim=-1)   # (N, C)
            feats.append(band)

        return torch.stack(feats, dim=2)        # (N, C, P)

    def forward(self, x):
        x = torch.fft.rfft(x, dim=2)
        x = torch.log(torch.abs(x) + 1e-8)

        # x = self.compute_power_bands(x)

        # x = self.temporal(x).squeeze(-1)
        # x = self.fc(x)

        # x = x.unsqueeze(1)
        # x = self.eeg_encoder(x)

        x = self.eeg_transformer(x)

        x = F.normalize(x, p=2, dim=1)
        # x = torch.tanh(x) * 3
        return x
    
class EEGClassifier(nn.Module):
    def __init__(self, embedding_dim=384):
        super(EEGClassifier, self).__init__()

        dropprob = .3

        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear(128, 16),
            nn.LeakyReLU(),
            nn.Dropout(dropprob),
            nn.Linear(16, 3)
        )

    def forward(self, x):
        return self.classifier(x)
    
class EEGCLIPModel(nn.Module):
    def __init__(self, ch_count=8196, embedding_dim=384, freeze_text=True):
        super(EEGCLIPModel, self).__init__()
        self.text_encoder = TextEncoder()
        self.eeg_encoder = EEGEncoder(ch_count=ch_count, embedding_dim=embedding_dim)
        self.eeg_classifier = EEGClassifier(embedding_dim=embedding_dim)

        if freeze_text:
            for param in self.text_encoder.parameters():
                param.requires_grad = False

    def forward(self, eeg_data, texts):
        eeg_embeddings = self.eeg_encoder(eeg_data)
        text_embeddings = self.text_encoder(texts)
        classification = self.eeg_classifier(eeg_embeddings)
        return eeg_embeddings, text_embeddings, classification

# Training

In [7]:
# model = EEGCLIPModel().to(device)
model = EEGCLIPModel(105, 768, freeze_text=False).to(device)
# model = EEGCLIPModel(1076, 768, freeze_text=True).to(device)
# model.load_state_dict(torch.load("last_model.pt"))

In [None]:
eeg, sent, text = ds[0]
res = model(eeg.unsqueeze(0).to(torch.float32).to("cuda"), [text])

res[0].shape, res[1].shape, res[2]

(torch.Size([1, 768]),
 torch.Size([1, 768]),
 tensor([[-0.1902, -0.0113,  0.0710]], device='cuda:0',
        grad_fn=<AddmmBackward0>))

In [None]:
# Text Embedding Normalization Space
all_embeds = model.text_encoder([lbl[1] for lbl in ds.labels])

text_mean = all_embeds.mean(dim=0)
text_std = all_embeds.std(dim=0).clamp(min=1e-6)

text_mean, text_std

In [None]:
# Classification Classes
classification_classes = ["Positive", "Neutral", "Negative"]
classification_classes_embeds = model.text_encoder(classification_classes)
classification_classes_embeds_z = (classification_classes_embeds-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
classification_classes_embeds_z.shape

In [8]:
current_msg = ""
def log(txt: str, end="\n"):
    current_msg += txt + end

    if len(current_msg) > 1000:
        with open("logs.txt", "a") as f:
            f.write(current_msg)
        current_msg = ""

In [9]:
tau, alpha, beta = 1, .7, 1.3
loss_weights = (.5, 1, 1.5)

def train(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, epochs: int = 10, print_every = 20):
# def train(model: nn.Module, train_dataset: WordEEGDataset, valid_dataset: WordEEGDataset, epochs: int = 10):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    best_valid_loss = None

    for epoch in tqdm(range(epochs)):
        model.train()
        total_loss = 0.0
        train_count = 0
        train_correct = 0
        train_total_count = 0
        for batch in train_loader:
        # for batch in tqdm(train_loader):
            if batch[0] is None:
                continue
        # for batch in tqdm(train_dataset):
            # if isinstance(train_dataset, EEGDataset):
            #     batch = (batch[0].unsqueeze(0), [batch[1]])
            # batch = ([batch[0]], [batch[1]])
            B = batch[0].shape[0]
            eeg_data = batch[0].to(torch.float32).clone().detach().to(device) # (B,C,T)
            sent_lbl = batch[1].to(torch.int64).clone().detach().to(device) # (B)
            texts = batch[2] # [""]

            optimizer.zero_grad()
            eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts) # (B,D)
            sent_logits: torch.Tensor

            cosine_loss = 2*(1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean())
            ce_loss = loss_fn(sent_logits, sent_lbl)

            sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
            dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
            positive_losses = torch.relu(dist-tau*alpha)
            negative_losses = torch.relu(tau*beta-dist)
            positive_mask = torch.tril(sent_mask == True, -1)
            negative_mask = torch.tril(sent_mask == False, -1)
            if positive_mask.sum().item() == 0:
                positive_loss = 0
            else:
                positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
            if negative_mask.sum().item() == 0:
                negative_loss = 0
            else:
                negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
            clip_loss = positive_loss + negative_loss # clip loss
            
            loss = cosine_loss*loss_weights[0] + ce_loss*loss_weights[1] + clip_loss*loss_weights[2]

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            train_count += 1

            sent_probs = sent_logits.softmax(dim=1)
            sent_preds = sent_probs.argmax(dim=1)
            train_correct += (sent_lbl == sent_preds).sum().item()
            train_total_count += B

            # for i in range(len(batch[1])):
            #     eeg_data = batch[0][i].to(torch.float32).to(device)
            #     texts = batch[1][i]

            #     optimizer.zero_grad()
            #     eeg_embeddings, text_embeddings = model(eeg_data, texts)

            #     # text_z = (text_embeddings-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
            #     # loss = loss_fn(eeg_embeddings, text_z)
            #     loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()
            #     loss.backward()
            #     optimizer.step()

            #     total_loss += loss.item()
            #     train_count += 1

            #     # Accuracy Check
            #     # text_norm = F.normalize(text_embeddings, p=2, dim=1)
            #     # eeg_norm = F.normalize(eeg_embeddings, p=2, dim=1)
            #     # classes_norm = F.normalize(classification_classes_embeds, p=2, dim=1)
            #     sim_ground = text_embeddings @ classification_classes_embeds.T
            #     sim_pred = eeg_embeddings @ classification_classes_embeds.T

            #     # sim_ground = -torch.cdist(text_z.clone().detach(), classification_classes_embeds_z, p=2)
            #     # sim_pred = -torch.cdist(eeg_embeddings.clone().detach(), classification_classes_embeds_z, p=2)

            #     ground = sim_ground.argmax(dim=1)
            #     pred = sim_pred.argmax(dim=1)
            #     train_correct += (ground == pred).sum().item()
            #     train_total_count += ground.shape[0]


        # avg_loss = total_loss / len(train_loader)
        avg_loss = total_loss / train_count
        avg_acc = train_correct / train_total_count


        model.eval()
        total_valid_loss = 0.0
        valid_count = 0
        valid_correct = 0
        valid_total_count = 0
        with torch.inference_mode():
            for batch in valid_loader:
            # for batch in tqdm(valid_loader):
                if batch[0] is None:
                    continue
            # for batch in tqdm(valid_dataset):
                # if isinstance(valid_dataset, EEGDataset):
                #     batch = (batch[0].unsqueeze(0), [batch[1]])
                # batch = ([batch[0]], [batch[1]])
                B = batch[0].shape[0]
                eeg_data = batch[0].to(torch.float32).clone().detach().to(device)
                sent_lbl = batch[1].to(torch.int64).clone().detach().to(device)
                texts = batch[2]

                eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts)
                sent_logits: torch.Tensor

                cosine_loss = 2*(1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean())
                ce_loss = loss_fn(sent_logits, sent_lbl)

                sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
                dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
                positive_losses = torch.relu(dist-tau*alpha)
                negative_losses = torch.relu(tau*beta-dist)
                positive_mask = torch.tril(sent_mask == True, -1)
                negative_mask = torch.tril(sent_mask == False, -1)
                if positive_mask.sum().item() == 0:
                    positive_loss = 0
                else:
                    positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
                if negative_mask.sum().item() == 0:
                    negative_loss = 0
                else:
                    negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
                clip_loss = positive_loss + negative_loss # clip loss

                loss = cosine_loss*loss_weights[0] + ce_loss*loss_weights[1] + clip_loss*loss_weights[2]

                total_valid_loss += loss.item()
                valid_count += 1

                sent_probs = sent_logits.softmax(dim=1)
                sent_preds = sent_probs.argmax(dim=1)
                valid_correct += (sent_lbl == sent_preds).sum().item()
                valid_total_count += B

                # for i in range(len(batch[0])):
                #     eeg_data = batch[0][i].to(torch.float32).to(device)
                #     texts = batch[1][i]

                #     eeg_embeddings, text_embeddings = model(eeg_data, texts)

                #     # text_z = (text_embeddings-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
                #     # loss = loss_fn(eeg_embeddings, text_z)
                #     loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()

                #     total_valid_loss += loss.item()
                #     valid_count += 1

                #     # Accuracy Check
                #     # text_norm = F.normalize(text_embeddings, p=2, dim=1)
                #     # eeg_norm = F.normalize(eeg_embeddings, p=2, dim=1)
                #     # classes_norm = F.normalize(classification_classes_embeds, p=2, dim=1)
                #     sim_ground = text_embeddings @ classification_classes_embeds.T
                #     sim_pred = eeg_embeddings @ classification_classes_embeds.T

                #     # sim_ground = -torch.cdist(text_z.clone().detach(), classification_classes_embeds_z, p=2)
                #     # sim_pred = -torch.cdist(eeg_embeddings.clone().detach(), classification_classes_embeds_z, p=2)

                #     ground = sim_ground.argmax(dim=1)
                #     pred = sim_pred.argmax(dim=1)
                #     valid_correct += (ground == pred).sum().item()
                #     valid_total_count += ground.shape[0]
                    
        # avg_valid_loss = total_valid_loss / len(valid_loader)
        avg_valid_loss = total_valid_loss / valid_count
        avg_valid_acc = valid_correct / valid_total_count

        if (best_valid_loss is None) or (avg_valid_loss < best_valid_loss):
            msg = (f"Valid Loss: {avg_valid_loss:.10f}")
            log(msg)
            best_valid_loss = avg_valid_loss
            torch.save(model.state_dict(), "best_model.pt")

        msg = (f"Epoch [{epoch+1}/{epochs}]:- Train Loss: {avg_loss:.6f} | Train Acc: {avg_acc*100:.4f}% | Valid Loss: {avg_valid_loss:.6f} | Valid Acc: {avg_valid_acc*100:.4f}%")
        log(msg)
        if (epoch % print_every == 0):
            print(msg)
        torch.save(model.state_dict(), "last_model.pt")

        torch.cuda.empty_cache()


In [10]:
classification_classes = ["Positive", "Neutral", "Negative"]

def test(model: nn.Module, test_loader: DataLoader):
# def test(model: nn.Module, test_dataset: WordEEGDataset):
    loss_fn = nn.CrossEntropyLoss()

    model.eval()
    total_loss = 0.0
    count = 0
    correct = 0
    total_count = 0

    actuals = []
    preds = []
    with torch.inference_mode():
        for batch in tqdm(test_loader):
        # for batch in tqdm(test_dataset):
            # if isinstance(test_dataset, EEGDataset):
            #     batch = (batch[0].unsqueeze(0), [batch[1]])
            # batch = ([batch[0]], [batch[1]])
            B = batch[0].shape[0]
            eeg_data = batch[0].to(torch.float32).clone().detach().to(device)
            sent_lbl = batch[1].to(torch.int64).clone().detach().to(device)
            texts = batch[2]

            eeg_embeddings, text_embeddings, sent_logits = model(eeg_data, texts)
            sent_logits: torch.Tensor

            cosine_loss = 2*(1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean())
            ce_loss = loss_fn(sent_logits, sent_lbl)

            sent_mask = sent_lbl.view(-1,1) == sent_lbl.view(1,-1)
            dist = torch.cdist(eeg_embeddings, eeg_embeddings, p=2)
            positive_losses = torch.relu(dist-tau*alpha)
            negative_losses = torch.relu(tau*beta-dist)
            positive_mask = torch.tril(sent_mask == True, -1)
            negative_mask = torch.tril(sent_mask == False, -1)
            if positive_mask.sum().item() == 0:
                positive_loss = 0
            else:
                positive_loss = (positive_losses[positive_mask]).sum().item() / (positive_mask.sum().item())
            if negative_mask.sum().item() == 0:
                negative_loss = 0
            else:
                negative_loss = (negative_losses[negative_mask]).sum().item() / (negative_mask.sum().item())
            clip_loss = positive_loss + negative_loss # clip loss

            loss = cosine_loss*loss_weights[0] + ce_loss*loss_weights[1] + clip_loss*loss_weights[2]

            total_loss += loss.item()
            count += 1

            sent_probs = sent_logits.softmax(dim=1)
            sent_preds = sent_probs.argmax(dim=1)
            correct += (sent_lbl == sent_preds).sum().item()
            total_count += B

            actuals.append(sent_lbl)
            preds.append(sent_preds)
            # for i in range(len(batch[0])):
            #     eeg_data = batch[0][i].to(torch.float32).to(device)
            #     texts = batch[1][i]

            #     eeg_embeddings, text_embeddings = model(eeg_data, texts)

            #     # text_z = (text_embeddings-text_mean.unsqueeze(0))/text_std.unsqueeze(0)
            #     # loss = loss_fn(eeg_embeddings, text_z)
            #     loss = 1.0 - F.cosine_similarity(eeg_embeddings, text_embeddings).mean()

            #     total_loss += loss.item()
            #     count += 1

            #     # Accuracy Check
            #     # text_norm = F.normalize(text_embeddings, p=2, dim=1)
            #     # eeg_norm = F.normalize(eeg_embeddings, p=2, dim=1)
            #     # classes_norm = F.normalize(classification_classes_embeds, p=2, dim=1)
            #     sim_ground = text_embeddings @ classification_classes_embeds.T
            #     sim_pred = eeg_embeddings @ classification_classes_embeds.T

            #     # sim_ground = -torch.cdist(text_z.clone().detach(), classification_classes_embeds_z, p=2)
            #     # sim_pred = -torch.cdist(eeg_embeddings.clone().detach(), classification_classes_embeds_z, p=2)

            #     ground = sim_ground.argmax(dim=1)
            #     pred = sim_pred.argmax(dim=1)
            #     correct += (ground == pred).sum().item()
            #     total_count += ground.shape[0]
            #     actuals.append(ground)
            #     preds.append(pred)

    avg_loss = total_loss / count
    avg_acc = correct / total_count
    print(f"Test Loss: {avg_loss:.6f} | Test Acc: {avg_acc*100:.4f}%")

    actuals = torch.cat(actuals, dim=0)
    preds = torch.cat(preds, dim=0)
    cr = classification_report(actuals.cpu(), preds.cpu(), labels=[0,1,2], target_names=classification_classes)
    print(cr)

In [11]:
train(model, train_dl, val_dl, epochs=20)
# train(model, train_ds, val_ds, epochs=3)
# train(model, ds, epochs=20)

  0%|          | 0/20 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 842.00 MiB. GPU 0 has a total capacity of 3.63 GiB of which 122.25 MiB is free. Including non-PyTorch memory, this process has 3.34 GiB memory in use. Of the allocated memory 3.17 GiB is allocated by PyTorch, and 83.86 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [17]:
train(model, train_dl, val_dl, epochs=100)

100%|██████████| 135/135 [00:16<00:00,  8.03it/s]
100%|██████████| 29/29 [00:02<00:00, 10.09it/s]


Valid Loss: 3.2680826845
Epoch [1/100]:- Train Loss: 2.249491 | Train Acc: 59.3501% | Valid Loss: 3.268083 | Valid Acc: 37.6437%


100%|██████████| 135/135 [00:17<00:00,  7.70it/s]
100%|██████████| 29/29 [00:02<00:00, 10.05it/s]


Valid Loss: 3.2081084087
Epoch [2/100]:- Train Loss: 2.228060 | Train Acc: 61.1341% | Valid Loss: 3.208108 | Valid Acc: 34.4828%


100%|██████████| 135/135 [00:17<00:00,  7.69it/s]
100%|██████████| 29/29 [00:02<00:00,  9.74it/s]


Valid Loss: 3.1043478045
Epoch [3/100]:- Train Loss: 2.203040 | Train Acc: 62.8226% | Valid Loss: 3.104348 | Valid Acc: 38.5057%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00,  9.83it/s]


Valid Loss: 3.0620493396
Epoch [4/100]:- Train Loss: 2.233826 | Train Acc: 60.5925% | Valid Loss: 3.062049 | Valid Acc: 35.7759%


100%|██████████| 135/135 [00:17<00:00,  7.73it/s]
100%|██████████| 29/29 [00:02<00:00, 10.14it/s]


Epoch [5/100]:- Train Loss: 2.178991 | Train Acc: 62.3447% | Valid Loss: 3.191448 | Valid Acc: 38.0747%


100%|██████████| 135/135 [00:17<00:00,  7.72it/s]
100%|██████████| 29/29 [00:02<00:00, 10.01it/s]


Epoch [6/100]:- Train Loss: 2.197638 | Train Acc: 62.5358% | Valid Loss: 3.196040 | Valid Acc: 38.7931%


100%|██████████| 135/135 [00:17<00:00,  7.89it/s]
100%|██████████| 29/29 [00:02<00:00, 10.24it/s]


Epoch [7/100]:- Train Loss: 2.170930 | Train Acc: 63.1411% | Valid Loss: 3.311077 | Valid Acc: 35.7759%


100%|██████████| 135/135 [00:17<00:00,  7.70it/s]
100%|██████████| 29/29 [00:02<00:00, 10.12it/s]


Epoch [8/100]:- Train Loss: 2.141625 | Train Acc: 64.5747% | Valid Loss: 3.390451 | Valid Acc: 35.3448%


100%|██████████| 135/135 [00:17<00:00,  7.76it/s]
100%|██████████| 29/29 [00:02<00:00, 10.14it/s]


Epoch [9/100]:- Train Loss: 2.176647 | Train Acc: 62.6314% | Valid Loss: 3.643703 | Valid Acc: 37.2126%


100%|██████████| 135/135 [00:17<00:00,  7.77it/s]
100%|██████████| 29/29 [00:03<00:00,  9.61it/s]


Epoch [10/100]:- Train Loss: 2.151043 | Train Acc: 63.7783% | Valid Loss: 3.409590 | Valid Acc: 34.3391%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00,  9.98it/s]


Epoch [11/100]:- Train Loss: 2.145577 | Train Acc: 63.9694% | Valid Loss: 3.163512 | Valid Acc: 35.9195%


100%|██████████| 135/135 [00:17<00:00,  7.73it/s]
100%|██████████| 29/29 [00:02<00:00,  9.88it/s]


Epoch [12/100]:- Train Loss: 2.158830 | Train Acc: 64.5110% | Valid Loss: 3.442047 | Valid Acc: 36.6379%


100%|██████████| 135/135 [00:17<00:00,  7.76it/s]
100%|██████████| 29/29 [00:02<00:00,  9.99it/s]


Epoch [13/100]:- Train Loss: 2.174996 | Train Acc: 63.4916% | Valid Loss: 3.369437 | Valid Acc: 35.9195%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:02<00:00,  9.93it/s]


Epoch [14/100]:- Train Loss: 2.157163 | Train Acc: 63.8420% | Valid Loss: 3.392525 | Valid Acc: 36.2069%


100%|██████████| 135/135 [00:17<00:00,  7.77it/s]
100%|██████████| 29/29 [00:02<00:00, 10.08it/s]


Epoch [15/100]:- Train Loss: 2.136640 | Train Acc: 64.2561% | Valid Loss: 3.264354 | Valid Acc: 35.6322%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:02<00:00, 10.04it/s]


Epoch [16/100]:- Train Loss: 2.143131 | Train Acc: 65.0526% | Valid Loss: 3.109866 | Valid Acc: 35.7759%


100%|██████████| 135/135 [00:17<00:00,  7.72it/s]
100%|██████████| 29/29 [00:02<00:00,  9.91it/s]


Epoch [17/100]:- Train Loss: 2.123413 | Train Acc: 64.2561% | Valid Loss: 3.374861 | Valid Acc: 35.7759%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:03<00:00,  9.66it/s]


Epoch [18/100]:- Train Loss: 2.134039 | Train Acc: 65.5941% | Valid Loss: 3.320631 | Valid Acc: 36.6379%


100%|██████████| 135/135 [00:17<00:00,  7.66it/s]
100%|██████████| 29/29 [00:02<00:00,  9.89it/s]


Epoch [19/100]:- Train Loss: 2.099697 | Train Acc: 66.0083% | Valid Loss: 3.179503 | Valid Acc: 37.2126%


100%|██████████| 135/135 [00:17<00:00,  7.64it/s]
100%|██████████| 29/29 [00:02<00:00,  9.67it/s]


Epoch [20/100]:- Train Loss: 2.109516 | Train Acc: 66.7729% | Valid Loss: 3.169449 | Valid Acc: 35.0575%


100%|██████████| 135/135 [00:17<00:00,  7.69it/s]
100%|██████████| 29/29 [00:02<00:00, 10.00it/s]


Epoch [21/100]:- Train Loss: 2.062069 | Train Acc: 68.4932% | Valid Loss: 3.704993 | Valid Acc: 33.6207%


100%|██████████| 135/135 [00:17<00:00,  7.78it/s]
100%|██████████| 29/29 [00:02<00:00, 10.24it/s]


Epoch [22/100]:- Train Loss: 2.068074 | Train Acc: 67.6967% | Valid Loss: 3.221306 | Valid Acc: 38.6494%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:02<00:00, 10.04it/s]


Epoch [23/100]:- Train Loss: 2.092018 | Train Acc: 68.2064% | Valid Loss: 3.584303 | Valid Acc: 37.2126%


100%|██████████| 135/135 [00:17<00:00,  7.79it/s]
100%|██████████| 29/29 [00:03<00:00,  9.63it/s]


Epoch [24/100]:- Train Loss: 2.083319 | Train Acc: 68.1109% | Valid Loss: 3.475581 | Valid Acc: 37.6437%


100%|██████████| 135/135 [00:17<00:00,  7.72it/s]
100%|██████████| 29/29 [00:03<00:00,  9.66it/s]


Epoch [25/100]:- Train Loss: 2.081147 | Train Acc: 67.3781% | Valid Loss: 3.517997 | Valid Acc: 38.3621%


100%|██████████| 135/135 [00:17<00:00,  7.66it/s]
100%|██████████| 29/29 [00:02<00:00,  9.73it/s]


Epoch [26/100]:- Train Loss: 2.077596 | Train Acc: 66.8047% | Valid Loss: 3.322942 | Valid Acc: 39.3678%


100%|██████████| 135/135 [00:17<00:00,  7.82it/s]
100%|██████████| 29/29 [00:02<00:00, 10.02it/s]


Epoch [27/100]:- Train Loss: 2.046326 | Train Acc: 68.8754% | Valid Loss: 3.360196 | Valid Acc: 38.0747%


100%|██████████| 135/135 [00:17<00:00,  7.76it/s]
100%|██████████| 29/29 [00:02<00:00,  9.93it/s]


Epoch [28/100]:- Train Loss: 2.065541 | Train Acc: 69.6400% | Valid Loss: 3.520713 | Valid Acc: 39.6552%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:03<00:00,  9.59it/s]


Epoch [29/100]:- Train Loss: 2.059582 | Train Acc: 68.7480% | Valid Loss: 3.385758 | Valid Acc: 37.6437%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:02<00:00,  9.90it/s]


Epoch [30/100]:- Train Loss: 2.025690 | Train Acc: 70.6913% | Valid Loss: 3.668192 | Valid Acc: 38.2184%


100%|██████████| 135/135 [00:17<00:00,  7.73it/s]
100%|██████████| 29/29 [00:02<00:00, 10.11it/s]


Epoch [31/100]:- Train Loss: 2.043567 | Train Acc: 70.0223% | Valid Loss: 3.419722 | Valid Acc: 38.6494%


100%|██████████| 135/135 [00:17<00:00,  7.83it/s]
100%|██████████| 29/29 [00:03<00:00,  9.65it/s]


Epoch [32/100]:- Train Loss: 2.009348 | Train Acc: 70.4683% | Valid Loss: 3.218030 | Valid Acc: 37.6437%


100%|██████████| 135/135 [00:17<00:00,  7.69it/s]
100%|██████████| 29/29 [00:02<00:00, 10.03it/s]


Epoch [33/100]:- Train Loss: 2.011725 | Train Acc: 70.9780% | Valid Loss: 3.358367 | Valid Acc: 35.4885%


100%|██████████| 135/135 [00:17<00:00,  7.68it/s]
100%|██████████| 29/29 [00:02<00:00, 10.04it/s]


Epoch [34/100]:- Train Loss: 1.996947 | Train Acc: 72.6983% | Valid Loss: 3.607074 | Valid Acc: 34.3391%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00, 10.07it/s]


Epoch [35/100]:- Train Loss: 2.006346 | Train Acc: 72.1249% | Valid Loss: 3.527680 | Valid Acc: 36.9253%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00,  9.70it/s]


Epoch [36/100]:- Train Loss: 1.998595 | Train Acc: 71.3922% | Valid Loss: 3.539766 | Valid Acc: 34.6264%


100%|██████████| 135/135 [00:17<00:00,  7.68it/s]
100%|██████████| 29/29 [00:02<00:00, 10.25it/s]


Epoch [37/100]:- Train Loss: 1.998107 | Train Acc: 71.6470% | Valid Loss: 3.295211 | Valid Acc: 37.3563%


100%|██████████| 135/135 [00:17<00:00,  7.87it/s]
100%|██████████| 29/29 [00:02<00:00, 10.04it/s]


Epoch [38/100]:- Train Loss: 1.999004 | Train Acc: 72.2523% | Valid Loss: 3.565092 | Valid Acc: 37.5000%


100%|██████████| 135/135 [00:17<00:00,  7.75it/s]
100%|██████████| 29/29 [00:02<00:00,  9.81it/s]


Epoch [39/100]:- Train Loss: 1.986191 | Train Acc: 73.2399% | Valid Loss: 3.314063 | Valid Acc: 37.6437%


100%|██████████| 135/135 [00:17<00:00,  7.69it/s]
100%|██████████| 29/29 [00:02<00:00, 10.23it/s]


Epoch [40/100]:- Train Loss: 1.944652 | Train Acc: 73.9726% | Valid Loss: 4.255691 | Valid Acc: 35.9195%


100%|██████████| 135/135 [00:17<00:00,  7.74it/s]
100%|██████████| 29/29 [00:03<00:00,  9.59it/s]


Epoch [41/100]:- Train Loss: 1.979765 | Train Acc: 74.4505% | Valid Loss: 3.876803 | Valid Acc: 34.0517%


100%|██████████| 135/135 [00:17<00:00,  7.72it/s]
100%|██████████| 29/29 [00:02<00:00, 10.09it/s]


Epoch [42/100]:- Train Loss: 1.976251 | Train Acc: 73.9407% | Valid Loss: 3.414447 | Valid Acc: 34.4828%


100%|██████████| 135/135 [00:17<00:00,  7.77it/s]
100%|██████████| 29/29 [00:02<00:00,  9.74it/s]


Epoch [43/100]:- Train Loss: 1.946893 | Train Acc: 74.5460% | Valid Loss: 3.513112 | Valid Acc: 35.6322%


100%|██████████| 135/135 [00:17<00:00,  7.65it/s]
100%|██████████| 29/29 [00:02<00:00,  9.95it/s]


Epoch [44/100]:- Train Loss: 1.968735 | Train Acc: 73.4629% | Valid Loss: 3.252578 | Valid Acc: 32.6149%


100%|██████████| 135/135 [00:17<00:00,  7.72it/s]
100%|██████████| 29/29 [00:02<00:00,  9.82it/s]


Epoch [45/100]:- Train Loss: 1.966409 | Train Acc: 73.8133% | Valid Loss: 3.197411 | Valid Acc: 36.3506%


100%|██████████| 135/135 [00:17<00:00,  7.86it/s]
100%|██████████| 29/29 [00:02<00:00, 10.63it/s]


Epoch [46/100]:- Train Loss: 1.950946 | Train Acc: 75.9478% | Valid Loss: 3.752260 | Valid Acc: 36.6379%


100%|██████████| 135/135 [00:16<00:00,  8.29it/s]
100%|██████████| 29/29 [00:02<00:00, 10.95it/s]


Epoch [47/100]:- Train Loss: 1.938039 | Train Acc: 75.7566% | Valid Loss: 3.416536 | Valid Acc: 36.7816%


100%|██████████| 135/135 [00:16<00:00,  8.26it/s]
100%|██████████| 29/29 [00:02<00:00, 10.97it/s]


Epoch [48/100]:- Train Loss: 1.935113 | Train Acc: 76.2982% | Valid Loss: 4.479050 | Valid Acc: 33.6207%


100%|██████████| 135/135 [00:16<00:00,  8.32it/s]
100%|██████████| 29/29 [00:02<00:00, 11.14it/s]


Epoch [49/100]:- Train Loss: 1.961225 | Train Acc: 75.6610% | Valid Loss: 3.882872 | Valid Acc: 35.3448%


100%|██████████| 135/135 [00:16<00:00,  8.00it/s]
100%|██████████| 29/29 [00:02<00:00, 10.22it/s]


Epoch [50/100]:- Train Loss: 1.932819 | Train Acc: 76.5530% | Valid Loss: 3.405189 | Valid Acc: 34.4828%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00,  9.69it/s]


Epoch [51/100]:- Train Loss: 1.912125 | Train Acc: 77.4450% | Valid Loss: 3.528910 | Valid Acc: 37.7874%


100%|██████████| 135/135 [00:17<00:00,  7.68it/s]
100%|██████████| 29/29 [00:02<00:00, 10.14it/s]


Epoch [52/100]:- Train Loss: 1.915843 | Train Acc: 76.3938% | Valid Loss: 3.560339 | Valid Acc: 33.1897%


100%|██████████| 135/135 [00:17<00:00,  7.77it/s]
100%|██████████| 29/29 [00:02<00:00,  9.90it/s]


Epoch [53/100]:- Train Loss: 1.898035 | Train Acc: 76.5530% | Valid Loss: 3.400937 | Valid Acc: 38.2184%


100%|██████████| 135/135 [00:17<00:00,  7.62it/s]
100%|██████████| 29/29 [00:02<00:00, 10.13it/s]


Epoch [54/100]:- Train Loss: 1.909557 | Train Acc: 76.9990% | Valid Loss: 3.393067 | Valid Acc: 35.7759%


100%|██████████| 135/135 [00:17<00:00,  7.68it/s]
100%|██████████| 29/29 [00:02<00:00,  9.81it/s]


Epoch [55/100]:- Train Loss: 1.890379 | Train Acc: 77.3813% | Valid Loss: 3.314253 | Valid Acc: 36.4943%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00,  9.97it/s]


Epoch [56/100]:- Train Loss: 1.891142 | Train Acc: 78.3689% | Valid Loss: 3.440974 | Valid Acc: 35.2011%


100%|██████████| 135/135 [00:17<00:00,  7.61it/s]
100%|██████████| 29/29 [00:02<00:00,  9.77it/s]


Epoch [57/100]:- Train Loss: 1.880001 | Train Acc: 78.4008% | Valid Loss: 3.750726 | Valid Acc: 36.7816%


100%|██████████| 135/135 [00:17<00:00,  7.70it/s]
100%|██████████| 29/29 [00:02<00:00,  9.87it/s]


Epoch [58/100]:- Train Loss: 1.895034 | Train Acc: 78.6875% | Valid Loss: 3.726994 | Valid Acc: 36.7816%


100%|██████████| 135/135 [00:17<00:00,  7.69it/s]
100%|██████████| 29/29 [00:02<00:00,  9.86it/s]


Epoch [59/100]:- Train Loss: 1.913369 | Train Acc: 78.3052% | Valid Loss: 3.577477 | Valid Acc: 36.7816%


100%|██████████| 135/135 [00:17<00:00,  7.72it/s]
100%|██████████| 29/29 [00:02<00:00,  9.90it/s]


Epoch [60/100]:- Train Loss: 1.858537 | Train Acc: 79.9618% | Valid Loss: 3.632579 | Valid Acc: 35.6322%


100%|██████████| 135/135 [00:17<00:00,  7.70it/s]
100%|██████████| 29/29 [00:02<00:00, 10.00it/s]


Epoch [61/100]:- Train Loss: 1.884877 | Train Acc: 79.1653% | Valid Loss: 4.065130 | Valid Acc: 34.9138%


100%|██████████| 135/135 [00:17<00:00,  7.78it/s]
100%|██████████| 29/29 [00:02<00:00,  9.75it/s]


Epoch [62/100]:- Train Loss: 1.863423 | Train Acc: 79.5795% | Valid Loss: 3.848467 | Valid Acc: 36.2069%


100%|██████████| 135/135 [00:17<00:00,  7.64it/s]
100%|██████████| 29/29 [00:02<00:00, 10.29it/s]


Epoch [63/100]:- Train Loss: 1.881631 | Train Acc: 78.9105% | Valid Loss: 3.611138 | Valid Acc: 38.0747%


100%|██████████| 135/135 [00:17<00:00,  7.81it/s]
100%|██████████| 29/29 [00:02<00:00, 10.08it/s]


Epoch [64/100]:- Train Loss: 1.898721 | Train Acc: 78.7193% | Valid Loss: 3.871513 | Valid Acc: 37.0690%


100%|██████████| 135/135 [00:17<00:00,  7.82it/s]
100%|██████████| 29/29 [00:02<00:00, 10.10it/s]


Epoch [65/100]:- Train Loss: 1.856159 | Train Acc: 80.0892% | Valid Loss: 4.137388 | Valid Acc: 35.6322%


100%|██████████| 135/135 [00:17<00:00,  7.74it/s]
100%|██████████| 29/29 [00:02<00:00,  9.99it/s]


Epoch [66/100]:- Train Loss: 1.855988 | Train Acc: 80.9493% | Valid Loss: 4.289869 | Valid Acc: 33.4770%


100%|██████████| 135/135 [00:16<00:00,  8.02it/s]
100%|██████████| 29/29 [00:02<00:00, 10.55it/s]


Epoch [67/100]:- Train Loss: 1.886570 | Train Acc: 79.8662% | Valid Loss: 4.016235 | Valid Acc: 36.7816%


100%|██████████| 135/135 [00:17<00:00,  7.85it/s]
100%|██████████| 29/29 [00:02<00:00,  9.97it/s]


Epoch [68/100]:- Train Loss: 1.891394 | Train Acc: 78.9742% | Valid Loss: 3.440410 | Valid Acc: 34.6264%


100%|██████████| 135/135 [00:17<00:00,  7.65it/s]
100%|██████████| 29/29 [00:02<00:00,  9.78it/s]


Epoch [69/100]:- Train Loss: 1.858377 | Train Acc: 80.8856% | Valid Loss: 3.422334 | Valid Acc: 35.4885%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:02<00:00,  9.88it/s]


Epoch [70/100]:- Train Loss: 1.856368 | Train Acc: 80.3441% | Valid Loss: 3.813835 | Valid Acc: 34.9138%


100%|██████████| 135/135 [00:17<00:00,  7.73it/s]
100%|██████████| 29/29 [00:02<00:00, 10.27it/s]


Epoch [71/100]:- Train Loss: 1.851678 | Train Acc: 80.4078% | Valid Loss: 3.739020 | Valid Acc: 33.9080%


100%|██████████| 135/135 [00:17<00:00,  7.80it/s]
100%|██████████| 29/29 [00:02<00:00,  9.92it/s]


Epoch [72/100]:- Train Loss: 1.861102 | Train Acc: 80.1211% | Valid Loss: 3.871568 | Valid Acc: 37.5000%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00,  9.83it/s]


Epoch [73/100]:- Train Loss: 1.821911 | Train Acc: 81.6821% | Valid Loss: 3.437742 | Valid Acc: 36.2069%


100%|██████████| 135/135 [00:17<00:00,  7.78it/s]
100%|██████████| 29/29 [00:02<00:00, 10.09it/s]


Epoch [74/100]:- Train Loss: 1.844900 | Train Acc: 81.3635% | Valid Loss: 3.613374 | Valid Acc: 35.7759%


100%|██████████| 135/135 [00:17<00:00,  7.66it/s]
100%|██████████| 29/29 [00:02<00:00,  9.92it/s]


Epoch [75/100]:- Train Loss: 1.843161 | Train Acc: 81.2361% | Valid Loss: 3.853928 | Valid Acc: 34.6264%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:02<00:00,  9.85it/s]


Epoch [76/100]:- Train Loss: 1.798637 | Train Acc: 82.0962% | Valid Loss: 4.378736 | Valid Acc: 37.0690%


100%|██████████| 135/135 [00:16<00:00,  8.19it/s]
100%|██████████| 29/29 [00:02<00:00, 11.11it/s]


Epoch [77/100]:- Train Loss: 1.814853 | Train Acc: 83.2749% | Valid Loss: 3.599463 | Valid Acc: 34.7701%


100%|██████████| 135/135 [00:15<00:00,  8.46it/s]
100%|██████████| 29/29 [00:02<00:00, 11.18it/s]


Epoch [78/100]:- Train Loss: 1.859672 | Train Acc: 82.0644% | Valid Loss: 3.719543 | Valid Acc: 36.0632%


100%|██████████| 135/135 [00:16<00:00,  8.29it/s]
100%|██████████| 29/29 [00:02<00:00, 10.99it/s]


Epoch [79/100]:- Train Loss: 1.847012 | Train Acc: 80.2485% | Valid Loss: 4.130186 | Valid Acc: 35.0575%


100%|██████████| 135/135 [00:16<00:00,  8.12it/s]
100%|██████████| 29/29 [00:02<00:00, 10.42it/s]


Epoch [80/100]:- Train Loss: 1.828756 | Train Acc: 81.3316% | Valid Loss: 3.867987 | Valid Acc: 35.4885%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:02<00:00,  9.98it/s]


Epoch [81/100]:- Train Loss: 1.781751 | Train Acc: 83.4979% | Valid Loss: 3.721936 | Valid Acc: 34.3391%


100%|██████████| 135/135 [00:17<00:00,  7.67it/s]
100%|██████████| 29/29 [00:02<00:00,  9.89it/s]


Epoch [82/100]:- Train Loss: 1.794812 | Train Acc: 83.2749% | Valid Loss: 4.105492 | Valid Acc: 32.7586%


100%|██████████| 135/135 [00:17<00:00,  7.73it/s]
100%|██████████| 29/29 [00:02<00:00, 10.07it/s]


Epoch [83/100]:- Train Loss: 1.814728 | Train Acc: 82.3829% | Valid Loss: 4.022973 | Valid Acc: 35.3448%


100%|██████████| 135/135 [00:17<00:00,  7.84it/s]
100%|██████████| 29/29 [00:02<00:00,  9.92it/s]


Epoch [84/100]:- Train Loss: 1.801545 | Train Acc: 82.7971% | Valid Loss: 3.738206 | Valid Acc: 33.6207%


100%|██████████| 135/135 [00:17<00:00,  7.66it/s]
100%|██████████| 29/29 [00:03<00:00,  9.66it/s]


Epoch [85/100]:- Train Loss: 1.778513 | Train Acc: 83.2749% | Valid Loss: 3.616576 | Valid Acc: 35.4885%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00, 10.01it/s]


Epoch [86/100]:- Train Loss: 1.803291 | Train Acc: 83.4661% | Valid Loss: 3.853360 | Valid Acc: 34.9138%


100%|██████████| 135/135 [00:17<00:00,  7.75it/s]
100%|██████████| 29/29 [00:02<00:00, 10.58it/s]


Epoch [87/100]:- Train Loss: 1.783886 | Train Acc: 83.4979% | Valid Loss: 3.687244 | Valid Acc: 33.9080%


100%|██████████| 135/135 [00:17<00:00,  7.83it/s]
100%|██████████| 29/29 [00:02<00:00,  9.98it/s]


Epoch [88/100]:- Train Loss: 1.807667 | Train Acc: 82.4466% | Valid Loss: 4.165178 | Valid Acc: 33.9080%


100%|██████████| 135/135 [00:17<00:00,  7.66it/s]
100%|██████████| 29/29 [00:02<00:00, 10.19it/s]


Epoch [89/100]:- Train Loss: 1.811709 | Train Acc: 82.8289% | Valid Loss: 3.476168 | Valid Acc: 33.0460%


100%|██████████| 135/135 [00:17<00:00,  7.66it/s]
100%|██████████| 29/29 [00:02<00:00,  9.93it/s]


Epoch [90/100]:- Train Loss: 1.799220 | Train Acc: 83.9758% | Valid Loss: 4.243147 | Valid Acc: 34.7701%


100%|██████████| 135/135 [00:17<00:00,  7.69it/s]
100%|██████████| 29/29 [00:02<00:00,  9.90it/s]


Epoch [91/100]:- Train Loss: 1.778654 | Train Acc: 84.0395% | Valid Loss: 3.934400 | Valid Acc: 35.6322%


100%|██████████| 135/135 [00:17<00:00,  7.64it/s]
100%|██████████| 29/29 [00:02<00:00, 10.08it/s]


Epoch [92/100]:- Train Loss: 1.808064 | Train Acc: 83.3705% | Valid Loss: 4.106045 | Valid Acc: 33.7644%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00,  9.95it/s]


Epoch [93/100]:- Train Loss: 1.802012 | Train Acc: 83.6572% | Valid Loss: 4.274442 | Valid Acc: 33.6207%


100%|██████████| 135/135 [00:17<00:00,  7.90it/s]
100%|██████████| 29/29 [00:02<00:00,  9.84it/s]


Epoch [94/100]:- Train Loss: 1.801616 | Train Acc: 84.3262% | Valid Loss: 4.814084 | Valid Acc: 34.0517%


100%|██████████| 135/135 [00:17<00:00,  7.57it/s]
100%|██████████| 29/29 [00:02<00:00,  9.72it/s]


Epoch [95/100]:- Train Loss: 1.779139 | Train Acc: 85.0908% | Valid Loss: 3.848107 | Valid Acc: 35.0575%


100%|██████████| 135/135 [00:17<00:00,  7.84it/s]
100%|██████████| 29/29 [00:02<00:00, 10.15it/s]


Epoch [96/100]:- Train Loss: 1.795931 | Train Acc: 84.1032% | Valid Loss: 3.684386 | Valid Acc: 33.1897%


100%|██████████| 135/135 [00:17<00:00,  7.71it/s]
100%|██████████| 29/29 [00:02<00:00,  9.98it/s]


Epoch [97/100]:- Train Loss: 1.756725 | Train Acc: 85.3138% | Valid Loss: 4.304977 | Valid Acc: 36.7816%


100%|██████████| 135/135 [00:17<00:00,  7.66it/s]
100%|██████████| 29/29 [00:02<00:00,  9.75it/s]


Epoch [98/100]:- Train Loss: 1.812097 | Train Acc: 82.9882% | Valid Loss: 3.865751 | Valid Acc: 36.0632%


100%|██████████| 135/135 [00:17<00:00,  7.62it/s]
100%|██████████| 29/29 [00:03<00:00,  9.60it/s]


Epoch [99/100]:- Train Loss: 1.768256 | Train Acc: 84.8996% | Valid Loss: 4.064985 | Valid Acc: 32.7586%


100%|██████████| 135/135 [00:17<00:00,  7.61it/s]
100%|██████████| 29/29 [00:02<00:00,  9.73it/s]


Epoch [100/100]:- Train Loss: 1.824573 | Train Acc: 83.4342% | Valid Loss: 4.328241 | Valid Acc: 34.4828%


In [19]:
model.load_state_dict(torch.load("best_model.pt"))
# model.load_state_dict(torch.load("last_model.pt"))
test(model, test_dl)
# test(model, test_ds)
# test(model, ds)

100%|██████████| 29/29 [00:03<00:00,  9.35it/s]

Test Loss: 2.942256 | Test Acc: 37.0370%
              precision    recall  f1-score   support

    Positive       0.29      0.30      0.30       191
     Neutral       0.42      0.50      0.45       271
    Negative       0.37      0.28      0.32       240

    accuracy                           0.37       702
   macro avg       0.36      0.36      0.36       702
weighted avg       0.37      0.37      0.37       702




