In [1]:
from encoder_paths import *
import json
from pprint import pprint
TRAIN_FILE_PATH = "/tmp/semeval24_task3/SemEval-2024_Task3/official_data/Training_data/text/training.json"
VALIDATION_FILE_PATH = "/tmp/semeval24_task3/SemEval-2024_Task3/official_data/Training_data/text/testing.json"
with open(TRAIN_FILE_PATH) as f:
    train_data = json.load(f)
with open(VALIDATION_FILE_PATH) as f:
    validation_data = json.load(f)

pprint(len(train_data))
pprint(len(validation_data))

1236
138


In [2]:
import numpy
import random
import torch

numpy.random.seed(69)
random.seed(69)
torch.manual_seed(69)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f3190b86190>

In [3]:
all_emotions = set()
for conversation in train_data:
    conversation = conversation["conversation"]
    for utterance in conversation:
        all_emotions.add(utterance["emotion"])
pprint(all_emotions)

{'fear', 'anger', 'neutral', 'disgust', 'joy', 'surprise', 'sadness'}


In [4]:
class EmotionIndexer:
    def __init__(self):
        self.emotion_to_index = {
            'joy': 0,
            'sadness': 1,
            'anger': 2,
            'neutral': 3,
            'surprise': 4,
            'disgust': 5,
            'fear': 6,
        }

        self.index_to_emotion = {index: emotion for emotion, index in self.emotion_to_index.items()}

    def emotion_to_idx(self, emotion):
        return self.emotion_to_index.get(emotion, None)

    def idx_to_emotion(self, index):
        return self.index_to_emotion.get(index, None)

# Example usage
indexer = EmotionIndexer()


In [5]:
import torch
import json
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_video
from torchvision.transforms import functional as F
from PIL import Image
import numpy as np

In [6]:
import torch
import pickle

class YourAudioEncoder():
    def __init__(self, audio_embeddings_path):
        with open(audio_embeddings_path, "rb") as f:
            self.audio_embeddings = pickle.load(f)

    def lmao(self, audio_name):
        audio_name = audio_name.split(".")[0]
        audio_embedding = self.audio_embeddings[audio_name]
        audio_embedding = audio_embedding.squeeze()
        return torch.from_numpy(audio_embedding)
    
class YourVideoEncoder():
    def __init__(self, video_embeddings_path):
        with open(video_embeddings_path, "rb") as f:
            self.video_embeddings = pickle.load(f)

    def lmao(self, video_name):
        # video_name = video_name.split(".")[0]
        video_embedding = self.video_embeddings[video_name].reshape((16,-1))
        video_embedding = np.mean(video_embedding, axis=0)
        return torch.from_numpy(video_embedding)

class YourTextEncoder():
    def __init__(self, text_embeddings_path):
        with open(text_embeddings_path, "rb") as f:
            self.text_embeddings = pickle.load(f)

    def lmao(self, video_name):
        text_embedding = self.text_embeddings[video_name]
        return torch.from_numpy(text_embedding)


In [7]:
class ConversationDataset(Dataset):
    def __init__(self, json_file, audio_encoder, video_encoder, text_encoder, max_seq_len):
        self.max_seq_len = max_seq_len
        self.data = self.load_data(json_file)
        self.audio_encoder = audio_encoder
        self.video_encoder = video_encoder
        self.text_encoder = text_encoder

    def load_data(self, json_file):
        with open(json_file, 'r') as f:
            data = json.load(f)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        conversation = self.data[idx]['conversation']
        # emotion_labels = [utterance['emotion'] for utterance in conversation]
        audio_paths = [utterance['video_name'].replace('mp4', 'wav') for utterance in conversation]
        video_paths = [utterance['video_name'] for utterance in conversation]
        texts = [utterance['video_name'] for utterance in conversation]

        audio_embeddings = [self.audio_encoder.lmao(audio_path) for audio_path in audio_paths]
        video_embeddings = [self.video_encoder.lmao(video_path) for video_path in video_paths]
        text_embeddings = [self.text_encoder.lmao(text) for text in texts]

        cause_pairs = self.data[idx]['emotion-cause_pairs']
        useful_utterances = set([int(cause_pair[1]) for cause_pair in cause_pairs])
        cause_labels = []
        for utterance in conversation:
            if utterance['utterance_ID'] in useful_utterances:
                cause_labels.append(1)
            else:
                cause_labels.append(0)
        
        # Pad or truncate conversations to the maximum sequence length
        if len(conversation) < self.max_seq_len:
            pad_length = self.max_seq_len - len(conversation)
            audio_embeddings += [torch.zeros_like(audio_embeddings[0])] * pad_length
            video_embeddings += [torch.zeros_like(video_embeddings[0])] * pad_length
            text_embeddings += [torch.zeros_like(text_embeddings[0])] * pad_length
            cause_labels += [-1] * pad_length
            pad_mask = [1] * len(conversation) + [0] * pad_length
        else:
            audio_embeddings = audio_embeddings[:self.max_seq_len]
            video_embeddings = video_embeddings[:self.max_seq_len]
            text_embeddings = text_embeddings[:self.max_seq_len]
            cause_labels = cause_labels[:self.max_seq_len]
            pad_mask = [1] * self.max_seq_len

        audio_embeddings = torch.stack(audio_embeddings)
        video_embeddings = torch.stack(video_embeddings)
        text_embeddings = torch.stack(text_embeddings)
        cause_labels = torch.from_numpy(np.array(cause_labels))
        pad_mask = torch.from_numpy(np.array(pad_mask))
        
        return {
            'audio': audio_embeddings,
            'video': video_embeddings,
            'text': text_embeddings,
            'cause_labels': cause_labels,
            'pad_mask': pad_mask
        }

# Example usage
# You need to define your audio, video, and text encoders accordingly

# Define your data paths
# DATA_DIR = "/tmp/semeval24_task3"

# AUDIO_EMBEDDINGS_FILEPATH = "/tmp/semeval24_task3/og_paper_embeddings/audio_embedding_6373.npy"
# VIDEO_EMBEDDINGS_FILEPATH = "/tmp/semeval24_task3/og_paper_embeddings/video_embedding_4096.npy"
# TEXT_EMBEDDINGS_FILEPATH = os.path.join(DATA_DIR, "text_embeddings", "text_embeddings_bert_base.pkl")

audio_encoder = YourAudioEncoder(AUDIO_EMBEDDINGS_FILEPATH)
video_encoder = YourVideoEncoder(VIDEO_EMBEDDINGS_FILEPATH)
text_encoder = YourTextEncoder(TEXT_EMBEDDINGS_FILEPATH)
max_seq_len = 35  # Adjust this according to your needs

train_dataset = ConversationDataset(TRAIN_FILE_PATH, audio_encoder, video_encoder, text_encoder, max_seq_len)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

validation_dataset = ConversationDataset(VALIDATION_FILE_PATH, audio_encoder, video_encoder, text_encoder, max_seq_len)
validation_dataloader = DataLoader(validation_dataset, batch_size=16, shuffle=True)
# Example of iterating through batches
# for batch in dataloader:
#     audio = batch['audio']  # Shape: (batch_size, max_seq_len, audio_embedding_size)
#     video = batch['video']  # Shape: (batch_size, max_seq_len, video_embedding_size)
#     text = batch['text']    # Shape: (batch_size, max_seq_len, text_embedding_size)
#     cause_labels = batch['cause_labels']  # List of emotion labels for each utterance in the batch


In [8]:
# import torch
# import torch.nn as nn
# from torch.nn import TransformerEncoder, TransformerEncoderLayer

# class EmotionClassifier(nn.Module):
#     def __init__(self, input_size, hidden_size, num_layers, num_heads, dropout, num_emotions):
#         super(EmotionClassifier, self).__init__()
        
#         self.first_linear = nn.Linear(input_size, hidden_size, dtype=torch.float32)

#         self.transformer_encoder = TransformerEncoder(
#             TransformerEncoderLayer(hidden_size, num_heads, hidden_size, dropout),
#             num_layers
#         )
        
#         self.linear = nn.Linear(hidden_size, num_emotions)

#     def forward(self, audio_encoding, video_encoding, text_encoding):

#         # Concatenate or combine the audio, video, and text encodings here
#         # You can use any method like concatenation, addition, or other fusion techniques
#         # Combine the encodings (you can customize this part)
#         audio_encoding = audio_encoding.float()
#         video_encoding = video_encoding.float()
#         text_encoding = text_encoding.float().squeeze()
#         combined_encoding = torch.cat((audio_encoding, video_encoding, text_encoding), dim=2)
        
#         combined_encoding = self.first_linear(combined_encoding)
        
        
#         combined_encoding = combined_encoding.permute(1, 0, 2)  # Transformer expects (seq_len, batch_size, input_size)
        
        
#         transformer_output = self.transformer_encoder(combined_encoding)

#         # Take the output of the Transformer encoder for the last position as the summary
#         emotion_logits = self.linear(transformer_output.permute(1, 0, 2))

#         return emotion_logits

In [9]:
import torch
import torch.nn as nn

class EmotionClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout, num_emotions, embedding_dropout=0.2):
        super(EmotionClassifier, self).__init__()
        
        self.audio_dropout = nn.Dropout(embedding_dropout)
        self.video_dropout = nn.Dropout(embedding_dropout)
        self.text_dropout = nn.Dropout(embedding_dropout)

        # self.first_linear = nn.Linear(input_size, hidden_size, dtype=torch.float32)
        self.relu = nn.ReLU()
        
        # self.second_linear_layer = nn.Linear(hidden_size, hidden_size, dtype=torch.float32)
        # Replace Transformer with BiLSTM
        self.bilstm = nn.LSTM(input_size, input_size // 2, num_layers, 
                              dropout=dropout, bidirectional=True, batch_first=True)
        
        self.linear = nn.Linear(input_size, hidden_size)
        self.final_linear = nn.Linear(hidden_size, num_emotions)

    def forward(self, audio_encoding, video_encoding, text_encoding):
        # Concatenate or combine the audio, video, and text encodings
        audio_encoding = audio_encoding.float()
        video_encoding = video_encoding.float()
        text_encoding = text_encoding.float().squeeze()
        
        audio_encoding = self.audio_dropout(audio_encoding)
        video_encoding = self.video_dropout(video_encoding)
        text_encoding = self.text_dropout(text_encoding)
        
        combined_encoding = torch.cat((audio_encoding, video_encoding, text_encoding), dim=2)
        
        # Pass through BiLSTM
        lstm_output, _ = self.bilstm(combined_encoding)

        # Take the output of the BiLSTM
        emotion_logits = self.linear(lstm_output)
        emotion_logits = self.relu(emotion_logits)
        emotion_logits = self.final_linear(emotion_logits)
        # Apply a softmax layer
        emotion_logits = torch.softmax(emotion_logits, dim=2)

        return emotion_logits

In [10]:
# import torch
# import torch.nn as nn

# class EmotionClassifier(nn.Module):
#     def __init__(self, input_size, hidden_size, num_layers, dropout, num_emotions, embedding_dropout=0.2):
#         super(EmotionClassifier, self).__init__()
        
#         self.audio_dropout = nn.Dropout(embedding_dropout)
#         self.video_dropout = nn.Dropout(embedding_dropout)
#         self.text_dropout = nn.Dropout(embedding_dropout)

#         self.first_linear = nn.Linear(input_size, hidden_size, dtype=torch.float32)
#         self.relu = nn.ReLU()
        
#         self.second_linear_layer = nn.Linear(hidden_size, hidden_size, dtype=torch.float32)
#         # Replace Transformer with BiLSTM
#         self.bilstm = nn.LSTM(hidden_size, hidden_size // 2, num_layers, 
#                               dropout=dropout, bidirectional=True, batch_first=True)
        
#         self.linear = nn.Linear(hidden_size, num_emotions)

#     def forward(self, audio_encoding, video_encoding, text_encoding):
#         # Concatenate or combine the audio, video, and text encodings
#         audio_encoding = audio_encoding.float()
#         video_encoding = video_encoding.float()
#         text_encoding = text_encoding.float().squeeze()
        
#         audio_encoding = self.audio_dropout(audio_encoding)
#         video_encoding = self.video_dropout(video_encoding)
#         text_encoding = self.text_dropout(text_encoding)
        
#         combined_encoding = torch.cat((audio_encoding, video_encoding, text_encoding), dim=2)
        
#         combined_encoding = self.first_linear(combined_encoding)
#         combined_encoding = self.relu(combined_encoding)
#         combined_encoding = self.second_linear_layer(combined_encoding)
        
#         # Pass through BiLSTM
#         lstm_output, _ = self.bilstm(combined_encoding)

#         # Take the output of the BiLSTM
#         emotion_logits = self.linear(lstm_output)
#         # Apply a softmax layer
#         emotion_logits = torch.softmax(emotion_logits, dim=2)

#         return emotion_logits

In [11]:
from torch.optim import AdamW
from tqdm import tqdm
from sklearn.metrics import classification_report
from transformers import get_linear_schedule_with_warmup

# Define your model
model = EmotionClassifier(input_size=768*3, hidden_size=2000, num_emotions=2, num_layers=3, dropout=0.3)
model.to("cuda:1")

num_epochs = 40
# Define your loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = AdamW(model.parameters(), lr=0.5*1e-5)
total_steps = len(train_dataloader) * num_epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

# Define training parameters

# Training loop
best_model_file = None
best_val_loss = float('inf')
best_epoch = -1
best_classification_report = None

print(AUDIO_EMBEDDINGS_FILEPATH)
print(TEXT_EMBEDDINGS_FILEPATH)
print(VIDEO_EMBEDDINGS_FILEPATH)
for epoch in (range(num_epochs)):
    model.train()  # Set the model to training mode
    total_loss = 0.0
    total_tokens = 0
    total_correct = 0
    total_predictions = 0

    for batch in tqdm(train_dataloader):  # Assuming you have a DataLoader for your dataset
        # Extract data from the batch
        audio = batch['audio'].to('cuda:1')
        video = batch['video'].to('cuda:1')
        text = batch['text'].to('cuda:1')
        emotion_indices = batch['cause_labels'].to('cuda:1')
        pad_mask = batch['pad_mask'].to('cuda:1')

        # Forward pass
        emotion_logits = model(audio, video, text)

        # Reshape emotion_logits
        emotion_logits = emotion_logits.view(-1, emotion_logits.size(-1))

        # Flatten emotion_indices (assuming it's a 2D tensor with shape [batch_size, max_sequence_length])
        emotion_indices = emotion_indices.view(-1)

        # Calculate a mask to exclude padded positions from the loss
        pad_mask = pad_mask.view(-1)     

        # Calculate the loss, excluding padded positions
        loss = criterion(emotion_logits, emotion_indices)
        # masked_loss = torch.sum(loss * pad_mask) / torch.sum(pad_mask)
        masked_loss = loss

        # Backpropagation and optimization
        optimizer.zero_grad()
        masked_loss.backward()
        optimizer.step()

        total_loss += masked_loss.item()
        total_tokens += torch.sum(pad_mask).item()
        
        predicted_emotions = torch.argmax(emotion_logits, dim=1)
        correct_predictions = ((predicted_emotions == emotion_indices) * pad_mask).sum().item()

        total_correct += correct_predictions
        total_predictions += torch.sum(pad_mask).item()  # Batch size

    scheduler.step()
    
    model.eval()  # Set the model to evaluation mode
    total_val_loss = 0.0
    total_val_tokens = 0
    total_val_correct = 0
    total_val_predictions = 0
    true_labels = []
    predicted_labels = []
    padded_labels = []
    
    with torch.no_grad():
        for val_batch in tqdm(validation_dataloader):
            audio = val_batch['audio'].to('cuda:1')
            video = val_batch['video'].to('cuda:1')
            text = val_batch['text'].to('cuda:1')
            emotion_indices = val_batch['cause_labels'].to('cuda:1')
            pad_mask = val_batch['pad_mask'].to('cuda:1')

            emotion_logits = model(audio, video, text)

            # Reshape emotion_logits
            emotion_logits = emotion_logits.view(-1, emotion_logits.size(-1))

            # Flatten emotion_indices (assuming it's a 2D tensor with shape [batch_size, max_sequence_length])
            emotion_indices = emotion_indices.view(-1)

            pad_mask = pad_mask.view(-1)   

            # Calculate the loss, excluding padded positions
            val_loss = criterion(emotion_logits, emotion_indices)
            masked_loss = torch.sum(val_loss * pad_mask) / torch.sum(pad_mask)
            
            total_val_loss += masked_loss.item()
            total_val_tokens += torch.sum(pad_mask).item()
            
            predicted_emotions_val = torch.argmax(emotion_logits, dim=1)
            correct_predictions_val = ((predicted_emotions_val == emotion_indices) * pad_mask).sum().item()
            total_val_correct += correct_predictions_val
            total_val_predictions += torch.sum(pad_mask).item()

            # Store true and predicted labels for F1 score calculation
            true_labels.extend(emotion_indices.cpu().numpy())
            predicted_labels.extend(predicted_emotions_val.cpu().numpy())
            padded_labels.extend(pad_mask.cpu().numpy())

    final_true_labels = [label for label, pad in zip(true_labels, padded_labels) if pad == 1]
    final_predicted_labels = [label for label, pad in zip(predicted_labels, padded_labels) if pad == 1]
    classification_rep = classification_report(final_true_labels, final_predicted_labels)

    # Calculate and print the average loss for this epoch
    avg_loss = total_loss / total_tokens
    avg_val_loss = total_val_loss / total_val_tokens
    print(f"Epoch [{epoch}/{num_epochs}] Training Loss: {avg_loss}")
    print(f"Epoch [{epoch}/{num_epochs}] Validation Loss: {avg_val_loss}")
    print(f"Epoch [{epoch}/{num_epochs}] Classification Report:\n{classification_rep}")
    print(f"Epoch [{epoch}/{num_epochs}] Accuracy: {total_correct / total_predictions:.4f}")
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch
        best_classification_report = classification_rep
        best_model_file = f"/tmp/semeval24_task3/baseline_models/cause_models/best_cause_model.pt"
        torch.save(model.state_dict(), best_model_file)

    torch.save(model.state_dict(), f"/tmp/semeval24_task3/baseline_models/cause_models/cause_model_{epoch:02}.pt")

print("Training complete!")
print("=======================================")
print("BEST MODEL")
print(f"Best epoch: {best_epoch}")
print(f"Best validation loss: {best_val_loss}")
print(f"Best classification report:\n{best_classification_report}")
print("=======================================")

/tmp/semeval24_task3/audio_embeddings/audio_embeddings_microsoft_wavlm-base-plus-sd.pkl
/tmp/semeval24_task3/text_embeddings/text_embeddings_roberta_base_emotion.pkl
/tmp/semeval24_task3/video_embeddings/final_embeddings.pkl


100%|██████████| 78/78 [00:06<00:00, 11.95it/s]
100%|██████████| 9/9 [00:00<00:00, 27.91it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch [0/40] Training Loss: 0.004426745894076799
Epoch [0/40] Validation Loss: 0.004305580216977332
Epoch [0/40] Classification Report:
              precision    recall  f1-score   support

           0       0.53      1.00      0.69       759
           1       0.00      0.00      0.00       681

    accuracy                           0.53      1440
   macro avg       0.26      0.50      0.35      1440
weighted avg       0.28      0.53      0.36      1440

Epoch [0/40] Accuracy: 0.5211


100%|██████████| 78/78 [00:06<00:00, 12.01it/s]
100%|██████████| 9/9 [00:00<00:00, 26.68it/s]


Epoch [1/40] Training Loss: 0.004396225355051061
Epoch [1/40] Validation Loss: 0.0042608482970131764
Epoch [1/40] Classification Report:
              precision    recall  f1-score   support

           0       0.58      0.79      0.67       759
           1       0.61      0.36      0.45       681

    accuracy                           0.59      1440
   macro avg       0.59      0.58      0.56      1440
weighted avg       0.59      0.59      0.57      1440

Epoch [1/40] Accuracy: 0.5473


100%|██████████| 78/78 [00:06<00:00, 11.58it/s]
100%|██████████| 9/9 [00:00<00:00, 25.92it/s]


Epoch [2/40] Training Loss: 0.00431829717850507
Epoch [2/40] Validation Loss: 0.004196642711758613
Epoch [2/40] Classification Report:
              precision    recall  f1-score   support

           0       0.59      0.72      0.65       759
           1       0.59      0.45      0.51       681

    accuracy                           0.59      1440
   macro avg       0.59      0.59      0.58      1440
weighted avg       0.59      0.59      0.59      1440

Epoch [2/40] Accuracy: 0.5918


100%|██████████| 78/78 [00:06<00:00, 11.99it/s]
100%|██████████| 9/9 [00:00<00:00, 27.13it/s]


Epoch [3/40] Training Loss: 0.004265415207467963
Epoch [3/40] Validation Loss: 0.004198545010553466
Epoch [3/40] Classification Report:
              precision    recall  f1-score   support

           0       0.63      0.49      0.55       759
           1       0.54      0.68      0.60       681

    accuracy                           0.58      1440
   macro avg       0.59      0.58      0.58      1440
weighted avg       0.59      0.58      0.58      1440

Epoch [3/40] Accuracy: 0.5989


100%|██████████| 78/78 [00:06<00:00, 12.19it/s]
100%|██████████| 9/9 [00:00<00:00, 28.90it/s]


Epoch [4/40] Training Loss: 0.004227837168217998
Epoch [4/40] Validation Loss: 0.00416704879866706
Epoch [4/40] Classification Report:
              precision    recall  f1-score   support

           0       0.62      0.57      0.59       759
           1       0.56      0.61      0.59       681

    accuracy                           0.59      1440
   macro avg       0.59      0.59      0.59      1440
weighted avg       0.59      0.59      0.59      1440

Epoch [4/40] Accuracy: 0.6096


100%|██████████| 78/78 [00:06<00:00, 11.77it/s]
100%|██████████| 9/9 [00:00<00:00, 25.59it/s]


Epoch [5/40] Training Loss: 0.004200210466752137
Epoch [5/40] Validation Loss: 0.004144492497046789
Epoch [5/40] Classification Report:
              precision    recall  f1-score   support

           0       0.61      0.68      0.64       759
           1       0.59      0.51      0.55       681

    accuracy                           0.60      1440
   macro avg       0.60      0.59      0.59      1440
weighted avg       0.60      0.60      0.60      1440

Epoch [5/40] Accuracy: 0.6169


100%|██████████| 78/78 [00:06<00:00, 11.75it/s]
100%|██████████| 9/9 [00:00<00:00, 27.54it/s]


Epoch [6/40] Training Loss: 0.004179369963804033
Epoch [6/40] Validation Loss: 0.004121405300166872
Epoch [6/40] Classification Report:
              precision    recall  f1-score   support

           0       0.61      0.70      0.65       759
           1       0.60      0.49      0.54       681

    accuracy                           0.60      1440
   macro avg       0.60      0.60      0.60      1440
weighted avg       0.60      0.60      0.60      1440

Epoch [6/40] Accuracy: 0.6235


100%|██████████| 78/78 [00:06<00:00, 11.82it/s]
100%|██████████| 9/9 [00:00<00:00, 26.16it/s]


Epoch [7/40] Training Loss: 0.004154616133784394
Epoch [7/40] Validation Loss: 0.004085895129375987
Epoch [7/40] Classification Report:
              precision    recall  f1-score   support

           0       0.64      0.62      0.63       759
           1       0.59      0.62      0.61       681

    accuracy                           0.62      1440
   macro avg       0.62      0.62      0.62      1440
weighted avg       0.62      0.62      0.62      1440

Epoch [7/40] Accuracy: 0.6285


100%|██████████| 78/78 [00:06<00:00, 11.61it/s]
100%|██████████| 9/9 [00:00<00:00, 25.58it/s]


Epoch [8/40] Training Loss: 0.004115956572928798
Epoch [8/40] Validation Loss: 0.00408528277443515
Epoch [8/40] Classification Report:
              precision    recall  f1-score   support

           0       0.62      0.70      0.66       759
           1       0.61      0.51      0.55       681

    accuracy                           0.61      1440
   macro avg       0.61      0.61      0.61      1440
weighted avg       0.61      0.61      0.61      1440

Epoch [8/40] Accuracy: 0.6362


100%|██████████| 78/78 [00:06<00:00, 11.77it/s]
100%|██████████| 9/9 [00:00<00:00, 27.27it/s]


Epoch [9/40] Training Loss: 0.004088742934245078
Epoch [9/40] Validation Loss: 0.004057255842619472
Epoch [9/40] Classification Report:
              precision    recall  f1-score   support

           0       0.66      0.62      0.64       759
           1       0.60      0.64      0.62       681

    accuracy                           0.63      1440
   macro avg       0.63      0.63      0.63      1440
weighted avg       0.63      0.63      0.63      1440

Epoch [9/40] Accuracy: 0.6460


100%|██████████| 78/78 [00:06<00:00, 11.81it/s]
100%|██████████| 9/9 [00:00<00:00, 27.26it/s]


Epoch [10/40] Training Loss: 0.0040467780685080265
Epoch [10/40] Validation Loss: 0.00403335653245449
Epoch [10/40] Classification Report:
              precision    recall  f1-score   support

           0       0.67      0.58      0.62       759
           1       0.59      0.68      0.63       681

    accuracy                           0.62      1440
   macro avg       0.63      0.63      0.62      1440
weighted avg       0.63      0.62      0.62      1440

Epoch [10/40] Accuracy: 0.6492


100%|██████████| 78/78 [00:06<00:00, 12.04it/s]
100%|██████████| 9/9 [00:00<00:00, 27.51it/s]


Epoch [11/40] Training Loss: 0.004007200975901231
Epoch [11/40] Validation Loss: 0.004000326784120665
Epoch [11/40] Classification Report:
              precision    recall  f1-score   support

           0       0.66      0.65      0.65       759
           1       0.62      0.63      0.63       681

    accuracy                           0.64      1440
   macro avg       0.64      0.64      0.64      1440
weighted avg       0.64      0.64      0.64      1440

Epoch [11/40] Accuracy: 0.6545


100%|██████████| 78/78 [00:06<00:00, 12.03it/s]
100%|██████████| 9/9 [00:00<00:00, 28.00it/s]


Epoch [12/40] Training Loss: 0.003982924952061856
Epoch [12/40] Validation Loss: 0.003956558886501524
Epoch [12/40] Classification Report:
              precision    recall  f1-score   support

           0       0.66      0.69      0.67       759
           1       0.63      0.60      0.62       681

    accuracy                           0.65      1440
   macro avg       0.65      0.64      0.64      1440
weighted avg       0.65      0.65      0.65      1440

Epoch [12/40] Accuracy: 0.6620


100%|██████████| 78/78 [00:06<00:00, 11.94it/s]
100%|██████████| 9/9 [00:00<00:00, 26.99it/s]


Epoch [13/40] Training Loss: 0.003932579073258755
Epoch [13/40] Validation Loss: 0.003946219839983516
Epoch [13/40] Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.67      0.67       759
           1       0.63      0.65      0.64       681

    accuracy                           0.66      1440
   macro avg       0.66      0.66      0.66      1440
weighted avg       0.66      0.66      0.66      1440

Epoch [13/40] Accuracy: 0.6729


100%|██████████| 78/78 [00:06<00:00, 11.60it/s]
100%|██████████| 9/9 [00:00<00:00, 26.60it/s]


Epoch [14/40] Training Loss: 0.003916620273427214
Epoch [14/40] Validation Loss: 0.003932883217930794
Epoch [14/40] Classification Report:
              precision    recall  f1-score   support

           0       0.70      0.61      0.65       759
           1       0.62      0.70      0.66       681

    accuracy                           0.66      1440
   macro avg       0.66      0.66      0.66      1440
weighted avg       0.66      0.66      0.66      1440

Epoch [14/40] Accuracy: 0.6726


100%|██████████| 78/78 [00:06<00:00, 11.96it/s]
100%|██████████| 9/9 [00:00<00:00, 27.99it/s]


Epoch [15/40] Training Loss: 0.0038630178687470485
Epoch [15/40] Validation Loss: 0.003978284572561582
Epoch [15/40] Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.55      0.62       759
           1       0.60      0.75      0.67       681

    accuracy                           0.64      1440
   macro avg       0.65      0.65      0.64      1440
weighted avg       0.66      0.64      0.64      1440

Epoch [15/40] Accuracy: 0.6877


100%|██████████| 78/78 [00:06<00:00, 11.79it/s]
100%|██████████| 9/9 [00:00<00:00, 27.01it/s]


Epoch [16/40] Training Loss: 0.0038510722982461588
Epoch [16/40] Validation Loss: 0.003989286803536945
Epoch [16/40] Classification Report:
              precision    recall  f1-score   support

           0       0.73      0.50      0.59       759
           1       0.59      0.80      0.68       681

    accuracy                           0.64      1440
   macro avg       0.66      0.65      0.63      1440
weighted avg       0.66      0.64      0.63      1440

Epoch [16/40] Accuracy: 0.6902


100%|██████████| 78/78 [00:06<00:00, 11.87it/s]
100%|██████████| 9/9 [00:00<00:00, 27.26it/s]


Epoch [17/40] Training Loss: 0.003813076781780225
Epoch [17/40] Validation Loss: 0.0038844406604766845
Epoch [17/40] Classification Report:
              precision    recall  f1-score   support

           0       0.69      0.68      0.68       759
           1       0.64      0.65      0.65       681

    accuracy                           0.67      1440
   macro avg       0.67      0.67      0.67      1440
weighted avg       0.67      0.67      0.67      1440

Epoch [17/40] Accuracy: 0.6996


100%|██████████| 78/78 [00:06<00:00, 11.80it/s]
100%|██████████| 9/9 [00:00<00:00, 26.76it/s]


Epoch [18/40] Training Loss: 0.003786401155889793
Epoch [18/40] Validation Loss: 0.003889555732409159
Epoch [18/40] Classification Report:
              precision    recall  f1-score   support

           0       0.66      0.72      0.69       759
           1       0.66      0.59      0.62       681

    accuracy                           0.66      1440
   macro avg       0.66      0.66      0.66      1440
weighted avg       0.66      0.66      0.66      1440

Epoch [18/40] Accuracy: 0.7046


100%|██████████| 78/78 [00:06<00:00, 11.78it/s]
100%|██████████| 9/9 [00:00<00:00, 26.44it/s]


Epoch [19/40] Training Loss: 0.0037568745251721435
Epoch [19/40] Validation Loss: 0.003858761323822869
Epoch [19/40] Classification Report:
              precision    recall  f1-score   support

           0       0.70      0.63      0.66       759
           1       0.63      0.70      0.66       681

    accuracy                           0.66      1440
   macro avg       0.66      0.66      0.66      1440
weighted avg       0.67      0.66      0.66      1440

Epoch [19/40] Accuracy: 0.7125


100%|██████████| 78/78 [00:06<00:00, 11.61it/s]
100%|██████████| 9/9 [00:00<00:00, 25.30it/s]


Epoch [20/40] Training Loss: 0.003721628726257931
Epoch [20/40] Validation Loss: 0.003875546157360077
Epoch [20/40] Classification Report:
              precision    recall  f1-score   support

           0       0.66      0.71      0.69       759
           1       0.65      0.60      0.62       681

    accuracy                           0.66      1440
   macro avg       0.66      0.66      0.66      1440
weighted avg       0.66      0.66      0.66      1440

Epoch [20/40] Accuracy: 0.7134


100%|██████████| 78/78 [00:06<00:00, 11.70it/s]
100%|██████████| 9/9 [00:00<00:00, 26.74it/s]


Epoch [21/40] Training Loss: 0.003713920410263371
Epoch [21/40] Validation Loss: 0.003867802189456092
Epoch [21/40] Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.57      0.64       759
           1       0.61      0.75      0.68       681

    accuracy                           0.66      1440
   macro avg       0.67      0.66      0.66      1440
weighted avg       0.67      0.66      0.66      1440

Epoch [21/40] Accuracy: 0.7163


100%|██████████| 78/78 [00:06<00:00, 11.80it/s]
100%|██████████| 9/9 [00:00<00:00, 27.07it/s]


Epoch [22/40] Training Loss: 0.0036710917641351315
Epoch [22/40] Validation Loss: 0.0038430941187673146
Epoch [22/40] Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.68      0.68       759
           1       0.64      0.65      0.65       681

    accuracy                           0.66      1440
   macro avg       0.66      0.66      0.66      1440
weighted avg       0.66      0.66      0.66      1440

Epoch [22/40] Accuracy: 0.7222


100%|██████████| 78/78 [00:06<00:00, 11.83it/s]
100%|██████████| 9/9 [00:00<00:00, 26.85it/s]


Epoch [23/40] Training Loss: 0.0036402509951299804
Epoch [23/40] Validation Loss: 0.003844314192732175
Epoch [23/40] Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.71      0.70       759
           1       0.66      0.63      0.65       681

    accuracy                           0.67      1440
   macro avg       0.67      0.67      0.67      1440
weighted avg       0.67      0.67      0.67      1440

Epoch [23/40] Accuracy: 0.7313


100%|██████████| 78/78 [00:06<00:00, 11.79it/s]
100%|██████████| 9/9 [00:00<00:00, 27.08it/s]


Epoch [24/40] Training Loss: 0.0036460778030799872
Epoch [24/40] Validation Loss: 0.0038611435227923923
Epoch [24/40] Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.60      0.66       759
           1       0.63      0.74      0.68       681

    accuracy                           0.67      1440
   macro avg       0.67      0.67      0.67      1440
weighted avg       0.67      0.67      0.67      1440

Epoch [24/40] Accuracy: 0.7293


100%|██████████| 78/78 [00:06<00:00, 11.92it/s]
100%|██████████| 9/9 [00:00<00:00, 27.75it/s]


Epoch [25/40] Training Loss: 0.0036239902005523322
Epoch [25/40] Validation Loss: 0.003840582693616549
Epoch [25/40] Classification Report:
              precision    recall  f1-score   support

           0       0.69      0.65      0.67       759
           1       0.64      0.68      0.66       681

    accuracy                           0.66      1440
   macro avg       0.67      0.67      0.66      1440
weighted avg       0.67      0.66      0.66      1440

Epoch [25/40] Accuracy: 0.7345


100%|██████████| 78/78 [00:06<00:00, 11.65it/s]
100%|██████████| 9/9 [00:00<00:00, 25.69it/s]


Epoch [26/40] Training Loss: 0.003604227152253638
Epoch [26/40] Validation Loss: 0.0038413347055514654
Epoch [26/40] Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.63      0.66       759
           1       0.63      0.71      0.67       681

    accuracy                           0.67      1440
   macro avg       0.67      0.67      0.67      1440
weighted avg       0.67      0.67      0.67      1440

Epoch [26/40] Accuracy: 0.7393


100%|██████████| 78/78 [00:06<00:00, 11.88it/s]
100%|██████████| 9/9 [00:00<00:00, 27.39it/s]


Epoch [27/40] Training Loss: 0.0035760182985502646
Epoch [27/40] Validation Loss: 0.003845470564232932
Epoch [27/40] Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.61      0.66       759
           1       0.63      0.73      0.68       681

    accuracy                           0.67      1440
   macro avg       0.68      0.67      0.67      1440
weighted avg       0.68      0.67      0.67      1440

Epoch [27/40] Accuracy: 0.7459


100%|██████████| 78/78 [00:06<00:00, 11.82it/s]
100%|██████████| 9/9 [00:00<00:00, 26.91it/s]


Epoch [28/40] Training Loss: 0.003560076578956038
Epoch [28/40] Validation Loss: 0.003819638200932079
Epoch [28/40] Classification Report:
              precision    recall  f1-score   support

           0       0.69      0.67      0.68       759
           1       0.64      0.67      0.66       681

    accuracy                           0.67      1440
   macro avg       0.67      0.67      0.67      1440
weighted avg       0.67      0.67      0.67      1440

Epoch [28/40] Accuracy: 0.7444


100%|██████████| 78/78 [00:06<00:00, 11.80it/s]
100%|██████████| 9/9 [00:00<00:00, 26.97it/s]


Epoch [29/40] Training Loss: 0.0035544291183918957
Epoch [29/40] Validation Loss: 0.003867849831779798
Epoch [29/40] Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.59      0.65       759
           1       0.62      0.75      0.68       681

    accuracy                           0.66      1440
   macro avg       0.67      0.67      0.66      1440
weighted avg       0.67      0.66      0.66      1440

Epoch [29/40] Accuracy: 0.7437


100%|██████████| 78/78 [00:06<00:00, 11.76it/s]
100%|██████████| 9/9 [00:00<00:00, 26.77it/s]


Epoch [30/40] Training Loss: 0.0035094984089254462
Epoch [30/40] Validation Loss: 0.0038811109132236904
Epoch [30/40] Classification Report:
              precision    recall  f1-score   support

           0       0.74      0.58      0.65       759
           1       0.63      0.78      0.69       681

    accuracy                           0.67      1440
   macro avg       0.68      0.68      0.67      1440
weighted avg       0.69      0.67      0.67      1440

Epoch [30/40] Accuracy: 0.7555


100%|██████████| 78/78 [00:06<00:00, 11.71it/s]
100%|██████████| 9/9 [00:00<00:00, 26.55it/s]


Epoch [31/40] Training Loss: 0.0035137739960847203
Epoch [31/40] Validation Loss: 0.0038964255816406673
Epoch [31/40] Classification Report:
              precision    recall  f1-score   support

           0       0.75      0.56      0.64       759
           1       0.62      0.79      0.69       681

    accuracy                           0.67      1440
   macro avg       0.68      0.68      0.67      1440
weighted avg       0.69      0.67      0.67      1440

Epoch [31/40] Accuracy: 0.7547


100%|██████████| 78/78 [00:06<00:00, 11.70it/s]
100%|██████████| 9/9 [00:00<00:00, 26.23it/s]


Epoch [32/40] Training Loss: 0.003495307313328023
Epoch [32/40] Validation Loss: 0.003903319231337971
Epoch [32/40] Classification Report:
              precision    recall  f1-score   support

           0       0.64      0.75      0.70       759
           1       0.66      0.54      0.59       681

    accuracy                           0.65      1440
   macro avg       0.65      0.65      0.64      1440
weighted avg       0.65      0.65      0.65      1440

Epoch [32/40] Accuracy: 0.7578


100%|██████████| 78/78 [00:06<00:00, 11.77it/s]
100%|██████████| 9/9 [00:00<00:00, 27.35it/s]


Epoch [33/40] Training Loss: 0.00346962545137118
Epoch [33/40] Validation Loss: 0.003838244370288319
Epoch [33/40] Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.66      0.68       759
           1       0.65      0.69      0.67       681

    accuracy                           0.68      1440
   macro avg       0.68      0.68      0.68      1440
weighted avg       0.68      0.68      0.68      1440

Epoch [33/40] Accuracy: 0.7622


100%|██████████| 78/78 [00:06<00:00, 11.75it/s]
100%|██████████| 9/9 [00:00<00:00, 26.61it/s]


Epoch [34/40] Training Loss: 0.0034727444201622195
Epoch [34/40] Validation Loss: 0.0038300053526957828
Epoch [34/40] Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.70      0.69       759
           1       0.66      0.64      0.65       681

    accuracy                           0.67      1440
   macro avg       0.67      0.67      0.67      1440
weighted avg       0.67      0.67      0.67      1440

Epoch [34/40] Accuracy: 0.7599


100%|██████████| 78/78 [00:06<00:00, 11.95it/s]
100%|██████████| 9/9 [00:00<00:00, 27.53it/s]


Epoch [35/40] Training Loss: 0.0034514303982223584
Epoch [35/40] Validation Loss: 0.0038351755175325608
Epoch [35/40] Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.65      0.68       759
           1       0.64      0.70      0.67       681

    accuracy                           0.67      1440
   macro avg       0.67      0.68      0.67      1440
weighted avg       0.68      0.67      0.67      1440

Epoch [35/40] Accuracy: 0.7653


100%|██████████| 78/78 [00:06<00:00, 11.74it/s]
100%|██████████| 9/9 [00:00<00:00, 26.66it/s]


Epoch [36/40] Training Loss: 0.00341665865175144
Epoch [36/40] Validation Loss: 0.0038485397895177205
Epoch [36/40] Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.65      0.68       759
           1       0.64      0.70      0.67       681

    accuracy                           0.68      1440
   macro avg       0.68      0.68      0.68      1440
weighted avg       0.68      0.68      0.68      1440

Epoch [36/40] Accuracy: 0.7694


100%|██████████| 78/78 [00:06<00:00, 11.88it/s]
100%|██████████| 9/9 [00:00<00:00, 27.66it/s]


Epoch [37/40] Training Loss: 0.0033954001234536768
Epoch [37/40] Validation Loss: 0.0038499570141235988
Epoch [37/40] Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.64      0.68       759
           1       0.64      0.72      0.68       681

    accuracy                           0.68      1440
   macro avg       0.68      0.68      0.68      1440
weighted avg       0.68      0.68      0.68      1440

Epoch [37/40] Accuracy: 0.7777


100%|██████████| 78/78 [00:06<00:00, 11.92it/s]
100%|██████████| 9/9 [00:00<00:00, 27.21it/s]


Epoch [38/40] Training Loss: 0.003404490015308829
Epoch [38/40] Validation Loss: 0.0038540390216641957
Epoch [38/40] Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.71      0.70       759
           1       0.66      0.64      0.65       681

    accuracy                           0.67      1440
   macro avg       0.67      0.67      0.67      1440
weighted avg       0.67      0.67      0.67      1440

Epoch [38/40] Accuracy: 0.7719


100%|██████████| 78/78 [00:06<00:00, 11.54it/s]
100%|██████████| 9/9 [00:00<00:00, 25.34it/s]


Epoch [39/40] Training Loss: 0.003408442734595306
Epoch [39/40] Validation Loss: 0.003964760485622618
Epoch [39/40] Classification Report:
              precision    recall  f1-score   support

           0       0.63      0.80      0.70       759
           1       0.68      0.47      0.56       681

    accuracy                           0.64      1440
   macro avg       0.65      0.64      0.63      1440
weighted avg       0.65      0.64      0.63      1440

Epoch [39/40] Accuracy: 0.7707
Training complete!
BEST MODEL
Best epoch: 28
Best validation loss: 0.003819638200932079
Best classification report:
              precision    recall  f1-score   support

           0       0.69      0.67      0.68       759
           1       0.64      0.67      0.66       681

    accuracy                           0.67      1440
   macro avg       0.67      0.67      0.67      1440
weighted avg       0.67      0.67      0.67      1440

