In [None]:
%pip install torch transformers accelerate datasets evaluate numpy pandas jupyter scikit-learn

# Imports


In [4]:
import os
import numpy as np
import pandas as pd
import torch
import evaluate
from datetime import datetime
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from sklearn.metrics import (
    hamming_loss, accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score, average_precision_score, jaccard_score
)
import wandb
from wandb.keras import WandbCallback

# Configuration and Setup

In [None]:
labels = [
    "News", "Entertainment", "Shop", "Chat", "Education",
    "Government", "Health", "Technology", "Work", "Travel", "Uncategorized"
]
num_labels = len(labels)
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: i for i, label in enumerate(labels)}

# Initialize wandb
wandb.init(
    project="url-title-classifier",
    name=f"training-run-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
)

# Model and training parameters
MODEL_PATH = "answerdotai/ModernBERT-base"
LEARNING_RATE = 2e-5
BATCH_SIZE = 16
NUM_EPOCHS = 3
UNFREEZE_START_LAYER = 18 #22 layers

# Log config to wandb
wandb.config.update({
    "model_name": MODEL_PATH,
    "learning_rate": LEARNING_RATE,
    "batch_size": BATCH_SIZE,
    "num_epochs": NUM_EPOCHS,
    "num_labels": num_labels,
    "labels": labels,
    "unfreeze_start_layer": UNFREEZE_START_LAYER,
})

# Load Data

In [5]:
df = pd.read_parquet('../data/processed/cleaned_classified_data.parquet')
df = df[[df.columns[0], df.columns[1], df.columns[9]]]
df.columns = ['url', 'title', 'category']

#Create a dataset for training and validation
dataset = Dataset.from_pandas(df)
split_dataset = dataset.train_test_split(test_size=0.1, seed=1)
dataset_dict = DatasetDict({
    'train': split_dataset['train'],
    'validation': split_dataset['test']
})

# Log dataset info
wandb.config.update({
    "train_size": len(dataset_dict['train']),
    "val_size": len(dataset_dict['validation'])
})

# Load Model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_PATH,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    problem_type="multi_label_classification"
)

# Freeze/unfreeze layers as per your implementation
for param in model.base_model.parameters():
    param.requires_grad = False
num_layers = 22  # ModernBERT has 22 layers
for i in range(num_layers - UNFREEZE_START_LAYER, num_layers):
    for param in model.base_model.layers[i].parameters():
        param.requires_grad = True
for param in model.base_model.final_norm.parameters():
    param.requires_grad = True
# Log model architecture
wandb.watch(model, log="all", log_freq=100)

# Data Preprocessing

In [None]:
def preprocess_function(examples):
# This needs to be  [CLS][DOMAIN]{domain}[PATH]{path}[TITLE]{title}[SEP]
tokenized_data = dataset_dict.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Metrics

In [None]:
def compute_metrics(eval_pred):
    logits, true_labels = eval_pred
    probabilities = 1 / (1 + np.exp(-logits))
    predictions = (probabilities >= 0.5).astype(np.int32)
    true_labels_int = true_labels.astype(np.int32)

    # Flatten arrays for micro-averaged metrics
    predictions_flat = predictions.ravel()
    true_labels_flat = true_labels_int.ravel()
    probabilities_flat = probabilities.ravel()

    metrics = {
        "hamming_loss": hamming_loss(true_labels_int, predictions),
        "exact_match": accuracy_score(true_labels_int, predictions),
        
        # Micro metrics
        "precision_micro": precision_score(true_labels_flat, predictions_flat, average='micro'),
        "recall_micro": recall_score(true_labels_flat, predictions_flat, average='micro'),
        "f1_micro": f1_score(true_labels_flat, predictions_flat, average='micro'),
        
        # Macro metrics
        "precision_macro": precision_score(true_labels_int, predictions, average='macro'),
        "recall_macro": recall_score(true_labels_int, predictions, average='macro'),
        "f1_macro": f1_score(true_labels_int, predictions, average='macro'),
        
        # ROC-AUC
        "roc_auc_micro": roc_auc_score(true_labels_flat, probabilities_flat, average='micro'),
        "roc_auc_macro": roc_auc_score(true_labels_int, probabilities, average='macro', multi_class='ovr'),
    }
    
    # Per-label metrics
    for i, label in enumerate(labels):
        metrics[f"f1_{label}"] = f1_score(true_labels_int[:, i], predictions[:, i], zero_division=0)
    
    return metrics

# Training Setup

In [None]:
class WandBCustomCallback:
    def __init__(self):
        self.train_step = 0
        self.eval_step = 0
        
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return
            
        if "loss" in logs:
            wandb.log({
                "train/loss": logs["loss"],
                "train/learning_rate": logs["learning_rate"],
                "train/epoch": logs["epoch"],
                "train/step": self.train_step
            })
            self.train_step += 1
            
        if "eval_loss" in logs:
            metrics_dict = {f"eval/{k}": v for k, v in logs.items() if k.startswith("eval_")}
            wandb.log(metrics_dict)
            self.eval_step += 1

training_args = TrainingArguments(
    output_dir="data/models/URL-TITLE-classifier",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    logging_strategy="steps",
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    report_to="wandb",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[WandBCustomCallback()],
) 

In [None]:
torch.set_float32_matmul_precision('high')
trainer.train()

In [None]:
for i, label in enumerate(labels):
    true_labels = np.array([example["labels"][i] for example in tokenized_data["validation"]])
    predictions = np.array([pred[i] for pred in trainer.predict(tokenized_data["validation"]).predictions])
    predictions = (predictions >= 0.5).astype(int)
    
    cm = wandb.plot.confusion_matrix(
        y_true=true_labels,
        preds=predictions,
        class_names=["Negative", "Positive"],
        title=f"Confusion Matrix - {label}"
    )
    wandb.log({f"confusion_matrix_{label}": cm})
    
example_batch = tokenized_data["validation"][:5]
predictions = trainer.predict(example_batch).predictions
for i, example in enumerate(example_batch):
    pred_probs = 1 / (1 + np.exp(-predictions[i]))
    pred_labels = (pred_probs >= 0.5).astype(int)
    
    wandb.log({
        f"example_{i}/url": example["url"],
        f"example_{i}/title": example["title"],
        f"example_{i}/true_labels": example["labels"],
        f"example_{i}/predicted_labels": pred_labels.tolist(),
        f"example_{i}/prediction_probabilities": pred_probs.tolist()
    })

# Save the model
trainer.save_model()
wandb.save("data/models/URL-TITLE-classifier/*")

In [None]:
wandb.finish()