# 🗣 Ticket Triage with DistilBERT
Fine-tune a DistilBERT model to classify support tickets.

In [None]:
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
import torch
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Simulate dataset
data = {
    "text": [
        "Password reset not working",
        "Cannot connect to VPN",
        "Request refund for subscription",
        "System outage - urgent",
        "Printer is not working again"
    ],
    "label": [0, 1, 2, 1, 1]  # 0: Account, 1: IT, 2: Billing
}
df = pd.DataFrame(data)
train_texts, val_texts, train_labels, val_labels = train_test_split(df.text.tolist(), df.label.tolist(), test_size=0.2)


In [None]:
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)

class TicketDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        return {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} | {'labels': torch.tensor(self.labels[idx])}

    def __len__(self):
        return len(self.labels)

train_dataset = TicketDataset(train_encodings, train_labels)
val_dataset = TicketDataset(val_encodings, val_labels)


In [None]:
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3).to(device)

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    warmup_steps=10,
    evaluation_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()
