# Miscellaneous functions for dark mode

In [1]:
# Set the tqdm text color to white.

from IPython.display import HTML, display

def set_css_in_cell_output():
    display(HTML('''
        <style>
            .jupyter-widgets {color: #d5d5d5 !important;}
            .widget-label {color: #d5d5d5 !important;}
        </style>
    '''))

get_ipython().events.register('pre_run_cell', set_css_in_cell_output)

# Loading data

In [2]:
models = {
    '13B': [23000, 43000, 63000, 83000, 103000, 123000, 143000],
    '13B_deduped': [23000, 43000, 63000, 83000, 103000, 123000, 143000],
    '6.7B': [23000, 43000, 63000, 83000, 103000, 123000, 143000],
    '6.7B_deduped': [23000, 43000, 63000, 83000, 103000, 123000, 143000],
    '2.7B': [23000, 43000, 63000, 83000, 103000, 123000, 143000],
    '1.3B': [11500, 21500, 31500, 41500, 51500, 61500, 71500],
    '1.3B_deduped': [11500, 21500, 31500, 41500, 51500, 61500, 71500],
    '800M': [23000, 43000, 63000, 83000, 103000, 123000, 143000],
    '800M_deduped': [23000, 43000, 63000, 83000, 103000, 123000, 143000],
    '350M': [11500, 21500, 31500, 41500, 51500, 61500, 71500],
    '350M_deduped': [11500, 21500, 31500, 41500, 51500, 61500, 71500],
    '125M': [11500, 21500, 31500, 41500, 51500, 61500, 71500],
    '125M_deduped': [11500, 21500, 31500, 41500, 51500, 61500, 71500]
}
filepath = '/fsx/orz/transformer-memorization'

In [4]:
import os
import pandas as pd
from tqdm.auto import tqdm

tqdm.pandas()

In [5]:
memorization_results = {}
for model, checkpoints in models.items():
    for checkpoint in tqdm(checkpoints, desc=model):
        filename = os.path.join(filepath, f'memorization_results_{model}-{checkpoint}.csv.hdf')
        model_name = f'{model}-{checkpoint}'
        try:
            memorization_results[model_name] = pd.read_hdf(filename, key='memorization')
        except Exception as e:
            csv = pd.read_csv(os.path.join(filepath, f'memorization_results_{model}-{checkpoint}.csv'))
            csv.to_hdf(filename, key='memorization', index=False)
            memorization_results[model_name] = csv
            

13B:   0%|          | 0/7 [00:00<?, ?it/s]

13B_deduped:   0%|          | 0/7 [00:00<?, ?it/s]

6.7B:   0%|          | 0/7 [00:00<?, ?it/s]

6.7B_deduped:   0%|          | 0/7 [00:00<?, ?it/s]

2.7B:   0%|          | 0/7 [00:00<?, ?it/s]

1.3B:   0%|          | 0/7 [00:00<?, ?it/s]

1.3B_deduped:   0%|          | 0/7 [00:00<?, ?it/s]

800M:   0%|          | 0/7 [00:00<?, ?it/s]

800M_deduped:   0%|          | 0/7 [00:00<?, ?it/s]

350M:   0%|          | 0/7 [00:00<?, ?it/s]

350M_deduped:   0%|          | 0/7 [00:00<?, ?it/s]

125M:   0%|          | 0/7 [00:00<?, ?it/s]

125M_deduped:   0%|          | 0/7 [00:00<?, ?it/s]

In [10]:
from matplotlib import pyplot as plt
import matplotlib
import numpy as np

# Simple Multiprocess Manager

In [11]:
class MPManager:
    '''A simple class for starting multiple processes simultaneously
    
    processes: Array[multiprocess.Process] instances
    '''
    
    def __init__(self):
        self.processes = []
    
    def add(self, process):
        process.daemon=False
        self.processes.append(process)
        self.processes[-1].start()
    
    def join(self):
        for process in self.processes:
            process.join()

# Setting Params, styling

In [12]:
size = 25
def update_params():
    params = {
        'legend.fontsize': 'large',
        'figure.figsize': (15,10),
        'axes.labelsize': size,
        'axes.titlesize': size,
        'xtick.labelsize': size*0.75,
        'ytick.labelsize': size*0.75,
        'axes.titlepad': 25,
        'font.family':'sans-serif',
        'font.weight':'bold',
    }
    plt.rcParams.update(params)

# Model Accuracy Linear Regression plot

In [13]:
from itertools import cycle
import multiprocessing as mp
import scipy.stats as stats
from sklearn.metrics import r2_score

def plot_linear_regression_of_model(model, models, lock, percentile=0):
    '''Wrapper Script to plot regression lines between memorization accuracy measured and Sequence Index
    
    `model`: Name of model
    `models`: Dictionary of list of model names along with their corresponding checkpoints.
    `percentile`: Sample only Indicies with accuracy more than or equal to k percentile accuracy
    
    Utilizes global variable `memorization_results` to get the required evals.
    
    Plots the regression lines with confidence intervals and stores regression metrics in a csv.
    
    Metrics Stored:
        "model": Name of the model plotted
        "checkpoint": Checkpoint of the model plotted
        "most memorized percentile": `percentile` parameter of the current function
        If L(x, p) denotes the line of best fit for model p at Sequence Index x, then we store
        "slope": Defined as [L(x_max, p) - L(0, p)] / x_max
        "variation": Total variation, defined as L(x_max, p) - L(0, p)
        "% change": Percentage change induced by variation, defined as [L(x_max, p) - L(0, p)]/L(0, p)
    '''
    # Initialization and beautification
    cycol = cycle(['violet', 'indigo', 'blue', 'green', 'yellow', 'orange', 'red'])
    update_params()
    fig = matplotlib.figure.Figure()
    ax = fig.subplots()
    res = {}
    print(" ", end="", flush=True) # Bug of jupyter notebook. For more info, refer: https://github.com/tqdm/tqdm/issues/485
    bar = tqdm(
        desc=f'{model}-top {percentile} Linear Regression',
        total=len(models[model])
    )
    
    # Iterating through checkpoints
    for (model_name, evals) in memorization_results.items():
        m, checkpoint = model_name.split('-')
        if(m != model): continue
        checkpoint = int(checkpoint)
        res[checkpoint] = plot_linear_regression_of_checkpoint(model_name, evals, ax, next(cycol), 
            percentile)
        bar.update(1)
    
    # Titling and labeling plot
    fig.suptitle("Memorization Accuracy", fontsize=20)
    fig.supxlabel("Sequence Index", fontsize=20)
    fig.supylabel("Accuracy", fontsize=20)
    ax.legend(loc='lower right')
    
    # Saving plots and metrics
    fig.savefig(f'./plots/{model}-{percentile}_linear_regression.png', facecolor='white')
    lock.acquire()
    with open(f'./results/linear_regression.csv', 'a') as f:
        for checkpoint, scores in res.items():
            f.write(f'{model},{checkpoint},{percentile}')
            for score in scores:
                f.write(f',{score}')
            f.write('\n')
    bar.close()
    lock.release()
    
    
    

def plot_linear_regression_of_checkpoint(model_name, evals, axis, color, precentile = 0):
    '''Generates a linear regression plot and returns a tuple of results
    
    `model_name`: String with model and it's checkpoint
    `evals`: Evaluation results of corresponding model on corresponding checkpoint
    `axis`, `color`: Plotting params
    `percentile`: Sample only Indicies with accuracy more than or equal to k percentile accuracy
    
    Returns:
        tuple(
            slope: slope of regression line
            variation: Total variation, defined above
            % change: percentage change induced by variation, defined above
        )
    '''
    # Top percentile evals
    top_percentile_accuracy = np.percentile(evals['accuracy'],precentile)
    top_percentile_evals = evals[evals['accuracy'] >= top_percentile_accuracy]
    
    # Actually performing regression
    indicies, accuracy = top_percentile_evals['index'], top_percentile_evals['accuracy']
    a,b = np.polyfit(indicies, accuracy,1)
    acc_best_fit = a*indicies+b
    results = (
        a, # Slope
        acc_best_fit.iloc[-1] - acc_best_fit.iloc[0], # Variation
        (acc_best_fit.iloc[-1] - acc_best_fit.iloc[0])*100/acc_best_fit.iloc[0] # % change
    )
    
    # Confidence Interval plots
    y_model = np.polyval([a, b], indicies)
    x_mean = indicies.mean()
    y_mean = accuracy.mean()
    dof = len(indicies) - 2
    t = stats.t.ppf(0.99999, dof) 
    residual = accuracy - y_model
    std_error = (np.sum(residual**2) / dof)**.5
    x_line = np.linspace(np.min(indicies), np.max(indicies), int(1e6))
    y_line = np.polyval([a, b], x_line)
    ci = t * std_error * (1/len(indicies) + (x_line - x_mean)**2 / np.sum((indicies - x_mean)**2))**.5
    
    # Actually plotting data
    axis.plot(x_line, y_line,
        color=color, label=model_name)
    axis.fill_between(x_line, (y_line-ci), (y_line+ci), color=color, alpha=0.2)
    return results

# Bucketed Memorization Plots

In [14]:
from itertools import cycle
cycol = cycle(['violet', 'indigo', 'blue', 'green', 'yellow', 'orange', 'red'])

def plot_bucketed_scores_of_model(model, models, lock, percentile=0, metric='accuracy'):
    '''Wrapper Script to plot bucketed mean plots between memorization accuracy measured and Sequence Index
    
    `model`: Name of model
    `models`: Dictionary of list of model names along with their corresponding checkpoints.
    `percentile`: Sample only Indicies with accuracy more than or equal to k percentile accuracy
    `metric`: Could be one of accuracy or nll_loss. Using Accuracy only for now.
    
    Utilizes global variable `memorization_results` to get the required evals.
    
    Plots the bucketed memorization plots.
    '''
    # Initialization and beautification
    cycol = cycle(['violet', 'indigo', 'blue', 'green', 'yellow', 'orange', 'red'])
    fig = matplotlib.figure.Figure()
    ax = fig.subplots()
    print(" ", end="", flush=True)
    bar = tqdm(
        desc=f'{model}-top {percentile} Bucketed Memorization',
        total=len(models[model])
    )
    
    # Iterating through checkpoints
    for (model_name, evals) in memorization_results.items():
        m, checkpoint = model_name.split('-')
        if(m != model): continue
        checkpoint = int(checkpoint)
        plot_bucketed_scores_of_checkpoint(
            model_name, evals, 
           ax, next(cycol), 
           percentile, 
           metric=metric
        )
        bar.update(1)
    
    # Titling and labeling plot
    fig.suptitle("Bucketed Memorization Accuracy", fontsize=20)
    fig.supxlabel("Sequence Index", fontsize=20)
    fig.supylabel("Accuracy", fontsize=20)
    ax.legend(loc='lower right')
    
    # Saving plot
    fig.savefig(f'./plots/{model}-{percentile}_bucketed_memorization.png', facecolor='white')
    bar.close()

def plot_bucketed_scores_of_checkpoint(
    model_name, evals, axis, color, 
    precentile = 0, 
    bucket_size=235520,
    metric='accuracy'
):
    '''Plots bucketed memorization scores
    
    `model_name`: String with model and it's checkpoint
    `evals`: Evaluation results of corresponding model on corresponding checkpoint
    `axis`, `color`: Plotting params
    `percentile`: Sample only Indicies with accuracy more than or equal to k percentile accuracy
    `bucket_size`: Number of samples to group into one bucket, Score of a bucket is individual scores' mean
    `metric`: Is only accuracy for now
    '''
    # Top percentile evals
    top_percentile_metric = np.percentile(evals[metric],precentile)
    top_percentile_evals = evals[evals[metric] >= top_percentile_metric]
    
    # Bucketing data and calculating errors
    indicies = []
    metric_means = []
    metric_errs = []
    for i in range(0, (len(evals)),bucket_size):
        result = top_percentile_evals[
            (i <= top_percentile_evals['index']) & 
            (top_percentile_evals['index'] < (i+bucket_size))
        ]
        if(len(result) == 0):
            continue
        indicies.append(result['index'].iloc[-1])
        metric_means.append(result[metric].mean())
        metric_errs.append(result[metric].std()/np.sqrt(len(result)))
    
    # Plotting the results
    axis.plot(indicies, metric_means, label=model_name, color=color)

# Parallelizing plot generations

In [15]:
update_params()
manager = MPManager()
lock = mp.Lock()
with open('./results/linear_regression.csv', 'w') as f:
    f.write('model,checkpoint,most memorized percentile,slope,variation,% change\n')
for model in models.keys():
    manager.add(mp.Process(target=plot_bucketed_scores_of_model, args=(model, models, lock, 0)))
    manager.add(mp.Process(target=plot_bucketed_scores_of_model, args=(model, models, lock, 90)))
    manager.add(mp.Process(target=plot_bucketed_scores_of_model, args=(model, models, lock, 99)))
    manager.add(mp.Process(target=plot_linear_regression_of_model, args=(model, models, lock, 0)))
    manager.add(mp.Process(target=plot_linear_regression_of_model, args=(model, models, lock, 99)))
    manager.add(mp.Process(target=plot_linear_regression_of_model, args=(model, models, lock, 90)))
manager.join()

        

13B-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

13B-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

 

13B-top 99 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

 

13B-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

 

13B-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

 

13B-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

 

13B_deduped-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

 

13B_deduped-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

13B_deduped-top 99 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

 

13B_deduped-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

 

13B_deduped-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

13B_deduped-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

  

6.7B-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

 

6.7B-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

 

6.7B-top 99 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

 

6.7B-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

          

6.7B_deduped-top 90 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

6.7B-top 99 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

6.7B-top 90 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

 

6.7B_deduped-top 99 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

6.7B_deduped-top 0 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

   

6.7B_deduped-top 90 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

  

6.7B_deduped-top 0 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

2.7B-top 0 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

2.7B-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

 

2.7B-top 90 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

6.7B_deduped-top 99 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

   

2.7B-top 99 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

2.7B-top 0 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

 

1.3B-top 0 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

 

1.3B_deduped-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

 

1.3B-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

2.7B-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

    

1.3B-top 0 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

 

1.3B-top 99 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

       

1.3B_deduped-top 0 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

  

1.3B_deduped-top 90 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

     

1.3B_deduped-top 99 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

 

1.3B-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

  

1.3B-top 99 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

 

1.3B_deduped-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

1.3B_deduped-top 90 Linear Regression:   0%|          | 0/7 [00:01<?, ?it/s]

 

800M_deduped-top 99 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

  

800M-top 90 Bucketed Memorization:   0%|          | 0/7 [00:01<?, ?it/s]

        

800M-top 99 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

800M-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

800M-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

800M-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

800M_deduped-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

800M-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

350M_deduped-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

800M_deduped-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

800M_deduped-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

800M_deduped-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

350M_deduped-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

125M-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

800M_deduped-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

350M-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

125M-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

350M-top 99 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

350M-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

125M_deduped-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

350M-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

350M-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

350M_deduped-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

125M_deduped-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

350M-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

350M_deduped-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

125M-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

125M-top 99 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

125M-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

125M_deduped-top 0 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

350M_deduped-top 0 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

125M_deduped-top 99 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

125M_deduped-top 90 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

350M_deduped-top 99 Bucketed Memorization:   0%|          | 0/7 [00:00<?, ?it/s]

125M-top 90 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

125M_deduped-top 99 Linear Regression:   0%|          | 0/7 [00:00<?, ?it/s]

In [20]:
import matplotlib
import scipy.signal
from itertools import cycle
from multiprocessing import Process
cycol = cycle(['violet', 'indigo', 'blue', 'green', 'yellow', 'orange', 'red'])


def plot_normalized_cross_correlation_of_model(model, checkpoint, models):
    '''Wrapper Script to plot Normalized Signal Cross Correlation between two checkpoints of a model
    
    `model`: Model name
    `checkpoint`: Checkpoint of the corresponding model to be correlated from
    `evals`: Evaluation results of corresponding model on corresponding checkpoint
    `axis`, `color`: Plotting params
    '''
    # Initialization and beautification
    cycol = cycle(['violet', 'indigo', 'blue', 'green', 'yellow', 'orange', 'red'])
    fig = matplotlib.figure.Figure()
    ax = fig.subplots()
    print(" ", end="", flush=True)
    bar = tqdm(
        desc=f'{model}-{checkpoint} Cross Correlation',
        total=len(models[model][models[model].index(int(checkpoint)):]) - 1
    )
    one_evals = memorization_results[f'{model}-{checkpoint}']
    corr_peaks = {}
    
    # Iterating through checkpoints
    for (model_two, evals) in memorization_results.items():
        m, c = model_two.split('-')
        if(m != model): continue
        if(int(c) <= int(checkpoint)): continue
        name, peak = plot_normalized_cross_correlation(
            f'{model}-{checkpoint}', one_evals, 
            model_two, evals,
            ax, next(cycol)
        )
        corr_peaks[name] = peak
        bar.update(1)
    
    # Titling and labeling plot
    handles, labels = fig.gca().get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: -corr_peaks[t[0]]))
    fig.suptitle("Memorization Accuracy Cross Correlation", fontsize=20)
    fig.supxlabel("Correlation Lags", fontsize=20)
    fig.supylabel("Discrete Linear Normalized Correlation", fontsize=20)
    ax.set_xlim(-1e5, 1e5)
    ax.legend(handles, labels, loc='lower right')
    
    # Saving plot
    fig.savefig(f'./plots/{model}-{checkpoint}_normalized_correlation.png', facecolor='white')
    bar.close()


def normalized_corr(x, y):
    '''Performs Normalized Signal Correlation, Defined as:
    $$ normalized (f * g)(\tau) = \frac{1}{N}\frac{((f-\mu_f) * (g - \mu_g))(\tau)}{\sigma_f\sigma_g} $$
    where 
        a*b(\tau) is full discrete linear cross-correlation of a and b
        
    returns Correlation and It's correlation lags
    '''
    x -= x.mean()
    y -= y.mean()
    corr = scipy.signal.correlate(x, y)
    corr /= (np.sqrt(np.sum(x**2) * np.sum(y**2)))
    corr_lags = scipy.signal.correlation_lags(len(x), len(y))
    return corr, corr_lags

def plot_normalized_cross_correlation(name_one, evals_one, name_two, evals_two, axis, color):
    '''Plots Normalized Cross Correlation
    
    `name_one`, `name_two`: Model names with checkpoints
    `eval_one`, `eval_two`: Evaluation Results
    `color`: Beautifying parameter
    '''
    # Extract Data
    model_one, checkpoint_one = name_one.split('-')
    corr_peaks = {}
    model_two, checkpoint_two = name_two.split('-')
    x = evals_one['accuracy']
    y = evals_two['accuracy'][:len(x)]
    
    # Perform correlation
    corr, corr_lags = normalized_corr(x, y)
    
    # Plotting
    plot_color = next(cycol)
    plot_label = f'{model_one}: {checkpoint_one} and {checkpoint_two}'
    axis.plot(corr_lags, corr, label = plot_label,color=plot_color)
    
    # Indicate peaks
    corr_peaks[plot_label] = np.max(corr)
    axis.plot(corr_lags[np.argmax(corr)], np.max(corr), '_', 
             markersize=32.0, markeredgewidth=2.0, color=plot_color)
    return plot_label, np.max(corr)

In [21]:
manager = MPManager()
tot_processes = 0
for model_one, checkpoints in models.items():
    for checkpoint_one in checkpoints:
        if(checkpoint_one in [143000, 71500]): continue # No models to correlate with 143000 and 71500
            # Since these are final checkpoints
        manager.add(Process(target=plot_normalized_cross_correlation_of_model, args=(model_one, 
            str(checkpoint_one), models)))
        tot_processes+=1
        if(tot_processes >= mp.cpu_count()):
            tot_processes = 0
            manager.join()
    
manager.join()

      

13B-23000 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

  

13B-43000 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

13B-63000 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

13B-83000 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

 

13B-103000 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

 

13B-123000 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

 

13B_deduped-23000 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

 

13B_deduped-43000 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

13B_deduped-63000 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

13B_deduped-83000 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

13B_deduped-123000 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

 

13B_deduped-103000 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

 

6.7B-23000 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

 

6.7B-63000 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

6.7B-43000 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

6.7B-83000 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

 

6.7B-103000 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

 

6.7B-123000 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

 

6.7B_deduped-23000 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

 

6.7B_deduped-43000 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

6.7B_deduped-63000 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

6.7B_deduped-83000 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

 

6.7B_deduped-103000 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

 

6.7B_deduped-123000 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

  

2.7B-23000 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

2.7B-43000 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

2.7B-63000 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

2.7B-83000 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

2.7B-103000 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

  

2.7B-123000 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

1.3B-11500 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

 

1.3B-21500 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

   

1.3B-31500 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

1.3B-41500 Cross Correlation:   0%|          | 0/3 [00:01<?, ?it/s]

 

1.3B-51500 Cross Correlation:   0%|          | 0/2 [00:01<?, ?it/s]

 

1.3B-61500 Cross Correlation:   0%|          | 0/1 [00:01<?, ?it/s]

1.3B_deduped-11500 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

   

1.3B_deduped-21500 Cross Correlation:   0%|          | 0/5 [00:01<?, ?it/s]

1.3B_deduped-31500 Cross Correlation:   0%|          | 0/4 [00:01<?, ?it/s]

 

1.3B_deduped-41500 Cross Correlation:   0%|          | 0/3 [00:01<?, ?it/s]

1.3B_deduped-51500 Cross Correlation:   0%|          | 0/2 [00:01<?, ?it/s]

  

800M-23000 Cross Correlation:   0%|          | 0/6 [00:01<?, ?it/s]

800M-43000 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

1.3B_deduped-61500 Cross Correlation:   0%|          | 0/1 [00:01<?, ?it/s]

 

800M-63000 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

800M-103000 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

800M-83000 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

800M-123000 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

       

800M_deduped-23000 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

 

800M_deduped-43000 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

800M_deduped-63000 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

800M_deduped-83000 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

 

800M_deduped-103000 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

 

800M_deduped-123000 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

 

350M-11500 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

350M-21500 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

350M-31500 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

350M-41500 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

350M-51500 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

 

350M-61500 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

 

350M_deduped-11500 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

 

350M_deduped-21500 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

350M_deduped-31500 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

350M_deduped-41500 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

 

350M_deduped-51500 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

 

350M_deduped-61500 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

 

125M-11500 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

 

125M-21500 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

 

125M-31500 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

125M-41500 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

 

125M-51500 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

 

125M-61500 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

 

125M_deduped-11500 Cross Correlation:   0%|          | 0/6 [00:00<?, ?it/s]

125M_deduped-21500 Cross Correlation:   0%|          | 0/5 [00:00<?, ?it/s]

125M_deduped-31500 Cross Correlation:   0%|          | 0/4 [00:00<?, ?it/s]

 

125M_deduped-41500 Cross Correlation:   0%|          | 0/3 [00:00<?, ?it/s]

125M_deduped-51500 Cross Correlation:   0%|          | 0/2 [00:00<?, ?it/s]

125M_deduped-61500 Cross Correlation:   0%|          | 0/1 [00:00<?, ?it/s]

# Sampled KDE Plot

In [51]:
import seaborn as sns
from scipy.stats import gaussian_kde

def sampled_kde_of_checkpoint(model, checkpoint, models):
    '''Plots KDE distribution on marginal axes and Scatter plot with gaussian kde cmap on joint axis
    
    `model`: Model name
    `checkpoint`: Checkpoint of the corresponding model
    `models`: Dictionary containing model and it's checkpoint names
    '''
    # Theming
    sns.set_theme(font_scale=1.5)
    model_name = f'{model}-{checkpoint}'
    
    # Extract Data
    evals = memorization_results[model_name]
    sample_evals = evals.sample(n=int(1e4))
    sample_evals['Accuracy'] = sample_evals['accuracy']
    sample_evals['Sequence Index'] = sample_evals['index']
    
    # Plotting
    grid = sns.JointGrid(data=sample_evals, x='Sequence Index', y='Accuracy', height = 9, ratio=3, 
                         space=0.1, marginal_ticks=True)
    xy = np.vstack([sample_evals['Accuracy'],sample_evals['Sequence Index']])
    z = gaussian_kde(xy)(xy)
    grid.plot_joint(sns.scatterplot, c=z, cmap='viridis', alpha=0.5)
    grid.plot_marginals(sns.kdeplot, fill=True, alpha=1)
    fig  = grid.figure
    
    # Saving Plot
    fig.suptitle("Memorization Accuracy Scatter & KDE plot", fontsize=20)
    fig.tight_layout()
    fig.subplots_adjust(top=0.90)
    fig.savefig(f'./plots/{model_name}_kde_plot.png', facecolor='#F8F0DF')

In [16]:
import multiprocessing as mp
from multiprocessing import Process, Lock
manager = MPManager()
lock = Lock()
for model_one, checkpoints in models.items():
    for checkpoint_one in checkpoints:
        manager.add(Process(target=sampled_kde_of_checkpoint, args=(model_one, 
            checkpoint_one, models)))
manager.join()

NameError: name 'sampled_kde_of_checkpoint' is not defined