In [None]:
import nltk
import numpy as np
import os
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

In [None]:
def parse_metrics_filename(filename):    
    name_parts = filename.split("_")
    
    # Model size
    model_sizes = ["gpt2-large", "gpt2-medium", "gpt2-xl", "gpt2"]
    model_size = None
    for size in model_sizes:
        if size in filename:
            model_size = size
            break
            
    if "not_finetuned" in filename:
        return (model_size, None, None, None, None)
            
    if model_size is None:
        print(f"Cannot find model size in {filename}")
        return
    
    # Learning rate
    lr = None
    if "lr" in filename:
        for part in name_parts:
            if "lr" in part:
                lr = float(part[2:])
    else:
        for part in name_parts:
            if "e-" in part:
                lr = float(part)
    
    if lr is None:
        print(f"Cannot find learning rate in {filename}")
        return
    
    # Unfrozen block
    blk = None
    if "blk" in filename:
        for part in name_parts:
            if "blk" in part:
                blk = float(part[3:])
    else:
        blk = 1.0
        
    # Check if head is frozen
    head_frozen = False
    if blk < 1.0:
        head_frozen = "without_head" in filename
        
    # Epoch
    epoch = None
    if "checkpoint" in filename:
        for part in name_parts:
            if "checkpoint" in part:
                epoch = int(int(part[10:].split(".")[0])/5888)
    elif "epoch" in filename:
        epoch = int(name_parts[-1].split(".")[0])
    
    if epoch is None:
        print(f"Cannot find epoch in {filename}")
        return
    
    return (model_size, lr, blk, epoch, head_frozen)

def process_file(filename, force=False):
    if not force:
        for decode in ["top-p", "temperature", "beam_search"]:
            if decode in filename:
                return
        
    res = parse_metrics_filename(filename)
    if res is None:
        return None
        
    model_size, lr, blk, epoch, head_frozen = res
    df = pd.read_csv(f"{metrics_path}/{filename}")
    
    similarity_score_stats = dict(df["similarity_score"].describe().drop(["count",  "max"]))
    exact_match_stats = dict(df["exact_match"].describe().drop(["count", "min", "25%", "50%", "75%"]))
    
    exact_match_count = dict(df["exact_match"].value_counts().drop(0).sort_index(ascending=False).cumsum()[::-1])
    exact_match_count[0] = len(df.index) - exact_match_count[1]
    
    similarity_score_stats = {f'{k}_similarity_score': v for k, v in similarity_score_stats.items()}
    exact_match_stats = {f'{k}_exact_match': v for k, v in exact_match_stats.items()}
    exact_match_count = {f'exact_match_{k}': v for k, v in exact_match_count.items()}
    
    return {
        **{
            "size": model_size,
            "lr": lr,
            "blk": blk,
            "epoch": epoch,
            "head_frozen": head_frozen
        },
        **similarity_score_stats,
        **exact_match_stats,
        **exact_match_count,
    }

def plot_exact_match(df, min_idx, max_idx):
    ticks = range(min_idx, max_idx + 1)
    plot_df = df[min_idx:max_idx+1]
    fig, ax = plt.subplots(2, 1, sharex=True)
    plot_df.plot(marker='*', ax=ax[0], xlabel = "Exact match count", ylabel="Frequency", xticks=ticks)
    np.log(plot_df).plot(marker='*', ax=ax[1], xlabel = "Exact match count", ylabel="Log of Frequency", xticks=ticks)
    ax[1].get_legend().remove()
    fig.set_figwidth(5)
    fig.set_figheight(5)
    fig.tight_layout()
    
    

In [None]:
# Run this if you need to re-generate the metrics

metrics_path = "./metrics"

metrics_files = [f for f in os.listdir(metrics_path) if f.endswith(".csv")]

files_info = [process_file(f) for f in tqdm(metrics_files, position=0, leave=True)]
files_info = [f for f in files_info if f is not None]

df = pd.DataFrame(files_info)
df.to_csv("./agg_metrics.csv", index=False)

In [None]:
df = pd.read_csv("./agg_metrics.csv")

In [None]:
# Analyze model size effect

size_df = df.loc[
    (df["lr"] == 1e-05) & (
        ((df["size"] == "gpt2") & (df["epoch"] == 19)) |
        ((df["size"] == "gpt2-medium") & (df["epoch"] == 9)) |
        ((df["size"] == "gpt2-large") & (df["epoch"] == 4) & (df["blk"] == 1.0)) |
        ((df["size"] == "gpt2-xl") & (df["epoch"] == 2))
    )
]
print(size_df[["size", "mean_similarity_score", "50%_similarity_score", "mean_exact_match", "max_exact_match"]].sort_values("mean_similarity_score").to_string())

In [None]:
size_plot_df = pd.wide_to_long(size_df, "exact_match_", i="size", j="count")["exact_match_"].reset_index(
    level=["size"]).pivot(
    columns = "size", values = "exact_match_")

In [None]:
plot_exact_match(size_plot_df, 1, 15)
plt.savefig("./plots/size.png", bbox_inches='tight')

In [None]:
# Analyze frozen block effect

frozen_df = df.loc[
    ((df["size"] == "gpt2-large") & (df["blk"] == 1.0) & (df["lr"] == 1e-05) & (df["epoch"] == 4) & (df["head_frozen"] == False)) |
    ((df["blk"] == 0.7) & (df["lr"] == 5e-07) & (df["epoch"] == 2) & (df["head_frozen"] == True)) |
    ((df["blk"] == 0.4) & (df["lr"] == 1e-06) & (df["epoch"] == 3) & (df["head_frozen"] == True)) |
    ((df["blk"] == 0.7) & (df["lr"] == 5e-07) & (df["epoch"] == 2) & (df["head_frozen"] == False)) |
    ((df["blk"] == 0.4) & (df["lr"] == 5e-07) & (df["epoch"] == 2) & (df["head_frozen"] == False)) |
    ((df["blk"] == 0.0) & (df["lr"] == 1e-05) & (df["epoch"] == 2) & (df["head_frozen"] == False))
]
print(frozen_df[["blk", "head_frozen", "mean_similarity_score", "50%_similarity_score", "mean_exact_match", "max_exact_match"]].sort_values("mean_similarity_score").to_string())

In [None]:
frozen_plot_df = pd.wide_to_long(frozen_df, "exact_match_", i=["blk", "head_frozen"], j="count")["exact_match_"].reset_index(
    level=["blk", "head_frozen"]).pivot(
    columns = ["blk", "head_frozen"], values = "exact_match_")

In [None]:
plot_exact_match(frozen_plot_df, 1, 15)
plt.savefig("./plots/frozen.png", bbox_inches='tight')

In [None]:
# Analyze different decoding methods

greedy_df = process_file("metrics_model_gen_gpt2-large_5e-05_equal_dataset_50000_epoch_4.csv", force=True)
beam_search_df = process_file("metrics_model_gen_gpt2-large_equal_dataset_50000_beam_search_lr5e-05_epoch_4.csv", force=True)
temp_df = process_file("metrics_model_gen_gpt2-large_equal_dataset_50000_temperature_lr5e-05_epoch_4.csv", force=True)
top_p_df = process_file("metrics_model_gen_gpt2-large_equal_dataset_50000_top-p_lr5e-05_epoch_4.csv", force=True)

greedy_df["method"] = "greedy"
beam_search_df["method"] = "beam search (num_beam=8)"
temp_df["method"] = "random sampling (temp=0.85)"
top_p_df["method"] = "top p sampling (p=0.85)"

decode_df = pd.DataFrame([greedy_df, beam_search_df, temp_df, top_p_df]).fillna(0)

In [None]:
print(decode_df[["method", "mean_similarity_score", "50%_similarity_score", "mean_exact_match", "max_exact_match"]].sort_values("mean_similarity_score").to_string())

In [None]:
decode_plot_df = pd.wide_to_long(decode_df, "exact_match_", i="method", j="count")["exact_match_"].reset_index(
    level=["method"]).pivot(
    columns = "method", values = "exact_match_")

In [None]:
plot_exact_match(decode_plot_df, 1, 20)
plt.savefig("./plots/decode.png", bbox_inches='tight')

In [None]:
# Analyze prompt length

prompt_lengths = [20, 70, 100]
length_df = pd.read_csv("./metrics/metrics_prompt_lengths.csv")

In [None]:
length_df[[f"{length}_similarity_score" for length in prompt_lengths] + [f"{length}_exact_match" for length in prompt_lengths]].mean()

In [None]:
counts = length_df[[f"{length}_exact_match" for length in prompt_lengths]].apply(pd.Series.value_counts).fillna(0)[:0:-1].cumsum()[::-1]

In [None]:
np.log(counts).plot(marker='*')

In [None]:
# Investigate the longest exact match that seems to persist across many different model generations

df = pd.read_csv("./metrics/metrics_model_gen_gpt2-large_5e-05_equal_dataset_50000_epoch_4.csv")
print(list(df.loc[df["exact_match"] == 54]["text"]))
print(list(df.loc[df["exact_match"] == 54]["promptLength70_numBeams1"]))
print(list(df.loc[df["exact_match"] == 54]["original_sentence"]))
print(list(df.loc[df["exact_match"] == 54]["generated_sentence"]))