In [None]:
import sys
sys.path.insert(1, '/home/alon/')

from hf_token import HF_AUTH_TOKEN 
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.stats import entropy
from sklearn.metrics import mean_squared_error

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

os.environ["WANDB_DISABLED"] = "true"

BASE_DIR = Path("/home", "alon", "sdg", "tests")
CDLM_DIR = Path(BASE_DIR, "cdlm")
DATA_FNAME = CDLM_DIR / "tweets_datasets_of_eq_size"
TRAIN_PATH, TEST_PATH = DATA_FNAME / "train_set.csv", DATA_FNAME / "test_set.csv"
MAX_TWEETS_NUM = 100
DOC_SEP_START, DOC_SEP_END = r"<doc-s>", r"</doc-s>"
PADDING_TOKEN = "[PAD]"
LABELS_TO_5_PS = True

## Setup and Imports

In [None]:
import torch
from datasets import Dataset, DatasetDict
from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, )
from sklearn.metrics import f1_score
import numpy as np
import evaluate
import pandas as pd
from collections import Counter
from typing import Dict

## 1. Prepare Your Dataset

In [None]:
def map_to_5ps(sdg_label_int: int) -> int:
    if sdg_label_int == 0:
        return sdg_label_int
    elif 1<=sdg_label_int<=5:
        return 1
    elif sdg_label_int in [6, 12, 13, 14, 15]:
        return 2
    elif 7<=sdg_label_int<=11:
        return 3
    elif sdg_label_int == 16:
        return 4
    elif sdg_label_int == 17:
        return 5
    else:
        raise ValueError("Must provide a valid sdg label (int 0-17)")
        

def cut_tweets_by_cutoff(concatenated_tweets: str) -> str:
    concatenated_tweets = concatenated_tweets.rstrip(DOC_SEP_END)
    tweets = concatenated_tweets.split(f"{DOC_SEP_END}{DOC_SEP_START}")[:CUR_TWEETS_NUM]
    concatenated_tweets = f"{DOC_SEP_END}{DOC_SEP_START}".join(tweets) + DOC_SEP_END 
    return concatenated_tweets

def get_tweets_based_df(csv_full_path: Path) -> pd.DataFrame:
    df = pd.read_csv(csv_full_path, usecols=['label', 'concatenated_100_tweets'])
    df['concatenated_tweets'] = df['concatenated_100_tweets'].apply(cut_tweets_by_cutoff)
    df.drop(columns='concatenated_100_tweets', axis=1, inplace=True)
    df.rename(columns={"concatenated_tweets": "text"}, inplace=True)       
    return df


def get_description_based_df(csv_full_path: Path) -> pd.DataFrame:
    '''
    This function returns a dataframe with 2 columns ('text', 'label') where text is each company's description
    '''
    df = pd.read_csv(csv_full_path, usecols=['label', 'description'])
    df.rename(columns={"description": "text"}, inplace=True)

    return df

def get_dataset(csv_full_path: Path, description: bool = True):
    res = []
    if description:
        df = get_description_based_df(csv_full_path)
        
    else:
        df = get_tweets_based_df(csv_full_path)
        
    if LABELS_TO_5_PS:
        df['5ps_label'] = df['label'].apply(map_to_5ps)
        df.drop(columns='label', axis=1, inplace=True)
        df.rename(columns={"5ps_label": "label"}, inplace=True)
        print('set(df["label"])', set(df["label"]))
    
    for index, row in df.iterrows():
        res.append({"text": row["text"], "label": row["label"]})
        
    return res

def get_sorted_label_dict(ds) -> Dict[int, int]:
    return sorted(dict(Counter([i["label"] for i in ds])).items())


def pairs_list_to_histogram(pl: Counter, title: str):
    df = pd.DataFrame(pl, columns=['label', 'frequency'])
    df.plot(kind='bar', x='label', title=title)


CUR_TWEETS_NUM = MAX_TWEETS_NUM-75
train_dataset, test_dataset = get_dataset(TRAIN_PATH), get_dataset(TEST_PATH)
train_label_dist, test_label_dist = [get_sorted_label_dict(ds) for ds in [train_dataset, test_dataset]]

In [None]:
train_dataset = Dataset.from_list(train_dataset)
eval_dataset = Dataset.from_list(test_dataset)
print(len(train_dataset))
print(len(eval_dataset))

dataset = DatasetDict({'train': train_dataset, 'eval': eval_dataset})

## 2. Load Pre-trained Model and Tokenizer

In [None]:
model_name = "meta-llama/Llama-3.2-1B-Instruct" 
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_AUTH_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels= 6 if LABELS_TO_5_PS else 18, token=HF_AUTH_TOKEN)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id

## 3. Tokenize the Dataset

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) 

tokenized_datasets = dataset.map(tokenize_function, batched=True)

## 4. Define Training Arguments

In [None]:
training_args = TrainingArguments(output_dir="./alonm_tst_finetune5", 
                                  evaluation_strategy="steps",
                                  save_strategy="best", 
                                  learning_rate=2e-5, 
                                  per_device_train_batch_size=1, 
                                  per_device_eval_batch_size=1, 
                                  num_train_epochs=15,
                                  weight_decay=0.01,
                                  save_total_limit=2,
                                  load_best_model_at_end=True,
                                  metric_for_best_model="eval_loss",
                                  push_to_hub=False, 
                                  fp16=True,
                                  gradient_checkpointing=True,
                                 )
print("Done")

## 5. Define Evaluation Metric

In [None]:
accuracy_metric = evaluate.load("accuracy")
plots_dir = Path(BASE_DIR, "llm", "alonm_tst_finetune5", f"plots_tweets_{CUR_TWEETS_NUM}")
plots_dir.mkdir(exist_ok=True)
EPOCH_NUM=1

def plot_distribution(ax, hist_data, total_samples, title_prefix):
    ax.bar(range(6 if LABELS_TO_5_PS else 18), hist_data, width=0.8, alpha=0.7)
    ax.set_xticks(range(6 if LABELS_TO_5_PS else 18))
    ax.set_xlabel('Label ID')
    ax.set_ylabel('Count')
    ax.set_title(f'{title_prefix} Label Distribution (n={total_samples})')
    ax.grid(axis='y', alpha=0.3)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    bins = range(7 if LABELS_TO_5_PS else 19) 
    ref_hist, _ = np.histogram(labels, bins=bins)
    ref_total = len(labels)
    
    pred_hist, _ = np.histogram(predictions, bins=bins)
    pred_total = len(predictions)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    plot_distribution(ax1, ref_hist, ref_total, "Reference")
    plot_distribution(ax2, pred_hist, pred_total, "Predicted")
    
    plt.tight_layout()
    global EPOCH_NUM
    plt.savefig(plots_dir / f"distribution_comparison_{EPOCH_NUM}.png")
    plt.close()
    EPOCH_NUM+=1
    
    ref_dist = ref_hist / ref_hist.sum()
    pred_dist = pred_hist / pred_hist.sum()
    
    epsilon = 1e-10
    ref_dist_smooth = ref_dist + epsilon
    ref_dist_smooth = ref_dist_smooth / ref_dist_smooth.sum()
    pred_dist_smooth = pred_dist + epsilon
    pred_dist_smooth = pred_dist_smooth / pred_dist_smooth.sum()
    
    kl_div = entropy(ref_dist_smooth, pred_dist_smooth)
    
    m_dist = 0.5 * (ref_dist_smooth + pred_dist_smooth)
    js_div = 0.5 * (entropy(ref_dist_smooth, m_dist) + entropy(pred_dist_smooth, m_dist))
    
    rmse = np.sqrt(mean_squared_error(ref_dist, pred_dist))
    
    return {f"f1_{avg_type}": f1_score(labels, predictions, average=avg_type) for avg_type in ['weighted', 'macro', 'micro']} | {"kl_div": kl_div, "js_div":js_div, "rmse": rmse}

## 6. Create Trainer and Fine-tune the Model

In [None]:
trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=tokenized_datasets["train"],
                  eval_dataset=tokenized_datasets["eval"],
                  compute_metrics=compute_metrics,
                  )

with torch.backends.cuda.sdp_kernel(
    enable_flash=True,  
    enable_mem_efficient=True,  
    enable_math=False   
):   
    trainer.train()

## 7. Evaluate the Fine-tuned Model

In [None]:
evaluation_results = trainer.evaluate()

# 