In [1]:
import json
from typing import Dict, List, Tuple, Set
from pathlib import Path
import logging
import pandas as pd
import numpy as np
import ast
from tqdm import tqdm

import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

sys.argv.clear()
sys.argv.append("")

In [2]:
from src.refinement.config import RefinementConfig
config = RefinementConfig()
config.refinement_metric_maxrank = 100
config.max_refinement_metric = 0

In [3]:
refinement_run_id = "403a7e2fd2e8444ab0a084365dae9192"
reference_run_id = "84ebffbcd41f450786e7ae5a6fae8c0b"

In [4]:

pd.set_option('display.max_columns', None) 
pd.set_option('display.max_rows', None) 
pd.set_option('max_colwidth',100) 


# 1 Loading

## 1.1 Load attention weight

In [5]:
def load_attention_weights(run_id):
    attention_path = Path("../"
        + config.mlflow_dir
        + "{run_id}/artifacts/attention.json".format(run_id=run_id)
    )
    if not attention_path.exists():
        logging.debug(
            "No attention file for run {} in local MlFlow dir".format(run_id)
        )
        return {}

    with open(attention_path) as attention_file:
        return json.load(attention_file)["attention_weights"]

In [6]:
attention_base = load_attention_weights(reference_run_id)
attention_comp = load_attention_weights(refinement_run_id)

In [7]:
attention_base

{'996.04': {'996.04': '1.0'},
 '251.5': {'251.5': '1.0'},
 '805.6': {'805.6': '1.0'},
 'V61.42': {'V61.42': '1.0'},
 '600.10': {'600.10': '1.0'},
 '482.0': {'482.0': '1.0'},
 '227.0': {'227.0': '1.0'},
 '718.45': {'718.45': '1.0'},
 '718.56': {'718.56': '1.0'},
 '626.9': {'626.9': '1.0'},
 '008.41': {'008.41': '1.0'},
 '800.15': {'800.15': '1.0'},
 '221.8': {'221.8': '1.0'},
 '761.8': {'761.8': '1.0'},
 '852.40': {'852.40': '1.0'},
 '153.9': {'153.9': '1.0'},
 'E927.8': {'E927.8': '1.0'},
 '070.31': {'070.31': '1.0'},
 '873.49': {'873.49': '1.0'},
 '568.0': {'568.0': '1.0'},
 '553.8': {'553.8': '1.0'},
 '572.0': {'572.0': '1.0'},
 'E959': {'E959': '1.0'},
 '297.9': {'297.9': '1.0'},
 '873.1': {'873.1': '1.0'},
 '944.25': {'944.25': '1.0'},
 '958.93': {'958.93': '1.0'},
 '156.9': {'156.9': '1.0'},
 '227.1': {'227.1': '1.0'},
 '382.01': {'382.01': '1.0'},
 '442.2': {'442.2': '1.0'},
 '751.7': {'751.7': '1.0'},
 '942.14': {'942.14': '1.0'},
 '998.31': {'998.31': '1.0'},
 '459.89': {'459.8

In [8]:
attention_comp

{'996.04': {'996.04': '0.049188036',
  '-1': '9.63784e-06',
  '800-999': '0.111613736',
  '996-999': '0.6663229',
  '996': '0.016291933',
  '996.0': '0.1565738'},
 '251.5': {'251.5': '0.0020265693',
  '-1': '5.203079e-06',
  '249-259': '0.97066355',
  '240-279': '0.027219253',
  '251': '8.542773e-05'},
 '805.6': {'805.6': '0.1731188',
  '-1': '3.53414e-05',
  '800-999': '0.40930614',
  '805-809': '0.12043968',
  '805': '0.29710004'},
 'V61.42': {'V61.42': '0.0027444994',
  '-1': '1.312063e-06',
  'V01-V91': '0.018931095',
  'V60-V69': '0.9666285',
  'V61': '0.009186897',
  'V61.4': '0.0025076328'},
 '600.10': {'600.10': '0.005564471',
  '-1': '0.00012541213',
  '580-629': '0.95750254',
  '600-608': '0.016405165',
  '600': '0.00490035',
  '600.1': '0.0155021055'},
 '482.0': {'482.0': '0.88759786',
  '-1': '5.4857213e-07',
  '460-519': '0.0076541635',
  '480-488': '0.10396203',
  '482': '0.0007853392'},
 '227.0': {'227.0': '0.67153245',
  '-1': '2.6158211e-06',
  '140-239': '0.19584052',

## 1.2 Load input frequency

In [9]:
def load_input_frequency_dict(run_id) -> Dict[str, Dict[str, float]]:
    run_frequency_path = Path("../"
        + config.mlflow_dir
        + "{run_id}/artifacts/train_frequency.csv".format(run_id=run_id)
    )
    if not run_frequency_path.exists():
        logging.debug("No frequency file for run {} in MlFlow dir".format(run_id))
        return {}

    input_frequency_df = pd.read_csv(run_frequency_path).set_index("feature")
    input_frequency_df["relative_frequency"] = input_frequency_df[
        "absolue_frequency"
    ] / sum(input_frequency_df["absolue_frequency"])
    return input_frequency_df.to_dict("index")

In [10]:
train_frequency = load_input_frequency_dict(refinement_run_id)
train_frequency

{'996.04': {'absolue_frequency': 15,
  'relative_frequency': 5.318715140254518e-05},
 '251.5': {'absolue_frequency': 2,
  'relative_frequency': 7.091620187006025e-06},
 '805.6': {'absolue_frequency': 12,
  'relative_frequency': 4.2549721122036144e-05},
 'V61.42': {'absolue_frequency': 0, 'relative_frequency': 0.0},
 '600.10': {'absolue_frequency': 1,
  'relative_frequency': 3.5458100935030123e-06},
 '482.0': {'absolue_frequency': 115,
  'relative_frequency': 0.0004077681607528464},
 '227.0': {'absolue_frequency': 15,
  'relative_frequency': 5.318715140254518e-05},
 '718.45': {'absolue_frequency': 0, 'relative_frequency': 0.0},
 '718.56': {'absolue_frequency': 1,
  'relative_frequency': 3.5458100935030123e-06},
 '626.9': {'absolue_frequency': 1,
  'relative_frequency': 3.5458100935030123e-06},
 '008.41': {'absolue_frequency': 1,
  'relative_frequency': 3.5458100935030123e-06},
 '800.15': {'absolue_frequency': 1,
  'relative_frequency': 3.5458100935030123e-06},
 '221.8': {'absolue_freque

## 1.3 Generate comparison_df

In [11]:
def get_best_rank_of(output: str, predictions_str: str) -> int:
    predictions = ast.literal_eval(predictions_str)
    return len([x for x in predictions if predictions[x] > predictions[output]])

In [12]:
def convert_prediction_df(prediction_df: pd.DataFrame) -> pd.DataFrame:
    prediction_df["input_converted"] = prediction_df["input"].apply(
        lambda x: " -> ".join(
            [
                ", ".join([str(val) for val in sorted(v)])
                for (_, v) in sorted(
                    ast.literal_eval(x).items(), key=lambda y: y[0]
                )
            ]
        )
    )
    prediction_df["inputs"] = prediction_df["input"].apply(
        lambda x: ",".join(
            sorted(
                set(
                    [
                        x
                        for xs in [
                            [str(val) for val in sorted(v)]
                            for (_, v) in sorted(
                                ast.literal_eval(x).items(), key=lambda y: y[0]
                            )
                        ]
                        for x in xs
                    ]
                )
            )
        )
        + ","
    )
    prediction_df["output"] = prediction_df["output"].apply(
        lambda x: ast.literal_eval(x)
    )
    prediction_df = prediction_df.explode("output")
    prediction_df["output_rank"] = prediction_df[["output", "predictions"]].apply(
        lambda x: get_best_rank_of(x[0], x[1]), axis=1
    )
    return prediction_df

In [13]:
def load_prediction_df(run_id) -> pd.DataFrame:
    run_prediction_output_path = Path("../"
        + config.mlflow_dir
        + "{run_id}/artifacts/prediction_output.csv".format(run_id=run_id)
    )
    if not run_prediction_output_path.exists():
        logging.debug(
            "No prediction output file for run {} in local MlFlow dir".format(
                run_id
            )
        )
        return pd.DataFrame()

    prediction_output_df = pd.read_csv(run_prediction_output_path)
    return convert_prediction_df(prediction_output_df)

### 1.3.1 Load reference dataframe with defined function

In [14]:
suffix_base="_base"
suffix_comp="_comp"

prediction_df_base = load_prediction_df(reference_run_id)
prediction_df_base.head()

Unnamed: 0,input,output,predictions,input_converted,inputs,output_rank
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5', '576.1', '276.1', '998.59', '571.49']}",V40-V49,"{'E820-E825': 1.5604852e-05, 'E850-E858': 0.00044317846, '580-589': 0.03815708, '790-796': 0.011...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577.0, 584.5, 998.59","008.45,276.1,276.7,570,571.49,576.1,577.0,584.5,998.59,",8
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5', '576.1', '276.1', '998.59', '571.49']}",780-789,"{'E820-E825': 1.5604852e-05, 'E850-E858': 0.00044317846, '580-589': 0.03815708, '790-796': 0.011...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577.0, 584.5, 998.59","008.45,276.1,276.7,570,571.49,576.1,577.0,584.5,998.59,",2
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5', '576.1', '276.1', '998.59', '571.49']}",430-438,"{'E820-E825': 1.5604852e-05, 'E850-E858': 0.00044317846, '580-589': 0.03815708, '790-796': 0.011...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577.0, 584.5, 998.59","008.45,276.1,276.7,570,571.49,576.1,577.0,584.5,998.59,",52
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5', '576.1', '276.1', '998.59', '571.49']}",284,"{'E820-E825': 1.5604852e-05, 'E850-E858': 0.00044317846, '580-589': 0.03815708, '790-796': 0.011...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577.0, 584.5, 998.59","008.45,276.1,276.7,570,571.49,576.1,577.0,584.5,998.59,",45
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5', '576.1', '276.1', '998.59', '571.49']}",340-349,"{'E820-E825': 1.5604852e-05, 'E850-E858': 0.00044317846, '580-589': 0.03815708, '790-796': 0.011...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577.0, 584.5, 998.59","008.45,276.1,276.7,570,571.49,576.1,577.0,584.5,998.59,",24


### 1.3.2 Load refinement dataframe step by step

In [14]:
# prediction_df_comp = load_prediction_df(refinement_run_id)

# load prediction dataframe
run_prediction_output_path = Path("../"
    + config.mlflow_dir
    + "{run_id}/artifacts/prediction_output.csv".format(run_id=refinement_run_id)
)
prediction_df_comp = pd.read_csv(run_prediction_output_path)
prediction_df_comp.head()

Unnamed: 0,input,output,predictions
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...","['V40-V49', '780-789', '430-438', '284', '340-...","{'E820-E825': 7.797591e-06, 'E850-E858': 0.000..."
1,"{0: ['568.0', '263.9', '410.11', '453.8', '997...","['285', '510-519', '780-789', '430-438', 'V50-...","{'E820-E825': 8.152219e-06, 'E850-E858': 8.334..."
2,"{0: ['244.9', '790.01', '600.0', 'V58.61', '39...","['790-796', '030-041', 'V40-V49', '320-327', '...","{'E820-E825': 1.6887843e-05, 'E850-E858': 0.00..."
3,"{0: ['E879.8', '410.12', '427.1', '428.0', '99...","['780-789', '420-429', '270-279', '410-414']","{'E820-E825': 1.5772624e-05, 'E850-E858': 0.00..."
4,"{0: ['E879.8', '410.12', '427.1', '428.0', '99...","['780-789', '420-429', '410-414', '996-999', '...","{'E820-E825': 2.8320835e-06, 'E850-E858': 0.00..."


In [15]:
# 
prediction_df_comp["input_converted"] = prediction_df_comp["input"].apply(
    lambda x: " -> ".join(
        [
            ", ".join([str(val) for val in sorted(v)])
            for (_, v) in sorted(
                ast.literal_eval(x).items(), key=lambda y: y[0]
            )
        ]
    )
)
prediction_df_comp.head()

Unnamed: 0,input,output,predictions,input_converted
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...","['V40-V49', '780-789', '430-438', '284', '340-...","{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577...."
1,"{0: ['568.0', '263.9', '410.11', '453.8', '997...","['285', '510-519', '780-789', '430-438', 'V50-...","{'E820-E825': 8.152219e-06, 'E850-E858': 8.334...","153.4, 211.3, 263.9, 410.11, 428.0, 453.8, 568..."
2,"{0: ['244.9', '790.01', '600.0', 'V58.61', '39...","['790-796', '030-041', 'V40-V49', '320-327', '...","{'E820-E825': 1.6887843e-05, 'E850-E858': 0.00...","244.9, 396.3, 401.9, 427.31, 600.0, 780.51, 79..."
3,"{0: ['E879.8', '410.12', '427.1', '428.0', '99...","['780-789', '420-429', '270-279', '410-414']","{'E820-E825': 1.5772624e-05, 'E850-E858': 0.00...","287.5, 410.12, 427.1, 427.31, 428.0, 451.84, 9..."
4,"{0: ['E879.8', '410.12', '427.1', '428.0', '99...","['780-789', '420-429', '410-414', '996-999', '...","{'E820-E825': 2.8320835e-06, 'E850-E858': 0.00...","287.5, 410.12, 427.1, 427.31, 428.0, 451.84, 9..."


In [16]:
prediction_df_comp["inputs"] = prediction_df_comp["input"].apply(
    lambda x: ",".join(
        sorted(
            set(
                [
                    x
                    for xs in [
                        [str(val) for val in sorted(v)]
                        for (_, v) in sorted(
                            ast.literal_eval(x).items(), key=lambda y: y[0]
                        )
                    ]
                    for x in xs
                ]
            )
        )
    )
    + ","
)
prediction_df_comp.head()

Unnamed: 0,input,output,predictions,input_converted,inputs
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...","['V40-V49', '780-789', '430-438', '284', '340-...","{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584...."
1,"{0: ['568.0', '263.9', '410.11', '453.8', '997...","['285', '510-519', '780-789', '430-438', 'V50-...","{'E820-E825': 8.152219e-06, 'E850-E858': 8.334...","153.4, 211.3, 263.9, 410.11, 428.0, 453.8, 568...","153.4,211.3,263.9,410.11,428.0,453.8,568.0,997..."
2,"{0: ['244.9', '790.01', '600.0', 'V58.61', '39...","['790-796', '030-041', 'V40-V49', '320-327', '...","{'E820-E825': 1.6887843e-05, 'E850-E858': 0.00...","244.9, 396.3, 401.9, 427.31, 600.0, 780.51, 79...","244.9,396.3,401.9,427.31,600.0,780.51,790.01,V..."
3,"{0: ['E879.8', '410.12', '427.1', '428.0', '99...","['780-789', '420-429', '270-279', '410-414']","{'E820-E825': 1.5772624e-05, 'E850-E858': 0.00...","287.5, 410.12, 427.1, 427.31, 428.0, 451.84, 9...","287.5,410.12,427.1,427.31,428.0,451.84,999.2,E..."
4,"{0: ['E879.8', '410.12', '427.1', '428.0', '99...","['780-789', '420-429', '410-414', '996-999', '...","{'E820-E825': 2.8320835e-06, 'E850-E858': 0.00...","287.5, 410.12, 427.1, 427.31, 428.0, 451.84, 9...","276.8,287.5,410.12,412,414.01,414.8,427.1,427...."


In [17]:
prediction_df_comp["output"] = prediction_df_comp["output"].apply(
    lambda x: ast.literal_eval(x)
)
prediction_df_comp.head()

Unnamed: 0,input,output,predictions,input_converted,inputs
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...","[V40-V49, 780-789, 430-438, 284, 340-349]","{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584...."
1,"{0: ['568.0', '263.9', '410.11', '453.8', '997...","[285, 510-519, 780-789, 430-438, V50-V59, 401-...","{'E820-E825': 8.152219e-06, 'E850-E858': 8.334...","153.4, 211.3, 263.9, 410.11, 428.0, 453.8, 568...","153.4,211.3,263.9,410.11,428.0,453.8,568.0,997..."
2,"{0: ['244.9', '790.01', '600.0', 'V58.61', '39...","[790-796, 030-041, V40-V49, 320-327, 590-599, ...","{'E820-E825': 1.6887843e-05, 'E850-E858': 0.00...","244.9, 396.3, 401.9, 427.31, 600.0, 780.51, 79...","244.9,396.3,401.9,427.31,600.0,780.51,790.01,V..."
3,"{0: ['E879.8', '410.12', '427.1', '428.0', '99...","[780-789, 420-429, 270-279, 410-414]","{'E820-E825': 1.5772624e-05, 'E850-E858': 0.00...","287.5, 410.12, 427.1, 427.31, 428.0, 451.84, 9...","287.5,410.12,427.1,427.31,428.0,451.84,999.2,E..."
4,"{0: ['E879.8', '410.12', '427.1', '428.0', '99...","[780-789, 420-429, 410-414, 996-999, 249-259]","{'E820-E825': 2.8320835e-06, 'E850-E858': 0.00...","287.5, 410.12, 427.1, 427.31, 428.0, 451.84, 9...","276.8,287.5,410.12,412,414.01,414.8,427.1,427...."


In [18]:
prediction_df_comp = prediction_df_comp.explode("output")
prediction_df_comp.head()

Unnamed: 0,input,output,predictions,input_converted,inputs
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",V40-V49,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584...."
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",780-789,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584...."
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",430-438,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584...."
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",284,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584...."
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",340-349,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584...."


In [19]:
prediction_df_comp["output_rank"] = prediction_df_comp[["output", "predictions"]].apply(
    lambda x: get_best_rank_of(x[0], x[1]), axis=1
)
prediction_df_comp.head()

Unnamed: 0,input,output,predictions,input_converted,inputs,output_rank
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",V40-V49,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584....",12
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",780-789,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584....",3
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",430-438,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584....",58
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",284,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584....",39
0,"{0: ['008.45', '577.0', '570', '276.7', '584.5...",340-349,"{'E820-E825': 7.797591e-06, 'E850-E858': 0.000...","008.45, 276.1, 276.7, 570, 571.49, 576.1, 577....","008.45,276.1,276.7,570,571.49,576.1,577.0,584....",34


### 1.3.3 Generate final comparison dataframe

In [20]:
comparison_df = pd.merge(
    prediction_df_base.sort_values(by=["input_converted", "inputs", "output"])
    .reset_index(drop=True)
    .reset_index(drop=False),
    prediction_df_comp.sort_values(by=["input_converted", "inputs", "output"])
    .reset_index(drop=True)
    .reset_index(drop=False),
    on=["index", "input_converted", "inputs", "output"],
    suffixes=(suffix_base, suffix_comp),
)
comparison_df.head()

Unnamed: 0,index,input_base,output,predictions_base,input_converted,inputs,output_rank_base,input_comp,predictions_comp,output_rank_comp
0,0,"{0: ['E935.8', '070.44', '338.29', '493.20', '...",030-041,"{'E820-E825': 5.4491316e-05, 'E850-E858': 0.00...","007.4, 070.44, 250.83, 279.06, 338.29, 493.20,...","007.4,070.44,250.83,279.06,338.29,493.20,560.1...",14,"{0: ['E935.8', '070.44', '338.29', '493.20', '...","{'E820-E825': 3.717914e-05, 'E850-E858': 0.003...",14
1,1,"{0: ['E935.8', '070.44', '338.29', '493.20', '...",070-079,"{'E820-E825': 5.4491316e-05, 'E850-E858': 0.00...","007.4, 070.44, 250.83, 279.06, 338.29, 493.20,...","007.4,070.44,250.83,279.06,338.29,493.20,560.1...",23,"{0: ['E935.8', '070.44', '338.29', '493.20', '...","{'E820-E825': 3.717914e-05, 'E850-E858': 0.003...",0
2,2,"{0: ['E935.8', '070.44', '338.29', '493.20', '...",249-259,"{'E820-E825': 5.4491316e-05, 'E850-E858': 0.00...","007.4, 070.44, 250.83, 279.06, 338.29, 493.20,...","007.4,070.44,250.83,279.06,338.29,493.20,560.1...",26,"{0: ['E935.8', '070.44', '338.29', '493.20', '...","{'E820-E825': 3.717914e-05, 'E850-E858': 0.003...",3
3,3,"{0: ['E935.8', '070.44', '338.29', '493.20', '...",270-279,"{'E820-E825': 5.4491316e-05, 'E850-E858': 0.00...","007.4, 070.44, 250.83, 279.06, 338.29, 493.20,...","007.4,070.44,250.83,279.06,338.29,493.20,560.1...",3,"{0: ['E935.8', '070.44', '338.29', '493.20', '...","{'E820-E825': 3.717914e-05, 'E850-E858': 0.003...",4
4,4,"{0: ['E935.8', '070.44', '338.29', '493.20', '...",286,"{'E820-E825': 5.4491316e-05, 'E850-E858': 0.00...","007.4, 070.44, 250.83, 279.06, 338.29, 493.20,...","007.4,070.44,250.83,279.06,338.29,493.20,560.1...",52,"{0: ['E935.8', '070.44', '338.29', '493.20', '...","{'E820-E825': 3.717914e-05, 'E850-E858': 0.003...",40


# 2 Edge comparison

## 2.1 Generate edge set

In [21]:
# Set[Tuple[str, str]]:
edges_base = set([(c, p) for c, ps in attention_base.items() for p in ps])
edges_comp = set([(c, p) for c, ps in attention_comp.items() for p in ps])

added_edges = edges_comp - edges_base

In [22]:
added_edges

{('455.0', '455'),
 ('E890.2', 'E890'),
 ('799.3', '799'),
 ('934.8', '930-939'),
 ('E937.9', '-1'),
 ('018.90', '018.9'),
 ('198.2', '198'),
 ('900.03', '900.0'),
 ('305.70', '305.7'),
 ('736.79', '-1'),
 ('806.09', '806.0'),
 ('556.9', '520-579'),
 ('V45.72', 'V01-V91'),
 ('112.3', '110-118'),
 ('528.3', '520-579'),
 ('824.0', '800-999'),
 ('282.9', '-1'),
 ('293.83', '293.8'),
 ('305.53', '300-316'),
 ('724.03', '724.0'),
 ('922.31', '920-924'),
 ('996.73', '-1'),
 ('909.9', '909'),
 ('847.2', '840-848'),
 ('357.89', '320-389'),
 ('V15.84', 'V15.8'),
 ('903.1', '903'),
 ('574.80', '520-579'),
 ('E960.1', 'E960-E969'),
 ('398.91', '390-459'),
 ('V16.59', 'V10-V19'),
 ('596.51', '596.5'),
 ('782.3', '780-789'),
 ('288.60', '-1'),
 ('237.6', '235-238'),
 ('783.1', '783'),
 ('V60.4', 'V60-V69'),
 ('575.6', '520-579'),
 ('E922.9', 'E916-E928'),
 ('564.00', '564.0'),
 ('342.11', '-1'),
 ('781.94', '780-789'),
 ('300.01', '300.0'),
 ('433.00', '-1'),
 ('952.03', '-1'),
 ('695.9', '695'),
 

## 2.2 Calculate edge comparison

### 2.2.1 Take one edge as example

In [75]:
# Use two edges with a common node to see the difference of relevant calculation 

# The first one to be removed
    # ('719.02', '710-739') refinement metric is -4.949747
    # ('719.02', '710-719')

# Another two, both not to removed    
    # ('333.1', '320-389')
    # ('333.1', '330-337')

    
record = []
c = '719.02'
p = '710-739'

relevant_df_one = comparison_df[
    comparison_df["inputs"].apply(lambda x: c + "," in x)
]
relevant_df_one

Unnamed: 0,index,input_base,output,predictions_base,input_converted,inputs,output_rank_base,input_comp,predictions_comp,output_rank_comp
1951,1951,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",190-199,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",25,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",33
1952,1952,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",300-316,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",1,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",1
1953,1953,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",340-349,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",5,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",15
1954,1954,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",E846-E849,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",51,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",62
1955,1955,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",E878-E879,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",31,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",37


In [76]:
edge_weight_set = attention_comp.get(c, {})
edge_weight_set

{'719.02': '0.0058925725',
 '-1': '7.832021e-06',
 '710-739': '0.9553721',
 '710-719': '1.5992322e-05',
 '719': '0.03817658',
 '719.0': '0.00053494703'}

In [77]:
edge_weight = edge_weight_set.get(p, -1)
edge_weight

'0.9553721'

In [78]:
frequency_set = train_frequency.get(c, {})
frequency_set

{'absolue_frequency': 1, 'relative_frequency': 3.5458100935030123e-06}

In [79]:
frequency = frequency_set.get("absolue_frequency", 0.0)
frequency

1

In [80]:
relevant_df_2 = relevant_df_one[
    relevant_df_one["inputs"].apply(lambda x: c + "," in x)
].copy()    

if len(relevant_df_2) == 0:
    print("no relevant edges")

relevant_df_2

Unnamed: 0,index,input_base,output,predictions_base,input_converted,inputs,output_rank_base,input_comp,predictions_comp,output_rank_comp
1951,1951,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",190-199,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",25,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",33
1952,1952,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",300-316,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",1,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",1
1953,1953,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",340-349,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",5,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",15
1954,1954,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",E846-E849,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",51,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",62
1955,1955,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",E878-E879,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",31,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",37


In [81]:
relevant_df_2["output_rank_base"] = relevant_df_2["output_rank_base"].apply(
    lambda x: min(x, config.refinement_metric_maxrank)
)
relevant_df_2["output_rank_comp"] = relevant_df_2["output_rank_comp"].apply(
    lambda x: min(x, config.refinement_metric_maxrank)
)

relevant_df_2

Unnamed: 0,index,input_base,output,predictions_base,input_converted,inputs,output_rank_base,input_comp,predictions_comp,output_rank_comp
1951,1951,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",190-199,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",25,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",33
1952,1952,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",300-316,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",1,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",1
1953,1953,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",340-349,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",5,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",15
1954,1954,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",E846-E849,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",51,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",62
1955,1955,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...",E878-E879,"{'E820-E825': 5.3918633e-05, 'E850-E858': 0.0012692753, '580-589': 0.017867945, '790-796': 0.011...","038.9, 041.4, 191.3, 255.4, 276.5, 276.51, 285.9, 288.0, 300.00, 486, 518.0, 599.0, 682.3, 719.0...","038.9,041.4,191.3,255.4,276.5,276.51,285.9,288.0,300.00,486,518.0,599.0,682.3,719.02,995.91,E849...",31,"{0: ['255.4', '719.02', '276.51', '041.4', '518.0', 'E933.1', '300.00', '191.3', '038.9', '995.9...","{'E820-E825': 1.2194883e-05, 'E850-E858': 0.00184809, '580-589': 0.020954793, '790-796': 0.01113...",37


#### 2.2.1.1 use mean_outlier_score as refinement metric

In [82]:
if "outlier_score" in config.refinement_metric:
    outlier_score = (
        relevant_df_2[["output_rank_base", "output_rank_comp"]]
        .apply(lambda x: (int(x[0]) - int(x[1])) / np.sqrt(2), axis=1)
        .to_list()
    )    

In [83]:
print(*outlier_score)

-5.65685424949238 0.0 -7.071067811865475 -7.7781745930520225 -4.242640687119285


In [84]:
if "median" in config.refinement_metric:
    re_metric =  np.median(outlier_score)
elif "mean" in config.refinement_metric:
    re_metric = np.mean(outlier_score)

#print("refinement metric is" + re_metric)
re_metric

-4.949747468305832

#### 2.2.1.2 use accuracy as refinement metric

In [85]:
# change metric setting
#if "accuracy" in config.refinement_metric:

config_str = "top_5_categorical_accuracy"

accuracy_atss = [
    int(s) for s in config_str.split("_") if s.isdigit()
]
accuracy_atss


[5]

In [86]:
accuracy_ats = accuracy_atss[0] if len(accuracy_atss) > 0 else 1
accuracy_ats

5

In [87]:
accuracy_base = len(
    relevant_df_2[relevant_df_2["output_rank_base"] < accuracy_ats]
) / len(relevant_df_2)
accuracy_base

0.2

In [88]:
accuracy_comp = len(
    relevant_df_2[relevant_df_2["output_rank_comp"] < accuracy_ats]
) / len(relevant_df_2)
accuracy_comp

0.2

In [89]:
re_metric_2 = accuracy_comp - accuracy_base

re_metric_2

0.0

#### 2.2.1.3 add removed edge to set

In [90]:
record.append(
    {
        "child": c,
        "parent": p,
        "refinement_metric": re_metric
    }
)

record

[{'child': '719.02',
  'parent': '710-739',
  'refinement_metric': -4.949747468305832}]

In [91]:
edge_comparison_df =  pd.DataFrame.from_records(
    record, columns=["child", "parent", "refinement_metric"]
)
edge_comparison_df

Unnamed: 0,child,parent,refinement_metric
0,719.02,710-739,-4.949747


In [93]:
edge_comparison_df = (
    edge_comparison_df[
        edge_comparison_df["refinement_metric"]
        < config.max_refinement_metric
    ]
    .sort_values(by="refinement_metric", ascending=True)
    .head(n=config.max_edges_to_remove)
)

edge_comparison_df

Unnamed: 0,child,parent,refinement_metric
0,719.02,710-739,-4.949747


### 2.2.2 Complete dataset

In [None]:
def calculate_refinement_metric(
    input_feature: str, comparison_df: pd.DataFrame
) -> float:
    # refinement_metric > 0 -> comparison is better than base
    relevant_df = comparison_df[
        comparison_df["inputs"].apply(lambda x: input_feature + "," in x)
    ].copy()
    
    if len(relevant_df) == 0:
        return -1
    if self.config.refinement_metric_maxrank > 0:
        relevant_df["output_rank_base"] = relevant_df["output_rank_base"].apply(
            lambda x: min(x, self.config.refinement_metric_maxrank)
        )
        relevant_df["output_rank_comp"] = relevant_df["output_rank_comp"].apply(
            lambda x: min(x, self.config.refinement_metric_maxrank)
        )

    if "outlier_score" in self.config.refinement_metric:
        outlier_scores = (
            relevant_df[["output_rank_base", "output_rank_comp"]]
            .apply(lambda x: (int(x[0]) - int(x[1])) / np.sqrt(2), axis=1)
            .to_list()
        )
        if "median" in self.config.refinement_metric:
            return np.median(outlier_scores)
        elif "mean" in self.config.refinement_metric:
            return np.mean(outlier_scores)
    elif "accuracy" in self.config.refinement_metric:
        accuracy_ats = [
            int(s) for s in self.config.refinement_metric.split("_") if s.isdigit()
        ]
        accuracy_at = accuracy_ats[0] if len(accuracy_ats) > 0 else 1
        accuracy_base = len(
            relevant_df[relevant_df["output_rank_base"] < accuracy_at]
        ) / len(relevant_df)
        accuracy_comp = len(
            relevant_df[relevant_df["output_rank_comp"] < accuracy_at]
        ) / len(relevant_df)
        return accuracy_comp - accuracy_base

    logging.error("Unknown refinement metric: %s", self.config.refinement_metric)
    return -1

In [None]:
records = []

for c, p in tqdm(added_edges):    
    if c == p:
        continue

    relevant_df = comparison_df[
        comparison_df["inputs"].apply(lambda x: c + "," in x)
    ]
        
    if len(relevant_df) == 0:
        continue

    edge_weight = attention_comp.get(c, {}).get(p, -1)
    if float(edge_weight) < config.min_edge_weight:
        continue

    frequency = train_frequency.get(c, {}).get("absolue_frequency", 0.0)
    if frequency > self.config.max_train_examples:
        continue

    records.append(
        {
            "child": c,
            "parent": p,
            "refinement_metric": calculate_refinement_metric(
                c, relevant_df
            ),
        }
    )


In [None]:
edge_comparison_df =  pd.DataFrame.from_records(
    records, columns=["child", "parent", "refinement_metric"]
)
edge_comparison_df.head()

In [None]:
edge_comparison_df = (
    edge_comparison_df[
        edge_comparison_df["refinement_metric"]
        < config.max_refinement_metric
    ]
    .sort_values(by="refinement_metric", ascending=True)
    .head(n=self.config.max_edges_to_remove)
)

# 3 Refined edges

In [94]:
refined_knowledge = {c: [c] for c in attention_comp}
refined_knowledge

{'996.04': ['996.04'],
 '251.5': ['251.5'],
 '805.6': ['805.6'],
 'V61.42': ['V61.42'],
 '600.10': ['600.10'],
 '482.0': ['482.0'],
 '227.0': ['227.0'],
 '718.45': ['718.45'],
 '718.56': ['718.56'],
 '626.9': ['626.9'],
 '008.41': ['008.41'],
 '800.15': ['800.15'],
 '221.8': ['221.8'],
 '761.8': ['761.8'],
 '852.40': ['852.40'],
 '153.9': ['153.9'],
 'E927.8': ['E927.8'],
 '070.31': ['070.31'],
 '873.49': ['873.49'],
 '568.0': ['568.0'],
 '553.8': ['553.8'],
 '572.0': ['572.0'],
 'E959': ['E959'],
 '297.9': ['297.9'],
 '873.1': ['873.1'],
 '944.25': ['944.25'],
 '958.93': ['958.93'],
 '156.9': ['156.9'],
 '227.1': ['227.1'],
 '382.01': ['382.01'],
 '442.2': ['442.2'],
 '751.7': ['751.7'],
 '942.14': ['942.14'],
 '998.31': ['998.31'],
 '459.89': ['459.89'],
 '769': ['769'],
 '369.8': ['369.8'],
 '141.8': ['141.8'],
 '596.54': ['596.54'],
 '727.43': ['727.43'],
 '825.35': ['825.35'],
 '377.41': ['377.41'],
 '378.54': ['378.54'],
 '309.28': ['309.28'],
 '770.9': ['770.9'],
 'E870.2': ['E8

In [95]:
for child, parents in attention_comp.items():
    for parent in parents:
        if (
            len(
                edge_comparison_df[
                    (edge_comparison_df["child"] == child)
                    & (edge_comparison_df["parent"] == parent)
                ]
            )
            > 0
        ):
            continue

        refined_knowledge[child].append(parent)

refined_knowledge

{'996.04': ['996.04', '996.04', '-1', '800-999', '996-999', '996', '996.0'],
 '251.5': ['251.5', '251.5', '-1', '249-259', '240-279', '251'],
 '805.6': ['805.6', '805.6', '-1', '800-999', '805-809', '805'],
 'V61.42': ['V61.42', 'V61.42', '-1', 'V01-V91', 'V60-V69', 'V61', 'V61.4'],
 '600.10': ['600.10', '600.10', '-1', '580-629', '600-608', '600', '600.1'],
 '482.0': ['482.0', '482.0', '-1', '460-519', '480-488', '482'],
 '227.0': ['227.0', '227.0', '-1', '140-239', '210-229', '227'],
 '718.45': ['718.45', '718.45', '-1', '710-739', '710-719', '718', '718.4'],
 '718.56': ['718.56', '718.56', '-1', '710-739', '710-719', '718', '718.5'],
 '626.9': ['626.9', '626.9', '-1', '580-629', '617-629', '626'],
 '008.41': ['008.41', '008.41', '-1', '001-139', '001-009', '008', '008.4'],
 '800.15': ['800.15', '800.15', '-1', '800-999', '800-804', '800', '800.1'],
 '221.8': ['221.8', '221.8', '-1', '140-239', '210-229', '221'],
 '761.8': ['761.8', '761.8', '-1', '760-779', '760-763', '761'],
 '852.

In [96]:
def convert_to_node_mapping(
    all_nodes: List[str], use_node_mapping: bool = True
) -> Dict[str, str]:
    node_names = list(set([_node_name(x) for x in all_nodes]))
    node_mapping = {}
    for idx in range(len(node_names)):
        if use_node_mapping:
            node_mapping[node_names[idx]] = "feature" + str(idx)
        else:
            node_mapping[node_names[idx]] = node_names[idx]
    return node_mapping

In [98]:
def _node_name(node_id: str) -> str:
    if not str(node_id).isdigit():
        return node_id

    return "#" + str(node_id)

In [99]:
use_node_mapping=False

feature_node_mapping = convert_to_node_mapping(
    [x for x in attention_comp], use_node_mapping
)
feature_node_mapping

{'752.49': '752.49',
 '287.1': '287.1',
 '200.30': '200.30',
 'V43.65': 'V43.65',
 '141.0': '141.0',
 '289.3': '289.3',
 '786.2': '786.2',
 '480.2': '480.2',
 '018.96': '018.96',
 'V58.11': 'V58.11',
 '682.3': '682.3',
 '250.72': '250.72',
 '259.8': '259.8',
 '473.9': '473.9',
 '820.21': '820.21',
 '616.89': '616.89',
 '453.87': '453.87',
 '209.69': '209.69',
 '377.01': '377.01',
 '444.21': '444.21',
 '537.82': '537.82',
 '781.1': '781.1',
 '249.00': '249.00',
 '453.74': '453.74',
 '861.31': '861.31',
 '152.2': '152.2',
 '786.01': '786.01',
 'V45.09': 'V45.09',
 '996.77': '996.77',
 '361.07': '361.07',
 '996.57': '996.57',
 '274.9': '274.9',
 '369.8': '369.8',
 '258.1': '258.1',
 '360.19': '360.19',
 '457.0': '457.0',
 '765.13': '765.13',
 '380.22': '380.22',
 '427.0': '427.0',
 '786.05': '786.05',
 'E823.3': 'E823.3',
 '995.61': '995.61',
 '922.0': '922.0',
 '801.06': '801.06',
 '569.5': '569.5',
 '592.0': '592.0',
 '150.0': '150.0',
 '996.62': '996.62',
 '438.12': '438.12',
 '708.9':

In [101]:
reference_connections=set(
    [(c,p) for c,ps in refined_knowledge.items() for p in ps]
)
reference_connections

{('455.0', '455'),
 ('E937.9', '-1'),
 ('900.03', '900.0'),
 ('850.2', '850.2'),
 ('305.70', '305.7'),
 ('806.09', '806.0'),
 ('556.9', '520-579'),
 ('V45.72', 'V01-V91'),
 ('112.3', '110-118'),
 ('528.3', '520-579'),
 ('718.84', '718.84'),
 ('824.0', '800-999'),
 ('282.9', '-1'),
 ('305.01', '305.01'),
 ('922.31', '920-924'),
 ('909.9', '909'),
 ('357.89', '320-389'),
 ('V15.84', 'V15.8'),
 ('903.1', '903'),
 ('403.91', '403.91'),
 ('574.80', '520-579'),
 ('212.6', '212.6'),
 ('V15.51', 'V15.51'),
 ('136.1', '136.1'),
 ('379.21', '379.21'),
 ('783.1', '783'),
 ('V60.4', 'V60-V69'),
 ('575.6', '520-579'),
 ('E922.9', 'E916-E928'),
 ('342.11', '-1'),
 ('433.00', '-1'),
 ('952.03', '-1'),
 ('331.19', '320-389'),
 ('V54.01', 'V54.0'),
 ('V45.76', 'V40-V49'),
 ('250.11', '240-279'),
 ('784.41', '780-799'),
 ('282.60', '282.60'),
 ('018.80', '018.8'),
 ('153.9', '153'),
 ('255.4', '249-259'),
 ('524.60', '524.60'),
 ('745.2', '-1'),
 ('355.79', '-1'),
 ('269.2', '269'),
 ('765.24', '764-779

In [103]:
original_connections = set(
    [
        (child, parent)
        for child, parents in attention_comp.items()
        for parent in parents
    ]
)
original_connections

{('455.0', '455'),
 ('E937.9', '-1'),
 ('900.03', '900.0'),
 ('850.2', '850.2'),
 ('305.70', '305.7'),
 ('806.09', '806.0'),
 ('556.9', '520-579'),
 ('V45.72', 'V01-V91'),
 ('112.3', '110-118'),
 ('528.3', '520-579'),
 ('718.84', '718.84'),
 ('824.0', '800-999'),
 ('282.9', '-1'),
 ('305.01', '305.01'),
 ('922.31', '920-924'),
 ('909.9', '909'),
 ('357.89', '320-389'),
 ('V15.84', 'V15.8'),
 ('903.1', '903'),
 ('403.91', '403.91'),
 ('574.80', '520-579'),
 ('212.6', '212.6'),
 ('V15.51', 'V15.51'),
 ('136.1', '136.1'),
 ('379.21', '379.21'),
 ('783.1', '783'),
 ('V60.4', 'V60-V69'),
 ('575.6', '520-579'),
 ('E922.9', 'E916-E928'),
 ('342.11', '-1'),
 ('433.00', '-1'),
 ('952.03', '-1'),
 ('331.19', '320-389'),
 ('V54.01', 'V54.0'),
 ('V45.76', 'V40-V49'),
 ('250.11', '240-279'),
 ('784.41', '780-799'),
 ('282.60', '282.60'),
 ('018.80', '018.8'),
 ('153.9', '153'),
 ('255.4', '249-259'),
 ('524.60', '524.60'),
 ('745.2', '-1'),
 ('355.79', '-1'),
 ('269.2', '269'),
 ('765.24', '764-779

In [105]:
print(len(original_connections))
print(len(reference_connections))

26608
26607


In [106]:
colored_connections = (
    original_connections - reference_connections
    if len(original_connections) > len(reference_connections)
    else reference_connections - original_connections
)
colored_connections

{('719.02', '710-739')}

In [107]:
final_colored_connections = set(
    [
        (
            feature_node_mapping.get(_node_name(v[0]), _node_name(v[0])),
            feature_node_mapping.get(_node_name(v[1]), _node_name(v[1])),
        )
        for v in colored_connections
    ]
)

print("Removed", len(final_colored_connections), "edges")
final_colored_connections

Removed 1 edges


{('719.02', '710-739')}