In [31]:
import glob
import re
import json
from scipy.stats import entropy
from scipy.special import softmax
import pandas as pd
import numpy as np

In [51]:
epsilon = 10e-8
def read_and_parse_logits(lines):
    """Reads a file of logits, parses it and generates aggregate statistics.
    Important: don't load all the logits into memory as there are too many.
    """
    output_rows = []
    for idx, line in enumerate(lines):
        if line.strip() != '':
            datum = json.loads(line)
            if idx == 0:
                header = datum
            else:
                # additional calculations here
                cfg_probs = np.exp(datum['cfg_logits'])
                instruct_tuned_probs = np.exp(datum['second_model_logits'])
                vanilla_probs = np.exp(datum['prompted_logits'])
                unconditional_probs = np.exp(datum['unprompted_logits'])

                output_rows.append({
                    'instruction-tuned entropy': entropy(instruct_tuned_probs),
                    'cfg entropy': entropy(cfg_probs),
                    'kld (instruct || cfg)': entropy(instruct_tuned_probs + epsilon, qk=cfg_probs + epsilon),
                    'kld (instruct || vanilla)': entropy(instruct_tuned_probs + epsilon, qk=vanilla_probs + epsilon),
                })
    output_df = pd.DataFrame(output_rows)
    output_df['cfg-model'] = header['model']
    output_df['instruction-model'] = header['instruction-model']
    output_df['prompt'] = header['prompt']
    return output_df

In [53]:
file = 'results/logit-files__togethercomputer-redpajama-incite-base-3b-v1__38.txt'
with open(file) as f:
    lines = f.read().split('\n')
    one_file_df = read_and_parse_logits(lines)

In [56]:
one_file_df[['kld (instruct || cfg)', 'kld (instruct || vanilla)']]

Unnamed: 0,kld (instruct || cfg),kld (instruct || vanilla)
0,2.470686,1.650708
1,4.953694,3.409296
2,0.514555,0.242561
3,0.068734,0.040349
4,1.059919,0.491377
...,...,...
123,3.458066,2.098992
124,3.295724,1.873743
125,0.311076,0.206777
126,0.222711,0.187748
