# Notebook for visualizing vectorization ablation results

In [1]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
from matplotlib import pyplot as plt
from scipy import stats
import json
import random
from pprint import pprint
import re

from itertools import product, combinations
import importlib

from contextlib import redirect_stderr
import io
import sys
import os
import glob
from pathlib import Path

from sklearn.metrics.pairwise import cosine_similarity
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from tqdm.auto import tqdm

AICOPE_PY_LIB = os.environ.get("AICOPE_PY_LIB")
if AICOPE_PY_LIB and AICOPE_PY_LIB not in sys.path: sys.path.append(AICOPE_PY_LIB)
import importlib
import aicnlp
importlib.reload(aicnlp)

%config Completer.use_jedi = False
PACSIM_DATA = os.environ.get("AICOPE_SCRATCH") + "/pacsim"

In [2]:
from aicnlp import emb_mgr
importlib.reload(emb_mgr)

mgr = emb_mgr.EmbMgr("/workspace/scratch/pacsim/mgrdata")

In [3]:
from aicnlp.validation.agreement import extract_annotations

with open("/home/ubuntu/petr/similarity/validace/evaluation-response-20220907.json", encoding='utf-8') as f:
    valdata = json.load(f)

### Read data from all ablation experiments and calculate correlations

In [4]:
def read_abl_matsim(path):
    matsim = np.load(path)
    fname = Path(path).with_suffix("").name
    matmethod, rest, _ = fname.split("-")
    abldict = {"R": matmethod}
    for param in rest.split("_"):
        abldict[param[0]] = param[1:]
    return abldict, matsim 

def matsim_row(matsim, id_pac, mgr):
    pid = mgr.tr["id_pac", "pid"][str(id_pac)]
    try:
        pac_row = list(matsim["patients"]).index(pid)
    except ValueError:
        return None
    return pac_row

def extract_abl_predictions(annotations, matrix_files, mgr):
    predictions = []
    unknown = set()
    cat = "all"
    for abldict, matsim in tqdm(map(read_abl_matsim, matrix_files), total=len(matrix_files), desc="Predictions"):
        for pivot, proxy in annotations.groupby(["pivot", "proxy"]).groups.keys():
            value = -100

            pivot_index = matsim_row(matsim, pivot, mgr)
            if pivot_index is None:
                value = None
                unknown.add((cat+"_pivot", pivot))

            proxy_index = matsim_row(matsim, proxy, mgr)
            if proxy_index is None:
                value = None
                unknown.add((cat+"_proxy", proxy))

            if value is not None:
                value = matsim["sim"][pivot_index, proxy_index]

            predictions.append({
                "pivot": pivot,
                "proxy": proxy,
                # "cat": cat,
                **abldict,
                "value": value,
                
            })

    return pd.DataFrame.from_records(predictions), unknown


corr_fcn = lambda x, y: stats.kendalltau(x, y, nan_policy='omit', alternative="greater")


def get_correlations_all(mean_annotations, predictions):
    correlations = []
    groupcols = list(predictions.columns)[:-2]
    ann_dict = {pivot: dict(list(rest.groupby("cat"))) for pivot, rest in mean_annotations.groupby("pivot")}
    for modellist, pred in tqdm(predictions.groupby(groupcols, dropna=False)[["proxy", "value"]], desc="Correlations"):
        modeldict = dict(zip(groupcols, modellist))
        # print(modeldict)
        pivot = modeldict["pivot"]
        for cat, ann in ann_dict[pivot].items():
            # print(f"{pivot}, {cat}, {len(ann)}")
            pred = pred.sort_values("proxy")
            ann = ann.sort_values("proxy")
            
            c, p = corr_fcn(pred["value"], ann["value"])
            correlations.append({
                **modeldict,
                "pivot": pivot,
                "cat": cat,
                "value": c,
                "pval": p,
            })
    return pd.DataFrame.from_records(correlations)


def move_to_end(df, to_move):
    for moving in to_move:
        col = df.pop(moving)
        df[moving] = col

def get_ablation_data(valdata, matrix_files, mgr):
    annotations, mean_annotations = extract_annotations(valdata)
    predictions, unknown = extract_abl_predictions(annotations, matrix_files, mgr)
    move_to_end(predictions, ["pivot", "proxy", "value"])
    
    correlations = get_correlations_all(mean_annotations, predictions)
    move_to_end(correlations, ["pivot", "value", "pval"])

    return annotations, mean_annotations, predictions, correlations

matrix_files = sorted(glob.glob(f"{PACSIM_DATA}/3/*AV*"))
annotations, mean_annotations, predictions, correlations = get_ablation_data(valdata, matrix_files, mgr)
correlations

Predictions:   0%|          | 0/444 [00:00<?, ?it/s]

Correlations:   0%|          | 0/4440 [00:00<?, ?it/s]

Unnamed: 0,R,A,d,m,w,e,a,i,f,cat,pivot,value,pval
0,Reds,Vd2v,50,3,3,30,1,,,01,102913,0.400000,0.241667
1,Reds,Vd2v,50,3,3,30,1,,,02,102913,-0.400000,0.883333
2,Reds,Vd2v,50,3,3,30,1,,,03,102913,0.200000,0.408333
3,Reds,Vd2v,50,3,3,30,1,,,04,102913,0.200000,0.408333
4,Reds,Vd2v,50,3,3,30,1,,,05,102913,-0.200000,0.758333
...,...,...,...,...,...,...,...,...,...,...,...,...,...
44395,Rrv2,Vrbc,50,,,,,30,1,06,80561,-0.774597,0.958368
44396,Rrv2,Vrbc,50,,,,,30,1,07,80561,0.600000,0.116667
44397,Rrv2,Vrbc,50,,,,,30,1,08,80561,0.737865,0.038487
44398,Rrv2,Vrbc,50,,,,,30,1,09,80561,0.400000,0.241667


In [5]:
def latexify(text):
    return re.sub(
        r"\\background-color#([0-9a-f]*) \\color#([0-9a-f]*) (\S*)",
        r"\\cellcolor[HTML]{\1}{\\color[HTML]{\2} \3} ",
        text
    )

def print_abl_table(table):
    print(latexify(table.to_latex(sparse_index=False)))

### lsa results

In [6]:
results = correlations.query("A == 'Vlsa'").query("cat in ['07', '08', '09']").pivot_table(index="R", columns=["i", "m"], values="value")
ps = results.style.format("{:.2f}").background_gradient(cmap="RdYlGn", vmin=results.min().min(), vmax=results.max().max())
display(ps)
print_abl_table(ps)

i,10,10,10,20,20,20,30,30,30
m,1,2,3,1,2,3,1,2,3
R,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
Reds,0.36,0.36,0.4,0.4,0.4,0.38,0.4,0.4,0.38
Rmms,0.47,0.47,0.47,0.47,0.47,0.47,0.47,0.47,0.47
Rrv2,0.46,0.46,0.46,0.46,0.46,0.46,0.46,0.46,0.46


\begin{tabular}{lrrrrrrrrr}
i & \multicolumn{3}{r}{10} & \multicolumn{3}{r}{20} & \multicolumn{3}{r}{30} \\
m & 1 & 2 & 3 & 1 & 2 & 3 & 1 & 2 & 3 \\
R &  &  &  &  &  &  &  &  &  \\
Reds & \cellcolor[HTML]{a50026}{\color[HTML]{f1f1f1} 0.36}  & \cellcolor[HTML]{a50026}{\color[HTML]{f1f1f1} 0.36}  & \cellcolor[HTML]{fec877}{\color[HTML]{000000} 0.40}  & \cellcolor[HTML]{fec877}{\color[HTML]{000000} 0.40}  & \cellcolor[HTML]{fec877}{\color[HTML]{000000} 0.40}  & \cellcolor[HTML]{ed5f3c}{\color[HTML]{f1f1f1} 0.38}  & \cellcolor[HTML]{fec877}{\color[HTML]{000000} 0.40}  & \cellcolor[HTML]{fec877}{\color[HTML]{000000} 0.40}  & \cellcolor[HTML]{ed5f3c}{\color[HTML]{f1f1f1} 0.38}  \\
Rmms & \cellcolor[HTML]{006837}{\color[HTML]{f1f1f1} 0.47}  & \cellcolor[HTML]{006837}{\color[HTML]{f1f1f1} 0.47}  & \cellcolor[HTML]{006837}{\color[HTML]{f1f1f1} 0.47}  & \cellcolor[HTML]{006837}{\color[HTML]{f1f1f1} 0.47}  & \cellcolor[HTML]{006837}{\color[HTML]{f1f1f1} 0.47}  & \cellcolor[HTML]{006837}{\color[HT

### rbc results

In [7]:
results = correlations.query("A == 'Vrbc'").query("cat in ['07', '08', '09']").pivot_table(index="R", columns=["f", "i"], values="value")
ps = results.style.format("{:.2f}").background_gradient(cmap="RdYlGn", vmin=results.min().min(), vmax=results.max().max())
display(ps)
print_abl_table(ps)

f,0,1,1,1
i,10,10,20,30
R,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
Reds,0.32,0.36,0.36,0.36
Rmms,0.34,0.5,0.5,0.5
Rrv2,0.33,0.41,0.41,0.41


\begin{tabular}{lrrrr}
f & 0 & \multicolumn{3}{r}{1} \\
i & 10 & 10 & 20 & 30 \\
R &  &  &  &  \\
Reds & \cellcolor[HTML]{a50026}{\color[HTML]{f1f1f1} 0.32}  & \cellcolor[HTML]{f57547}{\color[HTML]{f1f1f1} 0.36}  & \cellcolor[HTML]{f57547}{\color[HTML]{f1f1f1} 0.36}  & \cellcolor[HTML]{f57547}{\color[HTML]{f1f1f1} 0.36}  \\
Rmms & \cellcolor[HTML]{d83128}{\color[HTML]{f1f1f1} 0.34}  & \cellcolor[HTML]{006837}{\color[HTML]{f1f1f1} 0.50}  & \cellcolor[HTML]{006837}{\color[HTML]{f1f1f1} 0.50}  & \cellcolor[HTML]{006837}{\color[HTML]{f1f1f1} 0.50}  \\
Rrv2 & \cellcolor[HTML]{c62027}{\color[HTML]{f1f1f1} 0.33}  & \cellcolor[HTML]{fffab6}{\color[HTML]{000000} 0.41}  & \cellcolor[HTML]{fffab6}{\color[HTML]{000000} 0.41}  & \cellcolor[HTML]{fffab6}{\color[HTML]{000000} 0.41}  \\
\end{tabular}



### d2v results

In [8]:
results = correlations.query("A == 'Vd2v'").query("cat in ['07', '08', '09']").pivot_table(index=["R", "w"], columns=["e", "m"], values="value", aggfunc="mean")
ps = results.style.format("{:.2f}").background_gradient(cmap="RdYlGn", vmin=results.min().min(), vmax=results.max().max())
display(ps)
print_abl_table(ps)

Unnamed: 0_level_0,e,30,30,30,40,40,40,50,50,50
Unnamed: 0_level_1,m,3,5,7,3,5,7,3,5,7
R,w,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
Reds,3,0.28,0.21,0.28,0.3,0.29,0.27,0.28,0.23,0.27
Reds,5,0.19,0.24,0.2,0.24,0.23,0.22,0.28,0.25,0.23
Reds,7,0.16,0.17,0.12,0.19,0.17,0.16,0.21,0.22,0.16
Rmms,3,0.28,0.31,0.31,0.34,0.31,0.32,0.33,0.32,0.36
Rmms,5,0.28,0.27,0.28,0.28,0.29,0.26,0.33,0.31,0.31
Rmms,7,0.24,0.25,0.26,0.23,0.25,0.23,0.26,0.24,0.26
Rrv2,3,0.51,0.51,0.53,0.51,0.51,0.49,0.49,0.5,0.47
Rrv2,5,0.53,0.51,0.49,0.5,0.48,0.48,0.49,0.48,0.48
Rrv2,7,0.51,0.53,0.49,0.49,0.49,0.49,0.48,0.48,0.46


\begin{tabular}{llrrrrrrrrr}
 & e & \multicolumn{3}{r}{30} & \multicolumn{3}{r}{40} & \multicolumn{3}{r}{50} \\
 & m & 3 & 5 & 7 & 3 & 5 & 7 & 3 & 5 & 7 \\
R & w &  &  &  &  &  &  &  &  &  \\
Reds & 3 & \cellcolor[HTML]{feda86}{\color[HTML]{000000} 0.28}  & \cellcolor[HTML]{f67c4a}{\color[HTML]{f1f1f1} 0.21}  & \cellcolor[HTML]{fed683}{\color[HTML]{000000} 0.28}  & \cellcolor[HTML]{feea9b}{\color[HTML]{000000} 0.30}  & \cellcolor[HTML]{fee491}{\color[HTML]{000000} 0.29}  & \cellcolor[HTML]{fec877}{\color[HTML]{000000} 0.27}  & \cellcolor[HTML]{feda86}{\color[HTML]{000000} 0.28}  & \cellcolor[HTML]{fa9656}{\color[HTML]{000000} 0.23}  & \cellcolor[HTML]{fecc7b}{\color[HTML]{000000} 0.27}  \\
Reds & 5 & \cellcolor[HTML]{ec5c3b}{\color[HTML]{f1f1f1} 0.19}  & \cellcolor[HTML]{fb9d59}{\color[HTML]{000000} 0.24}  & \cellcolor[HTML]{f36b42}{\color[HTML]{f1f1f1} 0.20}  & \cellcolor[HTML]{fca55d}{\color[HTML]{000000} 0.24}  & \cellcolor[HTML]{fa9b58}{\color[HTML]{000000} 0.23}  & \cellcolor[HTML