In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
from datasets import load_dataset

from transformers import AutoTokenizer, GPTNeoXForCausalLM

# Utils

In [None]:
def remove_checkpoint(model_size, checkpoint_step, rotate_vector=False):
    """
    Remove the checkpoint vector (checkpoint n+1 - checkpoint n) from the model and return the model.
    Rotate vector: instead of removing the checkpoint vector, remove a random vector of the same norm.
    """
    model_checkpoint = GPTNeoXForCausalLM.from_pretrained(
        f"EleutherAI/pythia-{model_size}-deduped",
        revision=f"step{checkpoint_step}",
        cache_dir=f"./pythia-{model_size}-deduped/",
    )
    model_next_checkpoint = GPTNeoXForCausalLM.from_pretrained(
        f"EleutherAI/pythia-{model_size}-deduped",
        revision=f"step{checkpoint_step + 1000}",
        cache_dir=f"./pythia-{model_size}-deduped/",
    )
    model_final = GPTNeoXForCausalLM.from_pretrained(
        f"EleutherAI/pythia-{model_size}-deduped",
        cache_dir=f"./pythia-{model_size}-deduped/",
    )

    # remove model_next_checkpoint - model_checkpoint from model_final
    for param_final, param_checkpoint, param_next_checkpoint in zip(
        model_final.parameters(), model_checkpoint.parameters(), model_next_checkpoint.parameters()
    ):
        if not rotate_vector:
            param_final.data = param_final.data - (param_next_checkpoint.data - param_checkpoint.data)
        else:
            checkpoint_vector = (param_next_checkpoint.data - param_checkpoint.data)
            # random vector of the same norm
            random_vector = torch.randn(checkpoint_vector.data.shape)
            random_vector = (random_vector / torch.norm(random_vector)) * torch.norm(checkpoint_vector)
            param_final.data = param_final.data - random_vector

    
    return model_final
    

# Initialize dataloader

In [None]:
# using the same tokenizer for all models, I think it's correct
tokenizer = AutoTokenizer.from_pretrained(
    f"EleutherAI/pythia-70m-deduped",
    cache_dir=f"./pythia-70m-deduped/",
)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
batch_size = 1
n_batch = 100
device = "cuda:2"


In [None]:
dataset = load_dataset("wikipedia", '20220301.en',
                       split="train", cache_dir=".huggingface")

dataset = dataset.select(range(n_batch + 1))
dataset = dataset.map(
    lambda e: tokenizer(
        e["text"], truncation=True, max_length=2048, padding="max_length"
    ),
    batched=True,
)
dataset = dataset.with_format(type="torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)



# Loading models

In [None]:
checkpoints_steps = [20_000, 40_000, 60_000, 80_000, 90_000, 100_000, 110_000, 120_000, 130_000, 140_000]
model_sizes = ["70m", "160m"]

In [None]:
# Load the models which have been trained without each checkpoint
#TODO
# models_retrained_dict = dict of dict of models, indexed by model_size and checkpoint_step
models_retrained_dict = {}

In [None]:
models_dict = {}
for model_size in model_sizes:
    models_dict[model_size] = {}
    for checkpoint_step in checkpoints_steps:
        print(f"Loading model {model_size} step {checkpoint_step}")
        models_dict[model_size][checkpoint_step] = remove_checkpoint(model_size, checkpoint_step, rotate_vector=False)
    models_dict[model_size]["final"] = GPTNeoXForCausalLM.from_pretrained(
        f"EleutherAI/pythia-{model_size}-deduped",
        cache_dir=f"./pythia-{model_size}-deduped/",
    )
    models_dict[model_size]["final"]
    models_dict[model_size]["final"].to(device)


# Compute the distance between retrained model and model with removed checkpoint

We measure the similarity between models by taking the KL divergence on their logit outputs.

In [None]:
losses_remove = np.zeros((len(model_sizes), len(checkpoints_steps), n_batch))
losses_retrained = np.zeros((len(model_sizes), len(checkpoints_steps), n_batch))
kl_divs_remove_vs_retrained = np.zeros((len(model_sizes), len(checkpoints_steps), n_batch))
kl_divs_retrained_vs_remove = np.zeros((len(model_sizes), len(checkpoints_steps), n_batch))
kl_divs_remove_vs_final = np.zeros((len(model_sizes), len(checkpoints_steps), n_batch))
kl_divs_final_vs_remove = np.zeros((len(model_sizes), len(checkpoints_steps), n_batch))
kl_divs_retrained_vs_final = np.zeros((len(model_sizes), len(checkpoints_steps), n_batch))
kl_divs_final_vs_retrained = np.zeros((len(model_sizes), len(checkpoints_steps), n_batch))

# the code can be sped up if we can load all the models on the GPU at the same time
for i, model_size in enumerate(model_sizes):
    model_final = models_dict[model_size]["final"]
    model_final.to(device)
    model_final.eval()
    for j, checkpoint_step in enumerate(checkpoints_steps):
        print(f"Running model {model_size} step {checkpoint_step}")
        model_remove = models_dict[model_size][checkpoint_step]
        model_remove.to(device)
        model_remove.eval()
        model_retrained = models_retrained_dict[model_size][checkpoint_step]
        model_retrained.to(device)
        for k, batch in tqdm(enumerate(dataloader)):
            if k == n_batch:
                break
            input_ids = batch["input_ids"].to(device)
            with torch.no_grad():
                outputs_remove = model_remove(input_ids, labels=input_ids, return_dict=True)
                logprobs_remove = torch.nn.functional.log_softmax(outputs_remove.logits, dim=-1)
                probs_remove = torch.exp(logprobs_remove)
                losses_remove[i, j, k] = outputs_remove.loss.mean().item()
                outputs_retrained = model_retrained(input_ids, labels=input_ids, return_dict=True)
                logprobs_retrained = torch.nn.functional.log_softmax(outputs_retrained.logits, dim=-1)
                probs_retrained = torch.exp(logprobs_retrained)
                losses_retrained[i, j, k] = outputs_retrained.loss.mean().item()
            outputs_final = model_final(input_ids, return_dict=True)
            logprobs_final = torch.nn.functional.log_softmax(outputs_final.logits, dim=-1)
            probs_final = torch.exp(logprobs_final)

            kl_divs_final_vs_remove[i, j, k] = torch.sum(
                probs_final * (logprobs_final - logprobs_remove),
                dim=-1,
            ).mean().item()
            kl_divs_remove_vs_final[i, j, k] = torch.sum(
                probs_remove * (logprobs_remove - logprobs_final),
                dim=-1,
            ).mean().item()
            kl_divs_retrained_vs_remove[i, j, k] = torch.sum(
                probs_retrained * (logprobs_retrained - logprobs_remove),
                dim=-1,
            ).mean().item()
            kl_divs_remove_vs_retrained[i, j, k] = torch.sum(
                probs_remove * (logprobs_remove - logprobs_retrained),
                dim=-1,
            ).mean().item()
            kl_divs_retrained_vs_final[i, j, k] = torch.sum(
                probs_retrained * (logprobs_retrained - logprobs_final),
                dim=-1,
            ).mean().item()
            kl_divs_final_vs_retrained[i, j, k] = torch.sum(
                probs_final * (logprobs_final - logprobs_retrained),
                dim=-1,
            ).mean().item()
            
        model_remove.to("cpu")
        model_retrained.to("cpu")
    model_final.to("cpu")


In [None]:
# print a table of the results with uncertainties, written in scientific notation
for i, model_size in enumerate(model_sizes):
    for j, checkpoint_step in enumerate(checkpoints_steps):
        print(f"Model {model_size} step {checkpoint_step}")
        print(f"Loss remove: {losses_remove[i, j].mean():.2e} +- {losses_remove[i, j].std():.2e}")
        print(f"Loss retrained: {losses_retrained[i, j].mean():.2e} +- {losses_retrained[i, j].std():.2e}")
        print(f"KL div remove vs retrained: {kl_divs_remove_vs_retrained[i, j].mean():.2e} +- {kl_divs_remove_vs_retrained[i, j].std():.2e}")
        print(f"KL div retrained vs remove: {kl_divs_retrained_vs_remove[i, j].mean():.2e} +- {kl_divs_retrained_vs_remove[i, j].std():.2e}")
        print(f"KL div remove vs final: {kl_divs_remove_vs_final[i, j].mean():.2e} +- {kl_divs_remove_vs_final[i, j].std():.2e}")
        print(f"KL div final vs remove: {kl_divs_final_vs_remove[i, j].mean():.2e} +- {kl_divs_final_vs_remove[i, j].std():.2e}")
        print()

In [None]:
# plot the kl results
for i, model_size in enumerate(model_sizes):
    for j, checkpoint_step in enumerate(checkpoints_steps):
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].hist(kl_divs_remove_vs_retrained[i, j], bins=100, label="remove vs retrained", alpha=0.5)
        axs[0].hist(kl_divs_remove_vs_final[i, j], bins=100, label="remove vs final", alpha=0.5)
        axs[1].hist(kl_divs_retrained_vs_remove[i, j], bins=100, label="retrained vs remove", alpha=0.5)
        axs[1].hist(kl_divs_final_vs_remove[i, j], bins=100, label="final vs remove", alpha=0.5)
        plt.legend()
        plt.title(f"Model {model_size} step {checkpoint_step}")
        # label the axes
        axs[0].set_xlabel("KL divergence")
        axs[0].set_ylabel("Count")
        plt.show()