In [None]:
# notebooks/2_classifier_training.ipynb

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from nlpaug.augmenter.word import SynonymAug
import numpy as np

# Load labeled dataset (assume manual labeling or better sampling)
df = pd.read_csv('../data/labeled_emails.csv')

# Data augmentation
aug = SynonymAug(aug_p=0.3)
def augment_text(text):
    return aug.augment(text)[0] if isinstance(text, str) else text

df['augmented_text'] = df['email_text'].apply(augment_text)
df = pd.concat([df[['email_text', 'label']], df.rename(columns={'augmented_text': 'email_text'})]).reset_index(drop=True)

# Encode labels
le = LabelEncoder()
df['label_enc'] = le.fit_transform(df['label'])

# Dataset class
class EmailDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# Train-test split
from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['email_text'].values, df['label_enc'].values, test_size=0.2, random_state=42
)

# Tokenizer and dataset
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_dataset = EmailDataset(train_texts, train_labels, tokenizer)
val_dataset = EmailDataset(val_texts, val_labels, tokenizer)

# Load model with mixed precision
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=len(le.classes_))

# Metrics function
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}

# Training arguments with optimization
training_args = TrainingArguments(
    output_dir='../models/intent_classifier',
    num_train_epochs=5,  # Increased epochs
    per_device_train_batch_size=8,  # Reduced batch size for stability
    per_device_eval_batch_size=8,
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_dir='../logs',
    logging_steps=10,
    learning_rate=2e-5,  # Tuned learning rate
    fp16=True,  # Mixed precision training
    load_best_model_at_end=True,
    metric_for_best_model='f1'
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Train
trainer.train()

# Save model and tokenizer
model.save_pretrained('../models/intent_classifier')
tokenizer.save_pretrained('../models/intent_classifier')

print("Training Complete & Model Saved!")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,1.6077,1.593578,0.26,0.107302,0.0676,0.26
2,1.5999,1.592909,0.26,0.107302,0.0676,0.26
3,1.5966,1.593558,0.26,0.107302,0.0676,0.26


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Training Complete & Model Saved!
