In [18]:
# !pip install numpy==1.26.4
# !pip install evaluate

# Imports

In [1]:
import torch
from torch.utils.data import DataLoader
from datasets import Dataset
import matplotlib.pyplot as plt
from pandas import DataFrame
from sklearn.model_selection import train_test_split

from datetime import datetime
from tqdm.auto import tqdm
import traceback
import itertools
import shutil
import optuna
import gc
import re
import os

from multitask_bart.single_target import (
    BartWithRegression,
    MultiTaskBartDataCollator,
    TrainingArguments,
    Trainer
)

from multitask_bart.losses import (
    EuclideanLoss,
    StaticWeightedLoss,
    UncertaintyLoss
)

# Utilities

In [2]:
SEP_TOKEN = BartWithRegression.SEP_TOKEN
TARGET_TOKEN = BartWithRegression.TARGET_TOKEN
PATH = os.path.join(".", "out", "datasets", "toxicity", "gen1", "chats")
if torch.cuda.is_available():
    print(f"Using device: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU")

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
OUT_DIR = os.path.join(".", "out", "models", timestamp)
TRAIN_MODE = "train_no_optuna"
# TRAIN_MODE = "train_with_optuna"
TEST_MODE = "test_no_optuna"
# TEST_MODE = "test_with_optuna"

NUM_WORKERS = 4

BATCH_SIZE = 3
NUM_EPOCHS = 2
GRADIENT_ACCUMULATION_STEPS = 4
EARLY_STOPPING_PATIENCE = 2
WARMUP_PERCENTAGE = 0.1
WEIGHT_DECAY = 0.01 # int the 0 to 0.1 range
BODY_LR = 3e-5
HEAD_LR = 1.5e-4

OPTUNA_TRAIN_TRIALS=2
OPTUNA_TEST_TRIALS=2
PRUNER_WARMUP_STEPS=2

Using device: NVIDIA GeForce RTX 3060 Laptop GPU


# Loading and Preparing the Dataset

## Loading and Pre-Processing the Dataset

In [3]:
def add_all_consecutive_subsets(dataset, messages, couple_dir):
    all_messages = [msg.group("name_content") for msg in messages]
    for target_idx, msg in enumerate(messages, start=0):
        for pre_context in range(target_idx + 1):
            for post_context in range(target_idx + 1, len(messages) + 1):
                # Re-assemble the context for each example
                context_parts = all_messages[target_idx-pre_context: target_idx] + \
                                [all_messages[target_idx] + TARGET_TOKEN] + \
                                all_messages[target_idx + 1: post_context]
                input_chat = (SEP_TOKEN + "\n").join(context_parts) + SEP_TOKEN

                dataset['chats'].append(input_chat)
                dataset['polarities'].append(float(msg.group("polarity")))
                dataset['explanations'].append(msg.group("explanation"))
                # use the directory as user_id
                dataset['user_ids'].append(couple_dir)

# def add_(dataset, messages, couple_dir):
    # all_messages = [msg.group("name_content") for msg in messages]
    # for target_idx, msg in enumerate(messages, start=0):
    #     for idx in range(target_idx, len(messages)):
    #         context_parts = all_messages[:target_idx] + \
    #                         [all_messages[target_idx] + TARGET_TOKEN] + \
    #                         all_messages[target_idx + 1 : idx + 1]
    #         input_chat = (SEP_TOKEN + "\n").join(context_parts) + SEP_TOKEN
            
    #         dataset['chats'].append(input_chat)
    #         dataset['polarities'].append(float(msg.group("polarity")))
    #         dataset['explanations'].append(msg.group("explanation"))
    #         dataset['user_ids'].append(couple_dir)

def simple_add(dataset, messages, couple_dir):
    all_messages = [msg.group("name_content") for msg in messages]
    for target_idx, msg in enumerate(messages, start=0):
        context_parts = all_messages[:target_idx] + \
                        [all_messages[target_idx] + TARGET_TOKEN] + \
                        all_messages[target_idx + 1:]
        input_chat = (SEP_TOKEN + "\n").join(context_parts) + SEP_TOKEN
        
        dataset['chats'].append(input_chat)
        dataset['polarities'].append(float(msg.group("polarity")))
        dataset['explanations'].append(msg.group("explanation"))
        # use the directory as user_id
        dataset['user_ids'].append(couple_dir)

def load_dataset(path):
    msgs_regex = re.compile(r"(?P<message>(?P<timestamp>\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d) \|? ?(?P<name_content>(?P<name>.+):\n(?P<content>.+))\n+Polarity: (?P<polarity>[-+]?\d\.?\d?\d?)\n\[(?P<tag_explanation>(?P<tag>Tag: .+)\n?Spiegazione: (?P<explanation>.+))\])")
    dirs = os.listdir(path)
    dataset = {
        "chats": [],
        "explanations": [],
        "polarities": [],
        "user_ids": []
    }
    skipped = 0
    model_dirs = os.listdir(path)
    for model_dir in tqdm(model_dirs, desc="📂 Loading Dataset"):
        model_dir_path = os.path.join(path, model_dir)
        couple_dirs = os.listdir(model_dir_path)
        for couple_dir in tqdm(couple_dirs, desc=f"📂 Loading Directory: {model_dir_path}"):
            couple_dir_path = os.path.join(model_dir_path, couple_dir)
            files = os.listdir(couple_dir_path)
            for file in files:
                with open(os.path.join(couple_dir_path, file), "r", encoding="utf-8") as f:
                    chat = f.read()
                    messages = list(msgs_regex.finditer(chat))
                    if len(messages) > 0: # checks if there are matched messages
                        simple_add(dataset, messages, couple_dir)
                    else:
                        skipped += 1
                        # print(f"No messages found in file: {os.path.join(path, directory, file)}")
    return dataset, skipped

dataset, skipped = load_dataset(PATH)
dataset = Dataset.from_dict(dataset)
print(f"Skipped: {skipped}")

📂 Loading Dataset:   0%|          | 0/6 [00:00<?, ?it/s]

📂 Loading Directory: .\out\datasets\toxicity\gen1\chats\gemini-2.0-flash-dataset_2025-07-09-15-56-22:   0%|   …

📂 Loading Directory: .\out\datasets\toxicity\gen1\chats\gemini-2.5-flash-dataset_2025-07-07-10-45-16:   0%|   …

📂 Loading Directory: .\out\datasets\toxicity\gen1\chats\gemini-2.5-flash-dataset_2025-07-08-20-06-22:   0%|   …

📂 Loading Directory: .\out\datasets\toxicity\gen1\chats\gemini-2.5-flash-dataset_2025-07-16-09-07-33:   0%|   …

📂 Loading Directory: .\out\datasets\toxicity\gen1\chats\gemini-2.5-flash-dataset_2025-07-19-14-42-59:   0%|   …

📂 Loading Directory: .\out\datasets\toxicity\gen1\chats\gemini-2.5-flash-lite-preview-06-17-dataset_2025-07-09…

Skipped: 201


In [188]:
def print_dataset_info(dataset):
    print(dataset)
    # For each field, print the first entry
    for field in dataset.features:
        print(f"{field}: {dataset[0][field]}\n")

print_dataset_info(dataset)

Dataset({
    features: ['chats', 'explanations', 'polarities', 'user_ids'],
    num_rows: 154770
})
chats: Sofia:
Ciao Marco, come stai?[TARGET][SEP]

explanations: Inizio cortese, ma con una certa formalità.

polarities: 0.8

user_ids: 2025-07-09-15-56-39



In [23]:
tokenizer = BartWithRegression.get_tokenizer()
def preprocess(examples):
    tokenized_chats = tokenizer(
        examples['chats'],
        # padding='max_length',
        truncation=True,
        max_length=1024,
        # return_tensors='pt',
        text_target=examples['explanations']
    )
    tokenized_chats["polarities"] = examples["polarities"]
    tokenized_chats["user_ids"] = examples["user_ids"]
    return tokenized_chats

In [None]:
tokenized_dataset = dataset.map(
    preprocess,
    batched=True,
    # batch_size=1000,
    remove_columns=dataset.column_names
)

Map:   0%|          | 0/12704 [00:00<?, ? examples/s]

## Splitting the Dataset

In [27]:
TEST_SIZE = 0.2

tokenized_dataset = DataFrame(tokenized_dataset.to_dict())
print(f"Users in dataset: {tokenized_dataset['user_ids'].nunique()}")
print(f"Dataset size: {len(tokenized_dataset)}\n")

# Split the dataset into train, test, and eval sets using user_ids grouped examples
grouped = tokenized_dataset.groupby('user_ids')#.size().reset_index(name='counts')
user_ids = list(grouped.groups.keys())
user_ids = list(grouped.groups.keys())[:10]
tokenized_dataset = tokenized_dataset[tokenized_dataset['user_ids'].isin(user_ids)]

train_ids, test_ids = train_test_split(user_ids, test_size=TEST_SIZE, random_state=42)
train_ids, eval_ids = train_test_split(train_ids, test_size=TEST_SIZE, random_state=42)
tokenized_train_set = tokenized_dataset[tokenized_dataset['user_ids'].isin(train_ids)]
tokenized_test_set = tokenized_dataset[tokenized_dataset['user_ids'].isin(test_ids)]
tokenized_eval_set = tokenized_dataset[tokenized_dataset['user_ids'].isin(eval_ids)]

# Prints how many users are in each set
print(f"Users in train set: {tokenized_train_set['user_ids'].nunique()}")
print(f"Users in test set: {tokenized_test_set['user_ids'].nunique()}")
print(f"Users in eval set: {tokenized_eval_set['user_ids'].nunique()}\n")

# Remove the 'user_ids' column from the train, test, and eval sets
tokenized_train_set = tokenized_train_set.drop(columns=['user_ids'])
tokenized_test_set = tokenized_test_set.drop(columns=['user_ids'])
tokenized_eval_set = tokenized_eval_set.drop(columns=['user_ids'])

tokenized_train_set = tokenized_train_set.reset_index(drop=True)
tokenized_test_set = tokenized_test_set.reset_index(drop=True)
tokenized_eval_set = tokenized_eval_set.reset_index(drop=True)

tokenized_train_set = Dataset.from_pandas(tokenized_train_set)
tokenized_test_set = Dataset.from_pandas(tokenized_test_set)
tokenized_eval_set = Dataset.from_pandas(tokenized_eval_set)

print(f"Train set size: {len(tokenized_train_set)}")
print(f"Test set size: {len(tokenized_test_set)}")
print(f"Eval set size: {len(tokenized_eval_set)}\n")

# Set the format to PyTorch tensors
tokenized_train_set.set_format("torch")
tokenized_test_set.set_format("torch")
tokenized_eval_set.set_format("torch")

print_dataset_info(tokenized_train_set)

data_collator = MultiTaskBartDataCollator(tokenizer=tokenizer)

train_dataloader = DataLoader(
    tokenized_train_set,
    batch_size=BATCH_SIZE, 
    shuffle=True,
    collate_fn=data_collator,
    # Parallelize data loading so the GPU
    # doesn't have to wait for the CPU
    # to prepare the next batch.
    num_workers=4
)
test_dataloader = DataLoader(
    tokenized_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=data_collator,
    num_workers=NUM_WORKERS
)
eval_dataloader = DataLoader(
    tokenized_eval_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=data_collator,
    num_workers=NUM_WORKERS
)

Users in dataset: 10
Dataset size: 2224

Users in train set: 6
Users in test set: 2
Users in eval set: 2

Train set size: 1340
Test set size: 402
Eval set size: 482

Dataset({
    features: ['polarities', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1340
})
polarities: 0.5

input_ids: tensor([    0, 25522, 12963,    30,   203, 27273,    86,    69,    16,  5136,
        13546,   637,   368,  1190,  1777,    18,  2576,  2485,  1296,   458,
         1526,  6081,  5937,   300,   311,  1233,  1705,   339,   889,   287,
         9724,    18, 49613,   301,  1798,  2948, 10329,   266,  5754,   384,
          429,   437,  8559,   765,  2230,  9745,   300,  1494,    18, 52001,
        52000, 27273,    86,    69, 23562,    30,   203, 33011,    16, 14795,
           16,  3046,    18,  3806,    81,  1955,   377,  2276,  1004,    16,
        26044, 28501, 20336,    18,  1474, 16359,  1812,   312,   473,   676,
          975,  1004,  7905,  4687,    18,   488,   384,   710, 26151,   300,
 

# Training the Model

## Setup

In [None]:
def plot_general_learning_curve(log_history, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    train_losses = log_history['total_train_losses']
    eval_losses = log_history['total_eval_losses']
    epochs = log_history['epochs']
    plt.figure(figsize=(7, 5))
    plt.plot(epochs, train_losses, 'b-o', label='Training Losses')
    plt.plot(epochs, eval_losses, 'r-o', label='Validation Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Losses')
    plt.title('Learning Curve Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(out_dir, "general_learning_curve.png"))
    plt.show()
    plt.close()

def plot_reg_learning_curve(log_history, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    reg_train_losses = log_history['reg_train_losses']
    reg_eval_losses = log_history['reg_eval_losses']
    epochs = log_history['epochs']
    plt.figure(figsize=(7, 5))
    plt.plot(epochs, reg_train_losses, 'g-o', label='Regression Training Losses')
    plt.plot(epochs, reg_eval_losses, 'm-o', label='Regression Validation Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Losses')
    plt.title('Regression Learning Curve Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(out_dir, "reg_learning_curve.png"))
    plt.show()
    plt.close()

def plot_gen_learning_curve(log_history, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    gen_train_losses = log_history['gen_train_losses']
    gen_eval_losses = log_history['gen_eval_losses']
    epochs = log_history['epochs']
    plt.figure(figsize=(7, 5))
    plt.plot(epochs, gen_train_losses, 'c-o', label='Generation Training Losses')
    plt.plot(epochs, gen_eval_losses, 'y-o', label='Generation Validation Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Losses')
    plt.title('Generation Learning Curve Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(out_dir, "gen_learning_curve.png"))
    plt.show()
    plt.close()

def plot_all_learning_curves(log_history, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    epochs = log_history['epochs']
    
    # Plot total train losses
    plt.figure(figsize=(7, 5))
    plt.plot(epochs, log_history['total_train_losses'], 'b-o', label='Total Train Losses')
    plt.plot(epochs, log_history['total_eval_losses'], 'r-o', label='Total Eval Losses')
    plt.plot(epochs, log_history['reg_train_losses'], 'g-o', label='Regression Train Losses')
    plt.plot(epochs, log_history['reg_eval_losses'], 'm-o', label='Regression Eval Losses')
    plt.plot(epochs, log_history['gen_train_losses'], 'c-o', label='Generation Train Losses')
    plt.plot(epochs, log_history['gen_eval_losses'], 'y-o', label='Generation Eval Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Losses')
    plt.title('Total Train Losses Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(out_dir, "all_learning_curves.png"))
    plt.show()
    plt.close()

## Training

In [None]:
if TRAIN_MODE == "train_no_optuna":
    criterion = EuclideanLoss(
        regression_loss_fn=torch.nn.MSELoss(),
        # regression_loss_fn=torch.nn.SmoothL1Loss(),
    ).to(DEVICE)
    # criterion = StaticWeightedLoss(
    #     # regression_loss_fn=torch.nn.MSELoss(),
    #     regression_loss_fn=torch.nn.SmoothL1Loss(),
    #     alpha=0.5
    # ).to(DEVICE)
    # criterion = UncertaintyLoss(
    #     # regression_loss_fn=torch.nn.MSELoss(),
    #     regression_loss_fn=torch.nn.SmoothL1Loss()
    # ).to(DEVICE)

    model = BartWithRegression().to(DEVICE)
    # model = torch.compile(model)

    suffix = f"-{criterion.__class__.__name__}-{criterion.regression_loss_fn.__class__.__name__}"
    save_path = OUT_DIR + suffix
    results_path = os.path.join(save_path, "results")
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(results_path, exist_ok=True)

    args = TrainingArguments(
        criterion=criterion,
        num_epochs=NUM_EPOCHS,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        # get_scheduler_fn=get_linear_schedule_with_warmup,
        warmup_percentage=WARMUP_PERCENTAGE,
        body_lr=BODY_LR, head_lr=HEAD_LR,
        weight_decay=WEIGHT_DECAY,
        early_stopping_patience=EARLY_STOPPING_PATIENCE,
        logging=True,
        save_path=save_path
    )

    trainer = Trainer(
        model=model, args=args, device=DEVICE,
        train_dataloader=train_dataloader,
        eval_dataloader=eval_dataloader
    )
    trainer.train()

    plot_general_learning_curve(
        trainer.log_history,
        results_path
    )
    plot_reg_learning_curve(
        trainer.log_history,
        results_path
    )
    plot_gen_learning_curve(
        trainer.log_history,
        results_path
    )

    del model, trainer
    gc.collect()
    torch.cuda.empty_cache()

## Train with Optuna Hyperparameters Optimization

In [None]:
if TRAIN_MODE == "train_with_optuna":
    temp_save_path = OUT_DIR + "-temp"
    save_path = OUT_DIR
    results_path = os.path.join(save_path, "results")
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(results_path, exist_ok=True)
    best_log_history = None

    losses = {
        "euclidean-mse": EuclideanLoss(
            regression_loss_fn=torch.nn.MSELoss(),
        ).to(DEVICE),
        "euclidean-smoothl1": EuclideanLoss(
            regression_loss_fn=torch.nn.SmoothL1Loss(),
        ).to(DEVICE),
        "uncertainty-mse": UncertaintyLoss(
            regression_loss_fn=torch.nn.MSELoss(),
        ).to(DEVICE),
        "uncertainty-smoothl1": UncertaintyLoss(
            regression_loss_fn=torch.nn.SmoothL1Loss(),
        ).to(DEVICE),
    }

# --- Define the Objective Function for Optuna ---
def objective(trial: optuna.Trial):
    """
    An objective function to be maximized or minimized by Optuna.
    This function takes a `trial` object, sets up a model with hyperparameters suggested
    by the trial, trains it, and returns the performance metric to optimize (eval_loss).
    """

    # BODY_LR = 3e-5
    # HEAD_LR = 1.5e-4
    criterion = trial.suggest_categorical("criterion", list(losses.keys()))
    head_lr = trial.suggest_float("head_lr", 1e-5, 1e-3, log=True)
    body_lr = trial.suggest_float("body_lr", 1e-5, 1e-4, log=True)
    # batch_size = trial.suggest_categorical("batch_size", [4, 8])
    weight_decay = trial.suggest_float("weight_decay", 1e-3, 0.1, log=True)
    # label_smoothing = trial.suggest_float("label_smoothing", 0.0, 0.2)
    
    # print(f"Trial {trial.number} - Hyperparameters:")
    # print(f"  Criterion: {criterion}")
    # print(f"  Head LR: {head_lr:.6f}")
    # print(f"  Body LR: {body_lr:.6f}")
    # print(f"  Weight Decay: {weight_decay:.6f}")

    criterion = losses[criterion]

    try:
        model = BartWithRegression().to(DEVICE)
        # model = torch.compile(model)

        args = TrainingArguments(
            criterion=criterion,
            num_epochs=NUM_EPOCHS,
            gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
            # get_scheduler_fn=get_linear_schedule_with_warmup,
            warmup_percentage=WARMUP_PERCENTAGE,
            body_lr=body_lr, head_lr=head_lr,
            weight_decay=weight_decay,
            early_stopping_patience=None,
            logging=True,
            save_path=temp_save_path
        )

        current_trainer = Trainer(
            model=model, args=args, device=DEVICE,
            train_dataloader=train_dataloader,
            eval_dataloader=eval_dataloader
        )
        current_trainer.train()

        best_eval_loss = min(current_trainer.log_history['total_eval_losses'])

        # Check if this is the best trial so far
        if trial.number == 0 or best_eval_loss > trial.study.best_value:
            shutil.rmtree(save_path)
            os.rename(temp_save_path, save_path)
            plot_general_learning_curve(
                current_trainer.log_history,
                results_path
            )
            plot_reg_learning_curve(
                current_trainer.log_history,
                results_path
            )
            plot_gen_learning_curve(
                current_trainer.log_history,
                results_path
            )
            # print(f"New best model saved with score: {best_eval_loss:.4f}")
        
        # --- Clean up GPU memory before the next trial ---
        del model, current_trainer, args
        gc.collect()
        torch.cuda.empty_cache()

        # Return the best validation loss achieved during this trial
        return best_eval_loss

    except optuna.exceptions.TrialPruned as e:
        # --- Handle pruned trials ---
        # Clean up memory for pruned trials as well
        gc.collect()
        torch.cuda.empty_cache()
        raise e
    except Exception as e:
        # --- Handle other errors like CUDA OOM ---
        print("Trial failed with error:")
        traceback.print_exc()
        gc.collect()
        torch.cuda.empty_cache()
        # Return a very high loss value so Optuna knows this trial was bad
        return float('inf')

In [None]:
if TRAIN_MODE == "train_with_optuna":
    # Ensure the output directory exists
    os.makedirs(save_path, exist_ok=True)

    # Create a study object. `direction="minimize"` means Optuna will try to minimize the return value of `objective`.
    # The TPE sampler is the algorithm for Bayesian Optimization.
    # The MedianPruner is an aggressive early stopping algorithm.
    study = optuna.create_study(
        direction="minimize",
        sampler=optuna.samplers.TPESampler(),
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=PRUNER_WARMUP_STEPS) # Prune after the 1st epoch
    )

    # Start the optimization. Optuna will call the `objective` function `n_trials` times.
    study.optimize(
        objective,
        n_trials=OPTUNA_TRAIN_TRIALS,
        show_progress_bar=True
    )

    # --- Print the results ---
    print("Study statistics: ")
    print(f"  Number of finished trials: {len(study.trials)}")

    print("Best trial:")
    best_trial = study.best_trial
    print(f"  Value (min eval loss): {best_trial.value}")

    print("  Params: ")
    for key, value in best_trial.params.items():
        print(f"    {key}: {value}")

In [None]:
if TRAIN_MODE == "train_with_optuna":
    results_path = os.path.join(save_path, "results")
    os.makedirs(results_path, exist_ok=True)

    # You can also save the results to a file
    df = study.trials_dataframe()
    path = os.path.join(results_path, "optuna_train_hyperparams_results.csv")
    df.to_csv(path)
    print(f"\nStudy results saved to {path}")

# Testing the Model

In [28]:
# loaded_model = BartWithRegression(os.path.join(save_path, "epoch-" + str(NUM_EPOCHS))).to(DEVICE)
loaded_model = BartWithRegression(
    ".\\out\\2025-07-21_09-40-43-UncertaintyLoss-SmoothL1Loss\\epoch-9",
    verbose=True
).to(DEVICE)
trainer = Trainer(
    model=loaded_model, args=None, device=DEVICE,
    train_dataloader=None,
    eval_dataloader=eval_dataloader,
    test_dataloader=test_dataloader
)

## Generation Hyperparameters Grid Search

In [None]:
# beam_search_grid = {
#     'num_beams': [3, 5, 7],
#     'repetition_penalty': [1.0, 1.2, 1.5],
#     'do_sample': [False] # Fix this to false for beam search
# }

# sampling_grid = {
#     'do_sample': [True], # Fix this to true for sampling
#     'top_k': [40, 50],
#     'top_p': [0.92, 0.95],
#     'temperature': [0.8, 0.85, 0.9, 0.95]
# }

# METRIC = 'sbert_similarity'

# keys, values = zip(*sampling_grid.items())
# hyperparameter_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

# best_score = -float('inf')
# best_hyperparams = None

# search_progress_bar = tqdm(
#     hyperparameter_combinations,
#     desc="Generation Hyperparameters Grid Search"
# )
# for params in search_progress_bar:  
#     scores = trainer.evaluate(eval_dataloader, leave=False, **params)

#     current_score = scores[METRIC]
#     if current_score > best_score:
#         best_score = current_score
#         best_hyperparams = params

#     search_progress_bar.set_postfix({
#         'current_score': f'{current_score:.4f}',
#         'best_score': f'{best_score:.4f}'
#     })

# print(f"Best hyperparameters: {best_hyperparams}")
# print(f"Best score: {best_score:.4f}")

## Generation Hyperparameters Search with Optuna

In [14]:
def objective_generation(trial: optuna.Trial):
    decoding_strategy = trial.suggest_categorical("strategy", ["beam_search", "sampling"])

    gen_kwargs = {
        "max_length": 1024,
        "repetition_penalty": trial.suggest_float("repetition_penalty", 1.0, 1.3)
    }

    if decoding_strategy == "beam_search":
        gen_kwargs["num_beams"] = trial.suggest_int("num_beams", 2, 8)
        gen_kwargs["early_stopping"] = True # Usually a good idea with beam search
    
    elif decoding_strategy == "sampling":
        gen_kwargs["do_sample"] = True
        gen_kwargs["top_p"] = trial.suggest_float("top_p", 0.85, 0.98)
        gen_kwargs["top_k"] = trial.suggest_int("top_k", 20, 100) # Optional, often top_p is enough
        gen_kwargs["temperature"] = trial.suggest_float("temperature", 0.7, 1.0)

    print(f"--- Trial {trial.number}: Testing with {gen_kwargs} ---")

    # --- Run Evaluation ---
    # This is fast because it's only inference
    results = trainer.evaluate(eval_dataloader, leave=False, **gen_kwargs)

    # --- Return the Metric to Maximize ---
    # We want to maximize a semantic score. Let's choose BERTScore F1.
    metric_to_optimize = 'sbert_similarity'
    score = results[metric_to_optimize]
    print(f"Trial {trial.number} Result -> {metric_to_optimize}: {score:.4f}")
    
    # Clean up memory just in case, although less critical for inference
    gc.collect()
    torch.cuda.empty_cache()

    return metric_to_optimize

In [None]:
if TEST_MODE == "test_with_optuna":
    # For this search, we want to MAXIMIZE the score.
    study = optuna.create_study(
        direction="maximize",
        sampler=optuna.samplers.TPESampler(),
        # Pruning is not needed here because each trial is just one full evaluation,
        # there are no intermediate steps to prune.
    )

    # Run the optimization. 20-30 trials should be plenty for this.
    study.optimize(
        objective_generation,
        n_trials=OPTUNA_TEST_TRIALS,
        show_progress_bar=True,
        # n_jobs=-1
    )

    # --- Print the results ---
    print("\n\n--- Generation Hyperparameter Search Complete ---")
    print(f"  Number of finished trials: {len(study.trials)}")

    print("Best trial for generation:")
    best_trial = study.best_trial
    print(f"    Best Score: {best_trial.value:.4f}")

    print("  Best Generation Params: ")
    best_hyperparams = best_trial.params
    for key, value in best_hyperparams.items():
        print(f"    {key}: {value}")

    # Save the results
    df = study.trials_dataframe()
    path = os.path.join(results_path, "optuna_gen_hyperparams_results.csv")
    df.to_csv(path)
    print(f"\nGeneration study results saved to {path}")


Skipping generation hyperparameter search. Using default settings.
Using default hyperparameters: {'max_length': 1024, 'do_sample': True, 'top_p': 0.9, 'top_k': 100, 'temperature': 0.8}


## Default Generation Hyperparameters

In [29]:
if TEST_MODE == "test_no_optuna":
    best_hyperparams = {
        "max_length": 1024,
        # "num_beams": 3,
        # # "repetition_penalty": 1.8,
        # "early_stopping": True,
        "do_sample": True,
        "top_p": 0.95,
        # Top-k sampling is a simple generalization of greedy decoding. Instead of choosing
        # the single most probable word to generate, we first truncate the distribution to the
        # top k most likely words, renormalize to produce a legitimate probability distribution,
        # and then randomly sample from within these k words according to their renormalized
        # probabilities.
        "top_k": 20,
        "temperature": 0.6 # 0.6
    }

## Inference Example

In [50]:
def inference_example(input_text):
    """
    Run inference on a single input text and print the results.
    """
    tokenized = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=1024)
    tokenized = {k: v.to(DEVICE) for k, v in tokenized.items()}

    output = loaded_model.generate(
        input_ids=tokenized['input_ids'][0].unsqueeze(0),
        attention_mask=tokenized['attention_mask'][0].unsqueeze(0),
        **best_hyperparams
    )

    print(f"Input:\n{input_text}")
    print(f"Polarities: {output['polarities']}")
    print(f"Explanations:\n{tokenizer.decode(output['explanations'][0], skip_special_tokens=True)}")

chat = '''
Topolino:
Sei la ragazza più bella del mondo. Spero che staremo per sempre insieme![TARGET][SEP]
'''
inference_example(chat)

Input:

Topolino:
Sei la ragazza più bella del mondo. Spero che staremo per sempre insieme![TARGET][SEP]

Polarities: tensor([0.5712], device='cuda:0', grad_fn=<SqueezeBackward1>)
Explanations:
Tag: Accettazione e Annuncio di Chiusura
Spiegazione: Topolino conclude la conversazione con una dichiarazione di affetto e speranza, sperando di mantenere un legame.


In [32]:
loaded_model.eval()
for example in tokenized_train_set.select(range(5)):
    output = loaded_model.generate(
        input_ids=example['input_ids'].unsqueeze(0).to(DEVICE),
        attention_mask=example['attention_mask'].unsqueeze(0).to(DEVICE),
        **best_hyperparams
    )
    # print(f"Output: {output['explanations'][0]}")
    decoded_output = tokenizer.decode(output['explanations'][0], skip_special_tokens=True)
    decoded_chat = tokenizer.decode(example['input_ids']) # , skip_special_tokens=True
    decoded_chat = decoded_chat.replace(SEP_TOKEN, SEP_TOKEN + "\n")
    decoded_true_explanation = tokenizer.decode(example['labels'], skip_special_tokens=True)

    print(f"Message:\n{decoded_chat}")
    print(f"True Polarity:{example['polarities']}")
    print(f"True Explanation:\n{decoded_true_explanation}")
    print(f"Generated Polarity:{output['polarities'].item()}")
    print(f"Generated Explanation:\n{decoded_output}")
    print("\n\n")

Message:
<s>Alessandro Conti:
Clara, devo comunicarti una cosa importante. Gli ultimi giorni sono stati estremamente intensi per il nuovo progetto su via del Corso. Richiede la mia totale concentrazione e purtroppo non mi sta lasciando molto spazio mentale per altro. [TARGET] [SEP]
 Clara Neri:
Oh, capisco, Ale. Immginavo fosse così, eri teso ultimamente. Mi dispiace tanto che tu sia sotto così tanta pressione. E non ti preoccupare per me, davvero. La tua serenità e la tua concentrazione sul lavoro sono la cosa più importante adesso. Prendi tutto lo spazio di cui hai bisogno. Sono qui per te. ❤️ [SEP]
 Alessandro Conti:
Apprezzo molto la tua comprensione, Clara. Non è una questione personale, ma una necessità logistica per gestire al meglio questa fase critica. Mi fa piacere sapere che sei di supporto. [SEP]
 Clara Neri:
Assolutamente sì, Ale. L'ho capito benissimo. La tua dedizione al lavoro è una delle cose che ammiro di più in te. Se hai bisogno di qualcosa di pratico, anche solo ch

## Evaluation of the Trained Model

In [None]:
results = trainer.evaluate(**best_hyperparams)
with open(os.path.join(results_path, "test_results.txt"), "w") as f:
    for key, value in results.items():
        print(f"{key}: {value:.4f}")
        f.write(f"{key}: {value:.4f}\n")