In [1]:
import json
from pprint import pprint
TRAIN_FILE_PATH = "/home2/suyash.mathur/final_clean_data/train/Subtask_2.json"
VALIDATION_FILE_PATH = "/home2/suyash.mathur/final_clean_data/val/Subtask_2.json"
with open(TRAIN_FILE_PATH) as f:
    train_data = json.load(f)
with open(VALIDATION_FILE_PATH) as f:
    validation_data = json.load(f)

pprint(len(train_data))
pprint(len(validation_data))

1236
138


In [2]:
class EmotionIndexer:
    def __init__(self):
        self.emotion_to_index = {
            'joy': 0,
            'sadness': 1,
            'anger': 2,
            'neutral': 3,
            'surprise': 4,
            'disgust': 5,
            'fear': 6,
            'pad': 7,
        }
        self.emotion_freq = [0]*7
        self.weights = None

        self.index_to_emotion = {index: emotion for emotion, index in self.emotion_to_index.items()}

    def emotion_to_idx(self, emotion):
        return self.emotion_to_index.get(emotion, None)

    def idx_to_emotion(self, index):
        return self.index_to_emotion.get(index, None)
    
    def compute_weights(self, data):
        for conversation in data:
            conversation = conversation['conversation']
            for utterance in conversation:
                emotion = utterance['emotion']
                self.emotion_freq[self.emotion_to_index[emotion]] += 1
        print(self.emotion_freq)
        self.weights = [1/freq for freq in self.emotion_freq]

# Example usage
indexer = EmotionIndexer()
indexer.compute_weights(train_data)
print(indexer.weights)

[2075, 1008, 1489, 5258, 1670, 376, 340]
[0.00048192771084337347, 0.000992063492063492, 0.000671591672263264, 0.0001901863826550019, 0.0005988023952095808, 0.0026595744680851063, 0.0029411764705882353]


In [3]:
import torch
import json
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_video
from torchvision.transforms import functional as F
from PIL import Image
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch
import pickle

class YourAudioEncoder():
    def __init__(self, audio_embeddings_path):
        with open(audio_embeddings_path, "rb") as f:
            self.audio_embeddings = pickle.load(f)

    def lmao(self, audio_name):
        audio_name = audio_name.split(".")[0]
        audio_embedding = self.audio_embeddings[audio_name]
        return torch.from_numpy(audio_embedding)
    
class YourVideoEncoder():
    def __init__(self, video_embeddings_path):
        with open(video_embeddings_path, "rb") as f:
            self.video_embeddings = pickle.load(f)

    def lmao(self, video_name):
        # video_name = video_name.split(".")[0]
        video_embedding = self.video_embeddings[video_name].reshape((16,-1))
        video_embedding = np.mean(video_embedding, axis=0)
        return torch.from_numpy(video_embedding)

class YourTextEncoder():
    def __init__(self, text_embeddings_path):
        with open(text_embeddings_path, "rb") as f:
            self.text_embeddings = pickle.load(f)

    def lmao(self, video_name):
        text_embedding = self.text_embeddings[video_name]
        return torch.from_numpy(text_embedding)


In [5]:
class ConversationDataset(Dataset):
    def __init__(self, json_file, audio_encoder, video_encoder, text_encoder, max_seq_len):
        self.max_seq_len = max_seq_len
        self.data = self.load_data(json_file)
        self.audio_encoder = audio_encoder
        self.video_encoder = video_encoder
        self.text_encoder = text_encoder

    def load_data(self, json_file):
        with open(json_file, 'r') as f:
            data = json.load(f)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        conversation = self.data[idx]['conversation']
        emotion_labels = [utterance['emotion'] for utterance in conversation]
        audio_paths = [utterance['video_name'].replace('mp4', 'wav') for utterance in conversation]
        video_paths = [utterance['video_name'] for utterance in conversation]
        texts = [utterance['video_name'] for utterance in conversation]

        audio_embeddings = [self.audio_encoder.lmao(audio_path) for audio_path in audio_paths]
        video_embeddings = [self.video_encoder.lmao(video_path) for video_path in video_paths]
        text_embeddings = [self.text_encoder.lmao(text) for text in texts]

        cause_pairs = self.data[idx]['emotion-cause_pairs']
        useful_utterances = set([int(cause_pair[1]) for cause_pair in cause_pairs])
        cause_labels = []
        for utterance in conversation:
            if utterance['utterance_ID'] in useful_utterances:
                cause_labels.append(1)
            else:
                cause_labels.append(0)
        
        
        # Pad or truncate conversations to the maximum sequence length
        if len(conversation) < self.max_seq_len and False:
            pad_length = self.max_seq_len - len(conversation)
            audio_embeddings += [torch.zeros_like(audio_embeddings[0])] * pad_length
            video_embeddings += [torch.zeros_like(video_embeddings[0])] * pad_length
            text_embeddings += [torch.zeros_like(text_embeddings[0])] * pad_length
            cause_labels += [-1] * pad_length
            emotion_labels += ['pad'] * pad_length
            pad_mask = [1] * len(conversation) + [0] * pad_length
        else:
            audio_embeddings = audio_embeddings[:self.max_seq_len]
            video_embeddings = video_embeddings[:self.max_seq_len]
            text_embeddings = text_embeddings[:self.max_seq_len]
            emotion_labels = emotion_labels[:self.max_seq_len]
            cause_labels = cause_labels[:self.max_seq_len]
            pad_mask = [1] * self.max_seq_len

        emotion_indices = [indexer.emotion_to_idx(emotion) for emotion in emotion_labels]
        
        audio_embeddings = torch.stack(audio_embeddings)
        video_embeddings = torch.stack(video_embeddings)
        text_embeddings = torch.stack(text_embeddings)
        emotion_indices = torch.from_numpy(np.array(emotion_indices))
        pad_mask = torch.from_numpy(np.array(pad_mask))
        cause_labels = torch.from_numpy(np.array(cause_labels))
        
        return {
            'audio': audio_embeddings,
            'video': video_embeddings,
            'text': text_embeddings,
            'emotion_labels': emotion_indices,
            'pad_mask': pad_mask,
            'cause_labels': cause_labels,
        }
# Example usage
# You need to define your audio, video, and text encoders accordingly

# Define your data paths
DATA_DIR = "/tmp/semeval24_task3"

AUDIO_EMBEDDINGS_FILEPATH = os.path.join(DATA_DIR, "audio_embeddings", "audio_embeddings.pkl")
VIDEO_EMBEDDINGS_FILEPATH = os.path.join(DATA_DIR, "video_embeddings", "train", "video_embeddings.pkl")
TEXT_EMBEDDINGS_FILEPATH = os.path.join(DATA_DIR, "text_embeddings", "text_embeddings.pkl")

audio_encoder = YourAudioEncoder(AUDIO_EMBEDDINGS_FILEPATH)
video_encoder = YourVideoEncoder(VIDEO_EMBEDDINGS_FILEPATH)
text_encoder = YourTextEncoder(TEXT_EMBEDDINGS_FILEPATH)
max_seq_len = 40  # Adjust this according to your needs

# Create the dataset and dataloader
train_dataset = ConversationDataset(TRAIN_FILE_PATH, audio_encoder, video_encoder, text_encoder, max_seq_len)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

validation_dataset = ConversationDataset(VALIDATION_FILE_PATH, audio_encoder, video_encoder, text_encoder, max_seq_len)
validation_dataloader = DataLoader(validation_dataset, batch_size=1, shuffle=True)

# Example of iterating through batches
# for batch in dataloader:
#     audio = batch['audio']  # Shape: (batch_size, max_seq_len, audio_embedding_size)
#     video = batch['video']  # Shape: (batch_size, max_seq_len, video_embedding_size)
#     text = batch['text']    # Shape: (batch_size, max_seq_len, text_embedding_size)
#     emotions = batch['emotion_labels']  # List of emotion labels for each utterance in the batch


In [6]:
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class BiLSTM_basic(nn.Module):

    def __init__(self, embedding_dim=768, hidden_dim=300, output_size=13):
        super(BiLSTM_basic, self).__init__()
        
        # 1. Embedding Layer
        # if embeddings is None:
        #     self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        # else:
        # self.embeddings = nn.Embedding.from_pretrained(embeddings)
        
        # 2. LSTM Layer
        #embedding dimension must be equal to bert embeddings
        #use of 'batch_first=true'?
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, num_layers=1, batch_first=False)
        
        # 3. Optional dropout layer
        self.dropout_layer = nn.Dropout(p=0.3)

        # 4. Dense Layer ?? 
        self.hidden2tag = nn.Linear(2*hidden_dim, output_size)

        self.relu=nn.ReLU

        self.hidden_dim = hidden_dim

    
    def generate_emissions(self, batch_text):
        hidden_layer = self.init_hidden(len(batch_text))

        embeddings = enco(batch_text)

        # x_packed = pack_padded_sequence(embeddings, batch_first = True)
        
        # packed_seqs = pack_padded_sequence(embeddings, batch_length)
        print(embeddings.shape)
        lstm_output, _ = self.lstm(embeddings, hidden_layer)
        print(lstm_output.shape)
        # lstm_output, _ = pad_packed_sequence(lstm_output)

        # self.relu(lstm_output)
        # lstm_output, op_lengths = pad_packed_sequence(lstm_output, batch_first = True)

        lstm_output = self.dropout_layer(lstm_output)
        print(lstm_output.shape)

        emissions = self.hidden2tag(lstm_output)
        # emissions = torch.squeeze(emissions)
        # emissions = emissions.unsqueeze(0)

        return emissions
        
    def loss(self, batch_text, batch_label):
        # print(len(batch_text))

        # hidden_layer = self.init_hidden(len(batch_text))

        # embeddings = enco(batch_text)

        # # x_packed = pack_padded_sequence(embeddings, batch_first = True)
        
        # # packed_seqs = pack_padded_sequence(embeddings, batch_length)
        # lstm_output, _ = self.lstm(embeddings, hidden_layer)
        # print(lstm_output.shape)
        # # lstm_output, _ = pad_packed_sequence(lstm_output)

        # # self.relu(lstm_output)
        # # lstm_output, op_lengths = pad_packed_sequence(lstm_output, batch_first = True)

        # lstm_output = self.dropout_layer(lstm_output)
        # print(lstm_output.shape)

        emissions = self.generate_emissions(batch_text)
        batch_label = batch_label.unsqueeze(1)
        # print(logits.shape)
        loss = -self.crf_model(emissions, batch_label)

        return loss
    
    def predict(self, batch_text):
        emissions = self.generate_emissions(batch_text)
        # print(logits.shape)
        label = self.crf_model.viterbi_decode(emissions)
        return label
    
    def init_hidden(self, batch_size):
        return (torch.randn(2, 1, self.hidden_dim).to(device), torch.randn(2, 1, self.hidden_dim).to(device))

In [7]:
import torch
import torch.nn as nn
from TorchCRF import CRF

class EmotionClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout, num_emotions, embedding_dropout=0.2):
        super(EmotionClassifier, self).__init__()
        
        self.audio_dropout = nn.Dropout(embedding_dropout)
        self.video_dropout = nn.Dropout(embedding_dropout)
        self.text_dropout = nn.Dropout(embedding_dropout)

        self.first_linear = nn.Linear(input_size, hidden_size, dtype=torch.float32)
        self.relu = nn.ReLU()
        
        self.second_linear_layer = nn.Linear(hidden_size, hidden_size, dtype=torch.float32)
        # Replace Transformer with BiLSTM
        self.bilstm = nn.LSTM(hidden_size, hidden_size // 2, num_layers, 
                              dropout=dropout, bidirectional=True, batch_first=True)
        
        self.linear = nn.Linear(hidden_size, num_emotions)
        self.crf_model = CRF(num_emotions)
        

    def generate_emissions(self, audio_encoding, video_encoding, text_encoding):
        # Concatenate or combine the audio, video, and text encodings
        audio_encoding = audio_encoding.float()
        video_encoding = video_encoding.float()
        text_encoding = text_encoding.float()
        
        audio_encoding = self.audio_dropout(audio_encoding)
        video_encoding = self.video_dropout(video_encoding)
        text_encoding = self.text_dropout(text_encoding)
        
        combined_encoding = torch.cat((audio_encoding, video_encoding, text_encoding), dim=2)
        
        combined_encoding = self.first_linear(combined_encoding)
        combined_encoding = self.relu(combined_encoding)
        combined_encoding = self.second_linear_layer(combined_encoding)
        
        # Pass through BiLSTM
        lstm_output, _ = self.bilstm(combined_encoding)

        # Take the output of the BiLSTM
        emotion_logits = self.linear(lstm_output)
        # Apply a softmax layer
        # emotion_logits = torch.softmax(emotion_logits, dim=2)

        return emotion_logits

    def loss(self, audio_encoding, video_encoding, text_encoding, emotion_labels, padding):

        emissions = self.generate_emissions(audio_encoding, video_encoding, text_encoding)
        emotion_labels = emotion_labels.unsqueeze(1)
        x, y, _ = emissions.shape
        padding = torch.ones((x, y), dtype=torch.bool).to('cuda')
        emotion_labels = emotion_labels.squeeze(1)
        loss = -self.crf_model(emissions, emotion_labels, padding)

        return loss
    
    def predict(self, audio_encoding, video_encoding, text_encoding):
        emissions = self.generate_emissions(audio_encoding, video_encoding, text_encoding)
        x, y, _ = emissions.shape
        padding = torch.ones((x, y), dtype=torch.bool).to('cuda')
        label = self.crf_model.viterbi_decode(emissions, padding)
        return label
    
    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim).to('cuda'), torch.randn(2, 1, self.hidden_dim).to('cuda'))

In [8]:
from torch.optim import AdamW
from tqdm import tqdm
from sklearn.metrics import classification_report
from transformers import get_linear_schedule_with_warmup

# Define your model
# model = EmotionClassifier(input_size=11237, hidden_size=5000, num_layers=2, num_heads=2, dropout=0.2, num_emotions=7)
emotion_model = EmotionClassifier(input_size=768*3, hidden_size=2000, num_layers=3, dropout=0.6, num_emotions=7)
cause_model = EmotionClassifier(input_size=768*3, hidden_size=2000, num_layers=2, dropout=0.6, num_emotions=2)
emotion_model.to("cuda")
cause_model.to("cuda")

weights_tensor = torch.tensor(np.array(indexer.weights)).to("cuda").float()
emotion_criterion = nn.CrossEntropyLoss(
    weight=weights_tensor,
    ignore_index=7
)

cause_criterion = nn.CrossEntropyLoss(ignore_index=-1)

num_epochs = 30
total_steps = len(train_dataloader) * num_epochs

emotion_optimizer = AdamW(emotion_model.parameters(), lr=0.0001)
emotion_lr_scheduler = get_linear_schedule_with_warmup(
    emotion_optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

cause_optimizer = AdamW(cause_model.parameters(), lr=1e-4)
cause_lr_scheduler = get_linear_schedule_with_warmup(
    cause_optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)


# Define training parameters

# Training loop
for epoch in (range(num_epochs)):
    emotion_model.train()  # Set the model to training mode
    cause_model.train()
    
    total_loss = 0.0
    total_tokens = 0
    total_correct_emotions = 0
    total_correct_causes = 0
    total_predictions = 0

    for batch in tqdm(train_dataloader):  # Assuming you have a DataLoader for your dataset
        # Extract data from the batch
        audio = batch['audio'].to('cuda')
        video = batch['video'].to('cuda')
        text = batch['text'].to('cuda')
        cause_indices = batch['cause_labels'].to('cuda')
        
        audio_copy = audio.clone().detach()
        video_copy = video.clone().detach()
        text_copy = text.clone().detach()
        
        emotion_indices = batch['emotion_labels'].to('cuda')
        pad_mask = batch['pad_mask'].to('cuda')

        # Forward pass
        # emotion_logits = emotion_model(audio_copy, video_copy, text_copy)

        # Reshape emotion_logits
        # emotion_logits = emotion_logits.view(-1, emotion_logits.size(-1))

        # Flatten emotion_indices (assuming it's a 2D tensor with shape [batch_size, max_sequence_length])
        # emotion_indices = emotion_indices.view(-1)

        # Calculate a mask to exclude padded positions from the loss
        pad_mask = pad_mask.view(-1).float()

        # Calculate the loss, excluding padded positions
        emotion_loss = emotion_model.loss(audio_copy, video_copy, text_copy, emotion_indices, pad_mask)
        # masked_loss = torch.sum(loss * pad_mask) / torch.sum(pad_mask)
        masked_loss = emotion_loss# *pad_mask
        # Backpropagation and optimization
        
        
        # cause_logits = cause_model(audio, video, text)
        # cause_logits = cause_logits.view(-1, cause_logits.size(-1))
        # cause_indices = cause_indices.view(-1)
        
        cause_loss = cause_model.loss(audio, video, text, cause_indices, pad_mask)
        masked_loss += cause_loss
        
        emotion_optimizer.zero_grad()
        cause_optimizer.zero_grad()
        
        masked_loss.backward()
        
        emotion_optimizer.step()
        cause_optimizer.step()        

        total_loss += masked_loss.item()
        total_tokens += torch.sum(pad_mask).item()
        
        predicted_emotions = emotion_model.predict(audio, video, text)
        correct_predictions_emotions = ((predicted_emotions == emotion_indices) * pad_mask).sum().item()

        predicted_causes = cause_model.predict(audio, video, text)
        correct_predictions_causes = ((predicted_causes == cause_indices) * pad_mask).sum().item()
        
        total_correct_emotions += correct_predictions_emotions
        total_correct_causes += correct_predictions_causes
        total_predictions += torch.sum(pad_mask).item()  # Batch size
        
    
    emotion_lr_scheduler.step()
    cause_lr_scheduler.step()
    
    emotion_model.eval()  # Set the model to evaluation mode
    cause_model.eval()
    
    total_val_loss = 0.0
    total_val_tokens = 0
    total_val_correct_emotions = 0
    total_val_correct_causes = 0
    total_val_predictions = 0
    true_labels_emotion = []
    predicted_labels_emotion = []
    true_labels_cause = []
    predicted_labels_cause = []
    padded_labels = []

    with torch.no_grad():
        for val_batch in tqdm(validation_dataloader):
            audio = val_batch['audio'].to('cuda')
            video = val_batch['video'].to('cuda')
            text = val_batch['text'].to('cuda')
            emotion_indices = val_batch['emotion_labels'].to('cuda')
            cause_indices = val_batch['cause_labels'].to('cuda')
            pad_mask = val_batch['pad_mask'].to('cuda')
            
            audio_copy = audio.clone().detach()
            video_copy = video.clone().detach()
            text_copy = text.clone().detach()

            # emotion_logits = emotion_model(audio_copy, video_copy, text_copy)

            # Reshape emotion_logits
            # emotion_logits = emotion_logits.view(-1, emotion_logits.size(-1))

            # Flatten emotion_indices (assuming it's a 2D tensor with shape [batch_size, max_sequence_length])
            # emotion_indices = emotion_indices.view(-1)

            pad_mask = pad_mask.view(-1)   

            # Calculate the loss, excluding padded positions
            val_loss = emotion_model.loss(audio_copy, video_copy, text_copy, emotion_indices, pad_mask)
            masked_loss = val_loss #torch.sum(val_loss * pad_mask) / torch.sum(pad_mask)
            
            # cause_logits = cause_model(audio, video, text)
            # cause_logits = cause_logits.view(-1, cause_logits.size(-1))
            # cause_indices = cause_indices.view(-1)
            cause_loss = cause_model.loss(audio, video, text, cause_indices, pad_mask)
            masked_loss += cause_loss
            
            total_val_loss += masked_loss.item()
            total_val_tokens += torch.sum(pad_mask).item()
            
            predicted_emotions_val = emotion_model.predict(audio, video, text)
            correct_predictions_val = ((predicted_emotions_val == emotion_indices) * pad_mask).sum().item()
            total_val_correct_emotions += correct_predictions_val
            
            predicted_causes_val = cause_model.predict(audio, video, text)
            correct_predictions_causes_val = ((predicted_causes_val == cause_indices) * pad_mask).sum().item()
            total_val_correct_causes += correct_predictions_causes_val
            
            total_val_predictions += torch.sum(pad_mask).item()

            # Store true and predicted labels for F1 score calculation
            emotion_indices = emotion_indices.cpu().squeeze(0).numpy()
            true_labels_emotion.extend(emotion_indices)
            predicted_labels_emotion.extend(predicted_emotions_val[0])
            
            
            true_labels_cause.extend(cause_indices.cpu().squeeze(0).numpy())
            predicted_labels_cause.extend(predicted_causes_val[0])
            padded_labels.extend(pad_mask.cpu().numpy())
    # print("AAAAAAAAAAAAAAAAAAAAAAAAaa")
    # print(len(final_true_labels_cause))
    # print(len(final_predicted_labels_cause))
    # print("AAAAAAAAAAAAAAAAAAAAAAAAaa")
    final_true_labels_emotion = [label for label, pad in zip(true_labels_emotion, padded_labels)]
    final_predicted_labels_emotion = [label for label, pad in zip(predicted_labels_emotion, padded_labels)]
    
    final_true_labels_cause = [label for label, pad in zip(true_labels_cause, padded_labels)]
    final_predicted_labels_cause = [label for label, pad in zip(predicted_labels_cause, padded_labels)]
    
    emotion_classification_rep = classification_report(final_true_labels_emotion, final_predicted_labels_emotion)
    cause_classification_rep = classification_report(final_true_labels_cause, final_predicted_labels_cause)
    
    # Calculate and print the average loss for this epoch
    # avg_loss = total_loss / total_tokens
    # avg_val_loss = total_val_loss / total_val_tokens
    
    print("===============================")
    print("Training data metrics")
    # print(f"Epoch [{epoch + 1}/{num_epochs}] Training Loss: {avg_loss}")
    # print(f"Epoch [{epoch + 1}/{num_epochs}] Accuracy: {total_correct_emotions / total_predictions}")
    # print(f"Epoch [{epoch + 1}/{num_epochs}] Accuracy: {total_correct_causes / total_predictions}")
    
    print("VALIDATION METRICS")
    # print(f"Epoch [{epoch + 1}/{num_epochs}] Validation Loss: {avg_val_loss}")
    print(emotion_classification_rep)
    print(cause_classification_rep)
    print("===============================")

    torch.save(emotion_model.state_dict(), f"/tmp/semeval24_task3/final_models/emotion_models/emotion_model_{epoch:02}.pt")
    torch.save(cause_model.state_dict(), f"/tmp/semeval24_task3/final_models/cause_models/cause_model_{epoch:02}.pt")

print("Training complete!")


100%|██████████| 1236/1236 [01:56<00:00, 10.60it/s]
100%|██████████| 138/138 [00:03<00:00, 36.85it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.50      0.35      0.41       226
           1       0.35      0.21      0.26       139
           2       0.17      0.08      0.11       126
           3       0.59      0.86      0.70       671
           4       0.21      0.15      0.17       170
           5       0.00      0.00      0.00        38
           6       0.00      0.00      0.00        33

    accuracy                           0.51      1403
   macro avg       0.26      0.23      0.24      1403
weighted avg       0.44      0.51      0.46      1403

              precision    recall  f1-score   support

           0       0.70      0.74      0.72       795
           1       0.63      0.58      0.60       608

    accuracy                           0.67      1403
   macro avg       0.66      0.66      0.66      1403
weighted avg       0.67      0.67      0.67      1403



100%|██████████| 1236/1236 [01:52<00:00, 11.02it/s]
100%|██████████| 138/138 [00:03<00:00, 37.06it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.62      0.27      0.37       226
           1       0.45      0.21      0.29       139
           2       0.27      0.24      0.25       126
           3       0.61      0.88      0.72       671
           4       0.54      0.52      0.53       170
           5       0.00      0.00      0.00        38
           6       0.00      0.00      0.00        33

    accuracy                           0.57      1403
   macro avg       0.36      0.30      0.31      1403
weighted avg       0.53      0.57      0.52      1403

              precision    recall  f1-score   support

           0       0.64      0.89      0.74       795
           1       0.71      0.34      0.46       608

    accuracy                           0.65      1403
   macro avg       0.67      0.62      0.60      1403
weighted avg       0.67      0.65      0.62      1403



100%|██████████| 1236/1236 [01:56<00:00, 10.61it/s]
100%|██████████| 138/138 [00:03<00:00, 36.71it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.63      0.31      0.42       226
           1       0.36      0.27      0.31       139
           2       0.38      0.19      0.25       126
           3       0.60      0.91      0.73       671
           4       0.63      0.42      0.51       170
           5       0.00      0.00      0.00        38
           6       0.00      0.00      0.00        33

    accuracy                           0.58      1403
   macro avg       0.37      0.30      0.32      1403
weighted avg       0.54      0.58      0.53      1403

              precision    recall  f1-score   support

           0       0.71      0.78      0.74       795
           1       0.67      0.57      0.62       608

    accuracy                           0.69      1403
   macro avg       0.69      0.68      0.68      1403
weighted avg       0.69      0.69      0.69      1403



100%|██████████| 1236/1236 [01:52<00:00, 10.98it/s]
100%|██████████| 138/138 [00:03<00:00, 36.90it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.54      0.40      0.46       226
           1       0.53      0.24      0.33       139
           2       0.28      0.38      0.33       126
           3       0.67      0.78      0.72       671
           4       0.45      0.59      0.51       170
           5       0.00      0.00      0.00        38
           6       0.00      0.00      0.00        33

    accuracy                           0.57      1403
   macro avg       0.35      0.34      0.34      1403
weighted avg       0.54      0.57      0.54      1403

              precision    recall  f1-score   support

           0       0.73      0.73      0.73       795
           1       0.65      0.65      0.65       608

    accuracy                           0.69      1403
   macro avg       0.69      0.69      0.69      1403
weighted avg       0.69      0.69      0.69      1403



100%|██████████| 1236/1236 [01:52<00:00, 10.99it/s]
100%|██████████| 138/138 [00:03<00:00, 37.04it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.53      0.40      0.45       226
           1       0.41      0.33      0.37       139
           2       0.34      0.33      0.33       126
           3       0.64      0.83      0.72       671
           4       0.65      0.48      0.55       170
           5       0.00      0.00      0.00        38
           6       0.00      0.00      0.00        33

    accuracy                           0.58      1403
   macro avg       0.37      0.34      0.35      1403
weighted avg       0.54      0.58      0.55      1403

              precision    recall  f1-score   support

           0       0.64      0.89      0.75       795
           1       0.71      0.34      0.46       608

    accuracy                           0.65      1403
   macro avg       0.67      0.62      0.60      1403
weighted avg       0.67      0.65      0.62      1403



100%|██████████| 1236/1236 [01:56<00:00, 10.63it/s]
100%|██████████| 138/138 [00:03<00:00, 36.75it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.57      0.44      0.49       226
           1       0.42      0.34      0.37       139
           2       0.32      0.39      0.35       126
           3       0.66      0.85      0.74       671
           4       0.70      0.43      0.53       170
           5       0.00      0.00      0.00        38
           6       0.00      0.00      0.00        33

    accuracy                           0.60      1403
   macro avg       0.38      0.35      0.36      1403
weighted avg       0.56      0.60      0.57      1403

              precision    recall  f1-score   support

           0       0.75      0.61      0.67       795
           1       0.59      0.74      0.66       608

    accuracy                           0.67      1403
   macro avg       0.67      0.67      0.66      1403
weighted avg       0.68      0.67      0.67      1403



100%|██████████| 1236/1236 [01:56<00:00, 10.63it/s]
100%|██████████| 138/138 [00:03<00:00, 35.61it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.51      0.39      0.44       226
           1       0.47      0.26      0.33       139
           2       0.24      0.53      0.33       126
           3       0.68      0.73      0.71       671
           4       0.55      0.51      0.53       170
           5       0.00      0.00      0.00        38
           6       0.00      0.00      0.00        33

    accuracy                           0.55      1403
   macro avg       0.35      0.35      0.33      1403
weighted avg       0.54      0.55      0.54      1403

              precision    recall  f1-score   support

           0       0.72      0.69      0.70       795
           1       0.61      0.65      0.63       608

    accuracy                           0.67      1403
   macro avg       0.66      0.67      0.67      1403
weighted avg       0.67      0.67      0.67      1403



100%|██████████| 1236/1236 [01:53<00:00, 10.85it/s]
100%|██████████| 138/138 [00:03<00:00, 36.72it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.47      0.46      0.46       226
           1       0.44      0.31      0.36       139
           2       0.35      0.36      0.36       126
           3       0.68      0.79      0.73       671
           4       0.55      0.55      0.55       170
           5       0.00      0.00      0.00        38
           6       0.00      0.00      0.00        33

    accuracy                           0.58      1403
   macro avg       0.36      0.35      0.35      1403
weighted avg       0.54      0.58      0.56      1403

              precision    recall  f1-score   support

           0       0.69      0.74      0.71       795
           1       0.62      0.57      0.59       608

    accuracy                           0.66      1403
   macro avg       0.65      0.65      0.65      1403
weighted avg       0.66      0.66      0.66      1403



100%|██████████| 1236/1236 [01:55<00:00, 10.67it/s]
100%|██████████| 138/138 [00:03<00:00, 35.93it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.43      0.57      0.49       226
           1       0.42      0.37      0.39       139
           2       0.38      0.34      0.36       126
           3       0.68      0.74      0.71       671
           4       0.63      0.49      0.55       170
           5       0.00      0.00      0.00        38
           6       0.25      0.06      0.10        33

    accuracy                           0.58      1403
   macro avg       0.40      0.37      0.37      1403
weighted avg       0.55      0.58      0.56      1403

              precision    recall  f1-score   support

           0       0.66      0.79      0.71       795
           1       0.62      0.46      0.53       608

    accuracy                           0.64      1403
   macro avg       0.64      0.62      0.62      1403
weighted avg       0.64      0.64      0.63      1403



100%|██████████| 1236/1236 [02:05<00:00,  9.82it/s]
100%|██████████| 138/138 [00:03<00:00, 36.87it/s]


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.57      0.42      0.48       226
           1       0.36      0.36      0.36       139
           2       0.42      0.25      0.31       126
           3       0.65      0.82      0.72       671
           4       0.57      0.48      0.52       170
           5       0.00      0.00      0.00        38
           6       0.13      0.12      0.12        33

    accuracy                           0.58      1403
   macro avg       0.39      0.35      0.36      1403
weighted avg       0.55      0.58      0.55      1403

              precision    recall  f1-score   support

           0       0.68      0.70      0.69       795
           1       0.60      0.57      0.58       608

    accuracy                           0.65      1403
   macro avg       0.64      0.64      0.64      1403
weighted avg       0.64      0.65      0.64      1403



100%|██████████| 1236/1236 [01:52<00:00, 11.02it/s]
100%|██████████| 138/138 [00:03<00:00, 37.03it/s]


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.39      0.52      0.44       226
           1       0.49      0.19      0.28       139
           2       0.28      0.41      0.34       126
           3       0.70      0.67      0.68       671
           4       0.50      0.61      0.55       170
           5       0.11      0.03      0.04        38
           6       0.22      0.06      0.10        33

    accuracy                           0.54      1403
   macro avg       0.39      0.36      0.35      1403
weighted avg       0.54      0.54      0.53      1403

              precision    recall  f1-score   support

           0       0.68      0.69      0.69       795
           1       0.59      0.58      0.58       608

    accuracy                           0.64      1403
   macro avg       0.64      0.63      0.64      1403
weighted avg       0.64      0.64      0.64      1403



100%|██████████| 1236/1236 [01:52<00:00, 11.02it/s]
100%|██████████| 138/138 [00:03<00:00, 37.03it/s]


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.43      0.50      0.46       226
           1       0.50      0.22      0.31       139
           2       0.31      0.41      0.36       126
           3       0.68      0.74      0.71       671
           4       0.56      0.57      0.57       170
           5       0.40      0.05      0.09        38
           6       0.00      0.00      0.00        33

    accuracy                           0.56      1403
   macro avg       0.41      0.36      0.36      1403
weighted avg       0.55      0.56      0.55      1403

              precision    recall  f1-score   support

           0       0.70      0.65      0.68       795
           1       0.58      0.64      0.61       608

    accuracy                           0.65      1403
   macro avg       0.64      0.65      0.64      1403
weighted avg       0.65      0.65      0.65      1403



100%|██████████| 1236/1236 [01:54<00:00, 10.76it/s]
100%|██████████| 138/138 [00:03<00:00, 36.02it/s]


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.49      0.42      0.45       226
           1       0.37      0.36      0.36       139
           2       0.39      0.37      0.38       126
           3       0.65      0.79      0.71       671
           4       0.62      0.46      0.53       170
           5       0.23      0.08      0.12        38
           6       0.50      0.09      0.15        33

    accuracy                           0.57      1403
   macro avg       0.47      0.37      0.39      1403
weighted avg       0.55      0.57      0.55      1403

              precision    recall  f1-score   support

           0       0.69      0.70      0.69       795
           1       0.60      0.58      0.59       608

    accuracy                           0.65      1403
   macro avg       0.64      0.64      0.64      1403
weighted avg       0.65      0.65      0.65      1403



100%|██████████| 1236/1236 [01:56<00:00, 10.58it/s]
100%|██████████| 138/138 [00:03<00:00, 36.46it/s]


Training data metrics
VALIDATION METRICS
              precision    recall  f1-score   support

           0       0.45      0.50      0.47       226
           1       0.38      0.29      0.33       139
           2       0.25      0.47      0.33       126
           3       0.72      0.63      0.67       671
           4       0.55      0.45      0.50       170
           5       0.10      0.11      0.10        38
           6       0.08      0.12      0.10        33

    accuracy                           0.51      1403
   macro avg       0.36      0.37      0.36      1403
weighted avg       0.55      0.51      0.52      1403

              precision    recall  f1-score   support

           0       0.68      0.68      0.68       795
           1       0.58      0.57      0.58       608

    accuracy                           0.64      1403
   macro avg       0.63      0.63      0.63      1403
weighted avg       0.64      0.64      0.64      1403



 46%|████▌     | 567/1236 [00:53<01:02, 10.66it/s]


KeyboardInterrupt: 