In [None]:
# packages 

import pandas as pd
import numpy as np
import torch
import re
import json
from pathlib import Path
from transformers import AutoTokenizer

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, hamming_loss 

from transformers import ( 
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)

In [None]:
# pull in labels and text data from Brian 

In [None]:
# Training/Val

X_train, X_val, y_train, y_val = train_test_split(
    texts, Y, test_size=0.2, random_state=69)

In [None]:
# Tokenizer + loading data

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
MAX_LEN = 256

class ChunkDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.float32)
        }
    
train_ds = ChunkDataset(X_train, y_train, tokenizer, MAX_LEN)
val_ds = ChunkDataset(X_val, y_val, tokenizer, MAX_LEN)

print("Train ds:", len(train_ds), "Val ds:", len(val_ds))

In [None]:
# model creation

model = AutoModelForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=Y.shape[1],
    problem_type='multi_label_classification'
)

In [None]:
# Metrics / performance

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = 1 / (1 + np.exp(-logits))  # Sigmoid to get probabilities
    preds = (probs > 0.5).astype(int)  # Thresholding

    micro_f1 = f1_score(labels, preds, average='micro', zero_division=0)
    macro_f1 = f1_score(labels, preds, average='macro', zero_division=0)
    hloss = hamming_loss(labels, preds)

    return {"micro_f1": micro_f1, "macro_f1": macro_f1, "hamming_loss": hloss}

In [None]:
# Training arugments

args = TrainingArguments(
    output_dir="bert_diablo_multilabel_checkpoints",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="micro_f1",
)

In [None]:
# Trainer 

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
# Save

def get_next_run_dir(base_name: str, root_dir: str = "models") -> Path:
    root = Path(root_dir)
    root.mkdir(parents=True, exist_ok=True)

    pattern = re.compile(rf"^{re.escape(base_name)}_(\d+)$")
    max_n = 0 

    for p in root.iterdir():
        if p.is_dir():
            match = pattern.match(p.name)
            if match:
                max_n = max(max_n, int(match.group(1)))

    next_n = max_n + 1
    next_dir = root / f"{base_name}_{next_n:03d}"
    next_dir.mkdir(parents=True, exist_ok=False)
    return next_dir

# Train
trainer.train()
metrics = trainer.evaluate()
print("eval metrics:", metrics)


# Creating a new version each time 
run_dir = get_next_run_dir("bert_diablo_multilabel", root_dir="models")
print (f"Saving model to {run_dir}")

# Saving the model, tokenizer and map
trainer.save_model(str(run_dir))
tokenizer.save_pretrained(str(run_dir))

with open(run_dir / "label_names.json", "w") as f:
    json.dump(label_names, f, indent=2)

print("Saved:", run_dir)