In [None]:
import os, sys
import argparse
import numpy as np
from tqdm import tqdm
import time, subprocess
from collections import defaultdict as ddict
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib
import pdb
import random
import multiprocessing

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from models.models import build_or_load_gen_model
from evaluator import smooth_bleu
from evaluator.CodeBLEU import calc_code_bleu
from evaluator.bleu import _bleu
from utils.utils import *
from utils.metrics import *
# from utils.configs import *
from utils.replay import Buffer
from models.T5prompt import PromptTuneT5
from dataloaders.generation_loader import CodeXGlueDataModule

%matplotlib widget

In [None]:
### UTILITY FUNCTIONS FOR ARGUMENTS

def add_args(parser):
    parser.add_argument("--task", type=str, default=None,)
    parser.add_argument("--sub_task", type=str, default='')
    parser.add_argument("--stream", type=str, default=None)

    parser.add_argument("--lang", type=str, default='')
    parser.add_argument("--eval_task", type=str, default='')
    parser.add_argument("--model_type", default="codet5", type=str, choices=['roberta', 'bart', 'codet5'])
    parser.add_argument("--add_lang_ids", action='store_true')
    parser.add_argument("--data_num", default=-1, type=int)
    parser.add_argument("--bleu_samples", default=5000, type=int)

    # task specific params
    parser.add_argument('--patience', nargs='+', default=None, type=int, help='Patience for early stopping.')
    parser.add_argument('--num_train_epochs', nargs='+', default=None, type=int, help='Number of epochs')
    parser.add_argument('--learning_rate', nargs='+', default=None, type=float, help='Learning rate')
    parser.add_argument('--max_source_length', nargs='+', default=None, type=int, help='max src len')
    parser.add_argument('--max_target_length', nargs='+', default=None, type=int, help='max tgt len')
    parser.add_argument('--train_batch_size', nargs='+', default=None, type=int, help='Batch size per GPU/CPU for training.')
    parser.add_argument('--eval_batch_size', nargs='+', default=None, type=int, help='Batch size per GPU/CPU for evaluation.')

    ## replay params
    parser.add_argument("--replay", default='res', type=str)
    parser.add_argument("--buffer_size", default=0, type=int)
    parser.add_argument("--buffer_bs", default=8, type=int)
    parser.add_argument("--replay_epoch_end", action='store_true')
    parser.add_argument("--alpha", default=0, type=float, help="Mixing weight for buffer loss.")

    ## Prompting params
    parser.add_argument("--pool_lambda", default=0.1, type=float)
    parser.add_argument("--prompt_loss_type", default='basic', type=str)
    parser.add_argument("--num_prompts_per_task", default=0, type=int)
    parser.add_argument("--prompt_init", default='vocab', type=str)
    parser.add_argument("--prompt_key_init", default='uniform', type=str)
    parser.add_argument("--prompt_lr", default=10, type=float)
    parser.add_argument("--query_key_lr", default=10, type=float)
    parser.add_argument("--query_pooling_mode", default='mean', type=str)
    parser.add_argument("--train_only_prompts", action='store_true')
    parser.add_argument("--io_queries", action='store_true')
    parser.add_argument("--prompt_pool", action='store_true')
    parser.add_argument("--batched_prompts", action='store_true')
    parser.add_argument("--pool_freq_norm", action='store_true')
    parser.add_argument("--pool_freq", action='store_true')
    parser.add_argument("--compute_avg_sim", action='store_true')
    parser.add_argument("--pool_size", default=60, type=int)
    parser.add_argument("--num_pool_prompt_tokens", default=5, type=int)
    parser.add_argument("--uniform_scale", default=1, type=float)
    parser.add_argument("--prompt_projection", action='store_true')
    parser.add_argument("--separate_projection", action='store_true')
    parser.add_argument("--projection_hid_dim", default=512, type=int)
    parser.add_argument("--projection_out_dim", default=512, type=int)
    parser.add_argument("--dropout", default=0.1, type=float)
    parser.add_argument("--projection_plot", default='', type=str)


    ## Directories.
    parser.add_argument("--project_dir", type=str, default='/mnt/efs/people/ptky/project/incremental-learning',)
    parser.add_argument("--data_dir", type=str, default='data',)
    parser.add_argument("--output_dir", default='saved_runs', type=str,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--cache_path", type=str, default='data/cache',)
    parser.add_argument("--summary_dir", type=str, default='saved_runs/logs',)
    parser.add_argument("--res_dir", type=str, default='',)
    parser.add_argument("--res_fn", type=str, default='')

    parser.add_argument("--add_task_prefix", action='store_true', help="Whether to add task prefix for t5 and codet5")
    parser.add_argument("--save_last_checkpoints", action='store_true')
    parser.add_argument("--always_save_model", action='store_true')
    parser.add_argument("--calc_stats", action='store_true')

    # wandb params
    parser.add_argument("--debug", action='store_true')
    parser.add_argument("--name", type=str, default='test')
    parser.add_argument("--project_name", type=str, default='debug')


    ## Huggingface params.
    parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--model_name_or_path", default="Salesforce/codet5-small", type=str, help="Path to pre-trained model: e.g. Salesforce/codet5-small")
    parser.add_argument("--tokenizer_name", default="Salesforce/codet5-small", type=str, help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--load_model_path", default=None, type=str, help="Path to trained model: Should contain the .bin files")

    ## Other parameters
    parser.add_argument("--train_filename", default=None, type=str,
                        help="The train filename. Should contain the .jsonl files for this task.")
    parser.add_argument("--dev_filename", default=None, type=str,
                        help="The dev filename. Should contain the .jsonl files for this task.")
    parser.add_argument("--test_filename", default=None, type=str,
                        help="The test filename. Should contain the .jsonl files for this task.")


    parser.add_argument("--no_train", action='store_true', help="Whether to run eval on the train set.")
    parser.add_argument("--no_eval", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--no_eval_bleu", action='store_true', help="Whether to evaluate bleu on dev set.")
    parser.add_argument("--no_eval_all", action='store_true', help="Whether to run eval on all tasks dev set after each task.")
    parser.add_argument("--no_test", action='store_true', help="Whether to evaluate on test set.")
    parser.add_argument("--full_matrix_eval", action='store_true', help="evaluate on future tasks as well in each epoch.")
    parser.add_argument("--zeroshot", action='store_true', help="Evaluate zeroshot performance on the test set.")
    parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
    parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available")

    parser.add_argument("--adafactor", action='store_true', help="Use adafactor instead of AdamW")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--beam_size", default=5, type=int, help="beam size for beam search")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")

    parser.add_argument("--num_workers", default=os.cpu_count(), type=int, )
    parser.add_argument("--pin_memory", default=True, type=bool, )
    parser.add_argument("--save_steps", default=-1, type=int, )
    parser.add_argument("--num_saves", default=10, type=int, )
    parser.add_argument("--log_steps", default=-1, type=int, )
    parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--eval_steps", default=-1, type=int, help="")
    parser.add_argument("--train_steps", default=-1, type=int, help="")
    parser.add_argument("--warmup_steps", default=100, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument('--seed', type=int, default=1234, help="random seed for initialization")

    args = parser.parse_args("")

    if args.task in ['summarize']:
        args.lang = args.sub_task
    elif args.task in ['refine', 'concode', 'clone']:
        args.lang = 'java'
    elif args.task == 'defect':
        args.lang = 'c'
    elif args.task == 'translate':
        args.lang = 'c_sharp' if args.sub_task == 'java-cs' else 'java'

    args.project_name = f"aws-{args.project_name}"
    args.cpu_cont = multiprocessing.cpu_count()

    args.data_dir = os.path.join(args.project_dir, args.data_dir)
    args.cache_path = os.path.join(args.project_dir, args.cache_path)
    args.output_dir = os.path.join(args.project_dir, args.output_dir)
    os.makedirs(args.cache_path, exist_ok=True)
    os.makedirs(args.output_dir, exist_ok=True)

    i = 0

    if not args.compute_avg_sim:
        while True:
            args.run_output_dir = f"{args.output_dir}/{args.project_name}/{args.name}~try={i}"
            if not os.path.exists(args.run_output_dir):
                os.makedirs(args.run_output_dir)
                args.name = args.name + f"~try={i}"
                break
            i += 1
    else:
        args.run_output_dir = f"{args.output_dir}/{args.project_name}/{args.name}"

    args.log_dir = os.path.join(args.run_output_dir, 'logs')
    args.checkpoint_dir = os.path.join(args.run_output_dir, 'checkpoints')
    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.checkpoint_dir, exist_ok=True)

    return args


def set_dist(args):
    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:
        # Setup for distributed data parallel
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    cpu_cont = multiprocessing.cpu_count()
    args.device = device
    args.cpu_cont = cpu_cont

def get_task_arglist(args, keys, datamodule):
    for key in keys:
        if hasattr(args, key):
            att_value = getattr(args, key)
            new_att_value = []
        else:
            raise ValueError(f"Key {key} is not in args!")

        if att_value is None:
            for task in datamodule.all_tasks:
                key_map = {"num_train_epochs": 'epoch', 'learning_rate': 'lr', 'patience': 'patience', 'max_source_length':'src_len',
                            'max_target_length':'trg_len', 'train_batch_size':'tbs', 'eval_batch_size':'ebs'}
                task_val = datamodule.task_params[task][key_map[key]]
                new_att_value.append(task_val)
        elif len(att_value) == 1:
            new_att_value = att_value * len(datamodule.all_tasks)
        elif len(att_value) == len(datamodule.all_tasks):
            new_att_value = att_value
        setattr(args, key, new_att_value)
    return args

def set_seed(args):
    """set random seed."""
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


In [None]:
def load_checkpoint(model, filepath, criteria):
    file = os.path.join(filepath, f'checkpoints/{criteria}.bin')
    
    if os.path.exists(file):
        loaded_state_dict = torch.load(file)
        for buf in model._buffers.keys():
            try:
                model._buffers[buf] = torch.empty_like(loaded_state_dict[buf])
            except:
                print(f"{buf} not found in stored model.")
        model.load_state_dict(loaded_state_dict, strict=False)
        print(f'LOADED MODEL WEIGHT FROM FILE: {file}')
        return model
    else:
        print('This file doesnt exist')

def get_task_queries(args, dataloader, model, tokenizer, eval_task=None, phase='eval'):
    model.eval()
    task_sims = []
    for batch in tqdm(dataloader, total=len(dataloader), desc="Eval ppl"):
        batch = tuple(t.to(args.device) for t in batch)
        source_ids, target_ids = batch
        targets = target_ids if args.io_queries else None
        # pdb.set_trace()
        source_mask = source_ids.ne(tokenizer.pad_token_id).type(torch.float)
        with torch.no_grad():
            task_id = model.get_task_id(eval_task)
            queries = model.get_inference_stats(input_ids=source_ids, labels=targets, task_id=task_id, phase=phase)
            queries = torch.nn.functional.normalize(queries, p=2.0, dim=1)
            task_sims.append(queries.detach().cpu())
    return torch.vstack(task_sims)



In [None]:
#### INITIALIZE ARGPARSE AND SET PARAMETERS

t0 = time.time()
parser = argparse.ArgumentParser()
args = add_args(parser)

set_dist(args)
seed_everything(args.seed)

args.stream = 'concode_none,translate_java-cs,summarize_ruby,refine_small'
# args.stream = 'concode_none,translate_java-cs,summarize_ruby'
args.prompt_method = 'pool'
args.pool_freq = True
# args.pool_freq_norm = True
# args.num_prompts_per_task = 20
args.eval_batch_size = [80]
args.query_pooling_mode = 'mean'
args.data_num = 5000
args.debug = True
# args.no_keys = True
# args.batched_prompts = True
args.pool_size = 500
args.num_pool_prompt_tokens = 1
args.num_prompts_per_task = 100

key_map = {
    "curr_count": "All",
    "concode_none_count": "CodeGen",
    "translate_java-cs_count": "CodeTrans",
    "summarize_ruby_count": "CodeSumm",
    "refine_small_count": "CodeRef",
}

In [None]:
### INITIALIZE DATALOADER AND MODEL

config, model, tokenizer = build_or_load_gen_model(args)
datamodule = CodeXGlueDataModule(args, tokenizer)
all_tasks = datamodule.all_tasks

task_specific_params = ['num_train_epochs', 'learning_rate', 'patience', 'max_source_length',
                        'max_target_length', 'train_batch_size', 'eval_batch_size']
args = get_task_arglist(args, task_specific_params, datamodule)
datamodule.setup(stage='fit')
train_dataloaders = datamodule.train_dataloader()

if args.num_prompts_per_task > 0 or args.prompt_method == 'pool':
    model = PromptTuneT5(args, model, tokenizer, datamodule)
if args.prompt_method == 'pool':
    model.initialize_prompt_pool(load=True)
model.to(args.device)
print('INITIALIZED DATALOADER AND MODEL')

In [None]:
def get_counts(model, key_map):
    counts = {}
    for k,v in model._buffers.items():
        if 'count' in k:
            counts[key_map[k]] = v.type(torch.int).cpu().numpy()
    return counts

def get_pairwise_intersection(data, noise_perc_thresh, noise_exact_thresh=10):
    tasks = list(data.keys())
    tasks.remove('All')
    print(tasks)
    
    common_union = np.zeros((len(tasks), len(tasks)))
    common_all = np.zeros((len(tasks), len(tasks)))

    for i, task_i in enumerate(tasks):
        for j, task_j in enumerate(tasks):
            arr_i = np.array(data[task_i])
            arr_j = np.array(data[task_j])
            thresh_i = min(int(arr_i.sum() * noise_perc_thresh), noise_exact_thresh)
            thresh_j = min(int(arr_j.sum() * noise_perc_thresh), noise_exact_thresh)
            used_i = np.where(arr_i > thresh_i)[0]
            used_j = np.where(arr_j>thresh_j)[0]
            intersect_over_union = len(np.intersect1d(used_i, used_j)) / len(np.union1d(used_i, used_j))
            intersect_over_all = len(np.intersect1d(used_i, used_j)) / len(arr_i)
            common_union[i,j] = intersect_over_union
            common_all[i,j] = intersect_over_all
    return common_union, common_all

def get_prompt_num_task_participation(data, noise_perc_thresh, noise_exact_thresh=10):
    tasks = list(data.keys())
    tasks.remove('All')
    print(tasks)

    prompt_tasks = np.zeros(len(data[tasks[0]]), dtype=int)
    for task_id, task_name in enumerate(tasks):
        task_counts = data[task_name]
        thresh_i = min(int(task_counts.sum() * noise_perc_thresh), noise_exact_thresh)
        thresh_j = min(int(task_counts.sum() * noise_perc_thresh), noise_exact_thresh)
        thresh = min(thresh_i, thresh_j)
        for pid, pcount in enumerate(task_counts):
            if pcount > thresh:
                prompt_tasks[pid] += 1
    num_keys_in_tasks = np.zeros(len(tasks)+1, dtype=int)
    for i in range(len(tasks)+1):
        num_keys_in_tasks[i] = len(np.where(prompt_tasks == i)[0])
    assert sum(num_keys_in_tasks) == len(prompt_tasks)
    return prompt_tasks, num_keys_in_tasks

# ### GET INITIAL FREQUENCY COUNT STATISTICS 
# run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-tune_lr/pool100_data_keyaggrandom_qlr0.1~try=0'
# model = load_checkpoint(model, run_path, 'best-bleu')
# training_counts = get_counts(model, key_map)
# prompt_tasks, num_keys_in_tasks = get_prompt_num_task_participation(training_counts, 0.01, 10)

In [None]:

### GET QUERIES FOR ALL MODEL CHECKPOINTS

model_types = ['first', 'concode_none', 'translate_java-cs', 'summarize_ruby', 'refine_small', 'best-bleu']

criteria_reps, criteria_keys, criteria_targets = [], [], []
for criteria in model_types:
    # run_path = "/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool60_data_keyaggrandom~try=0"
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool60_data_nokeys_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool60_data_freqnorm_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool60_data_batched_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool300_data_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool100_data_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool100_data_batched_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool100_data_freqnorm_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool100_data_nokeys_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool500_data_keyaggrandom~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool100_data_keyaggrandom_ER~try=0'

    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-tune_lr/pool100_data_keyaggrandom_qlr0.1~try=0'
    # run_path = '/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-tune_plam/pool100_data_keyaggrandom_plam0.001~try=0'
    run_path = "/mnt/efs/people/ptky/project/incremental-learning/saved_runs/aws-init5k_1/pool100_data_fixed_selection~try=0"

    model = load_checkpoint(model, run_path, criteria)
    if criteria == 'best-bleu':
        training_counts = get_counts(model, key_map)
        prompt_tasks, num_keys_in_tasks = get_prompt_num_task_participation(training_counts, 0.01, 10)
        print(f'how many prompts belong to how many tasks!'.upper())
        print(num_keys_in_tasks.tolist())

    task_reps = ddict(list)
    for curr_idx, curr_task in enumerate(all_tasks):
        train_loader = train_dataloaders[curr_task]
        print(f"***** Eval results on Task {curr_task} *****".upper())
        task_reps[curr_idx] = get_task_queries(args, train_loader, model, tokenizer, eval_task=curr_task, phase='init')
        print("\n\n")

    reps, targets = [], []
    for k,v in task_reps.items():
        reps.append(v)
        targets.append(torch.ones(v.shape[0]) * int(k))
    reps = torch.vstack(reps)
    targets = torch.cat(targets).numpy()
    criteria_reps.append(reps)
    criteria_targets.append(targets)
    criteria_keys.append(model.prompt_keys.detach().cpu().numpy())

In [None]:
# ['first', 'concode_none', 'translate_java-cs', 'summarize_ruby', 'refine_small', 'last']
for ii, criteria in enumerate(model_types):
    args.projection_plot = f'pca-3-{criteria}'
    method, dims, name, *_ = args.projection_plot.split('-')
    
    reps = criteria_reps[ii]
    targets = criteria_targets[ii]
    keys_ = criteria_keys[ii]
    keys = keys_ / np.linalg.norm(keys_, ord=2, axis=1, keepdims=True)

    if method =='tsne':
        print('Performing t-sne projection')
        meth = TSNE(n_components=int(dims), verbose=1, random_state=123)
    elif method == 'pca':
        print('Performing PCA projection')
        meth = PCA(n_components=int(dims))
    queries = meth.fit_transform(reps)
    keys_trans = meth.transform(keys)
    out = np.concatenate([queries, keys_trans])
    targets = np.concatenate([targets, np.ones(keys_trans.shape[0]) * (max(targets)+1)])

    if int(dims) == 3:

        Xax = out[:,0]
        Yax = out[:,1]
        Zax = out[:,2]

        cdict = {0:'orange',1:'darkcyan', 2:'burlywood', 3:'cyan', 4:'red'}
        labl = {0:'CodeGen Queries',1:'CodeTrans Queries',2:"CodeSumm Queries", 3:'CodeRef Queries', 4:"Keys"}
        marker = {0:'*',1:'o',2:'v',3:'s',4:'x'}
        alpha = {0:0.2,1:0.2,2:0.2,3:0.2,4:0.7}

        # cdict = {0:'orange',1:'darkcyan', 2:'burlywood', 3:'red'}
        # labl = {0:'CodeGen Queries',1:'CodeTrans Queries',2:"CodeSumm Queries", 3:"Keys"}
        # marker = {0:'*',1:'o',2:'v',3:'x'}
        # alpha = {0:0.2,1:0.2,2:0.2,3:0.7}

        fig = plt.figure(figsize=(7,5))
        ax = fig.add_subplot(111, projection='3d')
        plt.title(f"Task: {criteria}")#\tKeys with {key_init} inititalization")

        fig.patch.set_facecolor('white')
        for l in np.unique(targets):
            ix=np.where(targets==l)
            ax.scatter(Xax[ix], Yax[ix], Zax[ix], c=cdict[l], s=20,
                    label=labl[l], marker=marker[l], alpha=alpha[l])
        # for loop ends
        ax.set_xlabel("First PC", fontsize=14)
        ax.set_ylabel("Second PC", fontsize=14)
        ax.set_zlabel("Third PC", fontsize=14)
        ax.legend()
        plt.show()
        os.makedirs(f"{run_path}/plots/", exist_ok=True)
        plt.savefig(f"{run_path}/plots/{args.projection_plot}.png")
    elif int(dims) == 2:
        df = pd.DataFrame()
        df["y"] = targets
        df["First PC"] = out[:,0]
        df["Second PC"] = out[:,1]
        sns.scatterplot(x="First PC", y="Second PC", hue=df.y.tolist(),
                        palette=sns.color_palette("hls", np.unique(targets).shape[0]),
                        data=df).set(title="CodeT5 queries for task")
        # fig = plot[0].get_figure()
        # plt.show()
        plt.savefig(f"{run_path}/plots/{args.projection_plot}.png")
    else:
        raise ValueError(f"{dims }Dimension not supported.")

In [None]:
def get_df(arr):
    # sim_df = pd.DataFrame(arr, columns=['CodeGen','CodeTrans','CodeSumm','CodeRef'])
    sim_df = pd.DataFrame(arr, columns=['CodeGen','CodeTrans','CodeSumm'])
    sim_df = sim_df * 100
    sim_df = sim_df.round(2)
    # sim_df['tasks'] = ['CodeGen','CodeTrans','CodeSumm','CodeRef']
    sim_df['tasks'] = ['CodeGen','CodeTrans','CodeSumm']
    sim_df.set_index('tasks', inplace=True)
    return sim_df

def plot_heatmap(data, title, savename=None, vmin=None, vmax=None):

    params = {"lines.linewidth": 1.1, "font.size":14, "axes.titlesize":18, 
            "axes.labelsize":20, 'xtick.labelsize':14, 'ytick.labelsize':14,
            'legend.fontsize': 12, 'legend.labelspacing': 0.3, 'legend.handletextpad': 0.2,
            'legend.borderpad': 0.4}
    sns.set(style="ticks", rc=params)

    ax = sns.heatmap(data, annot=True, fmt='.1f', linewidths=.5, cmap="YlGnBu", vmin=vmin, vmax=vmax)
    ax.set(ylabel=r'', xlabel=r'')
    ax.set_title(f'{title}')
    ax.set_aspect('auto')
    if savename is not None:
        plt.savefig(f'./figures/{savename}.pdf', dpi=1000,  bbox_inches='tight')
    else:
        plt.show()
    
def get_intersection_stats(data, noise_perc_thresh, noise_exact_thresh=20):
    # tasks = ['CodeGen','CodeTrans','CodeSumm','CodeRef']
    tasks = ['CodeGen','CodeTrans','CodeSumm']
    common_union = np.zeros((len(tasks), len(tasks)))
    common_all = np.zeros((len(tasks), len(tasks)))

    for i, task_i in enumerate(tasks):
        for j, task_j in enumerate(tasks):
            arr_i = np.array(data[task_i])
            arr_j = np.array(data[task_j])
            thresh_i = min(int(arr_i.sum() * noise_perc_thresh), noise_exact_thresh)
            thresh_j = min(int(arr_j.sum() * noise_perc_thresh), noise_exact_thresh)
            # print(arr_i.sum(), arr_j.sum())
            used_i = np.where(arr_i > thresh_i)[0]
            used_j = np.where(arr_j>thresh_j)[0]
            intersect_over_union = len(np.intersect1d(used_i, used_j)) / len(np.union1d(used_i, used_j))
            intersect_over_all = len(np.intersect1d(used_i, used_j)) / len(arr_i)
            common_union[i,j] = intersect_over_union
            common_all[i,j] = intersect_over_all
    return common_union, common_all

def get_count_df(counts, noise_perc_thresh, noise_exact_thresh):
    (sparse_freq_overlap, dense_freq_overlap) = get_intersection_stats(counts, noise_perc_thresh, noise_exact_thresh)
    sparse_freq_overlap_df = get_df(sparse_freq_overlap)
    dense_freq_overlap_df = get_df(dense_freq_overlap)
    return sparse_freq_overlap, dense_freq_overlap

def heatmap(data, row_labels, col_labels, ax=None,
            cbar_kw={}, cbarlabel="", **kwargs):
    """
    Create a heatmap from a numpy array and two lists of labels.

    Parameters
    ----------
    data
        A 2D numpy array of shape (M, N).
    row_labels
        A list or array of length M with the labels for the rows.
    col_labels
        A list or array of length N with the labels for the columns.
    ax
        A `matplotlib.axes.Axes` instance to which the heatmap is plotted.  If
        not provided, use current axes or create a new one.  Optional.
    cbar_kw
        A dictionary with arguments to `matplotlib.Figure.colorbar`.  Optional.
    cbarlabel
        The label for the colorbar.  Optional.
    **kwargs
        All other arguments are forwarded to `imshow`.
    """

    if not ax:
        ax = plt.gca()

    # Plot the heatmap
    im = ax.imshow(data, **kwargs)

    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    # Show all ticks and label them with the respective list entries.
    ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
    ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
             rotation_mode="anchor")

    # Turn spines off and create white grid.
    ax.spines[:].set_visible(False)

    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar


def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
                     textcolors=("black", "white"),
                     threshold=None, **textkw):
    """
    A function to annotate a heatmap.

    Parameters
    ----------
    im
        The AxesImage to be labeled.
    data
        Data used to annotate.  If None, the image's data is used.  Optional.
    valfmt
        The format of the annotations inside the heatmap.  This should either
        use the string format method, e.g. "$ {x:.2f}", or be a
        `matplotlib.ticker.Formatter`.  Optional.
    textcolors
        A pair of colors.  The first is used for values below a threshold,
        the second for those above.  Optional.
    threshold
        Value in data units according to which the colors from textcolors are
        applied.  If None (the default) uses the middle of the colormap as
        separation.  Optional.
    **kwargs
        All other arguments are forwarded to each call to `text` used to create
        the text labels.
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max())/2.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",
              verticalalignment="center")
    kw.update(textkw)

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

In [None]:
count_dict = {}
# tasks = ['CodeGen','CodeTrans','CodeSumm','CodeRef']
tasks = ['CodeGen','CodeTrans','CodeSumm']


print(f"'Prompt_id' : list(range(60)),")
for k,v in model._buffers.items():
    if 'count' in k:
        print(f"'{key_map[k]}': {v.type(torch.int).tolist()},")
        count_dict[key_map[k]] = v.type(torch.int).tolist()
count_dict['prompt_id'] = range(len(count_dict['All']))

sparse_pool , dense_pool  = get_count_df(count_dict , noise_perc_thresh=0.01, noise_exact_thresh=4)
fig, ax = plt.subplots()
im, cbar = heatmap(dense_pool, tasks, tasks, ax=ax,
                   cmap="YlGn", cbarlabel="Prompt Overlap")
texts = annotate_heatmap(im, valfmt="{x:.2f}")
plt.savefig(f"{run_path}/plots/overlap.png")

# plot_heatmap(dense_pool, title="Training Counts", savename=None)


In [None]:
dense_pool

In [5]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests


url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = ViTFeatureExtractor.from_pretrained('WinKawaks/vit-small-patch16-224')
model = ViTForImageClassification.from_pretrained('WinKawaks/vit-small-patch16-224')


In [6]:
import numpy as np
# model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model_size = sum([np.prod(p.size()) for p in model.parameters()])
"{}M".format(round(model_size / 1e+6))

'22M'