In [1]:
%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import json
import os

In [2]:
grammatical_tasks = {
    'gpt_sv_agreement': { # (new_col, correct_col, wrong_col). new_col is generally the shared prefix of the two cols
        'simple': [('s', 'ss', 'sp'), ('p', 'pp', 'ps')], 
        'subjrelclause': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'sentcomp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], # this one is a bit weird since the one that agrees is the second noun, order is second noun, first noun, verb
        'shortvpcoord': [('ss', 'sss', 'ssp'), ('pp', 'ppp', 'pps')], # only taking case where first verb correctly agrees
        'pp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausethat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausenothat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
    }, 
    'gpt_anaphora': {
        'simple': [('s', 'ss', 'sp'), ('p', 'pp', 'ps')], 
        'sentcomp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], # this one is a bit weird since the one that agrees is the second noun and psp means second noun is p, first noun is s, anaphor is p
        'objrelclausethat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausenothat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
    }, 
    'bert_sv_agreement': { # (new_col, correct_col, wrong_col). new_col is generally the shared prefix of the two cols
        'simple': [('s', 'ss', 'sp'), ('p', 'pp', 'ps')], 
        'subjrelclause': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'sentcomp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], # this one is a bit weird since the one that agrees is the second noun, order is second noun, first noun, verb
        'shortvpcoord': [('ss', 'sss', 'ssp'), ('pp', 'ppp', 'pps')], # only taking case where first verb correctly agrees
        'pp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausethat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausenothat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
    }, 
    'bert_anaphora': {
        'simple': [('s', 'ss', 'sp'), ('p', 'pp', 'ps')], 
        'sentcomp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], # this one is a bit weird since the one that agrees is the second noun and psp means second noun is p, first noun is s, anaphor is p
        'objrelclausethat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausenothat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
    }, 
    'txl_sv_agreement': { # (new_col, correct_col, wrong_col). new_col is generally the shared prefix of the two cols
        'simple': [('s', 'ss', 'sp'), ('p', 'pp', 'ps')], 
        'subjrelclause': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'sentcomp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], # this one is a bit weird since the one that agrees is the second noun, order is second noun, first noun, verb
        'shortvpcoord': [('ss', 'sss', 'ssp'), ('pp', 'ppp', 'pps')], # only taking case where first verb correctly agrees
        'pp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausethat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausenothat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
    }, 
    'txl_anaphora': {
        'simple': [('s', 'ss', 'sp'), ('p', 'pp', 'ps')], 
        'sentcomp': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], # this one is a bit weird since the one that agrees is the second noun and psp means second noun is p, first noun is s, anaphor is p
        'objrelclausethat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
        'objrelclausenothat': [('ss', 'sss', 'ssp'), ('sp', 'sps', 'spp'), ('ps', 'psp', 'pss'), ('pp', 'ppp', 'pps')], 
    }, 
}

In [3]:
# Load the frequency counts
freqs = {}
with open("../sv_agreement/simple/word_freqs.json.txt", "r") as f: 
    freqs = json.load(f)
webtext_freqs = {}
with open("../sv_agreement/simple/webtext_word_freqs.json.txt", "r") as f: 
    webtext_freqs = json.load(f)

In [4]:
def convert_to_col(col, fine_tuned=False): 
    col_name_split = ['SG' if c=='s' else 'PL' for c in col]
    if fine_tuned: 
        col_name_split.pop(0)
    return '_'.join(col_name_split)

In [8]:
def take_column_differences(df, sent_type, task, fine_tuned_number=None): 
    fine_tuned = fine_tuned_number is not None
    if fine_tuned: 
        fine_tuned_number = fine_tuned_number.lower()
        if fine_tuned_number not in ['sg', 'pl']: 
            raise ValueError('fine_tuned_number in take_column_differences is wrong: ' + str(fine_tuned_number))
    new_df = pd.DataFrame()
    for new_col, correct_col, wrong_col in grammatical_tasks[sent_type][task]: 
        correct_col_name = convert_to_col(correct_col, fine_tuned=fine_tuned)
        wrong_col_name = convert_to_col(wrong_col, fine_tuned=fine_tuned)
        if fine_tuned: 
            if new_col[0] != fine_tuned_number[0]: 
                continue
        new_df[new_col] = df[correct_col_name]-df[wrong_col_name]
    if fine_tuned: 
        new_df['category'] = df['category']
        new_df['model_name'] = df['model_name']
    else: 
        new_df['sent'] = df['sent']
        if sent_type.startswith('gpt'): 
            new_df['s_freq'] = [webtext_freqs[x][2] for x in df['sent']]
            new_df['p_freq'] = [webtext_freqs[x][3] for x in df['sent']]
            new_df['freq'] = [webtext_freqs[x][4] for x in df['sent']]
        else: 
            new_df['s_freq'] = [freqs[x][2] for x in df['sent']]
            new_df['p_freq'] = [freqs[x][3] for x in df['sent']]
            new_df['freq'] = [freqs[x][4] for x in df['sent']]
    new_df = new_df[(new_df.T != 0).all()]
    return new_df

In [9]:
colors = {
    's': 'r', 
    'p': 'b', 
    'ss': 'r', 
    'sp': 'b', 
    'ps': 'c', 
    'pp': 'm',
}

not_plot_cols = {'freq', 'sent', 's_freq', 'p_freq'}

def plot_task(df, sent_type, task, sg_df=None, pl_df=None): 
    plot_one_shots = sg_df is not None
    to_graphs = [[col_name] for col_name in df.columns if col_name not in not_plot_cols]
    if not plot_one_shots: 
        to_graphs.append([col_name for col_name in df.columns if col_name not in not_plot_cols])
    for to_graph in to_graphs: 
        fig = plt.figure(figsize=(12.8,9.6))
        ax = plt.gca()
        ax.set_xscale('log')
        plt.ylabel("Mean of diffs")
        plt.xlabel("Frequency of noun in wikitext-103 training set")
        plt.xlim(1, 10e5)
        plt.title("%s: Freq. of nouns vs mean of dist. of diffs (correct-wrong) (%s)" % (sent_type, task))
        for struct in to_graph: 
            x = df[f'{struct[0]}_freq']
            y = df[struct]

            ax.scatter(x, y, label=struct, s=2, color=colors[struct])
        if plot_one_shots and len(to_graph)==1: 
            colormap = plt.cm.gist_ncar
            if to_graph[0][0] == 's': 
                sub_df = sg_df
            else: # 'p'
                sub_df = pl_df
            colorcycle = [colormap(i) for i in np.linspace(0, 0.9, len(sub_df)+1)]
            for index, row in sub_df.iterrows(): 
                ax.axhline(row[to_graph[0]], label=row['model_name'], c=colorcycle[index], lw=1)
            ax.axhline(df[to_graph[0]].mean(), label='base_mean', c=colorcycle[len(sub_df)], lw=1)

            plt.legend()

            filename = '../one_shot/figures/%s/%s/finetune_%s-%s.png' % (sent_type, task, task, '_'.join(to_graph))

            filename_split = filename[3:].split('/')[:-1] # [2:] to get rid of ../, [:-1] because don't want the png filename
            for i in range(1, len(filename_split)+1): 
                subdir = '/'.join(filename_split[:i])
                if not os.path.isdir(subdir): 
                    os.mkdir(subdir)

            plt.savefig(filename)

            continue


        plt.legend()
        
        filename = '../%s/figures/%s/freqs_%s-%s.png' % (sent_type, task, task, '_'.join(to_graph))

        # make directories if needed

        filename_split = filename.split('/')[:-1] # [2:] to get rid of ./, [:-1] because don't want the png filename
        for i in range(1, len(filename_split)+1): 
            subdir = '/'.join(filename_split[:i])
            if not os.path.isdir(subdir): 
                os.mkdir(subdir)
                print('making %s' % subdir)

        plt.savefig(filename)


In [10]:
def run_plotting_pipeline(retake_differences = False, save_differences = False, specific_task=None, one_shot=False):  
    """
    set retake_differences to true to always compute differences (do this if changing which differences to take)
    """
    def run_task(sent_type, task, one_shot=False): 
        differences_filename = '../%s/differences_data/%s.differences.csv' % (sent_type, task)
        if not retake_differences and os.path.exists(differences_filename): 
            df = pd.read_csv(differences_filename)
        else: 
            df = pd.read_csv('../%s/consolidated_data/%s.consolidated.csv' % (sent_type, task))
            df = take_column_differences(df, sent_type, task)
            if save_differences: 
                df.to_csv(differences_filename, index=False)
            
        if one_shot: 
            split_ind = sent_type.index('_')
            model = sent_type[:split_ind]
            sent_type_short = sent_type[split_ind+1:]
            sg_df = pd.read_csv('../compute/one_shot_consolidated_results/%s/sg/%s/%s.csv' % (model, sent_type_short, task))
            pl_df = pd.read_csv('../compute/one_shot_consolidated_results/%s/pl/%s/%s.csv' % (model, sent_type_short, task))
            sg_df = take_column_differences(sg_df, sent_type, task, fine_tuned_number='sg')
            pl_df = take_column_differences(pl_df, sent_type, task, fine_tuned_number='pl')
            # can't merge because trained with different things
            # one_shot_df = sg_df.merge(pl_df, on=['category', 'model_name'])
            # one_shot_df = one_shot_df[[c for c in one_shot_df if c!='category' and c!='model_name']+['category', 'model_name']] 
            plot_task(df, sent_type, task, sg_df=sg_df, pl_df=pl_df)
        else: 
            plot_task(df, sent_type, task)

    if specific_task is not None: 
        run_task(*specific_task, one_shot=one_shot)
    else: 
        for sent_type in grammatical_tasks: 
            for task in grammatical_tasks[sent_type]: 
                run_task(sent_type, task, one_shot=one_shot)

In [11]:
run_plotting_pipeline(retake_differences=True, save_differences=True)
plt.close("all")