In [None]:
import numpy as np

In [None]:
%%time
from metal.mmtl.glue.glue_tasks import create_tasks_and_payloads
task_kwargs = {
    "dl_kwargs": {"batch_size": 16},
    "bert_model": 'bert-large-cased',
    "max_len": 100   
}

task_names = ["RTE"]
from metal.mmtl.metal_model import MetalModel
tasks, payloads = create_tasks_and_payloads(task_names, **task_kwargs)
model = MetalModel(tasks, seed=1, verbose=False)

In [None]:
checkpoints = {
#     "RTE": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_18/10_08_00/checkpoints/RTE/best_model.pth",
#     "RTE:dash_semicolon": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_18/10_08_00/checkpoints/RTE:dash_semicolon/best_model.pth"
#     "RTE": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_17/23_11_40//checkpoints/RTE/best_model.pth",
#     "RTE:dash_semicolon": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_17/23_11_40//checkpoints/RTE:dash_semicolon/best_model.pth",
#     "RTE:more_people": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_17/23_11_40//checkpoints/RTE:more_people/best_model.pth",
#     "RTE": "/dfs/scratch0/chami/metal/logs/slicing/23_11_40/best_model.pth",
#         "RTE": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_19/17_45_14/checkpoints/RTE/best_model.pth"
        "RTE": "/dfs/scratch0/chami/metal/logs/2019_03_17/RTE_23_22_34/best_model.pth"
#     "RTE:dash_semicolon": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_19/17_45_14/checkpoints/RTE:dash_semicolon/best_model.pth"
#     "RTE:more_people": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_19/17_45_14/checkpoints/RTE:more_people/best_model.pth"
# "RTE": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_20/13_16_49/checkpoints/RTE/best_model.pth",
# "RTE:dash_semicolon": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_20/13_16_49/checkpoints/RTE:dash_semicolon/best_model.pth",
# "RTE:more_people": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_20/13_16_49/checkpoints/RTE:more_people/best_model.pth",
# "RTE:entity_secondonly": "/dfs/scratch0/vschen/metal-mmtl/logs/2019_03_20/13_16_49/checkpoints/RTE:entity_secondonly/best_model.pth"
}

In [None]:
from collections import defaultdict

In [None]:
# Create tasks and payloads
slice_dict = defaultdict(list)
task_names = []
Ys_probs_all, Ys_all = {}, {}
for name, model_path in checkpoints.items():
    if ":" in name:
        task_name, slice_name = tuple(name.split(":"))
        slice_dict[task_name].append(slice_name)
        task_kwargs.update({"slice_dict": slice_dict})
    else:
        task_name = name

    task_names.append(task_name)
    tasks, payloads = create_tasks_and_payloads(task_names, **task_kwargs)

    model.load_weights(model_path)
    
    # TODO: change for train/dev/test
    payload = payloads[1]
    print(model.score(payloads[1]))
    
    # eval on single model via predict_with_gold
    Ys, Ys_probs, Ys_preds = model.predict_with_gold(
        payload, [name], return_preds=True
    )
    
    # use "name" = {task_name}:{slice_name} to grab slice-specific predictions
    Ys_probs_all[name] = np.array(Ys_probs[name])
    Ys_all[name] = np.array(Ys[name])
    print(f"Extracting probs for {name}")

In [None]:
def ensemble_avg(task_name, Ys_probs_all, Ys_all):
    
    # only average if the mask is nonzero
    n = len(Ys_probs_all[task_name])
    scores = np.zeros(Ys_probs_all[task_name].shape)
    counts = np.zeros((n,1))
    for k in Ys_probs_all.keys():
        Y_probs = Ys_probs_all[k]
        Y = Ys_all[k].squeeze()
        counts[Y != 0] += 1
        scores += Y_probs
        print(f"Num abstains for {k}: {np.sum(Y == 0)}")

    averaged_preds = scores / counts
    return averaged_preds

def ensembled_masked(task_name, Ys_probs_all, Ys_all):
    """Alwasy defer to the expert slice head.
    NOTE: assumes heads don't overlap.
    """
    
    # only average if the mask is nonzero
    n = len(Ys_probs_all[task_name])
    scores = Ys_probs_all[task_name]
    
    for k in Ys_probs_all.keys():
        Y_probs = Ys_probs_all[k]
        Y = Ys_all[k].squeeze()
        
        if ":" in k:
            scores[Y != 0, :] = Y_probs[Y != 0, :]

    return scores

In [None]:
# probs = ensemble_avg('RTE', Ys_probs_all, Ys_all)
probs = ensembled_masked('RTE', Ys_probs_all, Ys_all)


In [None]:
from metal.mmtl.metal_model import probs_to_preds
preds = probs_to_preds(probs)
labels = list(Ys_all.values())[0]

In [None]:
task_metrics_dict = {}
for task_name in task_names:
    target_metrics = {task_name: None}
    metrics_dict = {}
    scorer = model.task_map[task_name].scorer
    task_metrics_dict[task_name] = scorer.score(
        labels,
        probs,
        preds,
        target_metrics=target_metrics[task_name],
    )
print(task_metrics_dict)

In [None]:
len(labels)