In [1]:
import torch

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import numpy
# import random
# import torch

# numpy.random.seed(69)
# random.seed(69)
# torch.manual_seed(69)

In [3]:
import os.path

BASE_DIR = "/tmp/akshett.jindal"
HUGGINGFACE_CACHE_DIR = os.path.join(BASE_DIR, ".huggingface_cache")

In [4]:
import os.path

DATA_DIR = "/tmp/semeval24_task3"

# TRAIN_DATA_FILEPATH = os.path.join(DATA_DIR, "final_clean_data", "train", "Subtask_2.json")
# VAL_DATA_FILEPATH = os.path.join(DATA_DIR, "final_clean_data", "val", "Subtask_2.json")

TRAIN_DATA_FILEPATH = "/tmp/semeval24_task3/SemEval-2024_Task3/official_data/Training_data/text/training.json"
VAL_DATA_FILEPATH = "/tmp/semeval24_task3/SemEval-2024_Task3/official_data/Training_data/text/testing.json"

In [5]:
import torch
import pickle
import numpy as np

VID_ID_MAPPING = np.load("/home2/suyash.mathur/semeval24/task3/MECPE/data/video_id_mapping.npy", allow_pickle=True).item()

class YourAudioEncoder():
    def __init__(self, audio_embeddings_path):
        self.audio_embeddings = np.load(audio_embeddings_path)
        # with open(audio_embeddings_path, "rb") as f:
            # self.audio_embeddings = pickle.load(f)

    def lmao(self, audio_name):
        audio_name = audio_name.split(".")[0]
        audio_name = VID_ID_MAPPING[audio_name]
        audio_embedding = self.audio_embeddings[audio_name]
        return torch.from_numpy(audio_embedding)
    

In [6]:
import torch
import pickle
import numpy as np

class YourVideoEncoder():
    def __init__(self, video_embeddings_path):
        self.video_embeddings = np.load(video_embeddings_path)
        # with open(video_embeddings_path, "rb") as f:
        #     self.video_embeddings = pickle.load(f)

    def lmao(self, video_name):
        video_name = video_name.split(".")[0]
        video_name = VID_ID_MAPPING[video_name]
        video_embedding = self.video_embeddings[video_name]
        # video_embedding = video_embedding.reshape((16,-1))
        # video_embedding = np.mean(video_embedding, axis=0)
        return torch.from_numpy(video_embedding)

In [7]:
import pickle
from transformers import RobertaTokenizer, RobertaModel
import torch

class YourTextEncoder():
    def __init__(self, text_embeddings_path):
        with open(text_embeddings_path, "rb") as f:
            self.text_embeddings = pickle.load(f)

    def lmao(self, video_name):
        text_embedding = self.text_embeddings[video_name]
        return torch.from_numpy(text_embedding)


In [8]:
import json
import random
from torch.utils.data import Dataset

class EmotionCauseDataset(Dataset):
    def __init__(self, file_path, audio_encoder, video_encoder, text_encoder, neg_to_pos_ratio):
        with open(file_path) as f:
            self.file_data = json.load(f)

        self.audio_encoder = audio_encoder
        self.video_encoder = video_encoder
        self.text_encoder = text_encoder
        self.data = []
        self.POSITIVE_SAMPLE_COUNT = 0
        self.NEGATIVE_SAMPLE_COUNT = 0

        for conversation in self.file_data:
            positive_samples = []
            negative_samples = []

            utterances = {
                utterance["utterance_ID"]: utterance
                for utterance in conversation["conversation"]
            }
            utterance_ids = set(utterances.keys())

            causes = {
                utterance_id: []
                for utterance_id in utterance_ids
            }

            for emo_cause_pairs in conversation["emotion-cause_pairs"]:
                emotion_utterance_num, emotion = emo_cause_pairs[0].split("_")
                emotion_utterance_num = int(emotion_utterance_num)
                cause_utterance_num = int(emo_cause_pairs[1])

                if emotion != "neutral":
                    causes[emotion_utterance_num].append(cause_utterance_num)

            for emo_utterance_id in utterance_ids:
                emo_utterance = utterances[emo_utterance_id]

                if utterances[emo_utterance_id]["emotion"] == "neutral":
                    continue

                for cause_utterance_id in utterance_ids:
                    cause_utterance = utterances[cause_utterance_id]

                    is_cause = cause_utterance_id in causes[emo_utterance_id]

                    data_point = {
                        "original_utterance": {
                            "id": emo_utterance_id,
                            "text": emo_utterance["text"],
                            "video_name": emo_utterance["video_name"],
                        },
                        "cause_utterance": {
                            "id": cause_utterance_id,
                            "text": cause_utterance["text"],
                            "video_name": cause_utterance["video_name"],
                        },
                        "is_cause": is_cause,
                    }

                    if is_cause:
                        positive_samples.append(data_point)
                        self.POSITIVE_SAMPLE_COUNT += 1
                    else:
                        negative_samples.append(data_point)
                        self.NEGATIVE_SAMPLE_COUNT += 1

            random.shuffle(negative_samples)

            self.data.extend(positive_samples)
            if neg_to_pos_ratio is not None:
                self.data.extend(negative_samples[:min(neg_to_pos_ratio * len(positive_samples), len(negative_samples))])
            else:
                self.data.extend(negative_samples)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]

        orig_utt = data["original_utterance"]
        cause_utt = data["cause_utterance"]

        orig_id = orig_utt["id"]
        orig_text = orig_utt["text"]
        orig_video = orig_utt["video_name"]
        orig_audio = orig_utt["video_name"].replace(".mp4", ".wav")

        cause_id = cause_utt["id"]
        cause_text = cause_utt["text"]
        cause_video = cause_utt["video_name"]
        cause_audio = cause_utt["video_name"].replace(".mp4", ".wav")

        is_cause = data["is_cause"]

        return {
            "distance": abs(orig_id - cause_id),
            "original_audio": self.audio_encoder.lmao(orig_audio).float(),
            "original_video": self.video_encoder.lmao(orig_video).float(),
            "original_text": self.text_encoder.lmao(orig_video).squeeze().float(),
            "cause_audio": self.audio_encoder.lmao(cause_audio).float(),
            "cause_video": self.video_encoder.lmao(cause_video).float(),
            "cause_text": self.text_encoder.lmao(cause_video).squeeze().float(),
            "is_cause": 1.0 if is_cause else 0.0,
        }

In [9]:
import numpy as np
import os.path
from torch.utils.data import DataLoader

AUDIO_EMBEDDINGS_FILEPATH = "/tmp/semeval24_task3/og_paper_embeddings/audio_embedding_6373.npy"
VIDEO_EMBEDDINGS_FILEPATH = "/tmp/semeval24_task3/og_paper_embeddings/video_embedding_4096.npy"
TEXT_EMBEDDINGS_FILEPATH = os.path.join(DATA_DIR, "text_embeddings", "text_embeddings_bert_base.pkl")

audio_encoder = YourAudioEncoder(AUDIO_EMBEDDINGS_FILEPATH)
video_encoder = YourVideoEncoder(VIDEO_EMBEDDINGS_FILEPATH)
text_encoder = YourTextEncoder(TEXT_EMBEDDINGS_FILEPATH)

trn_dataset = EmotionCauseDataset(
    TRAIN_DATA_FILEPATH,
    audio_encoder,
    video_encoder,
    text_encoder,
    neg_to_pos_ratio=5
)
trn_dataloader = DataLoader(trn_dataset, batch_size=64, shuffle=True)

val_dataset = EmotionCauseDataset(
    VAL_DATA_FILEPATH,
    audio_encoder,
    video_encoder,
    text_encoder,
    neg_to_pos_ratio=None,
)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)

len(trn_dataset), len(val_dataset)

(49835, 11047)

In [10]:
import numpy as np

def generate_positional_embeddings(dimension, count):
    embeddings = [list(np.zeros(dimension))]
    embeddings.extend([
        list(np.random.normal(loc=0.0, scale=0.1, size=dimension)) for _ in range(count)
    ])
    return np.array(embeddings)

In [11]:
import torch.nn as nn

class EmotionCauseDetector(nn.Module):
    def __init__(
        self,
        utterance_embedding_size,
        device,
        hidden_dimension=4096,
        positional_embeddings_dimension=200,
        dropout=0.2,
        *args, **kwargs,
    ):
        super().__init__()

        self.hidden_dimension = hidden_dimension

        positional_embeddings = generate_positional_embeddings(positional_embeddings_dimension, 200)
        self.positional_embeddings = torch.from_numpy(positional_embeddings).to(device).float()
        
        self.non_neutral_dropout = nn.Dropout(dropout)
        self.candidate_cause_dropout = nn.Dropout(dropout)
        self.distance_dropout = nn.Dropout(dropout)

        self.linear1 = nn.Linear(utterance_embedding_size*2 + positional_embeddings_dimension, hidden_dimension)
        self.linear1_activation = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dimension, 1)

    def forward(self, non_neutral_utterances, candidate_cause_utterances, distances):
        positional_embedding = self.positional_embeddings[distances]
        
        non_neutral_utterances = self.non_neutral_dropout(non_neutral_utterances)
        candidate_cause_utterances = self.candidate_cause_dropout(candidate_cause_utterances)
        positional_embedding = self.distance_dropout(positional_embedding)

        embeddings = torch.concat((non_neutral_utterances, candidate_cause_utterances, positional_embedding), axis=1)

        return self.linear2(
            self.linear1_activation(
                self.linear1(embeddings)
            )
        )

In [12]:
import torch
import numpy

def save_model(epoch_num):
    torch.save(model.state_dict(), f"/tmp/semeval24_task3/final_models/pairing_models/paring_model_{epoch_num:02}.pt")
    numpy.save(
        f"/tmp/semeval24_task3/final_models/pairing_models/pairing_model_pos_embeds_{epoch_num:02}.npy",
        model.positional_embeddings.cpu().numpy(),
    )

In [14]:
from tqdm.auto import tqdm
from sklearn.metrics import classification_report
from transformers import get_linear_schedule_with_warmup
NUM_EPOCHS = 20
AUDIO_EMBEDDING_SIZE = 6373
VIDEO_EMBEDDING_SIZE = 4096
TEXT_EMBEDDING_SIZE = 768
TOTAL_EMBEDDING_SIZE = (
    AUDIO_EMBEDDING_SIZE + VIDEO_EMBEDDING_SIZE + TEXT_EMBEDDING_SIZE
)
# TOTAL_EMBEDDING_SIZE = 11237

model = EmotionCauseDetector(
    TOTAL_EMBEDDING_SIZE,
    device,
    hidden_dimension=2000,
)
_ = model.to(device)

weight_ratio = trn_dataset.NEGATIVE_SAMPLE_COUNT / trn_dataset.POSITIVE_SAMPLE_COUNT

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(weight_ratio).to(device))
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)

total_steps = len(trn_dataloader) * NUM_EPOCHS

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

best_model_file = None
best_val_loss = float('inf')
best_epoch = -1
best_classification_report = None


epoch_iter = tqdm(range(NUM_EPOCHS), desc="Epoch", position=0)
for epoch in epoch_iter:
    model.train()
    total_loss = 0.0

    for batch in tqdm(trn_dataloader, desc="Train Data Batch", position=1, leave=False):
        distances = batch["distance"].to(device)

        orig_audios = batch["original_audio"].to(device)
        orig_videos = batch["original_video"].to(device)
        orig_texts = batch["original_text"].to(device)

        cause_audios = batch["cause_audio"].to(device)
        cause_videos = batch["cause_video"].to(device)
        cause_texts = batch["cause_text"].to(device)

        is_cause = batch["is_cause"].to(device)

        orig_embedding = torch.cat((orig_audios, orig_videos, orig_texts), axis=1).float()
        cause_embedding = torch.cat((cause_audios, cause_videos, cause_texts), axis=1).float()

        output_logits = model(orig_embedding, cause_embedding, distances).squeeze()

        loss = criterion(output_logits, is_cause)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    scheduler.step()

    model.eval()
    total_val_loss = 0.0
    total_val_correct = 0
    total_val_predictions = 0

    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for val_batch in tqdm(val_dataloader, desc="Val Data Batch", position=1, leave=False):
            distances = val_batch["distance"].to(device)

            orig_audios = val_batch["original_audio"].to(device)
            orig_videos = val_batch["original_video"].to(device)
            orig_texts = val_batch["original_text"].to(device)

            cause_audios = val_batch["cause_audio"].to(device)
            cause_videos = val_batch["cause_video"].to(device)
            cause_texts = val_batch["cause_text"].to(device)

            is_cause = val_batch["is_cause"].to(device)

            orig_embedding = torch.cat((orig_audios, orig_videos, orig_texts), axis=1).float()
            cause_embedding = torch.cat((cause_audios, cause_videos, cause_texts), axis=1).float()

            output_logits = model(orig_embedding, cause_embedding, distances).squeeze()

            val_loss = criterion(output_logits, is_cause)

            total_val_loss += val_loss.item()

            predicted_is_cause = (output_logits >= 0.5).float()

            correct_predictions_val = (is_cause == predicted_is_cause).sum().item()

            total_val_correct += correct_predictions_val
            total_val_predictions += (predicted_is_cause == 1.0).sum().item()

            true_labels.extend(is_cause.cpu().numpy())
            predicted_labels.extend(predicted_is_cause.cpu().numpy())

    report = classification_report(true_labels, predicted_labels)

    avg_loss = total_loss / len(trn_dataset)
    avg_val_loss = total_val_loss / len(val_dataset)
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch
        best_classification_report = report
        # best_model_file = save_model(epoch)
        torch.save(model.state_dict(), f"/tmp/semeval24_task3/final_models/pairing_models/paring_model_best_model.pt")
        numpy.save(
            f"/tmp/semeval24_task3/final_models/pairing_models/pairing_model_pos_embeds_best_model.npy",
            model.positional_embeddings.cpu().numpy(),
        )

    print(f"Epoch [{epoch+1:02}/{NUM_EPOCHS}] Classification Report:\n{report}")
    print(f"Epoch [{epoch+1:02}/{NUM_EPOCHS}] Train Loss: {avg_loss}")
    print(f"Epoch [{epoch+1:02}/{NUM_EPOCHS}] Validation Loss: {avg_val_loss}")
    print("------------------------------------------------------------")

    save_model(epoch)

print("===================BEST MODEL===================")
print(f"Best Model Epoch: {best_epoch}")
print(f"Best Model Validation Loss: {best_val_loss}")
print(f"Best Model Classification Report:\n{best_classification_report}")
print("================================================")

# Current best, Epoch 18

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch [01/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.91      0.98      0.94     10056
         1.0       0.13      0.03      0.05       991

    accuracy                           0.89     11047
   macro avg       0.52      0.51      0.50     11047
weighted avg       0.84      0.89      0.86     11047

Epoch [01/20] Train Loss: 16042088.849810366
Epoch [01/20] Validation Loss: 8971154.376946125
------------------------------------------------------------


Epoch:   5%|▌         | 1/20 [00:16<05:14, 16.57s/it]

Epoch [02/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      0.55      0.70     10056
         1.0       0.14      0.75      0.24       991

    accuracy                           0.57     11047
   macro avg       0.55      0.65      0.47     11047
weighted avg       0.88      0.57      0.66     11047

Epoch [02/20] Train Loss: 11341637.559529731
Epoch [02/20] Validation Loss: 3343965.1583822863
------------------------------------------------------------


Epoch:  10%|█         | 2/20 [00:34<05:08, 17.12s/it]

Epoch [03/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.93      0.02      0.03     10056
         1.0       0.09      0.99      0.17       991

    accuracy                           0.10     11047
   macro avg       0.51      0.50      0.10     11047
weighted avg       0.86      0.10      0.05     11047

Epoch [03/20] Train Loss: 8925704.810745174
Epoch [03/20] Validation Loss: 9843022.17737845
------------------------------------------------------------


Epoch:  15%|█▌        | 3/20 [00:51<04:50, 17.10s/it]

Epoch [04/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.90      0.80      0.85     10056
         1.0       0.07      0.14      0.09       991

    accuracy                           0.74     11047
   macro avg       0.48      0.47      0.47     11047
weighted avg       0.83      0.74      0.78     11047

Epoch [04/20] Train Loss: 8581250.568777325
Epoch [04/20] Validation Loss: 5959823.025720282
------------------------------------------------------------


Epoch:  20%|██        | 4/20 [01:07<04:30, 16.89s/it]

Epoch [05/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.95      0.16      0.27     10056
         1.0       0.10      0.91      0.17       991

    accuracy                           0.23     11047
   macro avg       0.52      0.53      0.22     11047
weighted avg       0.87      0.23      0.26     11047

Epoch [05/20] Train Loss: 8021735.268583822
Epoch [05/20] Validation Loss: 10445583.392639909
------------------------------------------------------------


Epoch:  25%|██▌       | 5/20 [01:24<04:13, 16.87s/it]

Epoch [06/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.95      0.83      0.89     10056
         1.0       0.25      0.59      0.35       991

    accuracy                           0.81     11047
   macro avg       0.60      0.71      0.62     11047
weighted avg       0.89      0.81      0.84     11047

Epoch [06/20] Train Loss: 9899819.614297159
Epoch [06/20] Validation Loss: 5080159.909369062
------------------------------------------------------------


Epoch:  30%|███       | 6/20 [01:41<03:55, 16.82s/it]

Epoch [07/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      0.42      0.58     10056
         1.0       0.12      0.83      0.21       991

    accuracy                           0.45     11047
   macro avg       0.54      0.62      0.40     11047
weighted avg       0.89      0.45      0.55     11047

Epoch [07/20] Train Loss: 8407077.200749332
Epoch [07/20] Validation Loss: 3268423.5122498414
------------------------------------------------------------


Epoch:  35%|███▌      | 7/20 [01:58<03:41, 17.02s/it]

Epoch [08/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.97      0.15      0.26     10056
         1.0       0.10      0.95      0.18       991

    accuracy                           0.22     11047
   macro avg       0.53      0.55      0.22     11047
weighted avg       0.89      0.22      0.25     11047

Epoch [08/20] Train Loss: 12463775.026729602
Epoch [08/20] Validation Loss: 5108105.152766351
------------------------------------------------------------


Epoch:  40%|████      | 8/20 [02:15<03:24, 17.02s/it]

Epoch [09/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      0.10      0.18     10056
         1.0       0.09      0.96      0.17       991

    accuracy                           0.17     11047
   macro avg       0.53      0.53      0.17     11047
weighted avg       0.88      0.17      0.18     11047

Epoch [09/20] Train Loss: 8420524.69103026
Epoch [09/20] Validation Loss: 5412287.8599892305
------------------------------------------------------------


Epoch:  45%|████▌     | 9/20 [02:32<03:05, 16.88s/it]

Epoch [10/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.95      0.92      0.94     10056
         1.0       0.40      0.53      0.46       991

    accuracy                           0.89     11047
   macro avg       0.68      0.73      0.70     11047
weighted avg       0.90      0.89      0.89     11047

Epoch [10/20] Train Loss: 9697328.21714467
Epoch [10/20] Validation Loss: 4393678.182388423
------------------------------------------------------------


Epoch:  50%|█████     | 10/20 [02:49<02:48, 16.88s/it]

Epoch [11/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.94      0.56      0.70     10056
         1.0       0.13      0.66      0.22       991

    accuracy                           0.57     11047
   macro avg       0.54      0.61      0.46     11047
weighted avg       0.87      0.57      0.66     11047

Epoch [11/20] Train Loss: 10418204.687576372
Epoch [11/20] Validation Loss: 6730420.380647115
------------------------------------------------------------


Epoch:  55%|█████▌    | 11/20 [03:05<02:31, 16.86s/it]

Epoch [12/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.95      0.61      0.74     10056
         1.0       0.14      0.65      0.23       991

    accuracy                           0.62     11047
   macro avg       0.54      0.63      0.49     11047
weighted avg       0.87      0.62      0.70     11047

Epoch [12/20] Train Loss: 10370160.058729129
Epoch [12/20] Validation Loss: 3371914.908136915
------------------------------------------------------------


Epoch:  60%|██████    | 12/20 [03:22<02:13, 16.69s/it]

Epoch [13/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.97      0.08      0.15     10056
         1.0       0.09      0.98      0.17       991

    accuracy                           0.16     11047
   macro avg       0.53      0.53      0.16     11047
weighted avg       0.89      0.16      0.15     11047

Epoch [13/20] Train Loss: 8579385.486864695
Epoch [13/20] Validation Loss: 8933342.988853788
------------------------------------------------------------


Epoch:  65%|██████▌   | 13/20 [03:38<01:56, 16.63s/it]

Epoch [14/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.95      0.06      0.12     10056
         1.0       0.09      0.96      0.17       991

    accuracy                           0.14     11047
   macro avg       0.52      0.51      0.14     11047
weighted avg       0.87      0.14      0.12     11047

Epoch [14/20] Train Loss: 10527625.077013155
Epoch [14/20] Validation Loss: 8349732.074300371
------------------------------------------------------------


Epoch:  70%|███████   | 14/20 [03:55<01:39, 16.56s/it]

Epoch [15/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      0.62      0.75     10056
         1.0       0.16      0.71      0.25       991

    accuracy                           0.63     11047
   macro avg       0.56      0.66      0.50     11047
weighted avg       0.88      0.63      0.71     11047

Epoch [15/20] Train Loss: 8977132.908064973
Epoch [15/20] Validation Loss: 5951977.333735759
------------------------------------------------------------


Epoch:  75%|███████▌  | 15/20 [04:12<01:23, 16.72s/it]

Epoch [16/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.90      0.73      0.80     10056
         1.0       0.05      0.13      0.07       991

    accuracy                           0.68     11047
   macro avg       0.47      0.43      0.44     11047
weighted avg       0.82      0.68      0.74     11047

Epoch [16/20] Train Loss: 7383186.311136257
Epoch [16/20] Validation Loss: 5044503.875582063
------------------------------------------------------------


Epoch:  80%|████████  | 16/20 [04:29<01:06, 16.73s/it]

Epoch [17/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.94      0.96      0.95     10056
         1.0       0.49      0.40      0.44       991

    accuracy                           0.91     11047
   macro avg       0.71      0.68      0.70     11047
weighted avg       0.90      0.91      0.90     11047

Epoch [17/20] Train Loss: 8313866.437490529
Epoch [17/20] Validation Loss: 7907599.801409071
------------------------------------------------------------


Epoch:  85%|████████▌ | 17/20 [04:45<00:50, 16.79s/it]

Epoch [18/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.95      0.48      0.63     10056
         1.0       0.12      0.74      0.21       991

    accuracy                           0.50     11047
   macro avg       0.54      0.61      0.42     11047
weighted avg       0.87      0.50      0.60     11047

Epoch [18/20] Train Loss: 8428199.515980573
Epoch [18/20] Validation Loss: 42952887.30262809
------------------------------------------------------------


Epoch:  90%|█████████ | 18/20 [05:02<00:33, 16.83s/it]

Epoch [19/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.98      0.02      0.03     10056
         1.0       0.09      1.00      0.17       991

    accuracy                           0.10     11047
   macro avg       0.54      0.51      0.10     11047
weighted avg       0.90      0.10      0.04     11047

Epoch [19/20] Train Loss: 9416122.739075698
Epoch [19/20] Validation Loss: 11224924.780358236
------------------------------------------------------------


Epoch:  95%|█████████▌| 19/20 [05:19<00:16, 16.81s/it]

Epoch [20/20] Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      0.29      0.44     10056
         1.0       0.11      0.88      0.19       991

    accuracy                           0.34     11047
   macro avg       0.53      0.58      0.32     11047
weighted avg       0.88      0.34      0.42     11047

Epoch [20/20] Train Loss: 8738941.533780562
Epoch [20/20] Validation Loss: 4306338.11160406
------------------------------------------------------------


Epoch: 100%|██████████| 20/20 [05:36<00:00, 16.81s/it]

Best Model Epoch: 6
Best Model Validation Loss: 3268423.5122498414
Best Model Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      0.42      0.58     10056
         1.0       0.12      0.83      0.21       991

    accuracy                           0.45     11047
   macro avg       0.54      0.62      0.40     11047
weighted avg       0.89      0.45      0.55     11047




