In [None]:
import numpy as np
import json
import os
import matplotlib.pyplot as plt
import matplotlib
from IPython.display import display, HTML
matplotlib.use('Agg')

with open("kl_config.json") as f:
    kl_config = json.load(f)

with open("input_text_tokenized.json") as f:
    tokenized_text = json.load(f)

# Function to load the appropriate .npy file
def load_npy_file(model_name, revision, middle="_", appendix="kl"):
    # Replace '/' with '_' in model name
    model_name_sanitized = model_name.replace('/', '-')
    
    # Construct the filename
    filename = f"{model_name_sanitized}{middle}{revision}{middle}{appendix}.npy"
    
    # Construct the full path to the file
    filepath = os.path.join("results", filename)
    
    # Load the .npy file
    if os.path.exists(filepath):
        data = np.load(filepath)
        return data
    else:
        raise FileNotFoundError(f"The file {filename} does not exist in the 'results' folder.")

# Function to find the index of the model_name and revision combination in the KL data
def find_comparison_index(kl_config, model_name, revision):
    try:
        model_names = kl_config['model_names']
        revisions = kl_config['revisions']
        model_index = model_names.index(model_name)
        revision_index = revisions.index(revision)
        return model_index * len(revisions) + revision_index
    except ValueError:
        return -1

# Function to retrieve the (1, 4171) slice from the data
def get_comparison_data(kl_config, data, target_model_name, target_revision):
    # Find the comparison index
    comparison_index = find_comparison_index(kl_config, target_model_name, target_revision)
    
    # Return the corresponding (1, 4171) data slice
    return data[comparison_index, :]

def color_text(color_array):
    cmap = matplotlib.cm.Blues
    norm = matplotlib.colors.Normalize(vmin=0, vmax=10)
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    for word, color in zip(tokenized_text[1:], color_array):
        normalized_color = norm(color)
        color_hex = matplotlib.colors.rgb2hex(cmap(normalized_color)[:3])
        colored_string += template.format(color_hex, word)
    colored_string += '<hr/>'
    return colored_string

def color_text_with_suprisal(blue_array, green_array, red_array):
    cmap_blue = matplotlib.cm.Blues
    cmap_green = matplotlib.cm.Greens
    cmap_red = matplotlib.cm.Reds

    norm_blue = matplotlib.colors.Normalize(vmin=np.min(blue_array), vmax=np.max(blue_array))
    norm_green = matplotlib.colors.Normalize(vmin=np.min(green_array), vmax=np.max(green_array))
    norm_red = matplotlib.colors.Normalize(vmin=np.min(red_array), vmax=np.max(red_array))

    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''

    for word, blue, green, red in zip(tokenized_text, blue_array, green_array, red_array):
        normalized_blue = norm_blue(blue)
        normalized_green = norm_green(green)
        normalized_red = norm_red(red)

        color_hex_blue = matplotlib.colors.rgb2hex(cmap_blue(normalized_blue)[:3])
        color_hex_green = matplotlib.colors.rgb2hex(cmap_green(normalized_green)[:3])
        color_hex_red = matplotlib.colors.rgb2hex(cmap_red(normalized_red)[:3])

        # Combine the colors (this is a simple average, you might want to use a different method)
        combined_color = (
            (int(color_hex_blue[1:3], 16) + int(color_hex_green[1:3], 16) + int(color_hex_red[1:3], 16)) // 3,
            (int(color_hex_blue[3:5], 16) + int(color_hex_green[3:5], 16) + int(color_hex_red[3:5], 16)) // 3,
            (int(color_hex_blue[5:7], 16) + int(color_hex_green[5:7], 16) + int(color_hex_red[5:7], 16)) // 3
        )
        combined_color_hex = '#{:02x}{:02x}{:02x}'.format(*combined_color)

        colored_string += template.format(combined_color_hex, word)

    colored_string += '<hr/>'
    return colored_string



# The following code is for all the individual plots

In [None]:
 for base_model_name in kl_config['model_names']:
    for base_revision in kl_config['revisions']:
        if (base_model_name == 'EleutherAI/pythia-1.4b-deduped') and (base_revision == 'step143000'):
            suprisals = load_npy_file(base_model_name, base_revision, "-", "surprisal")
            plot_suprisal(tokenized_text[1:], base_model_name, base_revision, suprisals)
            for target_name in kl_config['model_names']:
                for target_revision in kl_config['revisions']:
                    if (target_name == base_model_name) and (base_revision == target_revision):
                        continue
                    else:
                        comparison_data = get_comparison_data(kl_config, load_npy_file(base_model_name, base_revision), target_name, target_revision)
                        plot_kl(tokenized_text[1:], base_model_name, base_revision, target_name, target_revision, comparison_data) # model 1 = P
    
                        if (base_model_name == 'EleutherAI/pythia-1.4b-deduped') and (base_revision == 'step143000'):
                            if (target_revision == 'step143000') or (target_name == 'EleutherAI/pythia-1.4b-deduped'):
                                display(HTML(color_text(comparison_data)))
                                print("-----------------------------")

# The following code is for the average KL comparison (fixed now)

In [None]:
def plot_kl_scores(kl_config, base_model_name, target_name):
    revisions = kl_config["revisions"]
    kl_scores = []

    for revision in revisions:
        kl_data = get_comparison_data(kl_config, load_npy_file(base_model_name, revision), target_name, revision)
        average_kl_score = np.mean(kl_data)
        kl_scores.append(average_kl_score)

    plt.plot(revisions, kl_scores, marker='o', linestyle='-', label=f'{X(base_model_name)} vs {X(target_name)}')

def X(string):
    return string.replace("EleutherAI/pythia-", "").replace("-deduped", "")

def plot_all_kl_scores(kl_config):
    model_pairs = [
        ("EleutherAI/pythia-70m-deduped", "EleutherAI/pythia-410m-deduped"),
        ("EleutherAI/pythia-410m-deduped", "EleutherAI/pythia-70m-deduped"),
        ("EleutherAI/pythia-410m-deduped", "EleutherAI/pythia-1.4b-deduped"),
        ("EleutherAI/pythia-1.4b-deduped", "EleutherAI/pythia-410m-deduped"),
        ("EleutherAI/pythia-70m-deduped", "EleutherAI/pythia-1.4b-deduped"),
        ("EleutherAI/pythia-1.4b-deduped", "EleutherAI/pythia-70m-deduped")
    ]

    plt.figure(figsize=(12, 8))

    for base_model_name, target_name in model_pairs:
        plot_kl_scores(kl_config, base_model_name, target_name)

    plt.xlabel('Revisions')
    plt.ylabel('Average KL Score')
    plt.title('Average KL Score vs Revisions for Different Model Comparisons')
    plt.legend()
    plt.grid(True)
    plt.savefig('./graphics/average_kl_vs_revisions.png')

plot_all_kl_scores(kl_config)

# TODO (02.09)

- [x] Spearman correlation
- [x] Multiple moving averages compared (always both ways)
- [x] Moving average of suprisal, log probability of correct token.
- [x] Verify KL both ways
- [ ] Lookup the prediction tables for large divergences.
- [x] X-Axis accross training, Y: KL lines = medium vs large

# The following code is for the grid plots (verified, P = row)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import glob

def create_kl_grid():
    # Load all .npy files ending with '_kl'
    data = np.zeros((45, 45))
    i = 0
    ticks = []
    for base_model_name in kl_config['model_names']:
        for base_revision in kl_config['revisions']:
        # Load the file and calculate the averages
            loaded_data = np.load(f'./results/{base_model_name.replace("/", "-")}_{base_revision}_kl.npy')
            
            averages = np.mean(loaded_data, axis=1)
        
            # Insert the averages into the data array
            data[i, :i] = averages[:i]
            data[i, i+1:] = averages[i:]
            i += 1
            ticks.append(f"{base_model_name.replace('EleutherAI/pythia-', '').replace('-deduped', '')}-{base_revision}")

    # Set the diagonal to NaN
    np.fill_diagonal(data, 0)
    
    # Get the min and max values for the color scale
    vmin = np.nanmin(data)
    vmax = np.nanmax(data)

    # Adjust the figure size to make cells larger
    fig, ax = plt.subplots(figsize=(15, 15))
    cax = ax.matshow(np.log1p(data), cmap='gray', vmin=np.log1p(vmin), vmax=np.log1p(vmax))
    fig.colorbar(cax)
    plt.xticks(range(45), ticks, rotation=45)
    plt.yticks(range(45), ticks)
    plt.savefig('./graphics/kl_grid.png')
    plt.close('all')

def create_moby_dick_grid():
    # Load all .npy files ending with '_kl'
    data = np.zeros((45, 45))
    i = 0
    ticks = []
    for base_model_name in kl_config['model_names']:
        for base_revision in kl_config['revisions']:
        # Load the file and calculate the averages
            loaded_data = np.load(f'./results/{base_model_name.replace("/", "-")}_{base_revision}_kl.npy')
            
            averages = np.mean(loaded_data[:, 2272:2636], axis=1) - np.mean(loaded_data, axis=1) + 2
        
            # Insert the averages into the data array
            data[i, :i] = averages[:i]
            data[i, i+1:] = averages[i:]
            i += 1
            ticks.append(f"{base_model_name.replace('EleutherAI/pythia-', '').replace('-deduped', '')}-{base_revision}")
    
    # Set the diagonal to NaN
    np.fill_diagonal(data, 0)
    
    # Get the min and max values for the color scale
    vmin = np.nanmin(data)
    vmax = np.nanmax(data)

    # Adjust the figure size to make cells larger
    fig, ax = plt.subplots(figsize=(15, 15))
    cax = ax.matshow(np.log1p(data), cmap='gray', vmin=np.log1p(vmin), vmax=np.log1p(vmax))
    fig.colorbar(cax)

    plt.xticks(range(45), ticks, rotation=45)
    plt.yticks(range(45), ticks)
    plt.savefig('./graphics/moby_grid.png')
    plt.close('all')

def create_code_grid():
    # Load all .npy files ending with '_kl'
    data = np.zeros((45, 45))
    i = 0
    ticks = []
    for base_model_name in kl_config['model_names']:
        for base_revision in kl_config['revisions']:
        # Load the file and calculate the averages
            loaded_data = np.load(f'./results/{base_model_name.replace("/", "-")}_{base_revision}_kl.npy')
            
            averages = np.mean(loaded_data[:, 3500:], axis=1) - np.mean(loaded_data, axis=1) + 2
        
            # Insert the averages into the data array
            data[i, :i] = averages[:i]
            data[i, i+1:] = averages[i:]
            i += 1
            ticks.append(f"{base_model_name.replace('EleutherAI/pythia-', '').replace('-deduped', '')}-{base_revision}")

    # Set the diagonal to NaN
    np.fill_diagonal(data, 0)
    
    # Get the min and max values for the color scale
    vmin = np.nanmin(data)
    vmax = np.nanmax(data)

    # Adjust the figure size to make cells larger
    fig, ax = plt.subplots(figsize=(15, 15))
    cax = ax.matshow(np.log1p(data), cmap='gray', vmin=np.log1p(vmin), vmax=np.log1p(vmax))
    fig.colorbar(cax)

    plt.xticks(range(45), ticks, rotation=45)
    plt.yticks(range(45), ticks)
    plt.savefig('./graphics/code_grid.png')
    plt.close('all')

create_kl_grid()
create_moby_dick_grid()
create_code_grid()

# The following code is for the average KL plots (verified; P fixed, Q change)

In [None]:
# Function to plot average KL values
def plot_average_kl_values(kl_data, model_name, revision):
    kl_averages = np.mean(kl_data, axis=1)
    plt.figure(figsize=(10, 6))
    plt.plot(kl_averages, marker='o')
    plt.title(f'Average KL Values for {model_name} - {revision}')
    plt.xlabel('Comparison Index')
    plt.ylabel('Average KL Value')
    plt.grid(True)
    plt.savefig(f'./graphics/{model_name.replace('EleutherAI/pythia-', '').replace('-deduped', '')}-{revision}-line.png')
    plt.close('all')

# Load and plot data
for base_model_name in kl_config['model_names']:
    for base_revision in kl_config['revisions']:
        # Load the file and calculate the averages
        file_path = f'./results/{base_model_name.replace("/", "-")}_{base_revision}_kl.npy'
        loaded_data = np.load(file_path)
        plot_average_kl_values(loaded_data, base_model_name, base_revision)