In [1]:
import matplotlib.pyplot as plt
import scipy

import torch
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, LlamaForCausalLM, LlamaConfig
from datasets import load_dataset
import evaluate

import numpy as np
import pandas as pd
from tqdm import tqdm

import os
import argparse
import json
import hashlib
import subprocess

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def eval_tiny(model_path, eval_texts):
    perplexity = evaluate.load("perplexity", module_type="metric")
    result = perplexity.compute(model_id=model_path,
                                add_start_token=True,
                                predictions=eval_texts)
    pplx = np.log(result['perplexities'])

    return pplx

In [3]:
n_train = 50000

model_id = 0
ctrl_id = 0
dist_id = 0

order_epoch = 2
model_epoch = 2
dist_epoch = 2
ctrl_epoch = 1

# ./train_references/debug/tiny_ref_model_233ff826/epoch-2-index-0'

HASH = '233ff826'
df_path = f'./train_references/debug/tiny_ref_model_{HASH}/tinystories.csv'
model_path = f'./train_references/debug/tiny_ref_model_{HASH}/epoch-{model_epoch}-index-{model_id}'

CTRL_HASH = '233ff826'
ctrl_model_path = f'./train_references/debug/tiny_ref_model_{CTRL_HASH}/epoch-{ctrl_epoch}-index-{ctrl_id}'

DIST_HASH = '2a2a2a26'
dist_model_path = f'./distill_references/debug/tiny_dist_model_{DIST_HASH}/epoch-{dist_epoch}-index-{dist_id}'

dataset = load_dataset("roneneldan/TinyStories")
texts = dataset["train"]["text"][:n_train]
texts = [item for item in texts if item != ""]

model_pplx = eval_tiny(model_path,texts)
ctrl_pplx = eval_tiny(ctrl_model_path,texts)
dist_pplx = eval_tiny(dist_model_path,texts)

df = pd.read_csv(df_path)

og_stat = scipy.stats.spearmanr(np.argsort(df[f'order-{model_id}-epoch-{order_epoch}']), model_pplx-ctrl_pplx)
dist_stat = scipy.stats.spearmanr(np.argsort(df[f'order-{model_id}-epoch-{order_epoch}']), dist_pplx-ctrl_pplx)

  0%|                                                                         | 0/3125 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
100%|██████████████████████████████████████████████████████████████| 3125/3125 [06:31<00:00,  7.98it/s]
100%|██████████████████████████████████████████████████████████████| 3125/3125 [06:30<00:00,  8.00it/s]
100%|██████████████████████████████████████████████████████████████| 3125/3125 [06:30<00:00,  8.00it/s]


In [4]:
dist_stat

SignificanceResult(statistic=-0.015105925864951748, pvalue=0.0007315045730134855)

In [None]:
# for posterity

# og_stat = scipy.stats.spearmanr(np.argsort(df[f'order-{model_id}-epoch-{order_epoch}']), df[f'pplx-{model_id}-epoch-{model_epoch}']-df[f'pplx-{control_id}-epoch-{control_epoch}'])
# dist_stat = scipy.stats.spearmanr(np.argsort(df[f'order-{model_id}-epoch-{order_epoch}']), dist_pplx-df[f'pplx-{control_id}-epoch-{control_epoch}'])

In [10]:
model_id = 1
control_id = 0

order_epoch = 2
model_epoch = 2

stat = scipy.stats.spearmanr(np.argsort(df[f'order-{model_id}-epoch-{order_epoch}'])[:10000], df[f'pplx-{model_id}-epoch-{model_epoch}'][:10000]-df[f'pplx-{control_id}-epoch-{model_epoch}'][:10000])
print(stat)

SignificanceResult(statistic=-0.1652056402400564, pvalue=4.025806953374476e-62)


In [7]:
import numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt

model_id = 0
control_id = 1

# Ranges for order_epoch and model_epoch to loop over
order_epochs = list(range(10))   # Adjust as needed
model_epochs = list(range(10))   # Adjust as needed

# 2D array to store p-values from Spearman correlation
heatmap_values = np.zeros((len(order_epochs), len(model_epochs)))

# Compute Spearman p-value over each (order_epoch, model_epoch) pair
for i, oe in enumerate(order_epochs):
    for j, me in enumerate(model_epochs):
        # Spearman returns (correlation, p-value); [1] is the p-value
        p_val = scipy.stats.spearmanr(
            np.argsort(df[f'order-{model_id}-epoch-{oe}']),
            (df[f'pplx-{model_id}-epoch-{me}']
             - df[f'pplx-{control_id}-epoch-{me}'])
        )[1]
        heatmap_values[i, j] = max(np.log(p_val),-5)

# Plot the heatmap
plt.imshow(heatmap_values, cmap='viridis', origin='upper', aspect='auto')
plt.colorbar(label='Spearman p-value')
plt.xticks(range(len(model_epochs)), model_epochs)
plt.yticks(range(len(order_epochs)), order_epochs)
plt.xlabel('model_epoch')
plt.ylabel('order_epoch')
plt.title('Spearman p-values for (order_epoch, model_epoch) pairs')
plt.show()

KeyError: 'pplx-0-epoch-3'