# 1. Imports & Installs

In [None]:
!pip install transformers datasets evaluate

import pandas as pd
import torch
from datasets import Dataset, load_metric
from transformers import (AutoTokenizer, 
                          AutoModelForSequenceClassification, 
                          TrainingArguments, 
                          Trainer)

# 2. Load Processed Data

In [None]:
train_df = pd.read_csv("../data/processed/train.csv")
test_df  = pd.read_csv("../data/processed/test.csv")

# For demonstration, we will assume there's a 'category_label' that is an integer
# representing each category. We might also do the same for 'priority_label'.
# If we only have text labels like 'HR', 'Finance', 'Support',
# we should map them to integer IDs first.

# 3. Convert to Datasets

In [None]:
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

# 4. Tokenizer

In [None]:
model_name = "bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_fn(example):
    return tokenizer(
        example['full_text'],
        truncation=True,
        padding='max_length',
        max_length=256
    )

train_dataset = train_dataset.map(tokenize_fn, batched=True)
test_dataset  = test_dataset.map(tokenize_fn,  batched=True)

# Let's suppose 'category_label' is already an integer column in our CSV
train_dataset = train_dataset.rename_column("category_label", "labels")
test_dataset  = test_dataset.rename_column("category_label", "labels")

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# 5. Load Pre-trained Model

In [None]:
# Let's say we have X categories
num_labels = len(train_df['category'].unique())

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)


# 6. Training Arguments & Trainer

In [None]:
training_args = TrainingArguments(
    output_dir="../model_output",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=2e-5,
    save_total_limit=2
)

accuracy_metric = load_metric("accuracy")
precision_metric = load_metric("precision")
recall_metric = load_metric("recall")
f1_metric = load_metric("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = torch.argmax(torch.tensor(logits), dim=1)
    acc = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
    prec = precision_metric.compute(predictions=preds, references=labels, average="weighted")["precision"]
    rec = recall_metric.compute(predictions=preds, references=labels, average="weighted")["recall"]
    f1 = f1_metric.compute(predictions=preds, references=labels, average="weighted")["f1"]
    return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics
)

# 7. Train the Model

In [None]:
trainer.train()

# 8. Save Model

In [None]:
model.save_pretrained("../model_output/multilingual_model")
tokenizer.save_pretrained("../model_output/multilingual_model")
print("Model and tokenizer saved to ../model_output/multilingual_model")

## Notes

# 1. If we only want to train on priority, we could do a multi-task approach (two classification heads) or train a second model.
# 2. For Hebrew and English specifically, mBERT or XLM-R typically perform well out-of-the-box.
# 3. For Hebrew specifically, HebrewNLP could be investigated: https://discuss.huggingface.co/t/hebrew-nlp-introduction/4095