In [1]:
import os
import json
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime

# Check GPU availability
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.1f} GB")
else:
    print("WARNING: No GPU detected. Training will be slow.")

print(f"PyTorch Version: {torch.__version__}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Imports
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, f1_score, classification_report

GPU: NVIDIA GeForce RTX 4050 Laptop GPU
GPU Memory: 6.0 GB
PyTorch Version: 2.6.0+cu124
Device: cuda


### Load Processed Data

In [2]:
DATA_PATH = "../data/classifier_v2"

# Load training data
with open(f"{DATA_PATH}/train.json", 'r') as f:
    train_data = json.load(f)
print(f"Train:      {len(train_data):,} records")

# Load validation data
with open(f"{DATA_PATH}/val.json", 'r') as f:
    val_data = json.load(f)
print(f"Validation: {len(val_data):,} records")

# Load test data
with open(f"{DATA_PATH}/test.json", 'r') as f:
    test_data = json.load(f)
print(f"Test:       {len(test_data):,} records")

# Load label mappings
with open(f"{DATA_PATH}/label2id.json", 'r') as f:
    label2id = json.load(f)

with open(f"{DATA_PATH}/id2label.json", 'r') as f:
    id2label = json.load(f)
    id2label = {int(k): v for k, v in id2label.items()}

print(f"\nClasses:    {len(label2id)}")

# Show sample
print("SAMPLE RECORD:")
print(f"Text: {train_data[0]['text'][:100]}...")
print(f"Condition: {train_data[0]['condition']}")
print(f"Age: {train_data[0]['age']}")
print(f"Sex: {train_data[0]['sex']}")

Train:      1,025,602 records
Validation: 132,448 records
Test:       134,529 records

Classes:    49
SAMPLE RECORD:
Text: Patient is a 18 year old Male presenting with: live with 4 or more people; has had significantly inc...
Condition: URTI
Age: 18
Sex: M


### Prepare Datasets for Training

In [3]:
def prepare_dataset(data):
    #Converting data to HuggingFace Dataset format.
    texts = [item['text'] for item in data]
    labels = [label2id[item['condition']] for item in data]
    return Dataset.from_dict({
        'text': texts,
        'label': labels,
    })

train_dataset = prepare_dataset(train_data)
val_dataset = prepare_dataset(val_data)
test_dataset = prepare_dataset(test_data)

dataset = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset,
})

print("Dataset structure:")
print(dataset)

Dataset structure:
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 1025602
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 132448
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 134529
    })
})


### Load Model and Tokenizer

In [4]:
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"

print(f"Loading model: {MODEL_NAME}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("Tokenizer loaded")

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id,
)
print("Model loaded")

model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

Loading model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract
Tokenizer loaded


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded
Total parameters: 109,519,921


### Tokenize Datasets

In [5]:
MAX_LENGTH = 512  # Increased from 256 since v2 texts are longer

def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=MAX_LENGTH,
    )

print(f"Tokenizing datasets (max_length={MAX_LENGTH})...")

tokenized_train = train_dataset.map(tokenize_function, batched=True)
print(f"Train tokenized: {len(tokenized_train):,}")

tokenized_val = val_dataset.map(tokenize_function, batched=True)
print(f"Val tokenized: {len(tokenized_val):,}")

tokenized_test = test_dataset.map(tokenize_function, batched=True)
print(f"Test tokenized: {len(tokenized_test):,}")

# Set format for PyTorch
tokenized_train.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
tokenized_val.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
tokenized_test.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

print("Tokenization complete")

Tokenizing datasets (max_length=512)...


Map:   0%|          | 0/1025602 [00:00<?, ? examples/s]

Train tokenized: 1,025,602


Map:   0%|          | 0/132448 [00:00<?, ? examples/s]

Val tokenized: 132,448


Map:   0%|          | 0/134529 [00:00<?, ? examples/s]

Test tokenized: 134,529
Tokenization complete


### Training Configuration

In [6]:
from sklearn.utils.class_weight import compute_class_weight
import torch.nn as nn

# Hyperparameters
BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 2e-5
NUM_EPOCHS = 1  # Start with 1 epoch, can continue if needed
WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.01

OUTPUT_DIR = "../models/condition_classifier_v2"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Calculate class weights
train_labels = [label2id[item['condition']] for item in train_data]
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_labels),
    y=train_labels
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

print("Class weights calculated:")
print(f"  Min weight: {class_weights.min().item():.4f}")
print(f"  Max weight: {class_weights.max().item():.4f}")
print(f"  Ratio (max/min): {class_weights.max().item() / class_weights.min().item():.2f}")

# Show weights for extreme classes
weights_with_labels = [(id2label[i], class_weights[i].item()) for i in range(len(class_weights))]
weights_sorted = sorted(weights_with_labels, key=lambda x: x[1], reverse=True)

print("\nTop 5 highest weighted (rare classes):")
for label, weight in weights_sorted[:5]:
    print(f"  {label}: {weight:.4f}")

print("\nTop 5 lowest weighted (common classes):")
for label, weight in weights_sorted[-5:]:
    print(f"  {label}: {weight:.4f}")

Class weights calculated:
  Min weight: 0.3252
  Max weight: 80.1941
  Ratio (max/min): 246.62

Top 5 highest weighted (rare classes):
  Bronchiolitis: 80.1941
  Ebola: 29.1513
  Croup: 7.3389
  Spontaneous rib fracture: 3.6643
  Whooping cough: 3.4482

Top 5 lowest weighted (common classes):
  Localized edema: 0.7522
  HIV (initial infection): 0.7214
  Anemia: 0.4131
  Viral pharyngitis: 0.3396
  URTI: 0.3252


### Setup Trainer

In [7]:
class WeightedTrainer(Trainer):
    def __init__(self, class_weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        loss_fn = nn.CrossEntropyLoss(weight=self.class_weights)
        loss = loss_fn(logits, labels)
        
        return (loss, outputs) if return_outputs else loss

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    accuracy = accuracy_score(labels, predictions)
    f1_macro = f1_score(labels, predictions, average='macro')
    f1_weighted = f1_score(labels, predictions, average='weighted')
    
    return {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
    }
    
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id,
)
model = model.to(device)
print("Model reloaded")

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",  # Changed to macro since we care about rare classes
    greater_is_better=True,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=500,
    fp16=True,
    dataloader_num_workers=4,
    report_to="none",
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

trainer = WeightedTrainer(
    class_weights=class_weights,
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

steps_per_epoch = len(tokenized_train) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
print(f"Steps per epoch: {steps_per_epoch:,}")
print(f"Trainer ready with weighted loss")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model reloaded
Steps per epoch: 16,025
Trainer ready with weighted loss


  super().__init__(*args, **kwargs)


### Training The Model

In [8]:
print(f"Training started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Estimated time: ~9-10 hours for 1 epoch")

start_time = datetime.now()
train_result = trainer.train()
end_time = datetime.now()

duration = end_time - start_time
print(f"\nTraining completed at: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Duration: {duration}")

print("\nTraining metrics:")
for key, value in train_result.metrics.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

Training started at: 2026-01-23 23:37:00
Estimated time: ~9-10 hours for 1 epoch


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Weighted
1,0.0099,0.011182,0.997184,0.996637,0.997152



Training completed at: 2026-01-24 08:49:49
Duration: 9:12:48.342353

Training metrics:
  train_runtime: 33168.1285
  train_samples_per_second: 30.9210
  train_steps_per_second: 0.4830
  total_flos: 269961098518566912.0000
  train_loss: 0.1176
  epoch: 1.0000


### Save Model

In [9]:
FINAL_MODEL_PATH = "../models/condition_classifier_v2/final"
os.makedirs(FINAL_MODEL_PATH, exist_ok=True)

trainer.save_model(FINAL_MODEL_PATH)
tokenizer.save_pretrained(FINAL_MODEL_PATH)

# Save label mappings with model
import shutil
shutil.copy("../data/classifier_v2/label2id.json", f"{FINAL_MODEL_PATH}/label2id.json")
shutil.copy("../data/classifier_v2/id2label.json", f"{FINAL_MODEL_PATH}/id2label.json")

# Save training metrics
training_metrics = {
    "epochs": NUM_EPOCHS,
    "batch_size": BATCH_SIZE,
    "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
    "effective_batch_size": BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
    "learning_rate": LEARNING_RATE,
    "max_length": MAX_LENGTH,
    "training_time": str(duration),
    "train_loss": train_result.metrics['train_loss'],
    "class_weights_used": True,
}

with open(f"{FINAL_MODEL_PATH}/training_metrics.json", 'w') as f:
    json.dump(training_metrics, f, indent=2)

print(f"Model saved to: {FINAL_MODEL_PATH}")
print("\nSaved files:")
for f in os.listdir(FINAL_MODEL_PATH):
    print(f"  {f}")

Model saved to: ../models/condition_classifier_v2/final

Saved files:
  config.json
  model.safetensors
  tokenizer_config.json
  special_tokens_map.json
  vocab.txt
  tokenizer.json
  training_args.bin
  label2id.json
  id2label.json
  training_metrics.json


### Evaluate on Test Set

In [10]:
print("Evaluating on test set...")

test_results = trainer.evaluate(tokenized_test)

print("\nTest Set Results:")
print(f"  Accuracy:    {test_results['eval_accuracy']:.4f} ({test_results['eval_accuracy']*100:.2f}%)")
print(f"  F1 Macro:    {test_results['eval_f1_macro']:.4f} ({test_results['eval_f1_macro']*100:.2f}%)")
print(f"  F1 Weighted: {test_results['eval_f1_weighted']:.4f} ({test_results['eval_f1_weighted']*100:.2f}%)")


Evaluating on test set...



Test Set Results:
  Accuracy:    0.9974 (99.74%)
  F1 Macro:    0.9969 (99.69%)
  F1 Weighted: 0.9974 (99.74%)


### Per-Class Analysis

In [11]:
from sklearn.metrics import classification_report, confusion_matrix

# Get predictions
print("Generating predictions for detailed analysis...")
predictions = trainer.predict(tokenized_test)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = predictions.label_ids

# Classification report
print("\nPer-Class Performance:")
print("-" * 60)

report = classification_report(
    y_true, 
    y_pred, 
    target_names=[id2label[i] for i in range(len(id2label))],
    digits=4
)
print(report)

# Save report
with open(f"{FINAL_MODEL_PATH}/classification_report.txt", 'w') as f:
    f.write(report)
print(f"\nReport saved to {FINAL_MODEL_PATH}/classification_report.txt")

Generating predictions for detailed analysis...

Per-Class Performance:
------------------------------------------------------------
                                          precision    recall  f1-score   support

     Acute COPD exacerbation / infection     1.0000    1.0000    1.0000      2153
                Acute dystonic reactions     1.0000    1.0000    1.0000      3302
                        Acute laryngitis     0.9911    0.9991    0.9950      3217
                      Acute otitis media     1.0000    1.0000    1.0000      3516
                   Acute pulmonary edema     1.0000    1.0000    1.0000      2598
                    Acute rhinosinusitis     0.9839    0.8671    0.9218      1829
                      Allergic sinusitis     1.0000    1.0000    1.0000      2411
                             Anaphylaxis     1.0000    1.0000    1.0000      3799
                                  Anemia     1.0000    1.0000    1.0000      6842
                     Atrial fibrillation     1

### Validate - Check Clean Test Performance

In [12]:
# Load texts for overlap check
train_texts = set(item['text'] for item in train_data)
test_texts = [item['text'] for item in test_data]

# Find non-overlapping test indices
clean_indices = [i for i, text in enumerate(test_texts) if text not in train_texts]

print(f"Total test samples: {len(test_data):,}")
print(f"Overlapping samples: {len(test_data) - len(clean_indices):,}")
print(f"Clean test samples: {len(clean_indices):,}")
print(f"Overlap percentage: {(len(test_data) - len(clean_indices)) / len(test_data) * 100:.2f}%")

# Get predictions for clean samples only
clean_y_true = [y_true[i] for i in clean_indices]
clean_y_pred = [y_pred[i] for i in clean_indices]

# Calculate clean metrics
from sklearn.metrics import accuracy_score, f1_score

clean_accuracy = accuracy_score(clean_y_true, clean_y_pred)
clean_f1_macro = f1_score(clean_y_true, clean_y_pred, average='macro')
clean_f1_weighted = f1_score(clean_y_true, clean_y_pred, average='weighted')

print("\nMetrics Comparison:")
print(f"{'Metric':<15} {'All Test':<15} {'Clean Test':<15} {'Difference':<15}")
print("-" * 60)
print(f"{'Accuracy':<15} {test_results['eval_accuracy']*100:<15.2f} {clean_accuracy*100:<15.2f} {(clean_accuracy - test_results['eval_accuracy'])*100:+.2f}%")
print(f"{'F1 Macro':<15} {test_results['eval_f1_macro']*100:<15.2f} {clean_f1_macro*100:<15.2f} {(clean_f1_macro - test_results['eval_f1_macro'])*100:+.2f}%")
print(f"{'F1 Weighted':<15} {test_results['eval_f1_weighted']*100:<15.2f} {clean_f1_weighted*100:<15.2f} {(clean_f1_weighted - test_results['eval_f1_weighted'])*100:+.2f}%")

Total test samples: 134,529
Overlapping samples: 4,901
Clean test samples: 129,628
Overlap percentage: 3.64%

Metrics Comparison:
Metric          All Test        Clean Test      Difference     
------------------------------------------------------------
Accuracy        99.74           99.73           -0.01%
F1 Macro        99.69           99.69           -0.00%
F1 Weighted     99.74           99.73           -0.01%


### Model Training v2 - Complete

**Model:** PubMedBERT (microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract)  
**Parameters:** 109.5M

#### Training Configuration

| Setting | Value |
|:--------|:------|
| Epochs | 1 |
| Batch size | 4 (effective: 64 with gradient accumulation) |
| Learning rate | 2e-5 |
| Max sequence length | 512 |
| Class weights | Yes (balanced) |
| Training time | 9h 12m |

#### Data Improvements (v2 vs v1)

| Change | Impact |
|:-------|:-------|
| Included ALL symptoms | More information for model |
| Included symptom values | Severity, location, duration captured |
| Included patient age and sex | Demographic context helps diagnosis |
| Train-test overlap reduced | 99.35% → 3.59% (fixed data leakage) |

#### Test Set Results

| Metric | Score |
|:-------|:------|
| Accuracy | 99.74% |
| F1 Macro | 99.69% |
| F1 Weighted | 99.74% |

#### Comparison: v1 vs v2

| Metric | v1 (with leakage) | v2 (clean) | Improvement |
|:-------|:------------------|:-----------|:------------|
| Accuracy | 98.22% | 99.74% | +1.52% |
| F1 Macro | 97.94% | 99.69% | +1.75% |
| F1 Weighted | 98.10% | 99.74% | +1.64% |

#### Weak Classes Improvement

| Condition | v1 F1 | v2 F1 | Improvement |
|:----------|:------|:------|:------------|
| Acute rhinosinusitis | 53.91% | 92.18% | +71% |
| Chronic rhinosinusitis | 82.64% | 95.29% | +15% |
| Stable angina | 92.39% | 99.00% | +7% |
| Possible NSTEMI / STEMI | 93.92% | 100.00% | +6% |

#### Rare Classes Performance (Class Weights Impact)

| Condition | Samples | F1 Score |
|:----------|:--------|:---------|
| Bronchiolitis | 36 | 100% |
| Ebola | 100 | 100% |
| Croup | 344 | 100% |
| Whooping cough | 549 | 100% |

#### Model Saved To
```
../models/condition_classifier_v2/final/
```