# Proxy Perplexity

The goal of this notebook is to unravel the feasibility of the usage of the perplexity metric as a proxy for the groundtruth metric dataset for 1 of 50 samples with genertions in the models Llama3.2-3B-Instruct using 5 different random generations.

Here we just want to calculate a one shot scenario to have the knowledge for a further implementation

The chosen sample it was the idx = 3

The pre_collections are presented in the ".zip" file and need to be extracted, this experiment got them from the other experiment "datamodels_training_window_size
"


In [1]:
import polars as pl
import seaborn as sns
import json
import os
from utils.metrics.calculate_perplexity import calculate_perplexity, calculate_batch_perplexity
from utils.set_random_seed import set_random_seed
from utils.generate_context import get_context, get_batch_context

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
set_random_seed(42)

### Importing Data

In [3]:
#### Import collections
collections_list = []
for dir in os.listdir("collections"):
    for file in os.listdir(f"collections/{dir}"):
        collections_list.append(pl.read_ipc(f"collections/{dir}/{file}").with_columns(pl.lit(dir).alias("seed")))
collections = pl.concat(collections_list)

    

In [4]:
collections.head(5)

collection_idx,test_idx,input,evaluation,seed
i64,i64,"array[i64, 100]",f64,str
1000,0,"[0, 1, … 0]",0.0,"""7270"""
1000,1,"[0, 1, … 0]",1.0,"""7270"""
1000,2,"[0, 1, … 0]",1.0,"""7270"""
1000,3,"[0, 1, … 0]",1.0,"""7270"""
1000,4,"[0, 1, … 0]",1.0,"""7270"""


In [5]:
#### Import pre-collections
pre_collections_list = []
for dir in os.listdir("pre_collections"):
    for file in os.listdir(f"pre_collections/{dir}"):
       pre_collections_list.append(pl.read_ipc(f"pre_collections/{dir}/{file}").with_columns(pl.lit(dir).alias("seed")))
pre_collections = pl.concat(pre_collections_list)

In [6]:
pre_collections.head(5)

collection_idx,test_idx,input,predicted_output,true_output,seed
i64,i64,"array[i64, 100]",str,list[str],str
400,0,"[0, 0, … 0]","""Judith Keppel""","[""Judith Cynthia Aline Keppel""]","""7270"""
400,1,"[0, 0, … 0]","""George W. Bush""","[""George W. Bush"", ""Bush""]","""7270"""
400,2,"[0, 0, … 0]","""Sammi Smith.""","[""Kris Kristofferson""]","""7270"""
400,3,"[0, 0, … 0]","""October 27, 1904""","[""October 27 , 1904"", ""1904""]","""7270"""
400,4,"[0, 0, … 0]","""2004""","[""2004"", ""February 25 , 2004""]","""7270"""


In [7]:
## wiki import
WIKI_PATH = "../../data/wiki_dump2018_nq_open/processed/wiki.feather"
wiki = pl.read_ipc(WIKI_PATH).with_row_index("idx")
wiki.head(3)

idx,text,title
u32,str,str
0,"""Aaron Aaron ( or ; ""Ahärôn"") i…","""Aaron"""
1,"""God at Sinai granted Aaron the…","""Aaron"""
2,"""his rod turn into a snake. The…","""Aaron"""


In [8]:
retrievals_idx = {}
for dir in os.listdir("retrieval"):
    for file in os.listdir(f"retrieval/{dir}"):
        retrievals_idx[dir] = json.load(open(f"retrieval/{dir}/{file}"))
print(retrievals_idx.keys())

dict_keys(['7270'])


In [9]:
## wiki import
WIKI_PATH = "../../data/wiki_dump2018_nq_open/processed/wiki.feather"
wiki = pl.read_ipc(WIKI_PATH).with_row_index("idx")
wiki.head(3)

idx,text,title
u32,str,str
0,"""Aaron Aaron ( or ; ""Ahärôn"") i…","""Aaron"""
1,"""God at Sinai granted Aaron the…","""Aaron"""
2,"""his rod turn into a snake. The…","""Aaron"""


In [10]:
## wiki import
QUESTIONS_PATH = "../../data/nq_open_gold/processed/test.feather"
questions = pl.read_ipc(QUESTIONS_PATH).with_row_index("idx")
questions.head(3)

idx,example_id,question,answers,text,idx_gold_in_corpus
u32,i64,str,list[str],str,i64
0,-3290814144789249484,"""who got the first nobel prize …","[""Wilhelm Conrad Röntgen""]","""The first Nobel Prize in Physi…",20994698
1,8851020722386421469,"""when is the next deadpool movi…","[""May 18 , 2018""]","""Deadpool 2 is scheduled to be …",21032933
2,955374967862684316,"""the south west wind blows acro…","[""till September""]","""With the Intertropical Converg…",21032934


In [11]:
## Dataset collections import
import h5py
train_collections_datasets = []
test_collections_datasets = []
for dir in os.listdir("collections_dataset"):
    for file in os.listdir(f"collections_dataset/{dir}"):
        if file.endswith(".h5") and file.startswith("train_collection"):
            with h5py.File(f"collections_dataset/{dir}/{file}", "r") as f:
                train_collections_datasets.append(f["train_collection"][()])
        elif file.endswith(".h5") and file.startswith("test_collection"):
            with h5py.File(f"collections_dataset/{dir}/{file}", "r") as f:
                test_collections_datasets.append(f["test_collection"][()])

train_collections_datasets[0]

array([[82, 15, 48, ..., 24, 81, 51],
       [62, 73, 41, ..., 75,  2, 70],
       [34, 14, 28, ...,  8, 44, 94],
       ...,
       [91, 50, 38, ..., 10,  1, 55],
       [85, 37, 88, ..., 93, 47, 60],
       [70, 45, 45, ...,  5, 51, 34]])

## Estimation of Differential Perplexity - Singe Sample

In [31]:
collections.filter(pl.col("test_idx") == 3).filter((pl.col("evaluation") > 0) & (pl.col("evaluation") < 1))

collection_idx,test_idx,input,evaluation,seed
i64,i64,"array[i64, 100]",f64,str
1054,3,"[0, 0, … 1]",0.666667,"""7270"""
1100,3,"[0, 1, … 0]",0.4,"""7270"""
1238,3,"[0, 0, … 0]",0.666667,"""7270"""
1251,3,"[0, 0, … 0]",0.333333,"""7270"""
1281,3,"[0, 0, … 0]",0.333333,"""7270"""
…,…,…,…,…
756,3,"[0, 1, … 0]",0.166667,"""7270"""
810,3,"[0, 0, … 0]",0.222222,"""7270"""
811,3,"[0, 0, … 0]",0.666667,"""7270"""
922,3,"[0, 0, … 0]",0.666667,"""7270"""


In [13]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator
import torch



model_path = "../../models/llms/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token  # Set pad token to eos token for Llama models
model = AutoModelForCausalLM.from_pretrained(model_path,  device_map={"": Accelerator().process_index}, torch_dtype=torch.bfloat16,)


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.33it/s]


In [14]:
questions[3].select("question").item()

'what does hp mean in war and order'

In [15]:
### Example single perplexity calculation
base_perplexity = calculate_perplexity(
    questions[3].select("question").item(),
    model,
    tokenizer,
    device=Accelerator().device
)
base_perplexity

218.1916046142578

In [16]:

## Example batch perplexity calculation
calculate_batch_perplexity(
    [questions[3].select("question").item(), questions[3].select("question").item()],
    model,
    tokenizer,
    contexts=["Abacate", "Create a question:"],
    device=Accelerator().device,
    stride=48
)

Calculating perplexities: 100%|██████████| 2/2 [00:00<00:00, 36.52it/s]


tensor([957.2520, 326.5370], device='cuda:0')

In [17]:
get_context(
    train_collections_datasets[0],
    retrievals_idx["7270"],
    wiki,
    3,
    0,
)



'Document[1](Title: Sterling Street (IRT Nostrand Avenue Line))the IRT agreed to build a subway line along Nostrand Avenue in Brooklyn. The construction of the subway along Nostrand Avenue spurred real estate development in the surrounding areas. The Nostrand Avenue Line opened on August 23, 1920, and the Sterling Street station opened along with it. The platforms at Sterling Street were lengthened during the 1950s to 510 feet so that the platforms could accommodate 10-car trains. The underground station has two tracks and two side platforms. The platforms have original 1920s Dual Contracts era tiling. The name tablets have "STERLING ST." in white letters on a brown background\n\nDocument[2](Title: BMT Sea Beach Line)with poles and operated service on the line from May 1, 1915 until the line opened for full subway service on June 22, 1915, with trains running between Coney Island and Chambers Street in Lower Manhattan. Service started with two- and three-car trains operating via the Fo

In [19]:
contexts = get_batch_context(    
    train_collections_datasets[0],
    retrievals_idx["7270"],
    wiki,
    3,
    [i for i in range(0,2000)],
)

In [22]:
perplexity_3  = calculate_batch_perplexity(
    [questions[3].select("question").item() for _ in range(2000)],
    model,
    tokenizer,
    contexts,
    device=Accelerator().device
)

Calculating perplexities: 100%|██████████| 2000/2000 [14:49<00:00,  2.25it/s]


In [43]:
## Correlation
from scipy.stats import kendalltau
rouge = collections.filter(pl.col("test_idx") == 3).sort("collection_idx").select("evaluation").to_numpy().flatten().tolist()
per = perplexity_3.tolist()

kendalltau(rouge, per)

SignificanceResult(statistic=-0.1029672423927717, pvalue=1.1488643316498845e-08)