In [1]:
import torch

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
from encoder_paths import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import numpy
# import random
# import torch

# numpy.random.seed(69)
# random.seed(69)
# torch.manual_seed(69)

In [3]:
import os.path

BASE_DIR = "/tmp/akshett.jindal"
HUGGINGFACE_CACHE_DIR = os.path.join(BASE_DIR, ".huggingface_cache")

In [4]:
import os.path

# DATA_DIR = "/tmp/semeval24_task3"

# TRAIN_DATA_FILEPATH = os.path.join(DATA_DIR, "final_clean_data", "train", "Subtask_2.json")
# VAL_DATA_FILEPATH = os.path.join(DATA_DIR, "final_clean_data", "val", "Subtask_2.json")

TRAIN_DATA_FILEPATH = "/tmp/semeval24_task3/SemEval-2024_Task3/official_data/Training_data/text/training.json"
VAL_DATA_FILEPATH = "/tmp/semeval24_task3/SemEval-2024_Task3/official_data/Training_data/text/testing.json"

In [5]:
import torch
import pickle

class YourAudioEncoder():
    def __init__(self, audio_embeddings_path):
        with open(audio_embeddings_path, "rb") as f:
            self.audio_embeddings = pickle.load(f)

    def lmao(self, audio_name):
        audio_name = audio_name.split(".")[0]
        audio_embedding = self.audio_embeddings[audio_name]
        audio_embedding = audio_embedding.squeeze()
        return torch.from_numpy(audio_embedding)
    
class YourVideoEncoder():
    def __init__(self, video_embeddings_path):
        with open(video_embeddings_path, "rb") as f:
            self.video_embeddings = pickle.load(f)

    def lmao(self, video_name):
        # video_name = video_name.split(".")[0]
        video_embedding = self.video_embeddings[video_name].reshape((16,-1))
        video_embedding = np.mean(video_embedding, axis=0)
        return torch.from_numpy(video_embedding)

class YourTextEncoder():
    def __init__(self, text_embeddings_path):
        with open(text_embeddings_path, "rb") as f:
            self.text_embeddings = pickle.load(f)

    def lmao(self, video_name):
        text_embedding = self.text_embeddings[video_name]
        return torch.from_numpy(text_embedding)


In [6]:
import json
import random
from torch.utils.data import Dataset

class EmotionCauseDataset(Dataset):
    def __init__(self, file_path, audio_encoder, video_encoder, text_encoder, neg_to_pos_ratio):
        with open(file_path) as f:
            self.file_data = json.load(f)

        self.audio_encoder = audio_encoder
        self.video_encoder = video_encoder
        self.text_encoder = text_encoder
        self.data = []
        self.POSITIVE_SAMPLE_COUNT = 0
        self.NEGATIVE_SAMPLE_COUNT = 0

        for conversation in self.file_data:
            positive_samples = []
            negative_samples = []

            utterances = {
                utterance["utterance_ID"]: utterance
                for utterance in conversation["conversation"]
            }
            utterance_ids = set(utterances.keys())

            causes = {
                utterance_id: []
                for utterance_id in utterance_ids
            }

            for emo_cause_pairs in conversation["emotion-cause_pairs"]:
                emotion_utterance_num, emotion = emo_cause_pairs[0].split("_")
                emotion_utterance_num = int(emotion_utterance_num)
                cause_utterance_num = int(emo_cause_pairs[1])

                if emotion != "neutral":
                    causes[emotion_utterance_num].append(cause_utterance_num)

            for emo_utterance_id in utterance_ids:
                emo_utterance = utterances[emo_utterance_id]

                if utterances[emo_utterance_id]["emotion"] == "neutral":
                    continue

                for cause_utterance_id in utterance_ids:
                    cause_utterance = utterances[cause_utterance_id]

                    is_cause = cause_utterance_id in causes[emo_utterance_id]

                    data_point = {
                        "original_utterance": {
                            "id": emo_utterance_id,
                            "text": emo_utterance["text"],
                            "video_name": emo_utterance["video_name"],
                        },
                        "cause_utterance": {
                            "id": cause_utterance_id,
                            "text": cause_utterance["text"],
                            "video_name": cause_utterance["video_name"],
                        },
                        "is_cause": is_cause,
                    }

                    if is_cause:
                        positive_samples.append(data_point)
                        self.POSITIVE_SAMPLE_COUNT += 1
                    else:
                        negative_samples.append(data_point)
                        self.NEGATIVE_SAMPLE_COUNT += 1

            random.shuffle(negative_samples)

            self.data.extend(positive_samples)
            if neg_to_pos_ratio is not None:
                self.data.extend(negative_samples[:min(neg_to_pos_ratio * len(positive_samples), len(negative_samples))])
            else:
                self.data.extend(negative_samples)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]

        orig_utt = data["original_utterance"]
        cause_utt = data["cause_utterance"]

        orig_id = orig_utt["id"]
        orig_text = orig_utt["text"]
        orig_video = orig_utt["video_name"]
        orig_audio = orig_utt["video_name"].replace(".mp4", ".wav")

        cause_id = cause_utt["id"]
        cause_text = cause_utt["text"]
        cause_video = cause_utt["video_name"]
        cause_audio = cause_utt["video_name"].replace(".mp4", ".wav")

        is_cause = data["is_cause"]

        return {
            "distance": abs(orig_id - cause_id),
            "original_audio": self.audio_encoder.lmao(orig_audio).float(),
            "original_video": self.video_encoder.lmao(orig_video).float(),
            "original_text": self.text_encoder.lmao(orig_video).squeeze().float(),
            "cause_audio": self.audio_encoder.lmao(cause_audio).float(),
            "cause_video": self.video_encoder.lmao(cause_video).float(),
            "cause_text": self.text_encoder.lmao(cause_video).squeeze().float(),
            "is_cause": 1.0 if is_cause else 0.0,
        }

In [7]:
import numpy as np
import os.path
from torch.utils.data import DataLoader

# AUDIO_EMBEDDINGS_FILEPATH = "/tmp/semeval24_task3/og_paper_embeddings/audio_embedding_6373.npy"
# VIDEO_EMBEDDINGS_FILEPATH = "/tmp/semeval24_task3/og_paper_embeddings/video_embedding_4096.npy"
# TEXT_EMBEDDINGS_FILEPATH = os.path.join(DATA_DIR, "text_embeddings", "text_embeddings_bert_base.pkl")

audio_encoder = YourAudioEncoder(AUDIO_EMBEDDINGS_FILEPATH)
video_encoder = YourVideoEncoder(VIDEO_EMBEDDINGS_FILEPATH)
text_encoder = YourTextEncoder(TEXT_EMBEDDINGS_FILEPATH)

trn_dataset = EmotionCauseDataset(
    TRAIN_DATA_FILEPATH,
    audio_encoder,
    video_encoder,
    text_encoder,
    neg_to_pos_ratio=5
)
trn_dataloader = DataLoader(trn_dataset, batch_size=64, shuffle=True)

val_dataset = EmotionCauseDataset(
    VAL_DATA_FILEPATH,
    audio_encoder,
    video_encoder,
    text_encoder,
    neg_to_pos_ratio=None,
)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)

len(trn_dataset), len(val_dataset)

(49835, 11047)

In [8]:
import numpy as np

def generate_positional_embeddings(dimension, count):
    embeddings = [list(np.zeros(dimension))]
    embeddings.extend([
        list(np.random.normal(loc=0.0, scale=0.1, size=dimension)) for _ in range(count)
    ])
    return np.array(embeddings)

In [9]:
import torch.nn as nn

class EmotionCauseDetector(nn.Module):
    def __init__(
        self,
        utterance_embedding_size,
        device,
        hidden_dimension=4096,
        positional_embeddings_dimension=200,
        dropout=0.2,
        *args, **kwargs,
    ):
        super().__init__()

        self.hidden_dimension = hidden_dimension

        positional_embeddings = generate_positional_embeddings(positional_embeddings_dimension, 200)
        self.positional_embeddings = torch.from_numpy(positional_embeddings).to(device).float()
        
        self.non_neutral_dropout = nn.Dropout(dropout)
        self.candidate_cause_dropout = nn.Dropout(dropout)
        self.distance_dropout = nn.Dropout(dropout)

        self.linear1 = nn.Linear(utterance_embedding_size*2 + positional_embeddings_dimension, hidden_dimension)
        self.linear1_activation = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dimension, 1)

    def forward(self, non_neutral_utterances, candidate_cause_utterances, distances):
        positional_embedding = self.positional_embeddings[distances]
        
        non_neutral_utterances = self.non_neutral_dropout(non_neutral_utterances)
        candidate_cause_utterances = self.candidate_cause_dropout(candidate_cause_utterances)
        positional_embedding = self.distance_dropout(positional_embedding)

        embeddings = torch.concat((non_neutral_utterances, candidate_cause_utterances, positional_embedding), axis=1)

        return self.linear2(
            self.linear1_activation(
                self.linear1(embeddings)
            )
        )

In [10]:
import torch
import numpy

def save_model(epoch_num):
    torch.save(model.state_dict(), f"/tmp/semeval24_task3/baseline_models/pairing_models/paring_model_{epoch_num:02}.pt")
    numpy.save(
        f"/tmp/semeval24_task3/baseline_models/pairing_models/pairing_model_pos_embeds_{epoch_num:02}.npy",
        model.positional_embeddings.cpu().numpy(),
    )

In [11]:
from tqdm.auto import tqdm
from sklearn.metrics import classification_report
from transformers import get_linear_schedule_with_warmup
NUM_EPOCHS = 40
AUDIO_EMBEDDING_SIZE = 1024
VIDEO_EMBEDDING_SIZE = 768
TEXT_EMBEDDING_SIZE = 1024
TOTAL_EMBEDDING_SIZE = (
    AUDIO_EMBEDDING_SIZE + VIDEO_EMBEDDING_SIZE + TEXT_EMBEDDING_SIZE
)
# TOTAL_EMBEDDING_SIZE = 11237

model = EmotionCauseDetector(
    TOTAL_EMBEDDING_SIZE,
    device,
    hidden_dimension=2000,
)
_ = model.to(device)

weight_ratio = trn_dataset.NEGATIVE_SAMPLE_COUNT / trn_dataset.POSITIVE_SAMPLE_COUNT

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(weight_ratio).to(device))
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)

total_steps = len(trn_dataloader) * NUM_EPOCHS

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

best_model_file = None
best_val_loss = float('inf')
best_epoch = -1
best_classification_report = None

print(AUDIO_EMBEDDINGS_FILEPATH)
print(VIDEO_EMBEDDINGS_FILEPATH)
print(TEXT_EMBEDDINGS_FILEPATH)
epoch_iter = tqdm(range(NUM_EPOCHS), desc="Epoch", position=0)
for epoch in epoch_iter:
    model.train()
    total_loss = 0.0

    for batch in tqdm(trn_dataloader, desc="Train Data Batch", position=1, leave=False):
        distances = batch["distance"].to(device)

        orig_audios = batch["original_audio"].to(device)
        orig_videos = batch["original_video"].to(device)
        orig_texts = batch["original_text"].to(device)

        cause_audios = batch["cause_audio"].to(device)
        cause_videos = batch["cause_video"].to(device)
        cause_texts = batch["cause_text"].to(device)

        is_cause = batch["is_cause"].to(device)

        orig_embedding = torch.cat((orig_audios, orig_videos, orig_texts), axis=1).float()
        cause_embedding = torch.cat((cause_audios, cause_videos, cause_texts), axis=1).float()

        output_logits = model(orig_embedding, cause_embedding, distances).squeeze()

        loss = criterion(output_logits, is_cause)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    scheduler.step()

    model.eval()
    total_val_loss = 0.0
    total_val_correct = 0
    total_val_predictions = 0

    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for val_batch in tqdm(val_dataloader, desc="Val Data Batch", position=1, leave=False):
            distances = val_batch["distance"].to(device)

            orig_audios = val_batch["original_audio"].to(device)
            orig_videos = val_batch["original_video"].to(device)
            orig_texts = val_batch["original_text"].to(device)

            cause_audios = val_batch["cause_audio"].to(device)
            cause_videos = val_batch["cause_video"].to(device)
            cause_texts = val_batch["cause_text"].to(device)

            is_cause = val_batch["is_cause"].to(device)

            orig_embedding = torch.cat((orig_audios, orig_videos, orig_texts), axis=1).float()
            cause_embedding = torch.cat((cause_audios, cause_videos, cause_texts), axis=1).float()

            output_logits = model(orig_embedding, cause_embedding, distances).squeeze()

            val_loss = criterion(output_logits, is_cause)

            total_val_loss += val_loss.item()

            predicted_is_cause = (output_logits >= 0.5).float()

            correct_predictions_val = (is_cause == predicted_is_cause).sum().item()

            total_val_correct += correct_predictions_val
            total_val_predictions += (predicted_is_cause == 1.0).sum().item()

            true_labels.extend(is_cause.cpu().numpy())
            predicted_labels.extend(predicted_is_cause.cpu().numpy())

    report = classification_report(true_labels, predicted_labels)

    avg_loss = total_loss / len(trn_dataset)
    avg_val_loss = total_val_loss / len(val_dataset)
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch
        best_classification_report = report
        # best_model_file = save_model(epoch)
        torch.save(model.state_dict(), f"/tmp/semeval24_task3/baseline_models/pairing_models/paring_model_best_model.pt")
        numpy.save(
            f"/tmp/semeval24_task3/baseline_models/pairing_models/pairing_model_pos_embeds_best_model.npy",
            model.positional_embeddings.cpu().numpy(),
        )

    print(f"Epoch [{epoch+1:02}/{NUM_EPOCHS}] Classification Report:\n{report}")
    print(f"Epoch [{epoch+1:02}/{NUM_EPOCHS}] Train Loss: {avg_loss}")
    print(f"Epoch [{epoch+1:02}/{NUM_EPOCHS}] Validation Loss: {avg_val_loss}")
    print("------------------------------------------------------------")

    save_model(epoch)

print("===================BEST MODEL===================")
print(f"Best Model Epoch: {best_epoch}")
print(f"Best Model Validation Loss: {best_val_loss}")
print(f"Best Model Classification Report:\n{best_classification_report}")
print("================================================")

# Current best, Epoch 18

/tmp/semeval24_task3/audio_embeddings/audio_embeddings_facebook_wav2vec2-large-960h.pkl
/tmp/semeval24_task3/video_embeddings/final_embeddings.pkl
/tmp/semeval24_task3/text_embeddings/text_embeddings_roberta_large.pkl


Epoch:   2%|▎         | 1/40 [00:09<06:27,  9.93s/it]

Epoch [01/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      0.64      0.77     10056
         1.0       0.17      0.75      0.28       991

    accuracy                           0.65     11047
   macro avg       0.57      0.70      0.52     11047
weighted avg       0.89      0.65      0.72     11047

Epoch [01/40] Train Loss: 0.023019805240758758
Epoch [01/40] Validation Loss: 0.017721938183817048
------------------------------------------------------------


Epoch:   5%|▌         | 2/40 [00:18<05:56,  9.38s/it]

Epoch [02/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.56      0.72     10056
         1.0       0.17      0.90      0.28       991

    accuracy                           0.59     11047
   macro avg       0.58      0.73      0.50     11047
weighted avg       0.91      0.59      0.68     11047

Epoch [02/40] Train Loss: 0.020377223390874957
Epoch [02/40] Validation Loss: 0.017537865014028445
------------------------------------------------------------




Epoch [03/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.75      0.85     10056
         1.0       0.25      0.84      0.39       991

    accuracy                           0.76     11047
   macro avg       0.62      0.80      0.62     11047
weighted avg       0.91      0.76      0.81     11047

Epoch [03/40] Train Loss: 0.018737262210572068
Epoch [03/40] Validation Loss: 0.014832073480149373
------------------------------------------------------------


Epoch:  10%|█         | 4/40 [00:36<05:25,  9.03s/it]

Epoch [04/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.78      0.87     10056
         1.0       0.27      0.84      0.41       991

    accuracy                           0.79     11047
   macro avg       0.63      0.81      0.64     11047
weighted avg       0.92      0.79      0.83     11047

Epoch [04/40] Train Loss: 0.01746926741273728
Epoch [04/40] Validation Loss: 0.013907855917332486
------------------------------------------------------------


Epoch:  12%|█▎        | 5/40 [00:45<05:16,  9.04s/it]

Epoch [05/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.65      0.79     10056
         1.0       0.21      0.92      0.34       991

    accuracy                           0.68     11047
   macro avg       0.60      0.79      0.56     11047
weighted avg       0.92      0.68      0.75     11047

Epoch [05/40] Train Loss: 0.016527242160165753
Epoch [05/40] Validation Loss: 0.015354909502065032
------------------------------------------------------------


Epoch:  15%|█▌        | 6/40 [00:54<05:04,  8.95s/it]

Epoch [06/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.73      0.84     10056
         1.0       0.25      0.90      0.39       991

    accuracy                           0.75     11047
   macro avg       0.62      0.82      0.62     11047
weighted avg       0.92      0.75      0.80     11047

Epoch [06/40] Train Loss: 0.01579347915844255
Epoch [06/40] Validation Loss: 0.013766407985868704
------------------------------------------------------------


Epoch:  18%|█▊        | 7/40 [01:03<04:57,  9.03s/it]

Epoch [07/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.80      0.88     10056
         1.0       0.30      0.87      0.45       991

    accuracy                           0.81     11047
   macro avg       0.64      0.84      0.67     11047
weighted avg       0.92      0.81      0.85     11047

Epoch [07/40] Train Loss: 0.015204418182190963
Epoch [07/40] Validation Loss: 0.012321149869240877
------------------------------------------------------------




Epoch [08/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.81      0.89     10056
         1.0       0.31      0.87      0.46       991

    accuracy                           0.82     11047
   macro avg       0.65      0.84      0.67     11047
weighted avg       0.92      0.82      0.85     11047

Epoch [08/40] Train Loss: 0.014720347600291117
Epoch [08/40] Validation Loss: 0.011945050958622826
------------------------------------------------------------


Epoch:  22%|██▎       | 9/40 [01:21<04:37,  8.94s/it]

Epoch [09/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.79      0.88     10056
         1.0       0.30      0.88      0.44       991

    accuracy                           0.80     11047
   macro avg       0.64      0.84      0.66     11047
weighted avg       0.92      0.80      0.84     11047

Epoch [09/40] Train Loss: 0.01436860574574888
Epoch [09/40] Validation Loss: 0.01199325726455859
------------------------------------------------------------


Epoch:  25%|██▌       | 10/40 [01:30<04:26,  8.87s/it]

Epoch [10/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.77      0.87     10056
         1.0       0.28      0.90      0.43       991

    accuracy                           0.78     11047
   macro avg       0.63      0.84      0.65     11047
weighted avg       0.92      0.78      0.83     11047

Epoch [10/40] Train Loss: 0.014087434977784401
Epoch [10/40] Validation Loss: 0.01227590945850424
------------------------------------------------------------


Epoch:  28%|██▊       | 11/40 [01:39<04:17,  8.90s/it]

Epoch [11/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.80      0.88     10056
         1.0       0.30      0.88      0.45       991

    accuracy                           0.81     11047
   macro avg       0.64      0.84      0.67     11047
weighted avg       0.92      0.81      0.84     11047

Epoch [11/40] Train Loss: 0.013828440276044254
Epoch [11/40] Validation Loss: 0.011613748835212377
------------------------------------------------------------




Epoch [12/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.80      0.88     10056
         1.0       0.30      0.89      0.45       991

    accuracy                           0.81     11047
   macro avg       0.64      0.84      0.67     11047
weighted avg       0.93      0.81      0.84     11047

Epoch [12/40] Train Loss: 0.013589610907820153
Epoch [12/40] Validation Loss: 0.011530135414704816
------------------------------------------------------------


Epoch:  32%|███▎      | 13/40 [01:57<04:01,  8.95s/it]

Epoch [13/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.79      0.88     10056
         1.0       0.30      0.89      0.44       991

    accuracy                           0.80     11047
   macro avg       0.64      0.84      0.66     11047
weighted avg       0.92      0.80      0.84     11047

Epoch [13/40] Train Loss: 0.013429646918290442
Epoch [13/40] Validation Loss: 0.011649796067722757
------------------------------------------------------------




Epoch [14/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.85      0.91     10056
         1.0       0.35      0.84      0.49       991

    accuracy                           0.85     11047
   macro avg       0.67      0.84      0.70     11047
weighted avg       0.92      0.85      0.87     11047

Epoch [14/40] Train Loss: 0.013315895775298086
Epoch [14/40] Validation Loss: 0.010550904160911109
------------------------------------------------------------


Epoch:  38%|███▊      | 15/40 [02:14<03:42,  8.89s/it]

Epoch [15/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.81      0.89     10056
         1.0       0.31      0.88      0.46       991

    accuracy                           0.81     11047
   macro avg       0.65      0.84      0.67     11047
weighted avg       0.93      0.81      0.85     11047

Epoch [15/40] Train Loss: 0.013227813803537438
Epoch [15/40] Validation Loss: 0.011142473856228154
------------------------------------------------------------




Epoch [16/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.85      0.91     10056
         1.0       0.36      0.84      0.51       991

    accuracy                           0.85     11047
   macro avg       0.67      0.85      0.71     11047
weighted avg       0.93      0.85      0.88     11047

Epoch [16/40] Train Loss: 0.013066109219890034
Epoch [16/40] Validation Loss: 0.010295246999761323
------------------------------------------------------------


Epoch:  42%|████▎     | 17/40 [02:32<03:23,  8.87s/it]

Epoch [17/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.80      0.89     10056
         1.0       0.31      0.89      0.46       991

    accuracy                           0.81     11047
   macro avg       0.65      0.85      0.67     11047
weighted avg       0.93      0.81      0.85     11047

Epoch [17/40] Train Loss: 0.012921296439087692
Epoch [17/40] Validation Loss: 0.011136125945849073
------------------------------------------------------------




Epoch [18/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.84      0.91     10056
         1.0       0.34      0.85      0.49       991

    accuracy                           0.84     11047
   macro avg       0.66      0.85      0.70     11047
weighted avg       0.93      0.84      0.87     11047

Epoch [18/40] Train Loss: 0.012886253052394301
Epoch [18/40] Validation Loss: 0.010423825883440353
------------------------------------------------------------


Epoch:  48%|████▊     | 19/40 [02:50<03:06,  8.89s/it]

Epoch [19/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.83      0.90     10056
         1.0       0.34      0.86      0.48       991

    accuracy                           0.84     11047
   macro avg       0.66      0.85      0.69     11047
weighted avg       0.93      0.84      0.86     11047

Epoch [19/40] Train Loss: 0.012710328418250089
Epoch [19/40] Validation Loss: 0.010514792001838982
------------------------------------------------------------




Epoch [20/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.85      0.91     10056
         1.0       0.36      0.84      0.51       991

    accuracy                           0.85     11047
   macro avg       0.67      0.85      0.71     11047
weighted avg       0.93      0.85      0.88     11047

Epoch [20/40] Train Loss: 0.01264308993739267
Epoch [20/40] Validation Loss: 0.010144181939781059
------------------------------------------------------------


Epoch:  52%|█████▎    | 21/40 [03:08<02:48,  8.86s/it]

Epoch [21/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.82      0.89     10056
         1.0       0.32      0.88      0.47       991

    accuracy                           0.82     11047
   macro avg       0.65      0.85      0.68     11047
weighted avg       0.93      0.82      0.86     11047

Epoch [21/40] Train Loss: 0.012643137496183178
Epoch [21/40] Validation Loss: 0.010695002558918853
------------------------------------------------------------


Epoch:  55%|█████▌    | 22/40 [03:16<02:38,  8.81s/it]

Epoch [22/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.84      0.91     10056
         1.0       0.35      0.85      0.49       991

    accuracy                           0.84     11047
   macro avg       0.67      0.85      0.70     11047
weighted avg       0.93      0.84      0.87     11047

Epoch [22/40] Train Loss: 0.012594211163505138
Epoch [22/40] Validation Loss: 0.010203967989164928
------------------------------------------------------------


Epoch:  57%|█████▊    | 23/40 [03:25<02:29,  8.78s/it]

Epoch [23/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.77      0.86     10056
         1.0       0.28      0.91      0.43       991

    accuracy                           0.78     11047
   macro avg       0.63      0.84      0.65     11047
weighted avg       0.92      0.78      0.82     11047

Epoch [23/40] Train Loss: 0.012420908698362704
Epoch [23/40] Validation Loss: 0.012134027492923057
------------------------------------------------------------


Epoch:  60%|██████    | 24/40 [03:34<02:20,  8.76s/it]

Epoch [24/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.81      0.89     10056
         1.0       0.32      0.89      0.47       991

    accuracy                           0.82     11047
   macro avg       0.65      0.85      0.68     11047
weighted avg       0.93      0.82      0.85     11047

Epoch [24/40] Train Loss: 0.012536312672914824
Epoch [24/40] Validation Loss: 0.010765203331058607
------------------------------------------------------------


Epoch:  62%|██████▎   | 25/40 [03:43<02:12,  8.83s/it]

Epoch [25/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.86      0.92     10056
         1.0       0.37      0.84      0.51       991

    accuracy                           0.86     11047
   macro avg       0.67      0.85      0.71     11047
weighted avg       0.93      0.86      0.88     11047

Epoch [25/40] Train Loss: 0.012337094238018643
Epoch [25/40] Validation Loss: 0.009961833487222509
------------------------------------------------------------


Epoch:  65%|██████▌   | 26/40 [03:52<02:03,  8.82s/it]

Epoch [26/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.82      0.90     10056
         1.0       0.33      0.88      0.48       991

    accuracy                           0.83     11047
   macro avg       0.66      0.85      0.69     11047
weighted avg       0.93      0.83      0.86     11047

Epoch [26/40] Train Loss: 0.01225475004423198
Epoch [26/40] Validation Loss: 0.010454217664500734
------------------------------------------------------------


Epoch:  68%|██████▊   | 27/40 [04:00<01:55,  8.85s/it]

Epoch [27/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.80      0.88     10056
         1.0       0.30      0.90      0.45       991

    accuracy                           0.81     11047
   macro avg       0.65      0.85      0.67     11047
weighted avg       0.93      0.81      0.84     11047

Epoch [27/40] Train Loss: 0.012170719753150493
Epoch [27/40] Validation Loss: 0.011216373065381912
------------------------------------------------------------


Epoch:  70%|███████   | 28/40 [04:09<01:45,  8.82s/it]

Epoch [28/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.84      0.90     10056
         1.0       0.34      0.86      0.49       991

    accuracy                           0.84     11047
   macro avg       0.66      0.85      0.70     11047
weighted avg       0.93      0.84      0.87     11047

Epoch [28/40] Train Loss: 0.012230384573309109
Epoch [28/40] Validation Loss: 0.010183575811055728
------------------------------------------------------------


Epoch:  72%|███████▎  | 29/40 [04:18<01:36,  8.78s/it]

Epoch [29/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.85      0.91     10056
         1.0       0.36      0.84      0.50       991

    accuracy                           0.85     11047
   macro avg       0.67      0.85      0.71     11047
weighted avg       0.93      0.85      0.87     11047

Epoch [29/40] Train Loss: 0.012151601501033562
Epoch [29/40] Validation Loss: 0.010061267484035748
------------------------------------------------------------


Epoch:  75%|███████▌  | 30/40 [04:27<01:27,  8.76s/it]

Epoch [30/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.85      0.91     10056
         1.0       0.36      0.84      0.51       991

    accuracy                           0.85     11047
   macro avg       0.67      0.85      0.71     11047
weighted avg       0.93      0.85      0.88     11047

Epoch [30/40] Train Loss: 0.012079906590325616
Epoch [30/40] Validation Loss: 0.00999752428459374
------------------------------------------------------------


Epoch:  78%|███████▊  | 31/40 [04:35<01:18,  8.74s/it]

Epoch [31/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.83      0.90     10056
         1.0       0.34      0.87      0.49       991

    accuracy                           0.84     11047
   macro avg       0.66      0.85      0.69     11047
weighted avg       0.93      0.84      0.86     11047

Epoch [31/40] Train Loss: 0.012042137420873867
Epoch [31/40] Validation Loss: 0.010310363571512934
------------------------------------------------------------




Epoch [32/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.79      0.88     10056
         1.0       0.30      0.91      0.45       991

    accuracy                           0.80     11047
   macro avg       0.64      0.85      0.66     11047
weighted avg       0.93      0.80      0.84     11047

Epoch [32/40] Train Loss: 0.011925714443973464
Epoch [32/40] Validation Loss: 0.011367669704421777
------------------------------------------------------------


Epoch:  82%|████████▎ | 33/40 [04:53<01:01,  8.82s/it]

Epoch [33/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.79      0.88     10056
         1.0       0.30      0.90      0.45       991

    accuracy                           0.80     11047
   macro avg       0.64      0.85      0.66     11047
weighted avg       0.93      0.80      0.84     11047

Epoch [33/40] Train Loss: 0.011854307895691045
Epoch [33/40] Validation Loss: 0.0113405588541473
------------------------------------------------------------


Epoch:  85%|████████▌ | 34/40 [05:02<00:52,  8.80s/it]

Epoch [34/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.82      0.90     10056
         1.0       0.33      0.88      0.48       991

    accuracy                           0.83     11047
   macro avg       0.66      0.85      0.69     11047
weighted avg       0.93      0.83      0.86     11047

Epoch [34/40] Train Loss: 0.011849114518288878
Epoch [34/40] Validation Loss: 0.010443106616719235
------------------------------------------------------------


Epoch:  88%|████████▊ | 35/40 [05:11<00:43,  8.78s/it]

Epoch [35/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.83      0.90     10056
         1.0       0.34      0.87      0.49       991

    accuracy                           0.84     11047
   macro avg       0.66      0.85      0.70     11047
weighted avg       0.93      0.84      0.87     11047

Epoch [35/40] Train Loss: 0.01174270848134657
Epoch [35/40] Validation Loss: 0.010229980934406352
------------------------------------------------------------


Epoch:  90%|█████████ | 36/40 [05:19<00:35,  8.75s/it]

Epoch [36/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.84      0.90     10056
         1.0       0.34      0.86      0.49       991

    accuracy                           0.84     11047
   macro avg       0.66      0.85      0.70     11047
weighted avg       0.93      0.84      0.87     11047

Epoch [36/40] Train Loss: 0.01169726155610412
Epoch [36/40] Validation Loss: 0.010193402741936886
------------------------------------------------------------




Epoch [37/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.99      0.80      0.89     10056
         1.0       0.31      0.89      0.46       991

    accuracy                           0.81     11047
   macro avg       0.65      0.85      0.67     11047
weighted avg       0.93      0.81      0.85     11047

Epoch [37/40] Train Loss: 0.011499413647173929
Epoch [37/40] Validation Loss: 0.010989686616532187
------------------------------------------------------------


Epoch:  95%|█████████▌| 38/40 [05:37<00:17,  8.77s/it]

Epoch [38/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.83      0.90     10056
         1.0       0.34      0.86      0.49       991

    accuracy                           0.84     11047
   macro avg       0.66      0.85      0.70     11047
weighted avg       0.93      0.84      0.87     11047

Epoch [38/40] Train Loss: 0.011621868836603553
Epoch [38/40] Validation Loss: 0.010225147954727204
------------------------------------------------------------




Epoch [39/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.85      0.91     10056
         1.0       0.36      0.84      0.51       991

    accuracy                           0.85     11047
   macro avg       0.67      0.85      0.71     11047
weighted avg       0.93      0.85      0.88     11047

Epoch [39/40] Train Loss: 0.011540443475663286
Epoch [39/40] Validation Loss: 0.009907395855785595
------------------------------------------------------------


Epoch: 100%|██████████| 40/40 [05:55<00:00,  8.88s/it]

Epoch [40/40] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.84      0.91     10056
         1.0       0.35      0.85      0.50       991

    accuracy                           0.85     11047
   macro avg       0.67      0.85      0.70     11047
weighted avg       0.93      0.85      0.87     11047

Epoch [40/40] Train Loss: 0.01143114999688283
Epoch [40/40] Validation Loss: 0.010021769321699705
------------------------------------------------------------
Best Model Epoch: 38
Best Model Validation Loss: 0.009907395855785595
Best Model Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.85      0.91     10056
         1.0       0.36      0.84      0.51       991

    accuracy                           0.85     11047
   macro avg       0.67      0.85      0.71     11047
weighted avg       0.93      0.85      0.88     11047




