In [1]:
from datasets import load_dataset, ClassLabel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, EarlyStoppingCallback
from torch.utils.data import DataLoader
from evaluate import load
import pandas as pd
import numpy as np 



In [2]:
#######################
### I. Preprocess Data
### Balance using up and downsampling create training/validation   

DATA_DIR = "/teamspace/studios/this_studio/data/data_excluding_5"

ds = load_dataset(
    "csv",
    data_files={
        "train": f"{DATA_DIR}/medical_tc_train_raw_excl_5.csv",
        "test":  f"{DATA_DIR}/medical_tc_test_raw_excl_5.csv",
    }
)

# 3) Map labels to 0..3 and cast for stratification
ds = ds.map(lambda x: {"labels": int(x["condition_label"]) - 1})
ds = ds.cast_column("labels", ClassLabel(num_classes=4))

# 4) Balance TRAIN with up/down sampling to a common target (mean count)
def balance_dataset(hfds, label_col="labels", seed=42, strategy="mean"):
    labels = np.array(hfds[label_col])
    classes, counts = np.unique(labels, return_counts=True)
    if strategy == "max":
        target = counts.max()
    elif strategy == "min":
        target = counts.min()
    elif strategy == "mean":
        target = int(round(counts.mean()))
    else:
        target = int(strategy)  # allow int target
    rng = np.random.default_rng(seed)
    new_idx = []
    for c in classes:
        idx = np.flatnonzero(labels == c)
        if len(idx) > target:    # downsample
            sel = rng.choice(idx, size=target, replace=False)
        elif len(idx) < target:  # upsample
            sel = rng.choice(idx, size=target, replace=True)
        else:
            sel = idx
        new_idx.extend(sel.tolist())
    rng.shuffle(new_idx)
    return hfds.select(new_idx)

balanced_train_raw = balance_dataset(ds["train"], label_col="labels", seed=42, strategy="mean")

# 5) Split balanced TRAIN -> 80/20 (train/val), stratified
split = balanced_train_raw.train_test_split(test_size=0.2, stratify_by_column="labels", seed=42)
train_raw = split["train"]
val_raw   = split["test"]
test_raw  = ds["test"]        # test kept as-is (but class 5 already removed)

# 6) Tokenize AFTER split
tok = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")

def tokenize_fn(b):
    return tok(b["medical_abstract"], truncation=True, max_length=512)

cols_to_remove = ["medical_abstract", "condition_label"]

train_ds = train_raw.map(tokenize_fn, batched=True, remove_columns=cols_to_remove)
val_ds   = val_raw.map(tokenize_fn,   batched=True, remove_columns=cols_to_remove)
test_ds  = test_raw.map(tokenize_fn,  batched=True, remove_columns=cols_to_remove)

# 7) Collator (Tensor Core-friendly)
collator = DataCollatorWithPadding(tok, pad_to_multiple_of=8)

# quick sanity check
batch = next(iter(DataLoader(train_ds, batch_size=2, collate_fn=collator)))
print(batch.keys()); print(batch["labels"].shape)


KeysView({'labels': tensor([2, 0]), 'input_ids': tensor([[    2,  2564,  1927,  2861,  2259, 29043,  2053,  1990,  5067,  4976,
          1922,  2132,  1956,  7529, 21038,  1022, 13821, 18698,    18,  2038,
          3620,  2132,  1956,  7529, 21038,  1022, 13821, 18698,  1930,  2285,
          3297,  1990,    43,  3909,  1927, 19331,  8228,  4085,  4976,    18,
          4845, 16757,  2007,  8972,  1990,  1920,  2841,  4159,  1985, 10266,
          3713,  1954, 22701,    31,  1930,  2690,  1990,  1920,  8179,  4159,
            16, 27570,  1954, 22701,    18,  1958,  2136,  3297,    16,  3713,
          1954, 22701,  4845,  1985, 10711,  1942,  8411,  2254,  2690,  1988,
          1985, 27570,  1954, 22701,    18,  1920,  3175,  2192,  3227,  3586,
          1958, 27570,  1954,  1930,  3713,  1954, 22701,  4845,    16,  2154,
          1977,  7149,  1942,  2931, 13689,  5829,    16,  1982,  2567,  2502,
          1958,  1920,  2132,    18,  1920,  3604,  1927,  1920,  3562,  1990,
   

In [3]:

##################################################
#### Model Initialization
########################################################

# Initialize a BERT model for multilabel classification
model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)

print(model.config)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "dtype": "float32",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.56.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}



In [4]:
##############################
##### Definining several metrics to compare model quality
##################### 

acc = load("accuracy")
f1  = load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    return {
        "accuracy": acc.compute(predictions=preds, references=labels)["accuracy"],
        "f1_macro": f1.compute(predictions=preds, references=labels, average="macro")["f1"],
        "f1_weighted": f1.compute(predictions=preds, references=labels, average="weighted")["f1"],
    }

***Fine tune whole model***

In [5]:
'''#################################
#### Fine Tune whole model Training config
########################


training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,   
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    weight_decay=0.1,
    warmup_ratio=0.06,
    lr_scheduler_type="linear",
    
    seed=42,
    save_total_limit=2,
    load_best_model_at_end=True,
    logging_dir="./logs",
    logging_steps=100,
    fp16=True,                  
    metric_for_best_model="eval_loss",
    greater_is_better=False
)

trainer = Trainer(
    model=model,                        # Pre-trained BERT model
    args=training_args,                 # Training arguments
    train_dataset=train_ds,             
    eval_dataset=val_ds,
    tokenizer=tok,
    data_collator=collator,        # Efficient batching
    compute_metrics=compute_metrics,   # Custom metric
    callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
)

# Start training
trainer.train()

#Evaluate on Test Set
trainer.evaluate(test_ds)
'''

'#################################\n#### Fine Tune whole model Training config\n########################\n\n\ntraining_args = TrainingArguments(\n    output_dir="./results",\n    eval_strategy="epoch",\n    save_strategy="epoch",\n    learning_rate=2e-5,\n    per_device_train_batch_size=32,   \n    per_device_eval_batch_size=32,\n    num_train_epochs=10,\n    weight_decay=0.1,\n    warmup_ratio=0.06,\n    lr_scheduler_type="linear",\n    \n    seed=42,\n    save_total_limit=2,\n    load_best_model_at_end=True,\n    logging_dir="./logs",\n    logging_steps=100,\n    fp16=True,                  \n    metric_for_best_model="eval_loss",\n    greater_is_better=False\n)\n\ntrainer = Trainer(\n    model=model,                        # Pre-trained BERT model\n    args=training_args,                 # Training arguments\n    train_dataset=train_ds,             \n    eval_dataset=val_ds,\n    tokenizer=tok,\n    data_collator=collator,        # Efficient batching\n    compute_metrics=compute_met

**Attempt**
***Train only classifier Head***

In [6]:
####################  
## Training only classifier head, freeze all the other parameters
# Uncomment when training whole model
####################

# 1) freeze the encoder (backbone)
for p in model.bert.parameters():         # BiomedBERT is a BERT; backbone is `model.bert`
    p.requires_grad = False

# 2) keep the classifier trainable
for p in model.classifier.parameters():
    p.requires_grad = True

# (optional) sanity check
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"trainable params: {trainable:,} / {total:,}")


trainable params: 3,076 / 109,485,316


In [9]:
###################### 
###### Head only Trainer Config
########################


training_args = TrainingArguments(
    output_dir="./results_head_only",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-4,              # 5e-4 ~ 1e-3 typical; lower if unstable
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=20,              # ES will stop earlier if needed
    weight_decay=0.0,               
    warmup_ratio=0.1,
    lr_scheduler_type="linear",
    label_smoothing_factor=0.1,      # optional regularization
    seed=42,
    save_total_limit=1,
    load_best_model_at_end=True,
    logging_dir="./logs",
    logging_steps=100,
    fp16=True,                       # fine on T4; if loss blows up, try 5e-4 or fp16=False
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

trainer = Trainer(  # or your WeightedTrainer
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tok,
    data_collator=DataCollatorWithPadding(tok, pad_to_multiple_of=8),
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
trainer.train()
trainer.evaluate(test_ds)


  trainer = Trainer(  # or your WeightedTrainer


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Weighted
1,1.1158,1.093496,0.663855,0.661903,0.661909
2,1.0902,1.046268,0.687865,0.687981,0.687979
3,1.043,1.003937,0.699546,0.699617,0.699621
4,1.0085,0.976388,0.714471,0.7134,0.713406
5,0.9865,0.94931,0.722907,0.722808,0.722813
6,0.9718,0.936698,0.729396,0.729414,0.729412
7,0.9598,0.921106,0.738482,0.738242,0.73825
8,0.9386,0.910657,0.73913,0.739184,0.739192
9,0.9301,0.902051,0.741726,0.741766,0.741773
10,0.9195,0.896039,0.743673,0.743377,0.743388


{'eval_loss': 0.8576733469963074,
 'eval_accuracy': 0.7586922677737415,
 'eval_f1_macro': 0.7414171167519136,
 'eval_f1_weighted': 0.7611282720361802,
 'eval_runtime': 12.4216,
 'eval_samples_per_second': 155.133,
 'eval_steps_per_second': 4.911,
 'epoch': 20.0}