In [1]:
import numpy as np
import json
import os
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
matplotlib.use('Agg')

with open("../plotters/deduped_config.json") as f:
    kl_config = json.load(f)

with open("../common/input_text_tokenized.json") as f:
    tokenized_text = json.load(f)

# Function to load the appropriate .npy file
def load_npy_file(model_name, revision, middle="-", appendix="kl"):
    # Replace '/' with '_' in model name
    model_name_sanitized = model_name.replace('/', '-')
    
    # Construct the filename
    filename = f"{model_name_sanitized}{middle}{revision}{middle}{appendix}.npy"
    
    # Construct the full path to the file
    filepath = os.path.join("../results/deduped", filename)
    
    # Load the .npy file
    if os.path.exists(filepath):
        data = np.load(filepath)
        return data
    else:
        raise FileNotFoundError(f"The file {filename} does not exist in the 'results' folder.")

# Function to find the index of the model_name and revision combination in the KL data
def find_comparison_index(kl_config, model_name, revision):
    try:
        model_names = kl_config['model_names']
        revisions = kl_config['revisions']
        model_index = model_names.index(model_name)
        revision_index = revisions.index(revision)
        return model_index * len(revisions) + revision_index
    except ValueError:
        return -1

# Function to retrieve the (1, 4171) slice from the data
def get_comparison_data(kl_config, data, target_model_name, target_revision):
    # Find the comparison index
    comparison_index = find_comparison_index(kl_config, target_model_name, target_revision)
    
    # Return the corresponding (1, 4171) data slice
    return data[comparison_index, :]

def color_text(color_array):
    cmap = matplotlib.cm.Blues
    norm = matplotlib.colors.Normalize(vmin=0, vmax=10)
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    for word, color in zip(tokenized_text[1:], color_array):
        normalized_color = norm(color)
        color_hex = matplotlib.colors.rgb2hex(cmap(normalized_color)[:3])
        colored_string += template.format(color_hex, word)
    colored_string += '<hr/>'
    return colored_string

# The following code is for all the individual plots

In [2]:
x = 0
for base_model_name in kl_config['model_names']:
    for base_revision in kl_config['revisions']:
        if (base_model_name == 'EleutherAI/pythia-12b-deduped') and (base_revision == 'step143000'):
            for target_name in kl_config['model_names']:
                for target_revision in kl_config['revisions']:
                    if (target_name == base_model_name) and (base_revision == target_revision):
                        continue
                    else:
                        comparison_data = get_comparison_data(kl_config, load_npy_file(base_model_name, base_revision), target_name, target_revision)
                        if (base_model_name == 'EleutherAI/pythia-12b-deduped') and (base_revision == 'step143000'):
                            if (target_revision == 'step143000') or (target_name == 'EleutherAI/pythia-12b-deduped'):
                                html_content = color_text(comparison_data)

                                with open(f'../graphics/comparison{x}.html', 'w') as file:
                                    file.write(html_content)
                                x += 1
                                print(f"-----------------------------{x}")

-----------------------------1
-----------------------------2
-----------------------------3
-----------------------------4
-----------------------------5
-----------------------------6
-----------------------------7
-----------------------------8
-----------------------------9
-----------------------------10
-----------------------------11
-----------------------------12
-----------------------------13
-----------------------------14
-----------------------------15
