In [1]:
# Download dataset first: https://physionet.org/content/mednli/1.0.0/
# Requires credentialed access

In [1]:
!unzip mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0.zip

Archive:  mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0.zip
  inflating: mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0/LICENSE.txt  
  inflating: mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0/README.txt  
  inflating: mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0/SHA256SUMS.txt  
  inflating: mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0/mli_dev_v1.jsonl  
  inflating: mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0/mli_test_v1.jsonl  
  inflating: mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0/mli_train_v1.jsonl  


In [None]:
cp -r /kaggle/input/mednli/mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0 ./

In [1]:
cd mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0

/content/mednli-a-natural-language-inference-dataset-for-the-clinical-domain-1.0.0


In [2]:
import os
import json
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import classification_report
from tqdm import tqdm
from torch.optim import AdamW

# === Determinism Configuration ===
SEED = 55212
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
np.random.seed(SEED)
random.seed(SEED)


In [13]:
# === Optimizer Configuration ===
from torch.optim.lr_scheduler import LambdaLR

# === Hardware Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Model Initialization ===
def get_model():
    model = AutoModelForSequenceClassification.from_pretrained(
        'Sifal/ClinicalMosaic',
        num_labels=3,
        torch_dtype='auto',
        trust_remote_code=True,
    )
    return model.to(device)

# === Optimizer Configuration ===
def get_model_and_optimizer(
                            learning_rate,
                            classifier_lr,
                            weight_decay,
                            beta1, beta2,
                            eps,
                            num_epochs,
                            num_warmup_epochs,
                            freeze_backbone=True,
                            decay=False
    ):

    model = get_model()

    if freeze_backbone:
        for param in model.embedder.parameters():
            param.requires_grad = False

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if "classifier" not in n and "pooler" not in n and not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
        "lr": learning_rate,
    },
    {
        "params": [p for n, p in model.named_parameters() if "classifier" not in n and "pooler" not in n and any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
        "lr": learning_rate,
    },
    {
        "params": [p for n, p in model.named_parameters() if "classifier" in n or "pooler" in n],
        "weight_decay": weight_decay * 10,
        "lr": classifier_lr,
    },
]

    optimizer = AdamW(
        optimizer_grouped_parameters,
        betas=(beta1, beta2),
        eps=eps,
    )

    def warmup_stable_decay(epoch):
        if epoch < num_warmup_epochs:
            return epoch / num_warmup_epochs  # Linear warmup
        elif epoch < num_epochs * 0.75:  # Keep stable for 75% of training
            return 1.0
        else:
            return max(0.1, (num_epochs - epoch) / (num_epochs * 0.25))  # Decay over last 25%

    scheduler = LambdaLR(optimizer, lr_lambda=warmup_stable_decay)

    return model, optimizer, scheduler

# === Data Loading ===
class MedNLIDataset(Dataset):
    def __init__(self, filename, tokenizer):
        self.data = []
        with open(filename, 'r') as f:
            for line in f:
                item = json.loads(line)
                self.data.append(item)
        self.tokenizer = tokenizer
        self.label_map = {'entailment': 0, 'contradiction': 1, 'neutral': 2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        encoded = self.tokenizer(
            item['sentence1'],
            item['sentence2'],
            truncation=True,
            padding='max_length',
            max_length=128,
            return_tensors="pt"
        )
        return {
            **{key: val.squeeze(0) for key, val in encoded.items()},  # Removing batch dimension
            'labels': torch.tensor(self.label_map[item['gold_label']], dtype=torch.long)
        }

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


In [4]:
config = {
    'batch_size': 128,
     'eval_batch_size': 64,
}

In [5]:
# === Training Setup ===
tokenizer = AutoTokenizer.from_pretrained('Sifal/ClinicalMosaic')

# Create datasets
train_dataset = MedNLIDataset('mli_train_v1.jsonl', tokenizer)
dev_dataset = MedNLIDataset('mli_dev_v1.jsonl', tokenizer)
test_dataset = MedNLIDataset('mli_test_v1.jsonl', tokenizer)


# Deterministic data loaders
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=0,
    worker_init_fn=seed_worker,
    generator=g
)

dev_loader = DataLoader(
    dev_dataset,
    batch_size=config['eval_batch_size'],
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['eval_batch_size'],
    num_workers=0
)


In [11]:
config = {
    'learning_rate': 2e-5,
    'classifier_lr': 2e-4,
    'weight_decay':  2e-6,
    'beta1': 0.9,
    'beta2': 0.98,
    'eps': 1e-6,
    'batch_size': 128,
    'eval_batch_size': 64,
    'max_grad_norm': 1.5,
    'num_epochs': 60,
    'num_warmup_epochs' : 4,
    'early_stop_patience': 15,
    'freeze_backbone' : False,
    'decay' : True,
}

In [14]:
# Initialize components
model, optimizer, scheduler = get_model_and_optimizer(
    config['learning_rate'],
    config['classifier_lr'],
    config['weight_decay'],
    config['beta1'],
    config['beta2'],
    config['eps'],
    config['num_epochs'],
    config['num_warmup_epochs'],
    freeze_backbone = False,
    decay = True
)

You are using a model of type bert to instantiate a model of type clinical_mosaic. This is not supported for all configurations of models and can yield errors.


Checkpoint does not contain the classification layer (768x3 + 3 = 2307 params). It will be randomly initialized.


In [9]:
model

ClinicalMosaicForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30528, 768, padding_idx=0)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertUnpadAttention(
            (self): BertUnpadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (mlp): BertGatedLinearUnitMLP(
            (gated_layers): Lin

In [15]:
# === Enhanced Training Loop ===
best_accuracy = 0
patience_counter = 0
total_steps = len(train_loader) * config['num_epochs']

for epoch in range(config['num_epochs']):
    model.train()
    total_loss = 0

    for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**inputs)
        loss = outputs.loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
        optimizer.step()

        total_loss += loss.item()



    scheduler.step()

    avg_loss = total_loss / len(train_loader)  # Normalize loss
    # Enhanced Validation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in dev_loader:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            preds = torch.argmax(outputs.logits, dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch['labels'].cpu().numpy())

    report = classification_report(
        all_labels,
        all_preds,
        target_names=['entailment', 'contradiction', 'neutral'],
        output_dict=True
    )

    current_accuracy = report.get('accuracy', 0.0)  # Safer accuracy extraction
    print(f"\nEpoch {epoch+1} | Avg Loss: {avg_loss:.4f}")
    print(f"Validation Accuracy: {current_accuracy:.4f}")
    print(classification_report(all_labels, all_preds))

    # Early stopping with model checkpointing
    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        patience_counter += 1
        if patience_counter >= config['early_stop_patience']:
            print("\nEarly stopping triggered")
            break

Epoch 1: 100%|██████████| 88/88 [02:25<00:00,  1.65s/it]



Epoch 1 | Avg Loss: 1.1032
Validation Accuracy: 0.3455
              precision    recall  f1-score   support

           0       0.35      0.44      0.39       465
           1       0.34      0.21      0.26       465
           2       0.35      0.39      0.37       465

    accuracy                           0.35      1395
   macro avg       0.34      0.35      0.34      1395
weighted avg       0.34      0.35      0.34      1395



Epoch 2: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 2 | Avg Loss: 1.0992
Validation Accuracy: 0.4165
              precision    recall  f1-score   support

           0       0.37      0.29      0.33       465
           1       0.42      0.64      0.51       465
           2       0.45      0.32      0.37       465

    accuracy                           0.42      1395
   macro avg       0.42      0.42      0.40      1395
weighted avg       0.42      0.42      0.40      1395



Epoch 3: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 3 | Avg Loss: 1.0440
Validation Accuracy: 0.5326
              precision    recall  f1-score   support

           0       0.41      0.26      0.32       465
           1       0.66      0.69      0.68       465
           2       0.49      0.65      0.56       465

    accuracy                           0.53      1395
   macro avg       0.52      0.53      0.52      1395
weighted avg       0.52      0.53      0.52      1395



Epoch 4: 100%|██████████| 88/88 [02:28<00:00,  1.69s/it]



Epoch 4 | Avg Loss: 0.9401
Validation Accuracy: 0.6029
              precision    recall  f1-score   support

           0       0.54      0.28      0.37       465
           1       0.73      0.78      0.76       465
           2       0.53      0.75      0.62       465

    accuracy                           0.60      1395
   macro avg       0.60      0.60      0.58      1395
weighted avg       0.60      0.60      0.58      1395



Epoch 5: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 5 | Avg Loss: 1.0654
Validation Accuracy: 0.4409
              precision    recall  f1-score   support

           0       0.38      0.43      0.40       465
           1       0.46      0.60      0.52       465
           2       0.51      0.29      0.37       465

    accuracy                           0.44      1395
   macro avg       0.45      0.44      0.43      1395
weighted avg       0.45      0.44      0.43      1395



Epoch 6: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 6 | Avg Loss: 1.1049
Validation Accuracy: 0.4953
              precision    recall  f1-score   support

           0       0.44      0.21      0.28       465
           1       0.77      0.52      0.62       465
           2       0.41      0.76      0.53       465

    accuracy                           0.50      1395
   macro avg       0.54      0.50      0.48      1395
weighted avg       0.54      0.50      0.48      1395



Epoch 7: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 7 | Avg Loss: 0.8722
Validation Accuracy: 0.6767
              precision    recall  f1-score   support

           0       0.65      0.43      0.52       465
           1       0.85      0.82      0.84       465
           2       0.57      0.78      0.66       465

    accuracy                           0.68      1395
   macro avg       0.69      0.68      0.67      1395
weighted avg       0.69      0.68      0.67      1395



Epoch 8: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 8 | Avg Loss: 0.6215
Validation Accuracy: 0.8072
              precision    recall  f1-score   support

           0       0.74      0.80      0.77       465
           1       0.91      0.86      0.88       465
           2       0.78      0.77      0.77       465

    accuracy                           0.81      1395
   macro avg       0.81      0.81      0.81      1395
weighted avg       0.81      0.81      0.81      1395



Epoch 9: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 9 | Avg Loss: 0.4666
Validation Accuracy: 0.8287
              precision    recall  f1-score   support

           0       0.75      0.85      0.80       465
           1       0.90      0.90      0.90       465
           2       0.84      0.74      0.79       465

    accuracy                           0.83      1395
   macro avg       0.83      0.83      0.83      1395
weighted avg       0.83      0.83      0.83      1395



Epoch 10: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 10 | Avg Loss: 0.3656
Validation Accuracy: 0.8344
              precision    recall  f1-score   support

           0       0.74      0.87      0.80       465
           1       0.93      0.87      0.90       465
           2       0.85      0.76      0.80       465

    accuracy                           0.83      1395
   macro avg       0.84      0.83      0.84      1395
weighted avg       0.84      0.83      0.84      1395



Epoch 11: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 11 | Avg Loss: 0.2847
Validation Accuracy: 0.8459
              precision    recall  f1-score   support

           0       0.86      0.75      0.80       465
           1       0.92      0.89      0.90       465
           2       0.78      0.90      0.83       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 12: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 12 | Avg Loss: 0.2228
Validation Accuracy: 0.8401
              precision    recall  f1-score   support

           0       0.84      0.77      0.80       465
           1       0.93      0.88      0.90       465
           2       0.77      0.87      0.81       465

    accuracy                           0.84      1395
   macro avg       0.84      0.84      0.84      1395
weighted avg       0.84      0.84      0.84      1395



Epoch 13: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 13 | Avg Loss: 0.1595
Validation Accuracy: 0.8509
              precision    recall  f1-score   support

           0       0.79      0.87      0.83       465
           1       0.92      0.91      0.91       465
           2       0.85      0.78      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 14: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 14 | Avg Loss: 0.1266
Validation Accuracy: 0.8573
              precision    recall  f1-score   support

           0       0.81      0.86      0.83       465
           1       0.93      0.90      0.92       465
           2       0.83      0.81      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 15: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 15 | Avg Loss: 0.1016
Validation Accuracy: 0.8530
              precision    recall  f1-score   support

           0       0.84      0.81      0.82       465
           1       0.91      0.92      0.92       465
           2       0.81      0.84      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 16: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 16 | Avg Loss: 0.0868
Validation Accuracy: 0.8437
              precision    recall  f1-score   support

           0       0.83      0.77      0.80       465
           1       0.92      0.91      0.91       465
           2       0.79      0.85      0.82       465

    accuracy                           0.84      1395
   macro avg       0.85      0.84      0.84      1395
weighted avg       0.85      0.84      0.84      1395



Epoch 17: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 17 | Avg Loss: 0.0682
Validation Accuracy: 0.8616
              precision    recall  f1-score   support

           0       0.83      0.83      0.83       465
           1       0.92      0.92      0.92       465
           2       0.83      0.84      0.84       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 18: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 18 | Avg Loss: 0.0586
Validation Accuracy: 0.8495
              precision    recall  f1-score   support

           0       0.83      0.79      0.81       465
           1       0.90      0.93      0.91       465
           2       0.82      0.83      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 19: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 19 | Avg Loss: 0.0562
Validation Accuracy: 0.8516
              precision    recall  f1-score   support

           0       0.84      0.79      0.82       465
           1       0.90      0.92      0.91       465
           2       0.81      0.84      0.83       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 20: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 20 | Avg Loss: 0.0495
Validation Accuracy: 0.8566
              precision    recall  f1-score   support

           0       0.82      0.82      0.82       465
           1       0.92      0.91      0.92       465
           2       0.83      0.84      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 21: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 21 | Avg Loss: 0.0377
Validation Accuracy: 0.8659
              precision    recall  f1-score   support

           0       0.86      0.81      0.83       465
           1       0.90      0.94      0.92       465
           2       0.83      0.85      0.84       465

    accuracy                           0.87      1395
   macro avg       0.87      0.87      0.87      1395
weighted avg       0.87      0.87      0.87      1395



Epoch 22: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 22 | Avg Loss: 0.0364
Validation Accuracy: 0.8624
              precision    recall  f1-score   support

           0       0.83      0.84      0.84       465
           1       0.91      0.93      0.92       465
           2       0.85      0.82      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 23: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 23 | Avg Loss: 0.0310
Validation Accuracy: 0.8609
              precision    recall  f1-score   support

           0       0.83      0.84      0.83       465
           1       0.93      0.91      0.92       465
           2       0.83      0.83      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 24: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 24 | Avg Loss: 0.0366
Validation Accuracy: 0.8459
              precision    recall  f1-score   support

           0       0.79      0.85      0.82       465
           1       0.90      0.92      0.91       465
           2       0.84      0.77      0.80       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 25: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 25 | Avg Loss: 0.0316
Validation Accuracy: 0.8502
              precision    recall  f1-score   support

           0       0.82      0.82      0.82       465
           1       0.91      0.92      0.92       465
           2       0.82      0.81      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 26: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 26 | Avg Loss: 0.0262
Validation Accuracy: 0.8552
              precision    recall  f1-score   support

           0       0.80      0.85      0.83       465
           1       0.92      0.92      0.92       465
           2       0.85      0.79      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 27: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 27 | Avg Loss: 0.0279
Validation Accuracy: 0.8573
              precision    recall  f1-score   support

           0       0.82      0.85      0.83       465
           1       0.90      0.93      0.91       465
           2       0.86      0.79      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 28: 100%|██████████| 88/88 [02:28<00:00,  1.69s/it]



Epoch 28 | Avg Loss: 0.0286
Validation Accuracy: 0.8552
              precision    recall  f1-score   support

           0       0.81      0.85      0.83       465
           1       0.91      0.93      0.92       465
           2       0.85      0.79      0.82       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.85      1395
weighted avg       0.86      0.86      0.85      1395



Epoch 29: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 29 | Avg Loss: 0.0228
Validation Accuracy: 0.8509
              precision    recall  f1-score   support

           0       0.85      0.78      0.81       465
           1       0.91      0.92      0.91       465
           2       0.80      0.86      0.83       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 30: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 30 | Avg Loss: 0.0247
Validation Accuracy: 0.8595
              precision    recall  f1-score   support

           0       0.83      0.82      0.82       465
           1       0.91      0.94      0.92       465
           2       0.83      0.82      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395



Epoch 31: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 31 | Avg Loss: 0.0185
Validation Accuracy: 0.8480
              precision    recall  f1-score   support

           0       0.81      0.81      0.81       465
           1       0.91      0.93      0.92       465
           2       0.82      0.81      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 32: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 32 | Avg Loss: 0.0208
Validation Accuracy: 0.8538
              precision    recall  f1-score   support

           0       0.80      0.85      0.82       465
           1       0.92      0.92      0.92       465
           2       0.85      0.79      0.82       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 33: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 33 | Avg Loss: 0.0204
Validation Accuracy: 0.8545
              precision    recall  f1-score   support

           0       0.80      0.85      0.83       465
           1       0.91      0.92      0.92       465
           2       0.85      0.78      0.82       465

    accuracy                           0.85      1395
   macro avg       0.86      0.85      0.85      1395
weighted avg       0.86      0.85      0.85      1395



Epoch 34: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 34 | Avg Loss: 0.0168
Validation Accuracy: 0.8487
              precision    recall  f1-score   support

           0       0.78      0.87      0.82       465
           1       0.93      0.90      0.91       465
           2       0.85      0.78      0.81       465

    accuracy                           0.85      1395
   macro avg       0.85      0.85      0.85      1395
weighted avg       0.85      0.85      0.85      1395



Epoch 35: 100%|██████████| 88/88 [02:27<00:00,  1.68s/it]



Epoch 35 | Avg Loss: 0.0145
Validation Accuracy: 0.8452
              precision    recall  f1-score   support

           0       0.82      0.80      0.81       465
           1       0.90      0.92      0.91       465
           2       0.81      0.81      0.81       465

    accuracy                           0.85      1395
   macro avg       0.84      0.85      0.84      1395
weighted avg       0.84      0.85      0.84      1395



Epoch 36: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



Epoch 36 | Avg Loss: 0.0246
Validation Accuracy: 0.8573
              precision    recall  f1-score   support

           0       0.82      0.84      0.83       465
           1       0.90      0.93      0.91       465
           2       0.85      0.80      0.83       465

    accuracy                           0.86      1395
   macro avg       0.86      0.86      0.86      1395
weighted avg       0.86      0.86      0.86      1395


Early stopping triggered
