In [None]:
# !pip install -r ../input/requirements/requirements.txt
# !pip install evaluate
# !pip install rouge_score
# !pip install bert_score

# Imports

In [None]:
import torch
from datasets import Dataset
import matplotlib.pyplot as plt
from pandas import DataFrame
from sklearn.model_selection import train_test_split

from datetime import datetime
from tqdm.auto import tqdm
import shutil
import re
import os

import evaluate as ev
from sentence_transformers import SentenceTransformer, util
from transformers import (
    Seq2SeqTrainingArguments,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    AutoTokenizer,
    GenerationConfig,
)

# Utilities

## Hyper-Parameters

In [None]:
if torch.cuda.is_available():
    print(f"Using device: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU")

TRAIN_MODE = "train_no_optuna"
# TRAIN_MODE = "train_with_optuna"
TEST_MODE = "test_no_optuna"
# TEST_MODE = "test_with_optuna"

DEFAULT_CHECKPOINT = "morenolq/bart-it"

NUM_WORKERS = 4

SAVE_TOTAL_LIMIT = 3
TEST_SIZE = 0.2
BATCH_SIZE = 1
NUM_EPOCHS = 2
GRADIENT_ACCUMULATION_STEPS = 8
EARLY_STOPPING_PATIENCE = 2
WARMUP_PERCENTAGE = 0.1
WEIGHT_DECAY = 0.01 # int the 0 to 0.1 range
BODY_LR = 3e-5
HEAD_LR = 1.5e-4

OPTUNA_TRAIN_TRIALS=2
OPTUNA_TEST_TRIALS=2
PRUNER_WARMUP_STEPS=2

## Paths Settings

In [None]:
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# ==== LOCAL SETTINGS ====
PATH = os.path.join(".", "out", "datasets", "cipv-chats-toxicity", "chats") # , "gen2", "chats"
OUT_DIR = os.path.join(".", "out", "models", "BART", timestamp)

# ==== KAGGLE SETTINGS ====
# PATH = os.path.join(os.sep, "kaggle", "input", "cipv-chats-sentiment")
# OUT_DIR = os.path.join(os.sep, "kaggle", "working", "out")

RESULTS_PATH = os.path.join(OUT_DIR, "results")
os.makedirs(RESULTS_PATH, exist_ok=True)

## Kaggle Specific Utilities

In [None]:
# os.listdir(os.path.join(os.sep, "kaggle", "working"))

In [None]:
# zip_file_path = "/kaggle/working/out-EuclideanLoss-MSELoss-single_sep_False"
# shutil.make_archive(zip_file_path, 'zip', zip_file_path)
# shutil.rmtree(zip_file_path)
# os.remove(zip_file_path + '.zip')

## Plots Utilities

In [None]:
def plot_losses(log_history, out_dir):
    os.makedirs(out_dir, exist_ok=True)

    # Extract logs with 'epoch', 'loss', and 'eval_loss'
    train_logs = [log for log in log_history if 'epoch' in log and 'loss' in log]
    eval_logs = [log for log in log_history if 'epoch' in log and 'eval_loss' in log]

    # Convert to DataFrames for easy grouping
    train_df = DataFrame(train_logs)
    eval_df = DataFrame(eval_logs)

    # Group by epoch and compute mean loss per epoch
    train_epoch_loss = train_df.groupby('epoch')['loss'].mean()
    eval_epoch_loss = eval_df.groupby('epoch')['eval_loss'].mean()

    plt.figure(figsize=(7, 5))
    plt.plot(train_epoch_loss.index, train_epoch_loss.values, 'g-o', label='Train Loss')
    plt.plot(eval_epoch_loss.index, eval_epoch_loss.values, 'c-o', label='Eval Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Losses Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "learning_curve.png"))
    plt.show()
    plt.close()

# Loading the Dataset

In [None]:
def simple_add(dataset, messages, couple_dir):
    all_messages = [msg.group("name_content") for msg in messages]

    input_chat = "\n".join(all_messages)

    dataset['chats'].append(input_chat)
    # use the directory as user_id
    dataset['user_ids'].append(couple_dir)
    dataset['msgs_lengths'].append(len(messages))

def load_dataset(path):
    # msgs_regex = re.compile(r"(?P<message>(?P<timestamp>\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) \|? ?(?P<name_content>(?P<name>.+):\n(?P<content>.+))\n+Polarity: (?P<polarity>[-+]?\d\.?\d?\d?)\n\[(?P<tag_explanation>(?P<tag>Tag: .+)\n?Spiegazione: (?P<explanation>.+))\])")
    msgs_regex = re.compile(r"(?P<message>\(?(?P<timestamp>\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d)\)? ?\|? ?(?P<name_content>(?P<name>.+):\n?\s?(?P<content>.+))\n?\s?Polarity: (?P<polarity>(?:-?|\+?)\d\.?\d?\d?))")
    explanation_regex = re.compile(r"Spiegazione:\n(?P<explanation>(?:.|\n)+)")
    dataset = {
        "chats": [],
        "explanations": [],
        "user_ids": [],
        "msgs_lengths": []
    }
    skipped = 0
    model_dirs = os.listdir(path)
    for model_dir in tqdm(model_dirs, desc="📂 Loading Dataset"):
        model_dir_path = os.path.join(path, model_dir)
        couple_dirs = os.listdir(model_dir_path)
        for couple_dir in couple_dirs: # tqdm(couple_dirs, desc=f"📂 Loading Directory: {model_dir_path}")
            couple_dir_path = os.path.join(model_dir_path, couple_dir)
            files = os.listdir(couple_dir_path)
            for file in files:
                with open(os.path.join(couple_dir_path, file), "r", encoding="utf-8") as f:
                    chat = f.read()
                    messages = list(msgs_regex.finditer(chat))
                    if len(messages) > 0: # checks if there are matched messages
                        match = explanation_regex.search(chat)
                        if match:
                            dataset['explanations'].append(match.group("explanation"))
                        else:
                            print(f"No explanation found in file: {os.path.join(couple_dir_path, file)}")
                            skipped += 1
                            continue
                        simple_add(dataset, messages, couple_dir)
                    else:
                        skipped += 1
                        print(f"No messages found in file: {os.path.join(couple_dir_path, file)}")
                        
    return dataset, skipped

dataset, skipped = load_dataset(PATH)
dataset = Dataset.from_dict(dataset)
print(f"Skipped: {skipped}")

In [None]:
def print_dataset_info(dataset):
    print(dataset)
    # For each field, print the first entry
    for field in dataset.features:
        print(f"{field}: {dataset[0][field]}\n")

print_dataset_info(dataset)

# Pre-Processing the Dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_CHECKPOINT)

def preprocess(examples):
    tokenized_chats = tokenizer(
        examples['chats'],
        # padding='max_length',
        truncation=True,
        max_length=1024,
        # return_tensors='pt',
        text_target=examples['explanations']
    )
    tokenized_chats["user_ids"] = examples["user_ids"]
    tokenized_chats["msgs_lengths"] = examples["msgs_lengths"]
    return tokenized_chats

tokenized_dataset = dataset.map(
    preprocess,
    batched=True,
    # batch_size=1000,
    remove_columns=dataset.column_names
)
# print_dataset_info(tokenized_dataset)
print(tokenized_dataset)

In [None]:
df = DataFrame(tokenized_dataset)

more_than_1024_input_mask = df['input_ids'].apply(lambda x: len(x) > 1024)
more_than_512_input_mask = df['input_ids'].apply(lambda x: len(x) > 512)
more_than_1024_labels_mask = df['labels'].apply(lambda x: len(x) > 1024)
more_than_512_labels_mask = df['labels'].apply(lambda x: len(x) > 512)

# print(f"Tokenized dataset:\n{df.head()}")
print("input_ids token length statistics:")
print(f"Number of samples with more than 1024 tokens: {len(df[more_than_1024_input_mask])}")
print(f"Number of samples with more than 512 tokens: {len(df[more_than_512_input_mask])}")

print("\nlabels token length statistics:")
print(f"Number of samples with more than 1024 tokens: {len(df[more_than_1024_labels_mask])}")
print(f"Number of samples with more than 512 tokens: {len(df[more_than_512_labels_mask])}")

# plots histograms for input_ids and labels with different colors in a single plot
# with semi-transparent bars in order to visualize overlaps
# with 1024 + 1 bins where the last bin is for samples with more than 1024 tokens

input_ds_token_lengths = df['input_ids'].apply(lambda x: len(x))
labels_ds_token_lengths = df['labels'].apply(lambda x: len(x))

input_ds_token_lengths.hist(bins=input_ds_token_lengths.max(), edgecolor='blue', alpha=0.5, label='Input Chats')
labels_ds_token_lengths.hist(bins=labels_ds_token_lengths.max(), edgecolor='orange', alpha=0.5, label='Output Explanations')
plt.xlabel('Number of Tokens')
plt.ylabel('Frequency')
plt.title('Token Length Distribution')
plt.legend()
# plt.savefig(os.path.join(OUT_DIR, "token_and_message_length_distribution.png"))
plt.show()

# min_msgs = df['msgs_lengths'].min()
max_msgs = df['msgs_lengths'].max()
df['msgs_lengths'].hist(
    bins=range(max_msgs + 1),  # +2 so last bin includes max
    edgecolor='black',
    label='Messages Lengths',
    align='left'
)
plt.xlabel('Number of Messages')
plt.ylabel('Frequency')
plt.title('Messages Length Distribution')
plt.legend()
# plt.savefig(os.path.join(OUT_DIR, "messages_length_distribution.png"))
plt.yscale('log')
plt.xticks(range(max_msgs + 1))
plt.show()

# Remove all samples with more than 1024 tokens in input_ids and labels
df = df[~more_than_1024_input_mask & ~more_than_1024_labels_mask]

## Splitting the Dataset

In [None]:
print(f"Users in dataset: {df['user_ids'].nunique()}")
print(f"Dataset size: {len(df)}\n")

df.drop(columns=['msgs_lengths'], inplace=True)

# Split the dataset into train, test, and eval sets using user_ids grouped examples
grouped = df.groupby('user_ids')#.size().reset_index(name='counts')
user_ids = list(grouped.groups.keys())#[:10]
# df = df[df['user_ids'].isin(user_ids)]

train_ids, test_ids = train_test_split(user_ids, test_size=TEST_SIZE, random_state=42)
train_ids, eval_ids = train_test_split(train_ids, test_size=TEST_SIZE, random_state=42)
tokenized_train_set = df[df['user_ids'].isin(train_ids)]
tokenized_test_set = df[df['user_ids'].isin(test_ids)]
tokenized_eval_set = df[df['user_ids'].isin(eval_ids)]

# Prints how many users are in each set
print(f"Users in train set: {tokenized_train_set['user_ids'].nunique()}")
print(f"Users in test set: {tokenized_test_set['user_ids'].nunique()}")
print(f"Users in eval set: {tokenized_eval_set['user_ids'].nunique()}\n")

# Remove the 'user_ids' column from the train, test, and eval sets
tokenized_train_set = tokenized_train_set.drop(columns=['user_ids'])
tokenized_test_set = tokenized_test_set.drop(columns=['user_ids'])
tokenized_eval_set = tokenized_eval_set.drop(columns=['user_ids'])

tokenized_train_set = tokenized_train_set.reset_index(drop=True)
tokenized_test_set = tokenized_test_set.reset_index(drop=True)
tokenized_eval_set = tokenized_eval_set.reset_index(drop=True)

tokenized_train_set = Dataset.from_pandas(tokenized_train_set)
tokenized_test_set = Dataset.from_pandas(tokenized_test_set)
tokenized_eval_set = Dataset.from_pandas(tokenized_eval_set)

print(f"Train set size: {len(tokenized_train_set)}")
print(f"Test set size: {len(tokenized_test_set)}")
print(f"Eval set size: {len(tokenized_eval_set)}\n")

# Set the format to PyTorch tensors
tokenized_train_set.set_format("torch")
tokenized_test_set.set_format("torch")
tokenized_eval_set.set_format("torch")

print_dataset_info(tokenized_train_set)

# Fine-Tuning the Model

In [None]:
def create_compute_metrics_fn(tokenizer, sts_model):    
    # Initialize metrics outside the function to avoid reloading
    rouge_metric = ev.load('rouge')
    bleu_metric = ev.load('bleu')
    bertscore_metric = ev.load('bertscore')
    
    sts_model = SentenceTransformer(sts_model)
    def compute_metrics(eval_pred):
        """
        Compute metrics function for Seq2SeqTrainer.
        
        Args:
            eval_pred: EvalPrediction object containing predictions and label_ids
            tokenizer: The tokenizer used for decoding
            sts_model: Sentence transformer model for computing SBERT similarity
            device: Device to run computations on
        
        Returns:
            Dictionary containing computed metrics
        """
        predictions, labels = eval_pred
        
        # Replace -100 with pad_token_id for decoding
        labels = torch.where(
            torch.tensor(labels) == -100,
            torch.tensor(tokenizer.pad_token_id),
            torch.tensor(labels)
        )
        
        predicted_explanations = tokenizer.batch_decode(
            predictions, 
            skip_special_tokens=True
        )
        true_explanations = tokenizer.batch_decode(
            labels, 
            skip_special_tokens=True
        )
        
        # Compute ROUGE scores
        rouge_results = rouge_metric.compute(
            predictions=predicted_explanations,
            references=true_explanations
        )
        
        # Compute BLEU score
        bleu_results = bleu_metric.compute(
            predictions=predicted_explanations,
            references=[[ref] for ref in true_explanations]
        )
        
        # Compute SBERT similarity
        reference_embeddings = sts_model.encode(true_explanations, convert_to_tensor=True)
        generated_embeddings = sts_model.encode(predicted_explanations, convert_to_tensor=True)
        cosine_scores = util.cos_sim(generated_embeddings, reference_embeddings)
        sbert_similarity = torch.diag(cosine_scores).mean().item()
        
        # Compute BERTScore
        bertscore_results = bertscore_metric.compute(
            predictions=predicted_explanations,
            references=true_explanations,
            lang="it",
        )
        bertscore_f1 = sum(bertscore_results['f1']) / len(bertscore_results['f1'])
        
        return {
            'rouge1': rouge_results['rouge1'],
            'rouge2': rouge_results['rouge2'],
            'rougeL': rouge_results['rougeL'],
            'bleu': bleu_results['bleu'],
            'sbert_similarity': sbert_similarity,
            'bertscore_f1': bertscore_f1
        }
    
    return compute_metrics


In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(
    DEFAULT_CHECKPOINT
).to(DEVICE)

# checkpoint_path = os.path.abspath(os.path.join(
#     ".", "out", "models", "toxicity", "BART", "2025-08-03_17-21-32", "checkpoint-420"
# ))
# model = AutoModelForSeq2SeqLM.from_pretrained(
#     checkpoint_path
# ).to(DEVICE)

gen_config = GenerationConfig(
    max_length=1024,
    do_sample=True,
    top_p=0.95,
    top_k=25,
    temperature=0.6,
    decoder_start_token_id=model.config.decoder_start_token_id,
    bos_token_id=model.config.bos_token_id,
)

# model.tie_weights()

training_args = Seq2SeqTrainingArguments(
    output_dir=OUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=BODY_LR,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_PERCENTAGE,
    fp16=torch.cuda.is_available(),
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    # load_best_model_at_end=True,
    save_total_limit=SAVE_TOTAL_LIMIT,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    predict_with_generate=True,
    generation_config=gen_config,
    dataloader_num_workers=NUM_WORKERS,
    report_to="none"
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    # model=model
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_set,
    eval_dataset=tokenized_eval_set,
    # tokenizer=tokenizer,
    data_collator=data_collator,
    # compute_metrics=create_compute_metrics_fn(
    #     tokenizer=tokenizer,
    #     sts_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
    # )
)

# train_result = trainer.train()
# log_history = trainer.state.log_history
# plot_losses(log_history, RESULTS_PATH)

# Testing the Model

## Inference Example

In [None]:
# def inference(input_text):
#     inputs = tokenizer(
#         input_text,
#         max_length=1024,
#         truncation=True,
#         return_tensors="pt"
#     ).to(DEVICE)
#     # inputs = {k: v.to(DEVICE) for k, v in inputs.items()}  # Move tensors to device
#     outputs = model.generate(**inputs, generation_config=gen_config)
#     decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     return decoded_output

# chat = '''
# Topolino:
# Ei, come va?
# Topolina:
# Ciao Mauro! Tutto bene, grazie! E tu?
# Topolino:
# Tutto ok, grazie! Che fai di bello oggi?
# Topolina:
# Sinceramente non lo so, ho un po' di cose da fare ma non so da dove cominciare. Tu che fai?
# '''

# print(inference(chat))

In [None]:
# for example in tokenized_train_set.select(range(1, 2)):
#     decoded_chat = tokenizer.decode(example['input_ids'], skip_special_tokens=True)
#     decoded_true_explanation = tokenizer.decode(example['labels'], skip_special_tokens=True)

#     output = inference(decoded_chat)

#     print(f"Message:\n{decoded_chat}")
#     print(f"True Explanation:\n{decoded_true_explanation}\n")
#     print(f"Generated Explanation:\n{output}\n")
#     print("\n\n")

## Evaluation

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_set,
    eval_dataset=tokenized_eval_set,
    # tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=create_compute_metrics_fn(
        tokenizer=tokenizer,
        sts_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
    )
)

test_metrics = trainer.evaluate(
    eval_dataset=tokenized_test_set,
    metric_key_prefix="test"
)
for key, value in test_metrics.items():
    print(f"{key}: {value}")
with open(os.path.join(RESULTS_PATH, "test_metrics.txt"), "w") as f:
    for key, value in test_metrics.items():
        f.write(f"{key}: {value}\n")