In [1]:
# !pip install transformers
# !pip install datasets==2.21.0
# !pip install wandb


In [2]:
import torch
from transformers import GPT2ForQuestionAnswering, GPT2TokenizerFast, Trainer, TrainingArguments
from datasets import load_dataset, load_metric
from torch.nn.utils.rnn import pad_sequence
import os
import wandb
from transformers.integrations import WandbCallback
import numpy as np

from datasets import load_metric # used in compute_metrics
from transformers.trainer_utils import EvalPrediction
from typing import Dict, List


def freeze_layers(model, variant_type):
    if variant_type == "noNorm":
        for name, param in model.named_parameters():
            if "ln" in name:
                param.requires_grad = False
    elif variant_type == "AttnOnly":
        for name, param in model.named_parameters():
            if "ln_2" in name:  # Freeze FFN layer norm
                param.requires_grad = False
    elif variant_type == "FFOnly":
        for name, param in model.named_parameters():
            if "ln_1" in name:  # Freeze attention layer norm
                param.requires_grad = False
    # For baseModel, we don't freeze any layers

def prepare_squad_dataset(tokenizer):
    dataset = load_dataset("squad")

    def preprocess_function(examples):
        questions = examples["question"]
        contexts = examples["context"]
        answers = examples["answers"]
        example_ids = examples["id"]

        # Use a unique separator between question and context
        separator = tokenizer.eos_token  # GPT-2's eos_token is '<|endoftext|>'
        separator_length = len(separator)

        # Concatenate question and context with the separator
        inputs = [question + separator + context for question, context in zip(questions, contexts)]
        
        # Keep track of question lengths to adjust character positions later
        question_lengths = [len(question) for question in questions]

        # Tokenize concatenated inputs
        tokenized_examples = tokenizer(
            inputs,
            max_length=384,
            truncation=True,
            stride=128,
            return_overflowing_tokens=True,
            padding="max_length",
            return_offsets_mapping=True,  # Ensure offsets are returned
        )

        # Map from features to examples
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        offset_mapping = tokenized_examples.pop("offset_mapping")  # We will add this back to tokenized_examples

        # Add example_id to the tokenized examples
        tokenized_examples["example_id"] = []

        # Initialize lists
        tokenized_examples["offset_mapping"] = []
        tokenized_examples["example_id"] = []
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []

        for i in range(len(tokenized_examples["input_ids"])):
            # Map feature to its example index
            sample_index = sample_mapping[i]
            tokenized_examples["example_id"].append(example_ids[sample_index])

            # Get the offsets for the current feature
            offsets = offset_mapping[i]
            tokenized_examples["offset_mapping"].append(offsets)  # Add offsets to features

            # Get the answer text and its start position
            answer = answers[sample_index]
            if len(answer["answer_start"]) == 0 or len(answer["text"][0]) == 0:
                # If there's no answer, set positions to 0
                tokenized_examples["start_positions"].append(0)
                tokenized_examples["end_positions"].append(0)
                # Do not use 'continue' here; allow the loop to proceed
                continue

            # Compute the start and end character positions of the answer in the concatenated input
            answer_start_char = answer["answer_start"][0]
            answer_end_char = answer_start_char + len(answer["text"][0])

            # Adjust the answer positions to account for the question and separator
            context_start_char = question_lengths[sample_index] + separator_length
            adjusted_answer_start = answer_start_char + context_start_char
            adjusted_answer_end = answer_end_char + context_start_char

            # Find the start and end token indices in the tokenized input
            start_position = None
            end_position = None
            for idx, (offset_start, offset_end) in enumerate(offsets):
                if offset_start is None or offset_end is None:
                    continue
                if offset_start <= adjusted_answer_start < offset_end:
                    start_position = idx
                if offset_start < adjusted_answer_end <= offset_end:
                    end_position = idx
                    break
            if start_position is not None and end_position is not None:
                tokenized_examples["start_positions"].append(start_position)
                tokenized_examples["end_positions"].append(end_position)
            else:
                # If the answer is not found in the tokenized input, set positions to 0
                tokenized_examples["start_positions"].append(0)
                tokenized_examples["end_positions"].append(0)

        return tokenized_examples

    tokenized_train = dataset['train'].map(
        preprocess_function,
        batched=True,
        remove_columns=dataset['train'].column_names,
        load_from_cache_file=False,
    )

    # Prepare the validation dataset and collect features
    validation_features = []

    def preprocess_validation_function(examples):
        # Use the same preprocessing function
        tokenized_examples = preprocess_function(examples)

        # Collect features with example_id
        for i in range(len(tokenized_examples["input_ids"])):
            feature = {key: tokenized_examples[key][i] for key in tokenized_examples.keys()}
            validation_features.append(feature)

        return tokenized_examples

    tokenized_validation = dataset['validation'].map(
        preprocess_validation_function,
        batched=True,
        remove_columns=dataset['validation'].column_names,
        load_from_cache_file=False,
    )

    return tokenized_train, tokenized_validation, validation_features


In [3]:
from tqdm.auto import tqdm
import collections
import numpy as np

def postprocess_qa_predictions_single_logits(
    examples, features, raw_predictions, tokenizer, n_best_size=20, max_answer_length=30
):
    import numpy as np
    import collections
    from tqdm.auto import tqdm

    all_logits = raw_predictions  # raw_predictions: (num_features, sequence_length)

    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[feature["example_id"]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    pad_token_id = tokenizer.pad_token_id

    # Loop over all the examples.
    for example_index, example in enumerate(tqdm(examples)):
        example_id = example["id"]
        # Indices of the features associated with the current example.
        feature_indices = features_per_example[example_id]

        valid_answers = []

        context = example["context"]
        # Loop through all features associated with the current example.
        for feature_index in feature_indices:
            # Grab the predictions of the model for this feature.
            logits = all_logits[feature_index]
            # Map positions in logits to spans of text in the original context.
            offsets = features[feature_index]["offset_mapping"]
            input_ids = features[feature_index]["input_ids"]

            # Find the index of the pad_token (used as separator)
            try:
                sep_index = input_ids.index(pad_token_id)
            except ValueError:
                # Separator token not found; skip this feature
                continue

            # The context starts after the separator token
            context_start = sep_index + 1

            # Only consider context tokens
            context_offsets = offsets[context_start:]
            context_logits = logits[context_start:]

            # Get indices of the top logits
            top_indices = np.argsort(context_logits)[-n_best_size:]

            # Generate possible answer spans based on top scoring tokens
            for idx in top_indices:
                start_index = context_start + idx
                for end_index in range(start_index, min(start_index + max_answer_length, len(logits))):
                    if end_index >= len(offsets):
                        break
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Compute the score for the span (sum of logits)
                    span_score = logits[start_index] + logits[end_index]
                    start_char = offsets[start_index][0]
                    end_char = offsets[end_index][1]
                    answer_text = context[start_char:end_char]
                    valid_answers.append(
                        {
                            "score": span_score,
                            "text": answer_text
                        }
                    )

        # Select the best answer (with the highest score)
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In case no valid answer is found
            best_answer = {"text": "", "score": 0.0}

        predictions[example_id] = best_answer["text"]

    return predictions

In [4]:
from transformers import Trainer

class PredictionTrainer(Trainer):
    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys=None,
    ):
        if not prediction_loss_only:
            has_labels = all(inputs.get(k) is not None for k in self.label_names)
            inputs = self._prepare_inputs(inputs)

            with torch.no_grad():
                outputs = model(**inputs)
                logits = outputs.logits

            if has_labels:
                labels = tuple(inputs.get(name) for name in self.label_names)
            else:
                labels = None

            # Also return the inputs
            return (None, logits, labels, inputs)
        else:
            return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)

In [7]:
def fine_tune_model(model, tokenizer, train_dataset, validation_dataset, output_dir, variant, num_train_epochs=3):
    # Initialize wandb run
    wandb.init(project=f"GPT-Valkyrie_LN-124m__{variant}__SQuAD", reinit=True)
    run_name = wandb.run.name

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=80,
        per_device_eval_batch_size=80,
        warmup_steps=300,
        weight_decay=0.01,
        logging_dir="./logs",
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=100,
        save_steps=200,
        load_best_model_at_end=True,
        report_to="wandb",
        run_name=run_name,
    )

    # from transformers import DefaultDataCollator
    from transformers import DataCollatorWithPadding

    def custom_data_collator(features):
        labels = ['start_positions', 'end_positions']
        for f in features:
            for k in list(f.keys()):
                if k not in ['input_ids', 'attention_mask', 'offset_mapping', 'example_id'] + labels:
                    del f[k]

        # Check if all features have 'offset_mapping'
        for i, f in enumerate(features):
            if 'offset_mapping' not in f:
                # Handle the missing 'offset_mapping'
                f['offset_mapping'] = [(0, 0)] * len(f['input_ids'])  # Assign dummy offsets

        batch = DataCollatorWithPadding(tokenizer)(features)

        batch['offset_mapping'] = [f['offset_mapping'] for f in features]
        batch['example_id'] = [f['example_id'] for f in features]

        return batch

    # Use the custom PredictionTrainer
    trainer = PredictionTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        tokenizer=tokenizer,
        data_collator=custom_data_collator,
        compute_metrics=build_compute_metrics_fn(examples),
        callbacks=[WandbCallback()],
    )

    trainer.train()
    wandb.finish()
    return trainer.model, run_name

In [None]:
# TEST CODE
small_dataset = load_dataset("squad")['train'].select(range(10))
tokenized_small_dataset = small_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=small_dataset.column_names,
)

# Check if all features have 'offset_mapping'
for feature in tokenized_small_dataset:
    assert 'offset_mapping' in feature, "offset_mapping missing in feature"

In [9]:
# MAIN LOOP
wandb.login()

variants = ["noNorm", "AttnOnly", "FFNonly", "baseModel"]
base_model_path = "shng2025/GPT-Valkyrie_LN-124m__baseModel__"  # Changed to LN model

from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token


# Prepare the dataset and collect validation features
tokenized_train, tokenized_validation, validation_features = prepare_squad_dataset(tokenizer)
examples = load_dataset("squad")["validation"]  # Original validation examples


def build_compute_metrics_fn(examples):
    def compute_metrics(eval_pred):
        logits, labels, inputs = eval_pred
        logits = logits[0] if isinstance(logits, tuple) else logits  # Ensure logits is an array

        # Get the features from inputs
        features = {
            "input_ids": inputs["input_ids"],
            "offset_mapping": inputs["offset_mapping"],
            "example_id": inputs["example_id"],
        }

        # Convert features to a list of dictionaries
        features = [
            {
                "input_ids": features["input_ids"][i],
                "offset_mapping": features["offset_mapping"][i],
                "example_id": features["example_id"][i],
            }
            for i in range(len(features["input_ids"]))
        ]

        # Get final predictions
        final_predictions = postprocess_qa_predictions_single_logits(
            examples, features, logits, tokenizer
        )

        # Prepare references
        references = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]

        # Load the SQuAD metric
        metric = load_metric("squad")

        # Compute the metric
        formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
        results = metric.compute(predictions=formatted_predictions, references=references)

        return {
            "exact_match": results["exact_match"],
            "f1": results["f1"]
        }
    return compute_metrics

compute_metrics_fn = build_compute_metrics_fn(examples)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [10]:
for variant in variants:
    print(f"Processing {variant} model...")

    # Use the correct base model for each variant
    model_path = f"shng2025/GPT-Valkyrie_LN-124m__{variant}__"
    model = GPT2ForQuestionAnswering.from_pretrained(model_path)

    freeze_layers(model, variant)

    output_dir = f"./results/{variant}"
    fine_tuned_model, run_name = fine_tune_model(model, tokenizer, tokenized_train, tokenized_validation, output_dir, variant)

    # Save the model locally
    local_save_dir = f"./local_models/GPT-Valkyrie_LN-124m__{variant}__SQuAD"
    fine_tuned_model.save_pretrained(local_save_dir)
    tokenizer.save_pretrained(local_save_dir)
    print(f"Model saved locally to {local_save_dir}")

    # Push the model to your HuggingFace Hub repository
    new_repo_name = f"shng2025/GPT-Valkyrie_LN-124m__{variant}__SQuAD"
    fine_tuned_model.push_to_hub(new_repo_name, branch=run_name)
    tokenizer.push_to_hub(new_repo_name, branch=run_name)
    print(f"Model pushed to HuggingFace Hub: {new_repo_name}, branch: {run_name}")

Processing noNorm model...


Some weights of GPT2ForQuestionAnswering were not initialized from the model checkpoint at shng2025/GPT-Valkyrie_LN-124m__noNorm__ and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


You are adding a <class 'transformers.integrations.integration_utils.WandbCallback'> to the callbacks of this Trainer, but there is already one. The currentlist of callbacks is
:DefaultFlowCallback
WandbCallback


KeyError: 'offset_mapping'