In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from transformers import get_linear_schedule_with_warmup
import pandas as pd
import math
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import os
import itertools

# Custom dataset class
class TextValueDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        text = row['generated_text']
        label = row['suddenness'] - 1  # Convert value to 0-4 classes
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Load dataset
def load_dataset(file_path):
    df = pd.read_csv(file_path)
    return df

# Create DataLoader
def create_dataloader(df, tokenizer, max_length, batch_size):
    dataset = TextValueDataset(df, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model setup
def create_model(num_labels):
    model = RobertaForSequenceClassification.from_pretrained(
        'roberta-base',
        num_labels=num_labels
    )
    return model

# Evaluation function
def evaluate(model, dataloader, device):
    model.eval()
    predictions, true_labels = [], []

    eval_loss = 0
    eval_loop = tqdm(dataloader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for batch in eval_loop:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels = labels
            )
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)
            
            loss = outputs.loss
            eval_loss += loss.item()
            eval_loop.set_postfix(loss=loss.item())
            
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted')
    
    return eval_loss, accuracy, precision, recall, f1

# Training function
def train(model, train_loader, val_loader, test_loader, epochs, device, lr, save_path, 
          weight_decay=0.01, max_grad_norm=1.0, patience=3):
    optimizer = AdamW(model.parameters(), lr=lr)
    model = model.to(device)

    epoch_counter = 0
    min_val_loss = 1e10
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')
        
        # Training
        model.train()
        train_loss = 0
        train_loop = tqdm(train_loader, desc="Training")
        
        for batch in train_loop:
            optimizer.zero_grad()
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            loss.backward()

            clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item()
            train_loop.set_postfix(loss=loss.item())

        avg_train_loss = train_loss / len(train_loader)
        print(f'Average Training Loss: {avg_train_loss:.4f}')

        # Validation after each epoch
        val_loss, val_accuracy, val_precision, val_recall, val_f1 = evaluate(model, val_loader, device)
        print(f'Validation Accuracy: {val_accuracy:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}')

        # Save model after each epoch
        if val_loss <= min_val_loss:
            model_save_path = os.path.join(save_path, f'suddenness.pt')
            torch.save(model.state_dict(), model_save_path)
            print(f'Model saved to {model_save_path}')
            min_val_loss = val_loss
        else:
            epoch_counter +=1
            print(f'Model not saved due to higher validation loss {val_loss:.4f} compared to {min_val_loss:.4f}')

        if epoch_counter >= patience:
            print(f'Model did not improve after {epoch - patience} epochs for {patience} epochs. Training halted')  
            break
    # Evaluation
    _, test_accuracy, test_precision, test_recall, test_f1 = evaluate(model, test_loader, device)
    print(f'Test Accuracy: {test_accuracy:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}')
        

# Main function
def main():
    # Configurations
    max_length = 128
    batch_size = 8
    epochs = 50
    num_labels = 5
    lr = 2e-6
    save_path = './models'  # Directory to save models
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Make sure the save directory exists
    os.makedirs(save_path, exist_ok=True)

    # Load datasets
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    train_df = load_dataset('data/train.csv')
    val_df = load_dataset('data/val.csv')
    test_df = load_dataset('data/test.csv')

    # Create DataLoaders
    train_loader = create_dataloader(train_df, tokenizer, max_length, batch_size)
    val_loader = create_dataloader(val_df, tokenizer, max_length, batch_size)
    test_loader = create_dataloader(test_df, tokenizer, max_length, batch_size)

    # Initialize model
    model = create_model(num_labels)

    # Train and evaluate the model
    train(model, train_loader, val_loader, test_loader, epochs, device, lr, save_path)

if __name__ == '__main__':
    main()

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:07<00:00,  9.72it/s, loss=1.63]


Average Training Loss: 1.6046


  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.2796, Precision: 0.0782, Recall: 0.2796, F1: 0.1222
Model saved to ./models/suddenness.pt
Epoch 2/50


Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.84it/s, loss=1.7]


Average Training Loss: 1.5977


  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.2667, Precision: 0.1164, Recall: 0.2667, F1: 0.1613
Model saved to ./models/suddenness.pt
Epoch 3/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.80it/s, loss=1.66]


Average Training Loss: 1.5613


  _warn_prf(average, modifier, msg_start, len(result))


Validation Accuracy: 0.4093, Precision: 0.2936, Recall: 0.4093, F1: 0.3394
Model saved to ./models/suddenness.pt
Epoch 4/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.78it/s, loss=1.01]


Average Training Loss: 1.4232


                                                                                                                                                                                                                      

Validation Accuracy: 0.4000, Precision: 0.3810, Recall: 0.4000, F1: 0.3504
Model saved to ./models/suddenness.pt
Epoch 5/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.77it/s, loss=1.37]


Average Training Loss: 1.3723


                                                                                                                                                                                                                      

Validation Accuracy: 0.4204, Precision: 0.3872, Recall: 0.4204, F1: 0.3786
Model saved to ./models/suddenness.pt
Epoch 6/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.83it/s, loss=1.72]


Average Training Loss: 1.3242


                                                                                                                                                                                                                      

Validation Accuracy: 0.4463, Precision: 0.4109, Recall: 0.4463, F1: 0.4009
Model saved to ./models/suddenness.pt
Epoch 7/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.79it/s, loss=1.13]


Average Training Loss: 1.2874


                                                                                                                                                                                                                      

Validation Accuracy: 0.3852, Precision: 0.3831, Recall: 0.3852, F1: 0.3633
Model not saved due to higher validation loss 99.7621 compared to 96.0632
Epoch 8/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.87it/s, loss=0.839]


Average Training Loss: 1.2429


                                                                                                                                                                                                                      

Validation Accuracy: 0.4389, Precision: 0.4122, Recall: 0.4389, F1: 0.3975
Model not saved due to higher validation loss 97.8124 compared to 96.0632
Epoch 9/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.83it/s, loss=0.658]


Average Training Loss: 1.2036


                                                                                                                                                                                                                      

Validation Accuracy: 0.4000, Precision: 0.3992, Recall: 0.4000, F1: 0.3861
Model not saved due to higher validation loss 100.7852 compared to 96.0632
Model did not improve after 5 epochs for 3 epochs. Training halted
Epoch 10/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.83it/s, loss=1.89]


Average Training Loss: 1.1631


                                                                                                                                                                                                                      

Validation Accuracy: 0.3944, Precision: 0.3883, Recall: 0.3944, F1: 0.3861
Model not saved due to higher validation loss 103.1993 compared to 96.0632
Model did not improve after 6 epochs for 3 epochs. Training halted
Epoch 11/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.86it/s, loss=1.09]


Average Training Loss: 1.1252


                                                                                                                                                                                                                      

Validation Accuracy: 0.3981, Precision: 0.3916, Recall: 0.3981, F1: 0.3864
Model not saved due to higher validation loss 104.0224 compared to 96.0632
Model did not improve after 7 epochs for 3 epochs. Training halted
Epoch 12/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:05<00:00,  9.91it/s, loss=1.27]


Average Training Loss: 1.0905


                                                                                                                                                                                                                      

Validation Accuracy: 0.4000, Precision: 0.3817, Recall: 0.4000, F1: 0.3880
Model not saved due to higher validation loss 103.9705 compared to 96.0632
Model did not improve after 8 epochs for 3 epochs. Training halted
Epoch 13/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.81it/s, loss=0.818]


Average Training Loss: 1.0436


                                                                                                                                                                                                                      

Validation Accuracy: 0.3741, Precision: 0.3917, Recall: 0.3741, F1: 0.3770
Model not saved due to higher validation loss 109.0536 compared to 96.0632
Model did not improve after 9 epochs for 3 epochs. Training halted
Epoch 14/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.82it/s, loss=0.87]


Average Training Loss: 0.9999


                                                                                                                                                                                                                      

Validation Accuracy: 0.3648, Precision: 0.3636, Recall: 0.3648, F1: 0.3628
Model not saved due to higher validation loss 112.1852 compared to 96.0632
Model did not improve after 10 epochs for 3 epochs. Training halted
Epoch 15/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.81it/s, loss=0.684]


Average Training Loss: 0.9607


                                                                                                                                                                                                                      

Validation Accuracy: 0.3722, Precision: 0.3769, Recall: 0.3722, F1: 0.3702
Model not saved due to higher validation loss 114.6357 compared to 96.0632
Model did not improve after 11 epochs for 3 epochs. Training halted
Epoch 16/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:07<00:00,  9.74it/s, loss=0.605]


Average Training Loss: 0.9256


                                                                                                                                                                                                                      

Validation Accuracy: 0.3537, Precision: 0.3843, Recall: 0.3537, F1: 0.3625
Model not saved due to higher validation loss 116.4331 compared to 96.0632
Model did not improve after 12 epochs for 3 epochs. Training halted
Epoch 17/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.83it/s, loss=0.829]


Average Training Loss: 0.8774


                                                                                                                                                                                                                      

Validation Accuracy: 0.3389, Precision: 0.3500, Recall: 0.3389, F1: 0.3434
Model not saved due to higher validation loss 121.2264 compared to 96.0632
Model did not improve after 13 epochs for 3 epochs. Training halted
Epoch 18/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.81it/s, loss=0.285]


Average Training Loss: 0.8487


                                                                                                                                                                                                                      

Validation Accuracy: 0.3611, Precision: 0.3679, Recall: 0.3611, F1: 0.3643
Model not saved due to higher validation loss 123.2806 compared to 96.0632
Model did not improve after 14 epochs for 3 epochs. Training halted
Epoch 19/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.85it/s, loss=0.519]


Average Training Loss: 0.8231


                                                                                                                                                                                                                      

Validation Accuracy: 0.3444, Precision: 0.3655, Recall: 0.3444, F1: 0.3508
Model not saved due to higher validation loss 127.8445 compared to 96.0632
Model did not improve after 15 epochs for 3 epochs. Training halted
Epoch 20/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.82it/s, loss=0.939]


Average Training Loss: 0.7841


                                                                                                                                                                                                                      

Validation Accuracy: 0.3574, Precision: 0.3963, Recall: 0.3574, F1: 0.3696
Model not saved due to higher validation loss 132.0211 compared to 96.0632
Model did not improve after 16 epochs for 3 epochs. Training halted
Epoch 21/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.81it/s, loss=0.404]


Average Training Loss: 0.7473


                                                                                                                                                                                                                      

Validation Accuracy: 0.3407, Precision: 0.3796, Recall: 0.3407, F1: 0.3528
Model not saved due to higher validation loss 133.2221 compared to 96.0632
Model did not improve after 17 epochs for 3 epochs. Training halted
Epoch 22/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.80it/s, loss=0.246]


Average Training Loss: 0.7035


                                                                                                                                                                                                                      

Validation Accuracy: 0.3519, Precision: 0.3749, Recall: 0.3519, F1: 0.3590
Model not saved due to higher validation loss 138.6373 compared to 96.0632
Model did not improve after 18 epochs for 3 epochs. Training halted
Epoch 23/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.80it/s, loss=1.03]


Average Training Loss: 0.6856


                                                                                                                                                                                                                      

Validation Accuracy: 0.3333, Precision: 0.3710, Recall: 0.3333, F1: 0.3477
Model not saved due to higher validation loss 138.7044 compared to 96.0632
Model did not improve after 19 epochs for 3 epochs. Training halted
Epoch 24/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.78it/s, loss=1.81]


Average Training Loss: 0.6583


                                                                                                                                                                                                                      

Validation Accuracy: 0.3444, Precision: 0.3649, Recall: 0.3444, F1: 0.3521
Model not saved due to higher validation loss 143.9200 compared to 96.0632
Model did not improve after 20 epochs for 3 epochs. Training halted
Epoch 25/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.81it/s, loss=0.166]


Average Training Loss: 0.6202


                                                                                                                                                                                                                      

Validation Accuracy: 0.3796, Precision: 0.3933, Recall: 0.3796, F1: 0.3851
Model not saved due to higher validation loss 145.0655 compared to 96.0632
Model did not improve after 21 epochs for 3 epochs. Training halted
Epoch 26/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.88it/s, loss=0.428]


Average Training Loss: 0.5903


                                                                                                                                                                                                                      

Validation Accuracy: 0.3648, Precision: 0.3892, Recall: 0.3648, F1: 0.3745
Model not saved due to higher validation loss 151.1512 compared to 96.0632
Model did not improve after 22 epochs for 3 epochs. Training halted
Epoch 27/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.78it/s, loss=0.839]


Average Training Loss: 0.5888


                                                                                                                                                                                                                      

Validation Accuracy: 0.3833, Precision: 0.3877, Recall: 0.3833, F1: 0.3847
Model not saved due to higher validation loss 149.2786 compared to 96.0632
Model did not improve after 23 epochs for 3 epochs. Training halted
Epoch 28/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.86it/s, loss=1.95]


Average Training Loss: 0.5426


                                                                                                                                                                                                                      

Validation Accuracy: 0.3593, Precision: 0.3770, Recall: 0.3593, F1: 0.3668
Model not saved due to higher validation loss 154.6286 compared to 96.0632
Model did not improve after 24 epochs for 3 epochs. Training halted
Epoch 29/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.75it/s, loss=0.943]


Average Training Loss: 0.5291


                                                                                                                                                                                                                      

Validation Accuracy: 0.3500, Precision: 0.3820, Recall: 0.3500, F1: 0.3626
Model not saved due to higher validation loss 160.1877 compared to 96.0632
Model did not improve after 25 epochs for 3 epochs. Training halted
Epoch 30/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.85it/s, loss=0.923]


Average Training Loss: 0.5023


                                                                                                                                                                                                                      

Validation Accuracy: 0.3500, Precision: 0.3694, Recall: 0.3500, F1: 0.3582
Model not saved due to higher validation loss 161.3001 compared to 96.0632
Model did not improve after 26 epochs for 3 epochs. Training halted
Epoch 31/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.82it/s, loss=0.128]


Average Training Loss: 0.4784


                                                                                                                                                                                                                      

Validation Accuracy: 0.3722, Precision: 0.3783, Recall: 0.3722, F1: 0.3745
Model not saved due to higher validation loss 164.6160 compared to 96.0632
Model did not improve after 27 epochs for 3 epochs. Training halted
Epoch 32/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.81it/s, loss=0.402]


Average Training Loss: 0.4734


                                                                                                                                                                                                                      

Validation Accuracy: 0.3537, Precision: 0.3752, Recall: 0.3537, F1: 0.3619
Model not saved due to higher validation loss 168.8260 compared to 96.0632
Model did not improve after 28 epochs for 3 epochs. Training halted
Epoch 33/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:07<00:00,  9.72it/s, loss=0.174]


Average Training Loss: 0.4659


                                                                                                                                                                                                                      

Validation Accuracy: 0.3444, Precision: 0.3737, Recall: 0.3444, F1: 0.3561
Model not saved due to higher validation loss 173.6565 compared to 96.0632
Model did not improve after 29 epochs for 3 epochs. Training halted
Epoch 34/50


Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.76it/s, loss=0.8]


Average Training Loss: 0.4417


                                                                                                                                                                                                                      

Validation Accuracy: 0.3407, Precision: 0.3699, Recall: 0.3407, F1: 0.3523
Model not saved due to higher validation loss 173.7718 compared to 96.0632
Model did not improve after 30 epochs for 3 epochs. Training halted
Epoch 35/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.79it/s, loss=0.37]


Average Training Loss: 0.4326


                                                                                                                                                                                                                      

Validation Accuracy: 0.3444, Precision: 0.3839, Recall: 0.3444, F1: 0.3594
Model not saved due to higher validation loss 179.1121 compared to 96.0632
Model did not improve after 31 epochs for 3 epochs. Training halted
Epoch 36/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.86it/s, loss=0.221]


Average Training Loss: 0.4095


                                                                                                                                                                                                                      

Validation Accuracy: 0.3519, Precision: 0.3785, Recall: 0.3519, F1: 0.3629
Model not saved due to higher validation loss 179.3294 compared to 96.0632
Model did not improve after 32 epochs for 3 epochs. Training halted
Epoch 37/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.84it/s, loss=0.874]


Average Training Loss: 0.4037


                                                                                                                                                                                                                      

Validation Accuracy: 0.3611, Precision: 0.3889, Recall: 0.3611, F1: 0.3717
Model not saved due to higher validation loss 183.7980 compared to 96.0632
Model did not improve after 33 epochs for 3 epochs. Training halted
Epoch 38/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:07<00:00,  9.74it/s, loss=0.196]


Average Training Loss: 0.3794


                                                                                                                                                                                                                      

Validation Accuracy: 0.3556, Precision: 0.3712, Recall: 0.3556, F1: 0.3618
Model not saved due to higher validation loss 184.0213 compared to 96.0632
Model did not improve after 34 epochs for 3 epochs. Training halted
Epoch 39/50


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.80it/s, loss=0.0766]


Average Training Loss: 0.3737


                                                                                                                                                                                                                      

Validation Accuracy: 0.3407, Precision: 0.3634, Recall: 0.3407, F1: 0.3500
Model not saved due to higher validation loss 187.0019 compared to 96.0632
Model did not improve after 35 epochs for 3 epochs. Training halted
Epoch 40/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.88it/s, loss=0.139]


Average Training Loss: 0.3605


                                                                                                                                                                                                                      

Validation Accuracy: 0.3426, Precision: 0.3756, Recall: 0.3426, F1: 0.3553
Model not saved due to higher validation loss 191.0457 compared to 96.0632
Model did not improve after 36 epochs for 3 epochs. Training halted
Epoch 41/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.83it/s, loss=0.227]


Average Training Loss: 0.3536


                                                                                                                                                                                                                      

Validation Accuracy: 0.3519, Precision: 0.3774, Recall: 0.3519, F1: 0.3625
Model not saved due to higher validation loss 192.4903 compared to 96.0632
Model did not improve after 37 epochs for 3 epochs. Training halted
Epoch 42/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.81it/s, loss=0.322]


Average Training Loss: 0.3521


                                                                                                                                                                                                                      

Validation Accuracy: 0.3333, Precision: 0.3738, Recall: 0.3333, F1: 0.3488
Model not saved due to higher validation loss 195.1876 compared to 96.0632
Model did not improve after 38 epochs for 3 epochs. Training halted
Epoch 43/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.85it/s, loss=0.334]


Average Training Loss: 0.3404


                                                                                                                                                                                                                      

Validation Accuracy: 0.3426, Precision: 0.3745, Recall: 0.3426, F1: 0.3553
Model not saved due to higher validation loss 196.1913 compared to 96.0632
Model did not improve after 39 epochs for 3 epochs. Training halted
Epoch 44/50


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.83it/s, loss=0.14]


Average Training Loss: 0.3275


                                                                                                                                                                                                                      

Validation Accuracy: 0.3463, Precision: 0.3780, Recall: 0.3463, F1: 0.3584
Model not saved due to higher validation loss 197.1257 compared to 96.0632
Model did not improve after 40 epochs for 3 epochs. Training halted
Epoch 45/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.80it/s, loss=0.765]


Average Training Loss: 0.3310


                                                                                                                                                                                                                      

Validation Accuracy: 0.3426, Precision: 0.3739, Recall: 0.3426, F1: 0.3549
Model not saved due to higher validation loss 197.1907 compared to 96.0632
Model did not improve after 41 epochs for 3 epochs. Training halted
Epoch 46/50


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.86it/s, loss=0.0957]


Average Training Loss: 0.3197


                                                                                                                                                                                                                      

Validation Accuracy: 0.3407, Precision: 0.3640, Recall: 0.3407, F1: 0.3507
Model not saved due to higher validation loss 199.4381 compared to 96.0632
Model did not improve after 42 epochs for 3 epochs. Training halted
Epoch 47/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.85it/s, loss=0.368]


Average Training Loss: 0.3195


                                                                                                                                                                                                                      

Validation Accuracy: 0.3463, Precision: 0.3787, Recall: 0.3463, F1: 0.3593
Model not saved due to higher validation loss 200.3847 compared to 96.0632
Model did not improve after 43 epochs for 3 epochs. Training halted
Epoch 48/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.81it/s, loss=0.191]


Average Training Loss: 0.3360


                                                                                                                                                                                                                      

Validation Accuracy: 0.3537, Precision: 0.3821, Recall: 0.3537, F1: 0.3652
Model not saved due to higher validation loss 200.4022 compared to 96.0632
Model did not improve after 44 epochs for 3 epochs. Training halted
Epoch 49/50


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.79it/s, loss=0.0723]


Average Training Loss: 0.3064


                                                                                                                                                                                                                      

Validation Accuracy: 0.3444, Precision: 0.3798, Recall: 0.3444, F1: 0.3580
Model not saved due to higher validation loss 202.0178 compared to 96.0632
Model did not improve after 45 epochs for 3 epochs. Training halted
Epoch 50/50


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 653/653 [01:06<00:00,  9.79it/s, loss=0.322]


Average Training Loss: 0.3083


                                                                                                                                                                                                                      

Validation Accuracy: 0.3519, Precision: 0.3842, Recall: 0.3519, F1: 0.3645
Model not saved due to higher validation loss 200.6037 compared to 96.0632
Model did not improve after 46 epochs for 3 epochs. Training halted


                                                                                                                                                                                                                      

Test Accuracy: 0.3407, Precision: 0.3645, Recall: 0.3407, F1: 0.3492


