In [1]:
# !pip install transformers wandb rouge_score
# !pip install datasets==2.21.0
# !pip install evaluate
!pip install sacrebleu

Collecting sacrebleu
  Downloading sacrebleu-2.4.3-py3-none-any.whl.metadata (51 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/51.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-2.10.1-py3-none-any.whl.metadata (8.5 kB)
Collecting colorama (from sacrebleu)
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Downloading sacrebleu-2.4.3-py3-none-any.whl (103 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/104.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.0/104.0 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Downloading portalocker-2.10.1-py3-none-any.whl (18 kB)
Installing collected packages: portalocker, colorama, sac

In [2]:
import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, DataCollatorWithPadding
from transformers import Trainer, TrainingArguments
from datasets import load_dataset, DatasetDict
import numpy as np
from rouge_score import rouge_scorer
import wandb
from transformers.integrations import WandbCallback

# load dataset
def load_billsum(num_examples=10000, test_size=0.1):
    dataset = load_dataset("billsum", split="ca_test")
    dataset = dataset.select(range(min(num_examples, len(dataset))))

    train_val_dataset = dataset.train_test_split(test_size=test_size)

    return DatasetDict({
        "train": train_val_dataset["train"],
        "validation": train_val_dataset["test"]
    })

# Initialize tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Preprocess function
def preprocess_function(examples):
    prefix = " TL;DR: "  # Prefix to guide the model for summarization

    # Prepare input texts with prefix
    inputs = [prefix + doc for doc in examples["text"]]

    # Tokenize inputs (source texts) with padding and truncation
    model_inputs = tokenizer(
        inputs,
        max_length=1024,
        truncation=True,
        padding="max_length"  # Uniform padding
    )

    # Tokenize labels (summaries) with padding and truncation
    labels = tokenizer(
        examples["summary"],
        max_length=128,
        truncation=True,
        padding="max_length"  # Uniform padding
    )

    # Replace padding token id's of the labels by -100 so it's ignored by the loss
    labels_ids = labels["input_ids"]
    labels_ids = [
        [(token_id if token_id != tokenizer.pad_token_id else -100) for token_id in label]
        for label in labels_ids
    ]

    model_inputs["labels"] = labels_ids  # Assign processed labels

    return model_inputs


# Load and preprocess the dataset
dataset = load_billsum()
tokenized_datasets = dataset.map(preprocess_function, batched=True)

# Function to freeze layers based on variant type
def freeze_layers(model, variant_type):
    if variant_type == "noNorm":
        for name, param in model.named_parameters():
            if "ln" in name:
                param.requires_grad = False
    elif variant_type == "AttnOnly":
        for name, param in model.named_parameters():
            if "ln_2" in name:  # Freeze FFN layer norm
                param.requires_grad = False
    elif variant_type == "FFNonly":
        for name, param in model.named_parameters():
            if "ln_1" in name:  # Freeze attention layer norm
                param.requires_grad = False
    # For baseModel, we don't freeze any layers



Map:   0%|          | 0/1113 [00:00<?, ? examples/s]

Map:   0%|          | 0/124 [00:00<?, ? examples/s]

In [3]:
import evaluate  # Import the evaluate library

# Initialize ROUGE and BLEU metrics
rouge = evaluate.load("rouge")
# bleu = evaluate.load("bleu")


import sacrebleu

def compute_metrics(eval_pred):
    """
    Compute ROUGE and BLEU metrics for summarization using SacreBLEU with smoothing.

    Args:
        eval_pred (EvalPrediction): Contains predictions and label_ids.

    Returns:
        dict: Average ROUGE and BLEU scores.
    """
    predictions, labels = eval_pred

    # Convert logits to token IDs by taking the argmax over the vocabulary dimension
    pred_ids = np.argmax(predictions, axis=-1)

    # Decode the predicted token IDs to text
    decoded_preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)

    # Replace -100 in the labels with the pad token ID and decode
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Ensure that the predictions and references are lists of strings
    # and remove any leading/trailing whitespace
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]

    # Compute ROUGE scores using the evaluate library
    rouge_result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    # Compute BLEU scores using SacreBLEU with smoothing
    bleu_scores = sacrebleu.corpus_bleu(
        decoded_preds,
        [decoded_labels],
        smooth_method='exp',       # Exponential smoothing
        smooth_value=0.1,
        force=True,                # Force compute even if length mismatch
        lowercase=True,            # Normalize case
        tokenize='13a'             # Tokenizer type (SacreBLEU default)
    )

    bleu_score = bleu_scores.score  # SacreBLEU returns a score attribute

    # Aggregate the results
    result = {
        "rouge1": rouge_result["rouge1"],
        "rouge2": rouge_result["rouge2"],
        "rougeL": rouge_result["rougeL"],
        "bleu": bleu_score
    }

    # Optional: Calculate average prediction length
    prediction_lens = [len(pred.split()) for pred in decoded_preds]
    result["gen_len"] = np.mean(prediction_lens)

    # Round the results to four decimal places for readability
    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [4]:
# Fine-tuning function
def fine_tune_model(model, tokenizer, dataset, output_dir, variant, norm_type):
    # Initialize wandb run
    wandb.init(project=f"GPT-Valkyrie_{norm_type}-124m__{variant}__Billsum", reinit=True)
    run_name = wandb.run.name

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=1,
        eval_accumulation_steps=2,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir="./logs",
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=5,
        save_steps=10,
        load_best_model_at_end=True,
        metric_for_best_model="rougeL",
        report_to="wandb",
        run_name=run_name,
        save_total_limit=2,  # Limit the total number of checkpoints
    )

    from transformers import DataCollatorForLanguageModeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False  # Causal language modeling
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,  # Use the updated data collator
        compute_metrics=compute_metrics,
        callbacks=[WandbCallback()],
    )

    trainer.train()
    wandb.finish()
    return trainer.model, run_name

In [None]:
# Main training loop
variants = ["baseModel", "noNorm", "AttnOnly", "FFNonly"]
norm_types = ["LN", "RMSN"]

for norm_type in norm_types:
    for variant in variants:
        print(f"Processing {norm_type} {variant} model...")

        # Use the correct base model for each variant
        model_path = f"shng2025/GPT-Valkyrie_{norm_type}-124m__{variant}__"
        model = GPT2LMHeadModel.from_pretrained(model_path)

        model.config.pad_token_id = tokenizer.pad_token_id
        # Print to verify
        print(f"Tokenizer pad token: {tokenizer.pad_token}")
        print(f"Tokenizer pad token ID: {tokenizer.pad_token_id}")
        print(f"Model pad token ID: {model.config.pad_token_id}")

        freeze_layers(model, variant)

        output_dir = f"./results/{norm_type}/{variant}"
        fine_tuned_model, run_name = fine_tune_model(model, tokenizer, tokenized_datasets, output_dir, variant, norm_type)

        # Save the model locally
        local_save_dir = f"./local_models/GPT-Valkyrie_{norm_type}-124m__{variant}__CNN-DM"
        fine_tuned_model.save_pretrained(local_save_dir)
        tokenizer.save_pretrained(local_save_dir)
        print(f"Model saved locally to {local_save_dir}")

        # Push the model to your HuggingFace Hub repository
        new_repo_name = f"shng2025/GPT-Valkyrie_{norm_type}-124m__{variant}__CNN-DM"
        fine_tuned_model.push_to_hub(new_repo_name, branch=run_name)
        tokenizer.push_to_hub(new_repo_name, branch=run_name)
        print(f"Model pushed to HuggingFace Hub: {new_repo_name}, branch: {run_name}")

print("Training completed for all variants and normalization types.")

Processing LN baseModel model...
Tokenizer pad token: <|endoftext|>
Tokenizer pad token ID: 50256
Model pad token ID: 50256


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112897011106624, max=1.0…

You are adding a <class 'transformers.integrations.integration_utils.WandbCallback'> to the callbacks of this Trainer, but there is already one. The currentlist of callbacks is
:DefaultFlowCallback
WandbCallback


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Bleu,Gen Len
5,No log,2.461563,0.7077,0.3517,0.4877,32.1037,756.6532
