In [1]:
import numpy as np
import torch
import pickle
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torchmetrics.functional import auroc, average_precision
from sklearn.metrics import roc_auc_score, average_precision_score

import os
import re
import sys
plt.style.use('bmh')
plt.rcParams['svg.fonttype'] = 'none'
os.chdir('/project/lcd_v2')

from run.tools.cli.infer import prepare_inference

In [11]:
def collect_inferences(file_cfg=dict(), target_dir='.cache/sim_data/mos_auc_result'):
    # get result from traditional method
    # file name format: game={game}-method={method}-seed={seed}-note={note}.pkl
    game = file_cfg.get('game', ['DonkeyKong'])
    periods = file_cfg.get('period', '0_128')
    method = file_cfg.get('method', ['corr', 'mi', 'gc', 'lingam'])
    # seed = file_cfg.get('seed', [42])
    note = file_cfg.get('note', ['default'])
    result = {}

    for _note in note:
        result[_note] = {}
        for _game in game:
            result[_note][_game] = {}
            for _method in method:
                result[_note][_game][f'{_method}_score'] = {'auroc': {}, 'auprc': {}}
                for window in periods:
                    file_name = os.path.join(target_dir, f'game={_game}-window={window}-method={_method}-seed=42-note={_note}.pkl')
                    file = pickle.load(open(file_name, 'rb'))
                    if _method in ['mi', 'lingam', 'gc']:
                        # directly calculate AUC (coeff, adjacency matrices)
                        result[_note][_game][f'{_method}_score']['auroc'][window] = roc_auc_score(np.array(file[_method]['label']), np.array(file[_method]['pred']))
                        result[_note][_game][f'{_method}_score']['auprc'][window] = average_precision_score(np.array(file[_method]['label']), np.array(file[_method]['pred']))
                    elif _method in ['corr']:
                        # calculate on the absolute value
                        result[_note][_game][f'{_method}_score']['auroc'][window] = roc_auc_score(np.array(file[_method]['label']), np.abs(np.array(file[_method]['pred'])))
                        result[_note][_game][f'{_method}_score']['auprc'][window] = average_precision_score(np.array(file[_method]['label']), np.abs(np.array(file[_method]['pred'])))

    return result

In [14]:
condition = 'default'
games = ['DonkeyKong', 'Pitfall', 'SpaceInvaders']
# collect inference result from traditional methods
result = collect_inferences(file_cfg={'game': games,
                                        'method': ['corr', 'mi', 'gc', 'lingam'],
                                        'period': ['0_128', '128_256', '384_512', '512_640',
                                                   '640_768'],
                                        'note': [condition]}, target_dir='.cache/sim_data/mos_auc_result')
result = result[condition]

In [4]:
result

{'DonkeyKong': {'corr_score': {'auroc': {'0_128': 0.9775118201776011,
    '128_256': 0.8844896055954927,
    '384_512': 0.9229661691012045,
    '512_640': 0.6225551885061088,
    '640_768': 0.9438486002239641},
   'auprc': {'0_128': 0.1772410654997663,
    '128_256': 0.1177534096031724,
    '384_512': 0.12028536759021562,
    '512_640': 0.05012911913019164,
    '640_768': 0.1882282317876899}},
  'mi_score': {'auroc': {'0_128': 0.950590749381095,
    '128_256': 0.8805906353215467,
    '384_512': 0.9000803288945204,
    '512_640': 0.6414445561551456,
    '640_768': 0.9375659574468084},
   'auprc': {'0_128': 0.1479700086682652,
    '128_256': 0.15871293603306355,
    '384_512': 0.0854565483799435,
    '512_640': 0.04935941492798762,
    '640_768': 0.19263126025508295}},
  'gc_score': {'auroc': {'0_128': 0.8338648737018565,
    '128_256': 0.8723201865164173,
    '384_512': 0.9042625629851558,
    '512_640': 0.5782254185161627,
    '640_768': 0.8882642777155656},
   'auprc': {'0_128': 0.248

In [10]:
# round all the floating number in the res to 2 decimal places
games = ["Donkey Kong", "Pitfall", "Space Invaders"]
pre_fix = ""
for game in games:
    pre_fix += f"\multirow{{2}}{{*}}{{{game}}} & "
    methods = ['corr', 'mi', 'lingam', 'gc']

    for period  in ['0_128', '128_256', '384_512', '512_640', '640_768']:
        start, end = period.split('_')
        pre_fix += f"[{start}, {end}] & "
        game = game.replace(' ', '')
        for method in methods:
            pre_fix += f"{result[game][f'{method}_score']['auroc'][period]:.2f} & "
        pre_fix += "[TRM]"
        pre_fix += "& "
        for method in methods:
            pre_fix += f"{result[game][f'{method}_score']['auprc'][period]:.2f} & "
        pre_fix += "[TRM]"

        pre_fix += f"\\\\ \n"

open("score_transformer.txt", "w").write(pre_fix)


1361

In [50]:

def get_stat(result, games=['DonkeyKong', 'Pitfall', 'SpaceInvaders']):
    stat = {}
    for game in games:
        stat[game] = {}
        for method in ['corr', 'mi', 'gc', 'lingam']:
            stat[game][method] ={}
            # average across all the periods
            stat[game][method]['auroc'] = np.mean(list(result[game][f'{method}_score']['auroc'].values()))
            stat[game][method]['auprc'] = np.mean(list(result[game][f'{method}_score']['auprc'].values()))
            # standard deviation across all the periods
            stat[game][method]['auroc_std'] = np.std(list(result[game][f'{method}_score']['auroc'].values()))
            stat[game][method]['auprc_std'] = np.std(list(result[game][f'{method}_score']['auprc'].values()))
    return stat
stat = get_stat(result)
stat

{'DonkeyKong': {'corr': {'auroc': 0.8602197106025296,
   'auprc': 0.07719992292923179,
   'auroc_std': 0.13530933109802226,
   'auprc_std': 0.02231606928066377},
  'mi': {'auroc': 0.8460670807674558,
   'auprc': 0.05431754180995121,
   'auroc_std': 0.12693344826040118,
   'auprc_std': 0.01999384970551608},
  'gc': {'auroc': 0.7913097481946411,
   'auprc': 0.08471494019103284,
   'auroc_std': 0.10885719852827942,
   'auprc_std': 0.02492775327523892},
  'lingam': {'auroc': 0.8282069479133296,
   'auprc': 0.03852301788807668,
   'auroc_std': 0.10926376294781275,
   'auprc_std': 0.007839911941701765}},
 'Pitfall': {'corr': {'auroc': 0.8467219463969442,
   'auprc': 0.07287619081649578,
   'auroc_std': 0.12394286145923737,
   'auprc_std': 0.02464053869427586},
  'mi': {'auroc': 0.8390113290484343,
   'auprc': 0.05833093440673385,
   'auroc_std': 0.11764846496977396,
   'auprc_std': 0.02582513880909912},
  'gc': {'auroc': 0.7853383986341976,
   'auprc': 0.08556360868167738,
   'auroc_std': 0.

In [15]:
condition = ['0.1noise', '0.05noise', '0.03noise']
games = ['DonkeyKong']
# collect inference result from traditional methods
noise_result = collect_inferences(file_cfg={'game': games,
                                        'method': ['corr', 'mi', 'gc', 'lingam'],
                                        'period': ['0_128', '128_256', '384_512', '512_640',
                                                   '640_768'],
                                        'note': condition}, target_dir='.cache/sim_data/mos_auc_result')
noise_result = {key: value['DonkeyKong'] for key, value in noise_result.items()}

In [17]:
noise_result

{'0.1noise': {'corr_score': {'auroc': {'0_128': 0.9742810581329567,
    '128_256': 0.8893017291626191,
    '384_512': 0.9192872241531419,
    '512_640': 0.6253793034267827,
    '640_768': 0.9430333706606944},
   'auprc': {'0_128': 0.15055168963111748,
    '128_256': 0.13697317401620362,
    '384_512': 0.11933212183674756,
    '512_640': 0.0471719375506415,
    '640_768': 0.176734207124657}},
  'mi_score': {'auroc': {'0_128': 0.9589323434313028,
    '128_256': 0.8844227705459491,
    '384_512': 0.8942253985359221,
    '512_640': 0.6292347474996182,
    '640_768': 0.9296430011198208},
   'auprc': {'0_128': 0.13967449047182273,
    '128_256': 0.14658002537615814,
    '384_512': 0.08814849677932653,
    '512_640': 0.051925541089196565,
    '640_768': 0.19909197955120872}},
  'gc_score': {'auroc': {'0_128': 0.9829334021870573,
    '128_256': 0.9081064697882262,
    '384_512': 0.9277286659679829,
    '512_640': 0.6007409668219217,
    '640_768': 0.9530875699888017},
   'auprc': {'0_128': 0.2

In [20]:
# round all the floating number in the res to 2 decimal places
games = ["0.03noise", "0.05noise", "0.1noise"]
pre_fix = ""
for game in games:
    pre_fix += f"\multirow{{5}}{{*}}{{{game}}} & "
    methods = ['corr', 'mi', 'lingam', 'gc']

    for period  in ['0_128', '128_256', '384_512', '512_640', '640_768']:
        start, end = period.split('_')
        pre_fix += f"& [{start}, {end}] & "
        game = game.replace(' ', '')
        for method in methods:
            pre_fix += f"{noise_result[game][f'{method}_score']['auroc'][period]:.2f} & "
        pre_fix += "[TRM]"
        pre_fix += "& "
        for method in methods:
            pre_fix += f"{noise_result[game][f'{method}_score']['auprc'][period]:.2f} & "
        pre_fix += "[TRM]"

        pre_fix += f"\\\\  \n"

open("score_transformer.txt", "w").write(pre_fix)


1400