# 📦 Importing Packages

In [None]:
import gc
import os
import shutil
from collections import defaultdict
from datetime import datetime

import emoji
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch

from datasets import Dataset
from scipy import stats
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
)
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EarlyStoppingCallback,
    Trainer,
    TrainingArguments,
)

# ⚙️ Global Settings

In [None]:
if os.path.exists("/kaggle"):
    # Clean up the entire /kaggle/working directory
    shutil.rmtree("/kaggle/working", ignore_errors=True)
    os.makedirs("/kaggle/working", exist_ok=True)
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Hyper-Parameters

In [None]:
SELECTED_BERT_MODEL = 'dbmdz/bert-base-italian-cased' #110M
# SELECTED_BERT_MODEL = 'FacebookAI/xlm-roberta-base' # 278M
# SELECTED_BERT_MODEL = 'Musixmatch/umberto-commoncrawl-cased-v1' #110M
DATASET_TYPE = 'toxicity'

TASK = 'multiclass'  # 'binary' or 'multiclass'

if TASK == 'binary':
    POLARITY_BINS = [-1.01, -0.35, 1.01]
    POLARITY_LABELS = [0, 1]  # 0: Toxic, 1: Healthy
    TARGET_NAMES = ['Toxic', 'Healthy']
    NUM_LABELS = 2
else:
    POLARITY_BINS = [-1.01, -0.35, 0.35, 1.01]
    POLARITY_LABELS = [0, 1, 2]
    TARGET_NAMES = ['Toxic', 'Neutral', 'Healthy']
    NUM_LABELS = 3
    cost_mat = np.array([
        [0, 8, 16],
        [8, 0, 1],
        [16, 4, 0]
    ])

WITH_SEP_TOKENS = False
WITH_TOKEN_TYPE_IDS = True

BATCH_SIZE = 32
NUM_EPOCHS = 30
GRADIENT_ACCUMULATION_STEPS = 4
WARMUP_PERCENTAGE = 0.1
NUM_WORKERS = 0
SAVE_TOTAL_LIMIT = 2
EARLY_STOPPING_PATIENCE = 4
# EARLY_STOPPING_THRESHOLD = 0.0005

LR_SCHEDULER_KWARGS = {
    "factor": 0.5,        # Riduce il learning rate del 50% quando non migliora
    "patience": 2,
    # "threshold": EARLY_STOPPING_THRESHOLD,
    "mode": "max"
}

LEARNING_RATE = 3e-5
WEIGHT_DECAY = 0.001
MAX_LENGTH = 512

In [None]:
import random

GLOBAL_SEED = 42
def set_seed(seed):
    """Set seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # The following two lines are for deterministic results on CUDA.
        # They can have a performance impact.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(GLOBAL_SEED)

## Paths Settings

In [None]:
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
MODEL_NAME = SELECTED_BERT_MODEL.replace('/', '-')

if os.path.exists("/kaggle"):
    # ==== KAGGLE SETTINGS ====
    PATH = os.path.join(os.sep, "kaggle", "input", f"cipv-chats-{DATASET_TYPE}", f"cipv-chats-{TASK}-{DATASET_TYPE}.parquet")
    OUT_DIR = os.path.join(os.sep, "kaggle", "working", f"{timestamp}-{MODEL_NAME}-Sep_{WITH_SEP_TOKENS}-Type_{WITH_TOKEN_TYPE_IDS}")
    NESTED_CV_RESULTS_PATH = os.path.join(OUT_DIR, "nested-cv-results", timestamp)
else:
    # ==== LOCAL SETTINGS ====
    PATH = os.path.join(".", "out", "datasets", f"cipv-chats-{TASK}-{DATASET_TYPE}.parquet")
    OUT_DIR = os.path.join(".", 'out', 'models', DATASET_TYPE, f'entire-chat-{TASK}-classification', f'{timestamp}-{MODEL_NAME}-Sep_{WITH_SEP_TOKENS}-Type_{WITH_TOKEN_TYPE_IDS}')
    NESTED_CV_RESULTS_PATH = os.path.join(".", "out", "models", DATASET_TYPE, f"entire-chat-{TASK}-classification", "nested-cv-results", timestamp)

RESULTS_PATH = os.path.join(OUT_DIR, "results")

# 🛠️ Utility Functions

In [None]:
def plot_aggregated_curves(log_histories, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    
    train_losses_by_epoch = defaultdict(list)
    eval_losses_by_epoch = defaultdict(list)
    eval_costs_by_epoch = defaultdict(list)

    for history in log_histories:
        for log in history:
            epoch = log.get('epoch')
            if epoch is None:
                continue

            # Round epoch to handle potential float values like 1.0, 2.0
            epoch = int(round(epoch))
            
            if 'loss' in log:
                train_losses_by_epoch[epoch].append(log['loss'])
            if 'eval_loss' in log:
                eval_losses_by_epoch[epoch].append(log['eval_loss'])
            if 'eval_cost' in log:
                eval_costs_by_epoch[epoch].append(log['eval_cost'])

    # --- Plotting Aggregated Loss Curve ---
    epochs = sorted(eval_losses_by_epoch.keys())
    
    mean_train_loss = [np.mean(train_losses_by_epoch[e]) for e in epochs if e in train_losses_by_epoch]
    std_train_loss = [np.std(train_losses_by_epoch[e]) for e in epochs if e in train_losses_by_epoch]
    
    mean_eval_loss = [np.mean(eval_losses_by_epoch[e]) for e in epochs]
    std_eval_loss = [np.std(eval_losses_by_epoch[e]) for e in epochs]
    
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, mean_eval_loss, 'c-o', label='Mean Eval Loss')
    plt.fill_between(
        epochs,
        np.array(mean_eval_loss) - np.array(std_eval_loss),
        np.array(mean_eval_loss) + np.array(std_eval_loss),
        color='c', alpha=0.2
    )

    # Ensure train epochs align with eval epochs for plotting
    train_epochs_for_plot = [e for e in epochs if e in train_losses_by_epoch]
    plt.plot(train_epochs_for_plot, mean_train_loss, 'g-o', label='Mean Train Loss')
    plt.fill_between(
        train_epochs_for_plot,
        np.array(mean_train_loss) - np.array(std_train_loss),
        np.array(mean_train_loss) + np.array(std_train_loss),
        color='g', alpha=0.2
    )

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Aggregated Learning Curve (Mean ± Std)')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "aggregated_learning_curve.png"))
    plt.show()
    plt.close()
    
    # --- Plotting Aggregated Cost Curve ---
    if TASK == 'multiclass':
        epochs = sorted(eval_costs_by_epoch.keys())
        mean_eval_cost = [np.mean(eval_costs_by_epoch[e]) for e in epochs]
        std_eval_cost = [np.std(eval_costs_by_epoch[e]) for e in epochs]
        
        plt.figure(figsize=(10, 6))
        plt.plot(epochs, mean_eval_cost, 'r-o', label='Mean Eval Cost')
        plt.fill_between(
            epochs,
            np.array(mean_eval_cost) - np.array(std_eval_cost),
            np.array(mean_eval_cost) + np.array(std_eval_cost),
            color='r', alpha=0.2
        )

        min_cost_epoch_idx = np.argmin(mean_eval_cost)
        min_cost_epoch = epochs[min_cost_epoch_idx]
        min_cost_value = mean_eval_cost[min_cost_epoch_idx]
        
        plt.axvline(
            x=min_cost_epoch, color='green', linestyle='--', alpha=0.7,
            label=f'Min Mean Cost at Epoch {min_cost_epoch}'
        )

        plt.xlabel('Epochs')
        plt.ylabel('Cost')
        plt.title('Aggregated Evaluation Cost (Mean ± Std)')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, "aggregated_cost_curve.png"))
        plt.show()
        plt.close()

# 📂 Dataset Loading

In [None]:
def get_formatted_msg(msg):
    return emoji.demojize(f"{msg['user']}:\n{msg['content']}", language='it')

def preprocess(messages):
    all_users = [msg['user'] for msg in messages]
    if WITH_TOKEN_TYPE_IDS:
        user_ids = list(set(all_users))
        all_messages = [ 
            f"[{user_ids.index(all_users[i])}]"
            + get_formatted_msg(msg)
            for i, msg in enumerate(messages)
        ]
    else:
        all_messages = [get_formatted_msg(msg) for msg in messages]

    if WITH_SEP_TOKENS:
        input_chat = '[CLS]' + ("[SEP]" + "\n").join(all_messages)
    else:
        input_chat = "[CLS]" + ("\n").join(all_messages)
    return input_chat

In [None]:
df = pd.read_parquet(PATH)
print(df.info())
df['messages'] = df['messages'].apply(lambda x: preprocess(x))
dataset = Dataset.from_pandas(df)

In [None]:
def print_dataset_info(dataset):
    print(dataset)
    # For each field, print the first entry
    for field in dataset.features:
        print(f"{field}: {dataset[0][field]}\n")

print_dataset_info(dataset)

# 🪄 Dataset Preprocessing

In [None]:
tokenizer = AutoTokenizer.from_pretrained(SELECTED_BERT_MODEL)
if WITH_TOKEN_TYPE_IDS:
    tokenizer.add_special_tokens({
        "additional_special_tokens": ["[0]", "[1]"]
    })
    id_0 = tokenizer.convert_tokens_to_ids("[0]")
    id_1 = tokenizer.convert_tokens_to_ids("[1]")
    print(f"Special tokens added: {id_0} for [0], {id_1} for [1]")

def preprocess(examples):
    tokenized_chats = tokenizer(
        examples['messages'],
        add_special_tokens=False # Skip special tokens in the target text
    )
    tokenized_chats["labels"] = examples["labels"]
    tokenized_chats["couple_ids"] = examples["couple_ids"]
    # tokenized_chats["msgs_lengths"] = examples["msgs_lengths"]

    if WITH_TOKEN_TYPE_IDS:
        for i, sample in enumerate(tokenized_chats['input_ids']):
            new_input_ids = []
            new_attention_mask = []
            new_token_type_ids = []
            tkn_type = 0  # Default token type
            
            for j, tkn in enumerate(sample):
                if tkn == id_0:
                    tkn_type = 0
                    # Skip adding this special token to new lists
                    continue
                elif tkn == id_1:
                    tkn_type = 1
                    # Skip adding this special token to new lists
                    continue
                else:
                    # Add token with current token type
                    new_input_ids.append(tkn)
                    new_attention_mask.append(tokenized_chats['attention_mask'][i][j])
                    new_token_type_ids.append(tkn_type)
            
            # Replace the original lists with the new ones
            tokenized_chats['input_ids'][i] = new_input_ids
            tokenized_chats['attention_mask'][i] = new_attention_mask
            tokenized_chats['token_type_ids'][i] = new_token_type_ids

            if len(new_input_ids) > MAX_LENGTH:
                raise ValueError(
                    f"Chat length exceeds MAX_LENGTH = {MAX_LENGTH}.\n"
                    f"Chat content:\n{new_input_ids}"
                )
             
    return tokenized_chats

tokenized_dataset = dataset.map(
    preprocess,
    batched=True,
    batch_size=BATCH_SIZE,
    remove_columns=dataset.column_names
)
print_dataset_info(tokenized_dataset)
# print(tokenized_dataset)

# remove the special tokens from the tokenizer
tokenizer.add_special_tokens({
    "additional_special_tokens": []
})

# 📄​ Methodology

In [None]:
def confidence_interval(scores, confidence_level=0.95):
    """
    Computes the confidence interval for a given performance metric.

    This function is useful for understanding the reliability of a single model's 
    mean performance score from cross-validation.

    Args:
        scores (list or np.ndarray): A list of scores from cross-validation folds.
        confidence_level (float): The desired confidence level (e.g., 0.95 for 95%).

    Returns:
        tuple: A tuple containing the mean score, and the lower and upper bounds 
               of the confidence interval (mean, lower_bound, upper_bound).
    """
    n = len(scores)
    if n <= 1:
        # Cannot compute CI for 1 or 0 scores, return mean and NaN for bounds
        return (np.mean(scores), np.nan, np.nan)
        
    mean_score = np.mean(scores)
    # Standard Error of the Mean (SEM) = Sn / sqrt(n)
    # where Sn is the standard deviation of the scores
    std_err = stats.sem(scores)
    
    # Degrees of freedom
    dof = n - 1
    
    # Get the critical value from the t-distribution
    t_critical = stats.t.ppf((1 + confidence_level) / 2., dof)
    
    margin_of_error = t_critical * std_err
    
    lower_bound = mean_score - margin_of_error
    upper_bound = mean_score + margin_of_error
    
    return (mean_score, lower_bound, upper_bound)

def print_and_save_classification_report_conf_intervals(cv_results, save_path, confidence=0.95):
    with open(os.path.join(save_path, "classification_report_with_cv.txt"), "w", encoding="utf-8") as f:
        f.write(f"=== Cross-Validation Results (Mean ± Std [{confidence * 100:.0f}% CI]) ===\n\n")
        print(f"=== Cross-Validation Results (Mean ± Std [{confidence * 100:.0f}% CI]) ===\n")

        # Create the classification report format
        report_lines = []
        
        # Header
        header = f"{'':>14} {'precision':>27} {'recall':>27} {'f1-score':>27}"
        report_lines.append(header)
        report_lines.append("")
        
        # Per-class metrics
        for i, name in enumerate(TARGET_NAMES):
            name_lower = name.lower()
            
            # Calculate confidence intervals for each metric
            precision_scores = cv_results[f'test_precision_{name_lower}']
            precision_mean, precision_lower, precision_upper = confidence_interval(precision_scores, confidence)
            precision_std = np.std(precision_scores)

            recall_scores = cv_results[f'test_recall_{name_lower}']
            recall_mean, recall_lower, recall_upper = confidence_interval(recall_scores, confidence)
            recall_std = np.std(recall_scores)

            f1_scores = cv_results[f'test_f1_{name_lower}']
            f1_mean, f1_lower, f1_upper = confidence_interval(f1_scores, confidence)
            f1_std = np.std(f1_scores)

            # Format with confidence intervals
            precision_ci = f"{precision_mean:.2f} ± {precision_std:.2f} [{precision_lower:.2f}, {precision_upper:.2f}]"
            recall_ci = f"{recall_mean:.2f} ± {recall_std:.2f} [{recall_lower:.2f}, {recall_upper:.2f}]"
            f1_ci = f"{f1_mean:.2f} ± {f1_std:.2f} [{f1_lower:.2f}, {f1_upper:.2f}]"

            line = f"{name:>14} {precision_ci:>27} {recall_ci:>27} {f1_ci:>27}"
            report_lines.append(line)
        
        report_lines.append("")

        # Accuracy
        accuracy_scores = cv_results['test_accuracy']
        accuracy_mean, accuracy_lower, accuracy_upper = confidence_interval(accuracy_scores, confidence)
        accuracy_std = np.std(accuracy_scores)
        accuracy_ci = f"{accuracy_mean:.2f} ± {accuracy_std:.2f} [{accuracy_lower:.2f}, {accuracy_upper:.2f}]"
        line = f"{'accuracy':>14} {'':>27} {'':>27} {accuracy_ci:>27}"
        report_lines.append(line)
        
        # Macro and weighted averages
        for avg_type in ['macro', 'weighted']:
            precision_scores = cv_results[f'test_precision_{avg_type}']
            precision_mean, precision_lower, precision_upper = confidence_interval(precision_scores, confidence)
            precision_std = np.std(precision_scores)

            recall_scores = cv_results[f'test_recall_{avg_type}']
            recall_mean, recall_lower, recall_upper = confidence_interval(recall_scores, confidence)
            recall_std = np.std(recall_scores)

            f1_scores = cv_results[f'test_f1_{avg_type}']
            f1_mean, f1_lower, f1_upper = confidence_interval(f1_scores, confidence)
            f1_std = np.std(f1_scores)

            # Format with confidence intervals
            precision_ci = f"{precision_mean:.2f} ± {precision_std:.2f} [{precision_lower:.2f}, {precision_upper:.2f}]"
            recall_ci = f"{recall_mean:.2f} ± {recall_std:.2f} [{recall_lower:.2f}, {recall_upper:.2f}]"
            f1_ci = f"{f1_mean:.2f} ± {f1_std:.2f} [{f1_lower:.2f}, {f1_upper:.2f}]"

            line = f"{avg_type + ' avg':>14} {precision_ci:>27} {recall_ci:>27} {f1_ci:>27}"
            report_lines.append(line)
        
        # Cost (if multiclass)
        if TASK == 'multiclass':
            cost_scores = cv_results['test_cost']
            cost_mean, cost_lower, cost_upper = confidence_interval(cost_scores, confidence)
            cost_std = np.std(cost_scores)
            cost_ci = f"{cost_mean:.2f} ± {cost_std:.2f} [{cost_lower:.2f}, {cost_upper:.2f}]"
            report_lines.append("")
            report_lines.append(f"Total Cost: {cost_ci}")
        
        # Write to file and print
        report_text = "\n".join(report_lines)
        f.write(report_text)
        print(report_text)

def plot_confusion_matrices(cms, classes, path=None):
    """
    Plots a single confusion matrix showing mean ± standard deviation for each cell.
    
    Args:
        cv_results: Cross-validation results containing confusion_matrix scores
        classes: List of class names
        path: Path to save the plot
    """
    n_classes = len(classes)

    # Calculate mean and std for each cell
    cm_mean = np.mean(cms, axis=0)
    cm_std = np.std(cms, axis=0)
    
    # Create the plot
    plt.figure(figsize=(7, 5))
    
    # Create annotations with mean ± std format
    annotations = np.empty_like(cm_mean, dtype=object)
    for i in range(n_classes):
        for j in range(n_classes):
            annotations[i, j] = f'{cm_mean[i, j]:.1f} ± {cm_std[i, j]:.2f}'
    
    sns.heatmap(
        cm_mean, 
        annot=annotations, 
        fmt='', 
        cmap=sns.color_palette("ch:s=-.2,r=.6", as_cmap=True),
        xticklabels=classes, 
        yticklabels=classes,
        cbar_kws={'label': 'Mean Count'}
    )
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.title('Confusion Matrix (Mean ± Std)')
    plt.tight_layout()
    
    if path:
        plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()


# ⚙️ Training Settings

In [None]:
def calculate_total_cost(y_true, y_pred):
    """
    Calculates the total cost of predictions using a cost matrix.
    
    Args:
        y_true: True labels
        y_pred: Predicted labels  
        cost_mat: Cost matrix where cost_mat[i,j] is the cost of 
                 predicting class j when true class is i
    
    Returns:
        Total cost (scalar)
    """
    # Generate labels that match cost matrix dimensions
    # Assumes labels are 0, 1, 2, ..., num_classes-1
    num_classes = cost_mat.shape[0]
    labels = np.arange(num_classes)
    
    # Get confusion matrix with all possible labels
    conf_mat = confusion_matrix(y_true, y_pred, labels=labels)
    
    # Calculate total cost
    return np.sum(conf_mat * cost_mat)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    f1_weighted = f1_score(labels, predictions, average='weighted', zero_division=0)
    precision = precision_score(labels, predictions, average='weighted', zero_division=0)
    recall = recall_score(labels, predictions, average='weighted', zero_division=0)
    accuracy = accuracy_score(labels, predictions)
    
    result = {
        'accuracy': torch.tensor([accuracy]),
        'precision': torch.tensor([precision]),
        'recall': torch.tensor([recall]),
        'f1_weighted': torch.tensor([f1_weighted]),
    }
    
    if TASK == 'multiclass':
        cost = calculate_total_cost(labels, predictions)
        result['cost'] = torch.tensor([cost])

    return result

In [None]:
def get_trainer(tokenized_train_set, tokenized_eval_set, out_dir):

    # free all unused occupied RAM and GPU memory
    torch.cuda.empty_cache()
    gc.collect()

    training_args = TrainingArguments(
        output_dir=out_dir,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,

        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_ratio=WARMUP_PERCENTAGE,
        
        # lr_scheduler_type="linear",
        lr_scheduler_type="reduce_lr_on_plateau",
        lr_scheduler_kwargs=LR_SCHEDULER_KWARGS,

        eval_strategy="epoch",
        save_strategy="epoch",
        
        report_to="none",
        logging_strategy="epoch",
        
        load_best_model_at_end=True,

        metric_for_best_model="cost" if TASK == "multiclass" else "f1_weighted",
        greater_is_better=False if TASK == "multiclass" else True,
        save_total_limit=SAVE_TOTAL_LIMIT,

        dataloader_num_workers=NUM_WORKERS,
        fp16=torch.cuda.is_available(),
        group_by_length=True,
        remove_unused_columns=True,

        seed=GLOBAL_SEED,
        data_seed=GLOBAL_SEED
    )

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    model = AutoModelForSequenceClassification.from_pretrained(
        SELECTED_BERT_MODEL,
        num_labels=NUM_LABELS,
        problem_type="single_label_classification"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train_set,
        eval_dataset=tokenized_eval_set,
        # tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(
            early_stopping_patience=EARLY_STOPPING_PATIENCE,
            # early_stopping_threshold=EARLY_STOPPING_THRESHOLD
        )]
    )

    return trainer

# 🤖 Fine-Tuning BERT

In [None]:
os.makedirs(RESULTS_PATH, exist_ok=True)
os.makedirs(NESTED_CV_RESULTS_PATH, exist_ok=True)

outer_cv = GroupKFold(n_splits=5) # , shuffle=True, random_state=GLOBAL_SEED
confusion_matrices = []
reports = []
all_log_histories = []
df = tokenized_dataset.to_pandas()
for fold, (train_idx, test_idx) in enumerate(outer_cv.split(
    X=df[['input_ids', 'token_type_ids', 'attention_mask']],
    y=df['labels'],
    groups=df['couple_ids']
)):
    print(f"Starting fold {fold + 1}/{outer_cv.get_n_splits()}")

    # Subset the datasets for the current fold
    fold_train_dataset = tokenized_dataset.select(train_idx)
    fold_test_dataset = tokenized_dataset.select(test_idx)

    # Further split fold_train_dataset into training and validation sets
    gss_val = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=GLOBAL_SEED)
    train_idx, eval_idx = next(gss_val.split(
        X=fold_train_dataset,
        y=fold_train_dataset['labels'],
        groups=fold_train_dataset['couple_ids']
    ))
    fold_eval_dataset = fold_train_dataset.select(eval_idx)
    fold_train_dataset = fold_train_dataset.select(train_idx)

    # remove all couple_ids columns
    fold_train_dataset = fold_train_dataset.remove_columns(['couple_ids'])
    fold_eval_dataset = fold_eval_dataset.remove_columns(['couple_ids'])
    fold_test_dataset = fold_test_dataset.remove_columns(['couple_ids'])

    # Set the format to PyTorch tensors
    fold_train_dataset.set_format("torch")
    fold_eval_dataset.set_format("torch")
    fold_test_dataset.set_format("torch")

    # Define a unique output directory for this fold
    fold_output_dir = os.path.join(OUT_DIR, f"fold_{fold+1}")
    trainer = get_trainer(fold_train_dataset, fold_eval_dataset, fold_output_dir)
    trainer.train()

    # Evaluate on the test set
    predictions_output = trainer.predict(fold_test_dataset)
    preds = np.argmax(predictions_output.predictions, axis=1)
    labels = predictions_output.label_ids
    report = classification_report(
        labels, preds, zero_division=0,
        target_names=TARGET_NAMES,
        output_dict=True
    )
    if TASK == 'multiclass':
        report['cost'] = calculate_total_cost(labels, preds)

    confusion_matrices.append(confusion_matrix(labels, preds))
    reports.append(report)
    all_log_histories.append(trainer.state.log_history)

    print(f"Cleaning up checkpoint directory: {fold_output_dir}")
    shutil.rmtree(fold_output_dir, ignore_errors=True)

# create cv_results to be passed to print_and_save_classification_report_conf_intervals from reports
cv_results = defaultdict(list)
for report in reports:
    for key, value in report.items():
        if key in TARGET_NAMES:
            key_lower = key.lower()
            cv_results[f'test_precision_{key_lower}'].append(value['precision'])
            cv_results[f'test_recall_{key_lower}'].append(value['recall'])
            cv_results[f'test_f1_{key_lower}'].append(value['f1-score'])
        elif key in ['macro avg', 'weighted avg']:
            avg_type = key.split()[0]  # 'macro' or 'weighted'
            avg_type_lower = avg_type.lower()
            cv_results[f'test_precision_{avg_type_lower}'].append(value['precision'])
            cv_results[f'test_recall_{avg_type_lower}'].append(value['recall'])
            cv_results[f'test_f1_{avg_type_lower}'].append(value['f1-score'])
        elif key in ['accuracy', 'cost']:
            cv_results['test_' + key].append(value)
            
cv_results = pd.DataFrame(cv_results)
cv_results.to_csv(os.path.join(NESTED_CV_RESULTS_PATH, "BERT.csv"), index=False)
print_and_save_classification_report_conf_intervals(cv_results, RESULTS_PATH, confidence=0.95)
plot_confusion_matrices(confusion_matrices, TARGET_NAMES, path=os.path.join(RESULTS_PATH, "confusion_matrix_cv.png"))

In [None]:
plot_aggregated_curves(all_log_histories, RESULTS_PATH)

In [None]:
def choose_best_nepochs(all_log_histories):
    """
    Determines the optimal number of epochs based on the task type.
    
    For binary classification: chooses epoch with maximum mean f1_weighted
    For multiclass classification: chooses epoch with minimum mean cost
    
    Args:
        all_log_histories: List of training log histories from all CV folds
        
    Returns:
        int: Optimal number of epochs
    """
    if TASK == 'multiclass':
        # For multiclass, minimize cost
        eval_costs_by_epoch = defaultdict(list)
        for history in all_log_histories:
            for log in history:
                if 'eval_cost' in log and log.get('epoch') is not None:
                    epoch = int(round(log['epoch']))
                    eval_costs_by_epoch[epoch].append(log['eval_cost'])
            
        epochs = sorted(eval_costs_by_epoch.keys())
        mean_eval_cost = [np.mean(eval_costs_by_epoch[e]) for e in epochs]
        optimal_epochs = epochs[np.argmin(mean_eval_cost)]
        
        print(f"Optimal number of epochs (min cost): {optimal_epochs}")
        print(f"Mean cost at optimal epoch: {min(mean_eval_cost):.4f}")
        
    else:  # binary classification
        # For binary, maximize f1_weighted
        eval_f1_by_epoch = defaultdict(list)
        for history in all_log_histories:
            for log in history:
                if 'eval_f1_weighted' in log and log.get('epoch') is not None:
                    epoch = int(round(log['epoch']))
                    eval_f1_by_epoch[epoch].append(log['eval_f1_weighted'])
            
        epochs = sorted(eval_f1_by_epoch.keys())
        mean_eval_f1 = [np.mean(eval_f1_by_epoch[e]) for e in epochs]
        optimal_epochs = epochs[np.argmax(mean_eval_f1)]
        
        print(f"Optimal number of epochs (max f1_weighted): {optimal_epochs}")
        print(f"Mean f1_weighted at optimal epoch: {max(mean_eval_f1):.4f}")
    
    return optimal_epochs

In [None]:
# 1. Determine Optimal Number of Epochs
optimal_epochs = choose_best_nepochs(all_log_histories)

# 2. Prepare Full Dataset
final_train_dataset = tokenized_dataset.remove_columns(['couple_ids'])
final_train_dataset.set_format("torch")

# 3. Configure Trainer for Final Run
FINAL_MODEL_DIR = os.path.join(OUT_DIR, "final_production_model")
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)

final_training_args = TrainingArguments(
    output_dir=FINAL_MODEL_DIR,
    num_train_epochs=optimal_epochs, # Train for the optimal number of epochs
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_PERCENTAGE,
    logging_strategy="epoch",
    save_strategy="no", # We will save manually at the end
    report_to="none",
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=NUM_WORKERS,
    seed=GLOBAL_SEED,
    data_seed=GLOBAL_SEED
)

final_model = AutoModelForSequenceClassification.from_pretrained(
    SELECTED_BERT_MODEL,
    num_labels=NUM_LABELS,
    problem_type="single_label_classification"
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

final_trainer = Trainer(
    model=final_model,
    args=final_training_args,
    train_dataset=final_train_dataset,
    # No eval_dataset needed for the final run
    data_collator=data_collator,
)

final_trainer.train()

final_trainer.save_model(FINAL_MODEL_DIR)
tokenizer.save_pretrained(FINAL_MODEL_DIR)

In [None]:
if os.path.exists("/kaggle"):
    shutil.make_archive(OUT_DIR, 'zip', OUT_DIR)
    # shutil.rmtree(OUT_DIR)
    # os.remove(OUT_DIR + '.zip')