In [1]:
import json
import os
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [2]:
sns.set_style('white')

In [11]:
OUTPUTS_DIR = '../extra/outputs/composers_em_eval'
VIZ_DIR = '../extra/viz'
RUNS = {
    'AbsoluteCalls_FullFT_DeepSeekCoder1p3Base_HP001': 'Absolute Calls with Path Distance sorting',
    'FileLevel_FullFT_DeepSeekCoder1p3Base_HP001': 'File-Level Completion',
    'FuncCallsWithStrip_FullFT_DeepSeekCoder1p3Base_HP001': 'Function Calls with Lines Strip',
    'LowTokensRatioFiltering_FullFT_DeepSeekCoder1p3Base_HP001': 'Low Token Ratio Filtering',
    'MediumLines_FullFT_DeepSeekCoder1p3Base_HP001': 'Low Token Ratio Filtering with Medium Line Length',
    'NearestDeclarations_FullFT_DeepSeekCoder1p3Base_HP001': 'Declarations Only sorted by Path Distance',
    'NoLongFiles_FullFT_DeepSeekCoder1p3Base_HP001': 'Long Files Filtering sorted by Function Calls',
    'NoLongFilesWithHalfMemory_FullFT_DeepSeekCoder1p3Base_HP001': 'Long Files Filtering with Half Memory',
    'PartialMemory_FullFT_DeepSeekCoder1p3Base_HP001': 'Path Distance Ordering with Half Memory',
    'PathDistance_FullFT_DeepSeekCoder1p3Base_HP002': 'Path Distance Ordering',
    'PathDistance_FullFT_DeepSeekCoder1p3Base_HP003': 'Path Distance without Completion Imports',
    'PureFuncCalls_FullFT_DeepSeekCoder1p3Base_HP001': 'Function Calls Ordering',
    'PythonFiles_FullFT_DeepSeekCoder1p3Base_HP001': 'Python Files Full Input Training',
    'RandomDeclarations_FullFT_DeepSeekCoder1p3Base_HP001': 'Declarations Only with Random Ordering',
    'RelativeCalls_FullFT_DeepSeekCoder1p3Base_HP001': 'Function Calls Ratio Ordering',
    'Strip_FullFT_DeepSeekCoder1p3Base_HP001': 'Stripped Filtered Lines sorted by Path Distance',
    'TextFiles_FullFT_DeepSeekCoder1p3Base_HP002': 'Text Files Groups',
}

In [4]:
def plot_heatmap(infile_eval_results: pd.DataFrame,
                 inproject_eval_results: pd.DataFrame,
                 scope: Literal['medium_context', 'large_context', 'huge_context'],
                 ) -> plt.Figure:
    fig, (infile_ax, inproject_ax) = plt.subplots(
        nrows=1,
        ncols=2,
        figsize=(12, 6),
        gridspec_kw=dict(width_ratios=(3, 3)),
    )

    minmax_norm = lambda x: (x - x.min()) / (x.max() - x.min())
    infile_colors = infile_eval_results.apply(minmax_norm, axis=1)
    inproject_colors = inproject_eval_results.apply(minmax_norm, axis=1)

    heatmap_kwargs = dict(
        cmap=sns.cm.rocket_r,
        fmt='0.04f',
        cbar=False,
        vmin=0,
        vmax=1,
    )
    sns.heatmap(data=infile_colors,
                ax=infile_ax,
                annot=infile_eval_results,
                yticklabels=True,
                **heatmap_kwargs)
    sns.heatmap(data=inproject_colors,
                ax=inproject_ax,
                annot=inproject_eval_results,
                yticklabels=False,
                **heatmap_kwargs)

    config_names = [
        'initial model\nnative composer',
        'fine-tuned model\nnative composer',
        'fine-tuned model\nbaseline composer',
    ]
    infile_ax.set_xticklabels(config_names)
    inproject_ax.set_xticklabels(config_names)

    infile_ax.xaxis.tick_bottom()
    infile_ax.yaxis.tick_left()
    inproject_ax.xaxis.tick_bottom()

    infile_ax.set_title('infile')
    inproject_ax.set_title('inproject')
    
    fig.suptitle(f'{scope.split("_")[0].title()} Dataset EM', x=0.65)
    fig.tight_layout()
    
    return fig

In [15]:
em_mode = 'exact_match'

for scope in ('medium_context', 'large_context'):
    infile_eval_results = {'1': [], '2': [], '3': []}
    inproject_eval_results = {'1': [], '2': [], '3': []}
    eval_results = {'infile': infile_eval_results, 'inproject': inproject_eval_results}
    
    for run_name in RUNS:
        for config_idx in map(str, range(1, 3 + 1)):
    
            with open(os.path.join(OUTPUTS_DIR, f'{run_name}_{config_idx}.jsonl')) as file:
                eval_result = json.load(file)
    
            for k, v in eval_result['scope'].items():
                if v != scope:
                    continue
    
                category = eval_result['category'][k]
                eval_results[category][config_idx].append(eval_result['scores'][k][em_mode]['mean'])
    
    index = RUNS.values()
    infile_eval_results = pd.DataFrame(infile_eval_results, index=index)
    inproject_eval_results = pd.DataFrame(inproject_eval_results, index=index)

    fig = plot_heatmap(infile_eval_results, inproject_eval_results, scope)
    path = os.path.join(VIZ_DIR, f'composers_em_eval_{scope.split("_")[0]}.png')
    fig.savefig(path, bbox_inches='tight')
    plt.close(fig)