In [None]:
import pandas as pd
import numpy as np
from datasets import Dataset
import evaluate
from transformers import (
    BartForConditionalGeneration,
    BartTokenizer,
    DataCollatorForSeq2Seq,
    TrainingArguments,
    Trainer,
    T5Tokenizer,
    T5ForConditionalGeneration
)
from sklearn.model_selection import train_test_split

In [None]:
df = pd.read_csv("../data/processed_data.csv", index_col=[0])
df.index = pd.to_datetime(df.index)

In [None]:
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

In [None]:
df = df.sort_values(by='datetime')
split_index = int(len(df) * 0.9)
df_train_time = df.iloc[:split_index]
df_val_time = df.iloc[split_index:]
print(f"\n--- Time-based Split ---")
print(f"Total samples: {len(df)}")
print(f"Training samples (90%): {len(df_train_time)}")
print(f"Validation samples (10%): {len(df_val_time)}")
print(f"Training data goes up to: {df_train_time['datetime'].max()}")
print(f"Validation data starts from: {df_val_time['datetime'].min()}")

In [None]:
all_companies = df['inferred company'].unique()
train_companies, val_companies = train_test_split(all_companies, test_size=0.1, random_state=42)
df_train_brands = df[df['inferred company'].isin(train_companies)]
df_val_brands = df[df['inferred company'].isin(val_companies)]
print(f"Companies for training: {len(train_companies)}")
print(f"Companies for validation: {len(val_companies)}")
print(f"Training samples: {len(df_train_brands)}")
print(f"Validation samples: {len(df_val_brands)}")

In [None]:
train_dataset = Dataset.from_pandas(df_train_time[['prompt', 'content']])
val_dataset = Dataset.from_pandas(df_val_time[['prompt', 'content']])

In [None]:
def tokenize_function(examples):

    model_inputs = tokenizer(
        examples['prompt'],
        max_length=128,  # Max length for the input prompt
        padding="max_length",
        truncation=True
    )


    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples['content'],
            max_length=64,   # Max length for the output tweet
            padding="max_length",
            truncation=True
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
print("\n--- Tokenizing datasets... ---")
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_val_dataset = val_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset = tokenized_train_dataset.remove_columns(['prompt', 'content'])
tokenized_val_dataset = tokenized_val_dataset.remove_columns(['prompt', 'content'])
print(tokenized_train_dataset[0])

In [None]:
metric_rouge = evaluate.load("rouge")
metric_bleu = evaluate.load("bleu")

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # Decode the predicted tokens back into text
    # skip_special_tokens=True removes padding/control tokens
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in labels (which are padding) with the pad token
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode the label tokens back into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # ROUGE metrics
    rouge_result = metric_rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    # BLEU metrics
    # For BLEU, references must be a list of lists
    decoded_labels_list = [[label] for label in decoded_labels]
    bleu_result = metric_bleu.compute(
        predictions=decoded_preds,
        references=decoded_labels_list
    )

    # Combine them and return
    # We'll just pick a few key metrics to log
    result = {
        "rouge1": rouge_result["rouge1"],
        "rouge2": rouge_result["rouge2"],
        "rougeL": rouge_result["rougeL"],
        "bleu": bleu_result["bleu"]
    }

    return result

In [None]:
model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

# --- 2. Data Collator ---
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# --- 3. Metric Function ---
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    # Extract f-measure for ROUGE
    result = {k: round(v.mid.fmeasure * 100, 2) for k, v in result.items()}

    # Add a key for the metric-for-best-model
    result["rougeL"] = result.get("rougeL", 0.0)
    return result

# --- 4. Training Arguments ---
training_args = TrainingArguments(
    output_dir="./results_bart",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs_bart",
    save_steps=1000,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    eval_strategy="steps",
    eval_steps=500,
    save_total_limit=2,
    fp16=True,
    report_to="tensorboard",
)

# --- 5. Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,  # You must have this variable defined
    eval_dataset=tokenized_val_dataset,    # You must have this variable defined
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# --- 6. Train ---
print("Starting training...")
trainer.train()
print("Training finished.")

# --- 7. Save Model ---
print("Saving final model...")
trainer.save_model("../results_bart/final_model")
tokenizer.save_pretrained("../results_bart/final_model")
print("Model saved to results_bart/final_model")