In [1]:
# Download dataset first: https://physionet.org/content/mednli/1.0.0/
# Requires credentialed access

In [2]:
cp -r /kaggle/input/mednli/mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0 ./

In [3]:
cd mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0

/kaggle/working/mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0


In [4]:
import os
import json
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import classification_report
from tqdm import tqdm
from torch.optim import AdamW

# === Determinism Configuration ===
SEED = 55212
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
np.random.seed(SEED)
random.seed(SEED)


In [5]:
# === Optimizer Configuration ===
from torch.optim.lr_scheduler import LambdaLR

# === Hardware Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Model Initialization ===
def get_model():
    model = AutoModelForSequenceClassification.from_pretrained(
        'Sifal/ClinicalMosaic',
        num_labels=3,
        torch_dtype='auto',
        trust_remote_code=True,
    )
    return model.to(device)

# === Optimizer Configuration ===
def get_model_and_optimizer(
                            learning_rate,
                            classifier_lr,
                            weight_decay,
                            beta1, beta2,
                            eps,
                            num_epochs,
                            num_warmup_epochs,
                            freeze_backbone=True,
                            decay=False
    ):
    
    model = get_model()

    if freeze_backbone:
        for param in model.embedder.parameters():
            param.requires_grad = False

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if "classifier" not in n and not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
            "lr": learning_rate,
        },
        {
            "params": [p for n, p in model.named_parameters() if "classifier" not in n and any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
            "lr": learning_rate,
        },
        {
            "params": [p for n, p in model.named_parameters() if "classifier" in n],
            "weight_decay": weight_decay,
            "lr": classifier_lr,
        },
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        betas=(beta1, beta2),
        eps=eps,
    )

    def warmup_stable_decay(epoch):
        if epoch < num_warmup_epochs:
            return epoch / num_warmup_epochs  # Linear warmup
        elif epoch < num_epochs * 0.75:  # Keep stable for 75% of training
            return 1.0
        else:
            return max(0.1, (num_epochs - epoch) / (num_epochs * 0.25))  # Decay over last 25%

    scheduler = LambdaLR(optimizer, lr_lambda=warmup_stable_decay)

    return model, optimizer, scheduler
    
# === Data Loading ===
class MedNLIDataset(Dataset):
    def __init__(self, filename, tokenizer):
        self.data = []
        with open(filename, 'r') as f:
            for line in f:
                item = json.loads(line)
                self.data.append(item)
        self.tokenizer = tokenizer
        self.label_map = {'entailment': 0, 'contradiction': 1, 'neutral': 2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        encoded = self.tokenizer(
            item['sentence1'],
            item['sentence2'],
            truncation=True,
            padding='max_length',
            max_length=128,
            return_tensors="pt"
        )
        return {
            **{key: val.squeeze(0) for key, val in encoded.items()},  # Removing batch dimension
            'labels': torch.tensor(self.label_map[item['gold_label']], dtype=torch.long)
        }

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


In [13]:
config = { 
    'batch_size': 64,
     'eval_batch_size': 64,         
}


In [14]:
# === Training Setup ===
tokenizer = AutoTokenizer.from_pretrained('Sifal/ClinicalMosaic')

# Create datasets
train_dataset = MedNLIDataset('mli_train_v1.jsonl', tokenizer)
dev_dataset = MedNLIDataset('mli_dev_v1.jsonl', tokenizer)
test_dataset = MedNLIDataset('mli_test_v1.jsonl', tokenizer)


# Deterministic data loaders
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=0,
    worker_init_fn=seed_worker,
    generator=g
)

dev_loader = DataLoader(
    dev_dataset,
    batch_size=config['eval_batch_size'],
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['eval_batch_size'],
    num_workers=0
)


In [15]:
config = {
    'learning_rate': 2e-5,
    'classifier_lr': 2e-4,
    'weight_decay':  1e-6,          
    'beta1': 0.9,
    'beta2': 0.98,                 
    'eps': 1e-6,                   
    'batch_size': 64,
    'eval_batch_size': 64,         
    'max_grad_norm': 1.0,
    'num_epochs': 40,
    'num_warmup_epochs' : 5,
    'early_stop_patience': 10,
    'freeze_backbone' : False,
    'decay' : True,
}

In [16]:
# Initialize components
model, optimizer, scheduler = get_model_and_optimizer(
    config['learning_rate'],
    config['classifier_lr'],
    config['weight_decay'],
    config['beta1'],
    config['beta2'],
    config['eps'],
    config['num_epochs'],
    config['num_warmup_epochs'],
    freeze_backbone = False,
    decay = True   
)

You are using a model of type bert to instantiate a model of type clinical_mosaic. This is not supported for all configurations of models and can yield errors.


Checkpoint does not contain the classification layer (768x3 + 3 = 2307 params). It will be randomly initialized.


In [17]:
# === Enhanced Training Loop ===
best_accuracy = 0
patience_counter = 0
total_steps = len(train_loader) * config['num_epochs']

for epoch in range(config['num_epochs']):
    model.train()
    total_loss = 0
    
    for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**inputs)
        loss = outputs.loss
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
        optimizer.step()
        
        total_loss += loss.item()
    
    scheduler.step()
    
    avg_loss = total_loss / len(train_loader)  # Normalize loss
    # Enhanced Validation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in dev_loader:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            preds = torch.argmax(outputs.logits, dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch['labels'].cpu().numpy())

    report = classification_report(
        all_labels, 
        all_preds, 
        target_names=['entailment', 'contradiction', 'neutral'],
        output_dict=True
    )
    
    current_accuracy = report.get('accuracy', 0.0)  # Safer accuracy extraction
    print(f"\nEpoch {epoch+1} | Avg Loss: {avg_loss:.4f}")
    print(f"Validation Accuracy: {current_accuracy:.4f}")
    print(classification_report(all_labels, all_preds))

    # Early stopping with model checkpointing
    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        patience_counter += 1
        if patience_counter >= config['early_stop_patience']:
            print("\nEarly stopping triggered")
            break

Epoch 1: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 1 | Avg Loss: 1.1061
Validation Accuracy: 0.3427
              precision    recall  f1-score   support

           0       0.33      0.55      0.42       465
           1       0.36      0.23      0.28       465
           2       0.35      0.25      0.29       465

    accuracy                           0.34      1395
   macro avg       0.35      0.34      0.33      1395
weighted avg       0.35      0.34      0.33      1395



Epoch 2: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 2 | Avg Loss: 1.1022
Validation Accuracy: 0.3462
              precision    recall  f1-score   support

           0       0.34      0.46      0.39       465
           1       0.38      0.17      0.23       465
           2       0.34      0.41      0.37       465

    accuracy                           0.35      1395
   macro avg       0.35      0.35      0.33      1395
weighted avg       0.35      0.35      0.33      1395



Epoch 3: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 3 | Avg Loss: 1.0916
Validation Accuracy: 0.4903
              precision    recall  f1-score   support

           0       0.44      0.09      0.15       465
           1       0.60      0.60      0.60       465
           2       0.44      0.78      0.56       465

    accuracy                           0.49      1395
   macro avg       0.49      0.49      0.44      1395
weighted avg       0.49      0.49      0.44      1395



Epoch 4: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 4 | Avg Loss: 0.8609
Validation Accuracy: 0.7025
              precision    recall  f1-score   support

           0       0.65      0.52      0.58       465
           1       0.87      0.84      0.85       465
           2       0.60      0.75      0.67       465

    accuracy                           0.70      1395
   macro avg       0.71      0.70      0.70      1395
weighted avg       0.71      0.70      0.70      1395



Epoch 5: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 5 | Avg Loss: 0.6261
Validation Accuracy: 0.8100
              precision    recall  f1-score   support

           0       0.75      0.78      0.77       465
           1       0.91      0.86      0.89       465
           2       0.78      0.78      0.78       465

    accuracy                           0.81      1395
   macro avg       0.81      0.81      0.81      1395
weighted avg       0.81      0.81      0.81      1395



Epoch 6: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 6 | Avg Loss: 0.4848
Validation Accuracy: 0.8358
              precision    recall  f1-score   support

           0       0.77      0.83      0.80       465
           1       0.92      0.86      0.89       465
           2       0.82      0.82      0.82       465

    accuracy                           0.84      1395
   macro avg       0.84      0.84      0.84      1395
weighted avg       0.84      0.84      0.84      1395



Epoch 7: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 7 | Avg Loss: 0.3680
Validation Accuracy: 0.8366
              precision    recall  f1-score   support

           0       0.81      0.79      0.80       465
           1       0.94      0.87      0.90       465
           2       0.78      0.85      0.81       465

    accuracy                           0.84      1395
   macro avg       0.84      0.84      0.84      1395
weighted avg       0.84      0.84      0.84      1395



Epoch 8: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 8 | Avg Loss: 0.2810
Validation Accuracy: 0.8437
              precision    recall  f1-score   support

           0       0.82      0.79      0.81       465
           1       0.91      0.90      0.91       465
           2       0.80      0.84      0.82       465

    accuracy                           0.84      1395
   macro avg       0.84      0.84      0.84      1395
weighted avg       0.84      0.84      0.84      1395



Epoch 9: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 9 | Avg Loss: 0.2024
Validation Accuracy: 0.8480
              precision    recall  f1-score   support

           0       0.82      0.82      0.82       465
           1       0.90      0.91      0.91       465
           2       0.82      0.82      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 10: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 10 | Avg Loss: 0.1533
Validation Accuracy: 0.8430
              precision    recall  f1-score   support

           0       0.82      0.81      0.82       465
           1       0.90      0.91      0.90       465
           2       0.81      0.81      0.81       465

    accuracy                           0.84      1395
   macro avg       0.84      0.84      0.84      1395
weighted avg       0.84      0.84      0.84      1395



Epoch 11: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 11 | Avg Loss: 0.1067
Validation Accuracy: 0.8466
              precision    recall  f1-score   support

           0       0.83      0.81      0.82       465
           1       0.90      0.91      0.90       465
           2       0.81      0.82      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 12: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 12 | Avg Loss: 0.0882
Validation Accuracy: 0.8452
              precision    recall  f1-score   support

           0       0.80      0.83      0.82       465
           1       0.91      0.91      0.91       465
           2       0.83      0.79      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 13: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 13 | Avg Loss: 0.0687
Validation Accuracy: 0.8566
              precision    recall  f1-score   support

           0       0.81      0.86      0.83       465
           1       0.95      0.89      0.92       465
           2       0.82      0.83      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 14: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 14 | Avg Loss: 0.0617
Validation Accuracy: 0.8559
              precision    recall  f1-score   support

           0       0.81      0.86      0.83       465
           1       0.93      0.91      0.92       465
           2       0.83      0.80      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 15: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 15 | Avg Loss: 0.0558
Validation Accuracy: 0.8502
              precision    recall  f1-score   support

           0       0.79      0.85      0.82       465
           1       0.93      0.91      0.92       465
           2       0.83      0.80      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 16: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 16 | Avg Loss: 0.0496
Validation Accuracy: 0.8573
              precision    recall  f1-score   support

           0       0.82      0.83      0.83       465
           1       0.93      0.91      0.92       465
           2       0.82      0.83      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 17: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 17 | Avg Loss: 0.0393
Validation Accuracy: 0.8416
              precision    recall  f1-score   support

           0       0.80      0.84      0.82       465
           1       0.92      0.89      0.91       465
           2       0.81      0.79      0.80       465

    accuracy                           0.84      1395
   macro avg       0.84      0.84      0.84      1395
weighted avg       0.84      0.84      0.84      1395



Epoch 18: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 18 | Avg Loss: 0.0395
Validation Accuracy: 0.8509
              precision    recall  f1-score   support

           0       0.84      0.81      0.82       465
           1       0.91      0.91      0.91       465
           2       0.80      0.83      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 19: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 19 | Avg Loss: 0.0383
Validation Accuracy: 0.8581
              precision    recall  f1-score   support

           0       0.84      0.82      0.83       465
           1       0.93      0.91      0.92       465
           2       0.81      0.84      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 20: 100%|██████████| 176/176 [01:32<00:00,  1.89it/s]



Epoch 20 | Avg Loss: 0.0425
Validation Accuracy: 0.8573
              precision    recall  f1-score   support

           0       0.83      0.83      0.83       465
           1       0.92      0.91      0.91       465
           2       0.83      0.83      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 21: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 21 | Avg Loss: 0.0359
Validation Accuracy: 0.8502
              precision    recall  f1-score   support

           0       0.79      0.87      0.83       465
           1       0.91      0.92      0.91       465
           2       0.86      0.77      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 22: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 22 | Avg Loss: 0.0270
Validation Accuracy: 0.8545
              precision    recall  f1-score   support

           0       0.80      0.87      0.84       465
           1       0.92      0.91      0.91       465
           2       0.85      0.78      0.81       465

    accuracy                           0.85      1395
   macro avg       0.86      0.85      0.85      1395
weighted avg       0.86      0.85      0.85      1395



Epoch 23: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 23 | Avg Loss: 0.0298
Validation Accuracy: 0.8573
              precision    recall  f1-score   support

           0       0.82      0.84      0.83       465
           1       0.92      0.91      0.92       465
           2       0.83      0.82      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 24: 100%|██████████| 176/176 [01:32<00:00,  1.91it/s]



Epoch 24 | Avg Loss: 0.0278
Validation Accuracy: 0.8566
              precision    recall  f1-score   support

           0       0.85      0.81      0.83       465
           1       0.92      0.90      0.91       465
           2       0.80      0.86      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 25: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 25 | Avg Loss: 0.0271
Validation Accuracy: 0.8502
              precision    recall  f1-score   support

           0       0.82      0.83      0.83       465
           1       0.91      0.91      0.91       465
           2       0.82      0.81      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 26: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 26 | Avg Loss: 0.0225
Validation Accuracy: 0.8559
              precision    recall  f1-score   support

           0       0.83      0.83      0.83       465
           1       0.92      0.91      0.92       465
           2       0.82      0.83      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 27: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 27 | Avg Loss: 0.0221
Validation Accuracy: 0.8645
              precision    recall  f1-score   support

           0       0.84      0.85      0.84       465
           1       0.92      0.91      0.91       465
           2       0.84      0.84      0.84       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 28: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 28 | Avg Loss: 0.0243
Validation Accuracy: 0.8559
              precision    recall  f1-score   support

           0       0.84      0.82      0.83       465
           1       0.90      0.93      0.91       465
           2       0.83      0.82      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 29: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 29 | Avg Loss: 0.0223
Validation Accuracy: 0.8616
              precision    recall  f1-score   support

           0       0.85      0.82      0.83       465
           1       0.90      0.93      0.92       465
           2       0.83      0.83      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 30: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 30 | Avg Loss: 0.0235
Validation Accuracy: 0.8538
              precision    recall  f1-score   support

           0       0.82      0.85      0.83       465
           1       0.92      0.89      0.90       465
           2       0.83      0.83      0.83       465

    accuracy                           0.85      1395
   macro avg       0.86      0.85      0.85      1395
weighted avg       0.86      0.85      0.85      1395



Epoch 31: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 31 | Avg Loss: 0.0194
Validation Accuracy: 0.8495
              precision    recall  f1-score   support

           0       0.79      0.85      0.82       465
           1       0.92      0.91      0.92       465
           2       0.84      0.79      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 32: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 32 | Avg Loss: 0.0219
Validation Accuracy: 0.8509
              precision    recall  f1-score   support

           0       0.83      0.82      0.82       465
           1       0.91      0.90      0.90       465
           2       0.82      0.83      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 33: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 33 | Avg Loss: 0.0144
Validation Accuracy: 0.8516
              precision    recall  f1-score   support

           0       0.79      0.87      0.83       465
           1       0.92      0.91      0.91       465
           2       0.85      0.78      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 34: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 34 | Avg Loss: 0.0113
Validation Accuracy: 0.8602
              precision    recall  f1-score   support

           0       0.81      0.86      0.83       465
           1       0.91      0.92      0.92       465
           2       0.86      0.80      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 35: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 35 | Avg Loss: 0.0064
Validation Accuracy: 0.8616
              precision    recall  f1-score   support

           0       0.85      0.82      0.84       465
           1       0.92      0.92      0.92       465
           2       0.82      0.85      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 36: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 36 | Avg Loss: 0.0074
Validation Accuracy: 0.8530
              precision    recall  f1-score   support

           0       0.82      0.83      0.82       465
           1       0.92      0.91      0.91       465
           2       0.82      0.83      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 37: 100%|██████████| 176/176 [01:32<00:00,  1.90it/s]



Epoch 37 | Avg Loss: 0.0062
Validation Accuracy: 0.8523
              precision    recall  f1-score   support

           0       0.80      0.87      0.83       465
           1       0.91      0.91      0.91       465
           2       0.85      0.78      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395


Early stopping triggered
