In [44]:
import json 
import numpy as np
import pandas as pd
import altair as alt 

with open("metax/distant_supervision/task_groupings.json") as f:
    task_grouping = json.load(f)

def get_group_name(task_name):
    for group_name in task_grouping:
        if task_name in task_grouping[group_name]:
            return group_name
    return None

data_paths = {
   "Random": "retrieved_data/zs_Random/RET_Random-unsupervised_no-1-16-512-42-{task_name}-6e-6-2_seed={seed}_round#{round_id}.json",
#    "Random_Rerank": "retrieved_data/zs_Random_reranked/RET_Random-unsupervised_rerank-2-16-512-42-{task_name}-6e-6-2_seed={seed}_round#{round_id}.json",
   "SBERT": "retrieved_data/zs_SentenceTransformer/RET_SentenceTransformer-unsupervised_no-1-16-512-42-{task_name}-6e-6-2_seed={seed}_round#{round_id}.json",
#    "SentenceTransformer_Rerank": "retrieved_data/zs_SentenceTransformer_reranked/RET_SentenceTransformer-unsupervised_rerank-2-16-512-42-{task_name}-6e-6-2_seed={seed}_round#{round_id}.json",
   "ReCross(init)": "retrieved_data/zs_BART/RET_BART-unsupervised_no-1-16-512-42-{task_name}-6e-6-2_seed={seed}_round#{round_id}.json",
#    "BART_Rerank": "retrieved_data/zs_BART_reranked/RET_BART-unsupervised_rerank_robert-2-16-512-42-{task_name}-6e-6-2_seed={seed}_round#{round_id}.json",
   "ReCross": "retrieved_data/zs_BART_reranked/RET_BART-unsupervised_rerank_roberta_base-2-16-512-42-{task_name}-6e-6-2_seed={seed}_round#{round_id}.json",
}

task_names = {
    "piqa": "piqa",
    "ai2_arc-ARC-Challenge": "arc",
    "squad_v2": "squad",
    "openbookqa-main": "obqa",
    "super_glue-wic": "wic",
    "super_glue-wsc.fixed": "wsc",
    "winogrande-winogrande_xl": "wngrnd",
    "super_glue-cb": "cb",
    "anli_r3": "anli",
    "hellaswag": "hswag",
}
round_ids = [0,1,2,3,4]

# method_name = "Random"
# method_name = "Random_Rerank"
# method_name = "SentenceTransformer"
# method_name = "SentenceTransformer_Rerank"
# method_name = "BART"
method_name = "ReCross"

all_group_names = sorted(list(task_grouping.keys()))
all_upstream_names = {
		"glue-mrpc": "mrpc",
		"glue-qqp": "qqp",
		"paws_x-en": "paws_x",
		"kilt_tasks-hotpotqa": "hotpotqa",
		"wiki_qa": "wiki_qa",
		"adversarial_qa-dbert": "advqa-dbert",
		"adversarial_qa-dbidaf": "advqa-dbidaf",
		"adversarial_qa-droberta": "advqa-droberta",
		"duorc-SelfRC": "duorc-SRC",
		"duorc-ParaphraseRC": "duorc-PRC",
		"ropes": "ropes",
		"quoref": "quoref",
		"cos_e-v1.11": "cos_e",
		"cosmos_qa": "cosmosQA",
		"dream": "dream",
		"qasc": "qasc",
		"quail": "quail",
		"quartz": "quartz",
		"sciq": "sciq",
		"social_i_qa": "social-iqa",
		"wiki_hop-original": "wiki_hop",
		"wiqa": "wiqa",
		"amazon_polarity": "amazon_pol.",
		"app_reviews": "app_reviews",
		"imdb": "imdb",
		"rotten_tomatoes": "rotten..",
		"yelp_review_full": "yelp_review",
		"common_gen": "common_gen",
		"wiki_bio": "wiki_bio",
		"cnn_dailymail-3.0.0": "cnn_dm",
		"gigaword": "gigaword",
		"multi_news": "multi_news",
		"samsum": "samsum",
		"xsum": "xsum",
		"ag_news": "ag_news",
		"dbpedia_14": "dbpedia",
	}

all_data = {"U":[],  "T":[], "p": []}
for task_name in task_names:
    ret_data = []
    for round_id in round_ids:
        path = data_paths[method_name].replace("{task_name}", task_name).replace("{round_id}", str(round_id))
        if method_name.startswith("Random"):
            path = path.replace("{seed}", str(42+round_id))
        else:
            path = path.replace("{seed}", "42")
        # print(path)
        with open(path) as f:
            ret_data += json.load(f)
        # print(len(ret_data))
    ret_task_names = [item[2].split("|")[0] for item in ret_data]
    group_names = [get_group_name(t) for t in ret_task_names]
    # print(ret_task_names)
    # print(group_names)
    distribution = {}


    # Option 1: for group-level heatmap
    # for g in all_group_names:
    #     # distribution[g] = 
    #     percent = group_names.count(g)/len(group_names)
    #     all_data["U"].append(task_name.replace("super_glue-", "").replace("ai2_",""))
    #     all_data["T"].append(g)
    #     all_data["p"].append(percent) 


    # Option 2: for task-level heatmap
    for t in all_upstream_names:
        # distribution[g] = 
        percent = ret_task_names.count(t)/len(ret_task_names)
        all_data["U"].append(task_names[task_name])
        all_data["T"].append(all_upstream_names[t])
        all_data["p"].append(percent) 

fig_title = method_name
df = pd.DataFrame(all_data) 
scale = alt.Scale(
    domain=[0, 0.3],
    # range=['white', 'black'],
    type='linear'
)

fig = alt.Chart(df).mark_rect().encode(
    x='T:O',
    y='U:O',
    color=alt.Color('p:Q', scale=scale)
).properties(title=fig_title)

fig = fig.properties(width=1000, height=600).configure_axis(
    labelFontSize=25,
    titleFontSize=0, 
).configure_legend(
    titleFontSize=0, labelFontSize=24, 
    # orient='bottom-right', strokeColor='gray',
    # fillColor='#EEEEEE',
    # padding=10,
    # cornerRadius=5,
).configure_title(
    fontSize=50,
    font='Courier',
    anchor='middle',
    orient="top", align="center",
    color='black'
)
fig.show() 

Displaying chart at http://localhost:18804/
