## Processing datasets

This notebook prepares the data for LoRA vs full finetuning analysis.

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import numpy as np
import textwrap
import random 
import ast
import re
from tqdm import tqdm
tqdm.pandas()

In [2]:
# lora_path = "../results/pythia-1.4b/lora/r_16/lr_2e-4/early_stopping/num_train_4096/bsize_128/tkn_freq_probs_final.csv"
lora_path = "../results/pythia-1.4b/lora/r_16/lr_2e-4/early_stopping/num_train_4096/bsize_128/tkn_freq_probs_best.csv"
lora_df = pd.read_csv(lora_path)

# full_path = "../results/pythia-1.4b/full-ft/lr_2e-6/early_stopping/num_train_4096/bsize_128/tkn_freq_probs_final.csv"
full_path = "../results/pythia-1.4b/full-ft/lr_2e-6/early_stopping/num_train_4096/bsize_128/tkn_freq_probs_best.csv"
full_df = pd.read_csv(full_path)

base_path = "../results/pythia-1.4b/base_model/num_train_4096/tkn_freq_probs_base.csv"
base_df = pd.read_csv(base_path)

In [3]:
lora_df["in_token_ids"]

0                                            tensor([4673])
1                                      tensor([4673,   75])
2                             tensor([ 4673,    75, 13293])
3                      tensor([ 4673,    75, 13293,   642])
4               tensor([ 4673,    75, 13293,   642,   657])
                                ...                        
520187    tensor([  775,  1952,   652,  6960,   253,   4...
520188    tensor([  775,  1952,   652,  6960,   253,   4...
520189    tensor([  775,  1952,   652,  6960,   253,   4...
520190    tensor([  775,  1952,   652,  6960,   253,   4...
520191    tensor([  775,  1952,   652,  6960,   253,   4...
Name: in_token_ids, Length: 520192, dtype: object

In [37]:
# Function to extract context length from df 

def context_processing(x):
    context_ids, token = x["in_token_ids"], x["curr_token_id"]
    context = ast.literal_eval(re.split(r'tensor\(|\].*', context_ids)[1] + ']')
    x["context_len"] = len(context)
    x["token_in_context"] = int(token in context)
    x["uniq_ctxt_tkns_count"] = len(set(context))
    return x
    
def get_context(x):
    return ast.literal_eval(re.split(r'tensor\(|\].*', x)[1] + ']')

class ContextProcessor():
    def __init__(self, df):
        token = df["curr_token_id"]
        context = df["in_token_ids"].progress_apply(lambda x: get_context(x))
        self.df = pd.DataFrame({"context": context, "token": token})

    def get_context_len(self):
        return self.df.context.apply(len)
    
    def is_token_in_context(self):
        token_in_context = self.df.apply(lambda x: int(x["token"] in x["context"]), axis=1)
        return token_in_context
    
    def uniq_ctxt_tkns_count(self): 
        return self.df.context.apply(lambda x: len(set(x)))

In [5]:
# cp = ContextProcessor(lora_df[:100])

# lora_df["context_len"] = cp.get_context_len()
# lora_df["token_in_context"] = cp.is_token_in_context()
# lora_df["uniq_ctxt_tkns_count"] = cp.uniq_ctxt_tkns_count()

In [6]:
lora_df = lora_df.apply(context_processing, axis=1)

In [7]:
full_df = full_df.apply(context_processing, axis=1)

In [8]:
lora_df.head()

Unnamed: 0,prev_token,curr_token,prev_token_id,curr_token_id,in_tokens,in_token_ids,prev_token_freq,curr_token_freq,pair_token_freq,curr_token_prob,pmi,context_len,token_in_context,uniq_ctxt_tkns_count
0,ĠSen,j,4673,75,Sen,tensor([4673]),8,119,6,0.03671,-0.850398,1,0,1
1,j,Åį,75,13293,Senj,"tensor([4673, 75])",119,32,6,0.192568,-2.236693,2,0,2
2,Åį,Ġno,13293,642,Senjō,"tensor([ 4673, 75, 13293])",32,267,6,0.021627,-3.044818,3,0,3
3,Ġno,ĠV,642,657,Senjō no,"tensor([ 4673, 75, 13293, 642])",267,247,14,0.000253,-4.241172,4,0,4
4,ĠV,alk,657,1278,Senjō no V,"tensor([ 4673, 75, 13293, 642, 657])",247,29,55,0.040921,-0.652944,5,0,5


In [9]:
lora_df = lora_df.rename(columns={"curr_token_prob": "lora_prob"})
full_df = full_df.rename(columns={"curr_token_prob": "full_prob"})
base_df = base_df.rename(columns={"curr_token_prob": "base_prob"})
common_cols = set(lora_df.columns).intersection(set(full_df.columns))
full_df.drop(common_cols, axis=1, inplace=True)
common_cols = common_cols.intersection(set(base_df.columns))
base_df.drop(common_cols, axis=1, inplace=True)
df_combined = pd.concat([lora_df, full_df, base_df], axis=1)
df_combined["full_lora_diff"] = df_combined.full_prob - df_combined.lora_prob
df_combined["lora_base_diff"] = df_combined.lora_prob - df_combined.base_prob
df_combined["full_base_diff"] = df_combined.full_prob - df_combined.base_prob

In [12]:
sorted_diffs = df_combined.sort_values("full_lora_diff")
df_selected = sorted_diffs[["in_tokens", "context_len", "token_in_context", "uniq_ctxt_tkns_count", "prev_token", "curr_token", "pmi", "curr_token_freq", "prev_token_freq", "pair_token_freq", "lora_prob", "full_prob", "base_prob", "full_lora_diff", "lora_base_diff", "full_base_diff"]]
finetune_df = df_selected
finetune_df = finetune_df.dropna().reset_index(drop=True)
df_selected = df_selected[(df_selected["pmi"] > -4.75) & (df_selected["pmi"] < -3.75)]
# find unique pairs of w_{i-1}, w_i
agg_fns = {c: 'first' for c in df_selected.columns if c not in ["curr_token"]}
agg_fns["full_lora_diff"] = "min"
df_uniq = df_selected.groupby(["curr_token"]).agg(agg_fns).reset_index().sort_values("full_lora_diff")

top_100 = df_uniq[:100]
top_100.loc[:, ["pmi", "curr_token_freq", "lora_prob", "full_prob", "base_prob", "full_lora_diff", "lora_base_diff", "full_base_diff"]] = top_100[["pmi", "curr_token_freq", "lora_prob", "full_prob", "base_prob", "full_lora_diff", "lora_base_diff", "full_base_diff"]].round(3)
top_100.to_csv("results/examples_ft_pt/finetune_data_100.csv", index=False)
top_100.head()

Unnamed: 0,curr_token,in_tokens,context_len,token_in_context,uniq_ctxt_tkns_count,prev_token,pmi,curr_token_freq,prev_token_freq,pair_token_freq,lora_prob,full_prob,base_prob,full_lora_diff,lora_base_diff,full_base_diff
1181,Ġ@,Mogadishu University ( MU ) is a non,12,0,12,Ġnon,-4.023,5438,55,79,0.94,0.022,0.0,-0.918,0.94,0.022
2,.,Del Toso was a 4 point wheelchair basketball ...,113,1,72,Ġ@,-3.831,1822,5438,3170,0.929,0.064,0.014,-0.865,0.915,0.05
9079,Ġold,18 @-@ year @-@,8,0,5,@,-4.215,101,5363,118,0.909,0.051,0.001,-0.858,0.909,0.05
64,Hg,A disturbance in the ITCZ developed into a tr...,99,0,76,Ġin,-4.198,12,9025,24,0.956,0.183,0.018,-0.773,0.938,0.165
11609,Ġtrack,Lost Horizons is the second studio album from...,77,0,63,Ġout,-3.966,74,435,9,0.82,0.053,0.016,-0.767,0.804,0.037


In [14]:
finetune_df.to_csv("results/data/finetune_data_probs.csv", index=False)
# finetune_df = pd.read_csv("results/data/finetune_data_probs.csv")
# finetune_df.columns

## Pretraining examples

In [15]:
lora_path = "../results/pythia-1.4b/lora/r_16/lr_2e-4/early_stopping/pretraining/tkn_freq_probs_best.csv"
lora_df = pd.read_csv(lora_path)
full_path = "../results/pythia-1.4b/full-ft/lr_2e-6/early_stopping/pretraining/tkn_freq_probs_best.csv"
full_df = pd.read_csv(full_path)
base_path = "../results/pythia-1.4b/base_model/pretraining/tkn_freq_probs_base.csv"
base_df = pd.read_csv(base_path)

In [16]:
lora_df["in_token_ids"]

0                          tensor([1413], dtype=torch.int32)
1                    tensor([1413,   27], dtype=torch.int32)
2              tensor([1413,   27,   49], dtype=torch.int32)
3          tensor([1413,   27,   49,  363], dtype=torch.i...
4          tensor([1413,   27,   49,  363, 2721], dtype=t...
                                 ...                        
2097148    tensor([  273, 32212,   267,  6943, 18334,  11...
2097149    tensor([32212,   267,  6943, 18334,  1119, 313...
2097150    tensor([  267,  6943, 18334,  1119, 31388,    ...
2097151    tensor([ 6943, 18334,  1119, 31388,    71,   5...
2097152    tensor([18334,  1119, 31388,    71,   579,   9...
Name: in_token_ids, Length: 2097153, dtype: object

In [42]:
lora_df = lora_df.dropna().reset_index(drop=True)
cp = ContextProcessor(lora_df)
lora_df["context_len"]= cp.get_context_len()
lora_df["token_in_context"]= cp.is_token_in_context()
lora_df["uniq_ctxt_tkns_count"]= cp.uniq_ctxt_tkns_count()

In [43]:
full_df = full_df.dropna().reset_index(drop=True)
cp = ContextProcessor(full_df)
full_df["context_len"]= cp.get_context_len()
full_df["token_in_context"]= cp.is_token_in_context()
full_df["uniq_ctxt_tkns_count"]= cp.uniq_ctxt_tkns_count()

 53%|█████▎    | 1111028/2096781 [03:07<02:40, 6136.78it/s]IOStream.flush timed out
100%|██████████| 2096781/2096781 [06:02<00:00, 5792.12it/s] 


In [44]:
lora_df = lora_df.rename(columns={"curr_token_prob": "lora_prob"})
full_df = full_df.rename(columns={"curr_token_prob": "full_prob"})
base_df = base_df.rename(columns={"curr_token_prob": "base_prob"})
common_cols = set(lora_df.columns).intersection(set(full_df.columns))
full_df.drop(common_cols, axis=1, inplace=True)
common_cols = common_cols.intersection(set(base_df.columns))
base_df.drop(common_cols, axis=1, inplace=True)
df_combined = pd.concat([lora_df, full_df, base_df], axis=1)
df_combined["full_lora_diff"] = df_combined.full_prob - df_combined.lora_prob
df_combined["lora_base_diff"] = df_combined.lora_prob - df_combined.base_prob
df_combined["full_base_diff"] = df_combined.full_prob - df_combined.base_prob

In [45]:
sorted_diffs = df_combined.sort_values("full_lora_diff", ascending=False)
df_selected = sorted_diffs[["in_tokens", "context_len", "token_in_context", "uniq_ctxt_tkns_count", "prev_token", "curr_token", "pmi", "curr_token_freq", "prev_token_freq", "pair_token_freq", "lora_prob", "full_prob", "base_prob", "full_lora_diff", "lora_base_diff", "full_base_diff"]]
pretrain_df = df_selected
pretrain_df = pretrain_df.dropna().reset_index(drop=True)

In [46]:
df_selected = pretrain_df[(pretrain_df["pmi"] > -5) & (pretrain_df["pmi"] < -4)]

# find unique pairs of w_{i-1}, w_i
agg_fns = {c: 'first' for c in pretrain_df.columns if c not in ["curr_token"]}
agg_fns["full_lora_diff"] = "max"
df_uniq = pretrain_df.groupby(["curr_token"]).agg(agg_fns).reset_index().sort_values("full_lora_diff", ascending=False)

top_100 = df_uniq[:100]

top_100.loc[:, ["pmi", "curr_token_freq", "lora_prob", "full_prob", "base_prob", "full_lora_diff", "lora_base_diff", "full_base_diff"]] = top_100[["pmi", "curr_token_freq", "lora_prob", "full_prob", "base_prob", "full_lora_diff", "lora_base_diff", "full_base_diff"]].round(3)
top_100.to_csv("results/examples_ft_pt/pretrain_data_100.csv", index=False)
top_100.head()

Unnamed: 0,curr_token,in_tokens,context_len,token_in_context,uniq_ctxt_tkns_count,prev_token,pmi,curr_token_freq,prev_token_freq,pair_token_freq,lora_prob,full_prob,base_prob,full_lora_diff,lora_base_diff,full_base_diff
4983,PEC,and no diagonal that accepted 2 inch eyepiece...,129.0,0.0,102.0,Ċ,-4.907,27.0,79130.0,999.0,0.0,1.0,0.949,1.0,-0.949,0.051
566,-,iance occurred when light passed from air to w...,129.0,1.0,74.0,ref,-4.013,22036.0,1195.0,30106.0,0.0,1.0,0.015,1.0,-0.015,0.985
966,107,2) 0.0005 (15) 0.0052 (16) 0.0093 (1...,129.0,0.0,38.0,.,-4.589,36.0,72634.0,1680.0,0.0,1.0,1.0,1.0,-1.0,-0.0
10475,gre,outrun the hare.\n\nIn the night\n\nhis eyes ...,67.0,0.0,44.0,Ċ,-4.6,22.0,79130.0,1106.0,0.0,1.0,0.198,1.0,-0.198,0.802
9333,edge,_Tapiola_\n\nHe is no more dead than Finland ...,129.0,0.0,86.0,Ċ,-4.274,22.0,79130.0,1533.0,0.0,1.0,0.677,1.0,-0.677,0.323


### Merge pretrain/finetune data with finetune/pretrain token frequencies

In [47]:
pt_data_stats = pretrain_df[["prev_token", "curr_token", "curr_token_freq", "prev_token_freq", "pair_token_freq", "pmi"]]
ft_data_stats = finetune_df[["prev_token", "curr_token", "curr_token_freq", "prev_token_freq", "pair_token_freq", "pmi"]]

pt_data_stats = pt_data_stats.add_prefix('pt_').drop_duplicates().reset_index(drop=True)
pt_data_stats = pt_data_stats.rename(columns={"pt_prev_token": "prev_token", "pt_curr_token": "curr_token"})
ft_data_stats = ft_data_stats.add_prefix('ft_').drop_duplicates().reset_index(drop=True)
ft_data_stats = ft_data_stats.rename(columns={"ft_prev_token": "prev_token", "ft_curr_token": "curr_token"})

In [48]:
pt_pairs = pt_data_stats[["curr_token", "prev_token", "pt_pmi"]]
ft_pairs = ft_data_stats[["curr_token", "prev_token", "ft_pmi"]]

common_pairs = pd.merge(pt_pairs, ft_pairs, on=["curr_token", "prev_token"])

In [49]:
pt_token_freq = pt_data_stats[["curr_token", "pt_curr_token_freq"]]
pt_token_freq = pt_token_freq.drop_duplicates().reset_index(drop=True)
finetune_df_merged = pd.merge(finetune_df, pt_token_freq, on=["curr_token"], how="left")
finetune_df_merged["pt_curr_token_freq"] = finetune_df_merged["pt_curr_token_freq"].fillna(0)
finetune_df_merged = finetune_df_merged.dropna().reset_index(drop=True)


ft_token_freq = ft_data_stats[["curr_token", "ft_curr_token_freq"]]
ft_token_freq = ft_token_freq.drop_duplicates().reset_index(drop=True)
pretrain_df_merged = pd.merge(pretrain_df, ft_token_freq, on=["curr_token"], how="left")
pretrain_df_merged["pt_curr_token_freq"] = pretrain_df_merged["ft_curr_token_freq"].fillna(0)
pretrain_df_merged = pretrain_df_merged.dropna().reset_index(drop=True)

In [50]:
finetune_df_merged.to_csv("results/data/finetune_data_probs.csv", index=False)
pretrain_df_merged.to_csv("results/data/pretrain_data_probs.csv", index=False)