In [48]:
import os
from medvqa.utils.common import RESULTS_DIR, WORKSPACE_DIR
from medvqa.utils.files import load_pickle

In [78]:
def collect_report_level_results(dataset_name):
    vqa_dirs = os.listdir(os.path.join(RESULTS_DIR,'vqa'))
    qa_dirs = os.listdir(os.path.join(RESULTS_DIR,'qa'))
    vm_module_dirs = os.listdir(os.path.join(RESULTS_DIR,'visual_module'))
    results = []
    for dirs, kind in zip([vqa_dirs, qa_dirs, vm_module_dirs], ['vqa', 'qa', 'visual_module']):
        for exp_name in dirs:
            exp_result_filenames = [x for x in os.listdir(os.path.join(RESULTS_DIR, kind, exp_name))\
                                    if 'report_level' in x and dataset_name in x]
            for filename in exp_result_filenames:
                results.append((kind, exp_name, filename))
    return results

def collect_visual_module_results(dataset_name):
    dirs = os.listdir(os.path.join(RESULTS_DIR,'visual_module'))
    results = []
    for exp_name in dirs:
        exp_result_filenames = [x for x in os.listdir(os.path.join(RESULTS_DIR, 'visual_module', exp_name))\
                                if 'report_level' not in x and dataset_name in x]
        for filename in exp_result_filenames:
            results.append(('visual_module', exp_name, filename))
    return results

In [39]:
from medvqa.evaluation.report_generation import get_report_level_metrics_dataframe
from medvqa.evaluation.visual_module import get_visual_module_metrics_dataframe

In [68]:
def plot_report_level_metrics(dataset_name):
    results = collect_report_level_results(dataset_name)
    metrics = [load_pickle(os.path.join(RESULTS_DIR, *result)) for result in results]
    method_names  = [result[1] + ' - ' + result[2] for result in results]
    return get_report_level_metrics_dataframe(metrics, method_names)

def plot_visual_module_metrics(dataset_name):
    results = collect_visual_module_results(dataset_name)
    metrics = [load_pickle(os.path.join(RESULTS_DIR, *result)) for result in results]
    method_names  = [result[1] + ' - ' + result[2] for result in results]
    return get_visual_module_metrics_dataframe(metrics, method_names)

In [69]:
df = plot_report_level_metrics('mimiccxr')

30it [00:31,  1.03s/it]


In [70]:
df.columns

Index(['method_name', 'b1', 'b2', 'b3', 'b4', 'cD', 'rg-L', 'met', 'mdcmp',
       'wmdcmp', 'chxlabf1(hard)', 'p(micro)', 'r(micro)', 'f1(micro)',
       'p(macro)', 'r(macro)', 'f1(macro)', 'acc', 'p(NF)', 'p(EC)', 'p(Card)',
       'p(LL)', 'p(LO)', 'p(E)', 'p(Cons)', 'p(Pn)', 'p(A)', 'p(Pt)', 'p(PE)',
       'p(PO)', 'p(F)', 'p(SD)', 'r(NF)', 'r(EC)', 'r(Card)', 'r(LL)', 'r(LO)',
       'r(E)', 'r(Cons)', 'r(Pn)', 'r(A)', 'r(Pt)', 'r(PE)', 'r(PO)', 'r(F)',
       'r(SD)', 'f1(NF)', 'f1(EC)', 'f1(Card)', 'f1(LL)', 'f1(LO)', 'f1(E)',
       'f1(Cons)', 'f1(Pn)', 'f1(A)', 'f1(Pt)', 'f1(PE)', 'f1(PO)', 'f1(F)',
       'f1(SD)'],
      dtype='object')

In [71]:
df_sorted = df.sort_values(['f1(macro)'], ascending=False)
df_sorted_and_filtered = df_sorted[['method_name', 'wmdcmp', 'f1(macro)', 'p(macro)', 'r(macro)', 'f1(micro)', 'r(micro)', 'p(micro)', 'b1', 'b2', 'b3', 'b4', 'rg-L', 'cD']]
df_sorted_and_filtered = df_sorted_and_filtered.reset_index(drop=True)

In [72]:
df_sorted_and_filtered

Unnamed: 0,method_name,wmdcmp,f1(macro),p(macro),r(macro),f1(micro),r(micro),p(micro),b1,b2,b3,b4,rg-L,cD
0,20220508_110803_mimiccxr+iuxray_oevqa(densenet...,0.150673,0.6159,0.567759,0.689625,0.71025,0.793251,0.642973,0.319978,0.186227,0.106163,0.058897,0.222314,0.1536888
1,20220626_123045_mimiccxr+iuxray_oevqa(densenet...,0.148884,0.614878,0.569115,0.691977,0.71135,0.810765,0.633653,0.304718,0.178233,0.101774,0.056504,0.219189,0.1299223
2,20220508_110803_mimiccxr+iuxray_oevqa(densenet...,0.153054,0.61482,0.559139,0.69883,0.70732,0.799924,0.633933,0.298971,0.175512,0.101967,0.058543,0.263689,0.131022
3,20220626_121543_mimiccxr+iuxray_oevqa(densenet...,0.147627,0.607148,0.560349,0.679041,0.70776,0.810482,0.628147,0.299489,0.175063,0.099745,0.055012,0.220708,0.1255519
4,20220622_071455_mimiccxr+iuxray+chexpert_oevqa...,0.150533,0.600409,0.553958,0.671073,0.702691,0.8,0.626488,0.306829,0.179708,0.102726,0.05676,0.222944,0.1378526
5,20220608_063623_mimiccxr+iuxray_oevqa(densenet...,0.146651,0.599112,0.546554,0.676854,0.702804,0.795184,0.629655,0.30003,0.175616,0.1003,0.055904,0.220072,0.1240379
6,20220525_005230_mimiccxr+iuxray_oevqa(densenet...,0.141,0.595904,0.53094,0.69423,0.697078,0.816414,0.60818,0.28823,0.166302,0.093835,0.05129,0.217633,0.1106957
7,20220627_111405_mimiccxr+iuxray+chexpert_oevqa...,0.153137,0.592567,0.559212,0.652754,0.701948,0.787819,0.632957,0.315155,0.185646,0.106606,0.05967,0.223734,0.1460881
8,20220502_032510_mimiccxr+iuxray_oevqa(densenet...,0.161557,0.591929,0.582153,0.631667,0.691834,0.732819,0.65519,0.325993,0.190569,0.110349,0.063638,0.269116,0.1904793
9,20220621_142330_mimiccxr+iuxray_oevqa(densenet...,0.148591,0.589818,0.542024,0.660556,0.702533,0.803966,0.623828,0.306458,0.178828,0.102132,0.056801,0.221214,0.1432354


In [73]:
for i, name in enumerate(df_sorted_and_filtered['method_name']):
    print(f'({i})    ',name)
    print()

(0)     20220508_110803_mimiccxr+iuxray_oevqa(densenet121+bilstm+lstm)_voc-minf=5_model-args=(128,128,512,2,128,1024,0)_cnn-pretr=0_mim-iux-freqs=200,10_medtok=1_tags=0_orien=1_chx=1_ql=1_use_amp - mimiccxr_report_level_metrics_(eval_mode=ground-truth).pkl

(1)     20220626_123045_mimiccxr+iuxray_oevqa(densenet121+bilstm+lstm)_voc-minf=5_model-args=(densenet-121,128,None,512,2,128,1024,0,qenc=one-hot,pretrained)_cnn-pretr=0_dataset_weights=0.83,0.07_medtok=1_tags=0_orien=0_chx=0_ql=0_use_amp - mimiccxr_report_level_metrics(eval_mode=ground-truth).pkl

(2)     20220508_110803_mimiccxr+iuxray_oevqa(densenet121+bilstm+lstm)_voc-minf=5_model-args=(128,128,512,2,128,1024,0)_cnn-pretr=0_mim-iux-freqs=200,10_medtok=1_tags=0_orien=1_chx=1_ql=1_use_amp - mimiccxr_report_level_metrics.pkl

(3)     20220626_121543_mimiccxr+iuxray_oevqa(densenet121+bilstm+lstm)_voc-minf=5_model-args=(densenet-121,128,None,512,2,128,1024,0,qenc=one-hot,pretrained)_cnn-pretr=0_dataset_weights=0.83,0.07_medtok=1_tags

In [79]:
df_vm = plot_visual_module_metrics('mimiccxr')

8it [00:00, 9642.08it/s]


In [80]:
df_vm.columns

Index(['method_name', 'chxlacc', 'f1(macro)', 'p(macro)', 'r(macro)',
       'f1(micro)', 'p(micro)', 'r(micro)', 'f1(NF)', 'f1(EC)', 'f1(Card)',
       'f1(LL)', 'f1(LO)', 'f1(E)', 'f1(Cons)', 'f1(Pn)', 'f1(A)', 'f1(Pt)',
       'f1(PE)', 'f1(PO)', 'f1(F)', 'f1(SD)', 'p(NF)', 'p(EC)', 'p(Card)',
       'p(LL)', 'p(LO)', 'p(E)', 'p(Cons)', 'p(Pn)', 'p(A)', 'p(Pt)', 'p(PE)',
       'p(PO)', 'p(F)', 'p(SD)', 'r(NF)', 'r(EC)', 'r(Card)', 'r(LL)', 'r(LO)',
       'r(E)', 'r(Cons)', 'r(Pn)', 'r(A)', 'r(Pt)', 'r(PE)', 'r(PO)', 'r(F)',
       'r(SD)', 'rocauc(macro)', 'rocauc(micro)', 'rocauc(NF)', 'rocauc(EC)',
       'rocauc(Card)', 'rocauc(LL)', 'rocauc(LO)', 'rocauc(E)', 'rocauc(Cons)',
       'rocauc(Pn)', 'rocauc(A)', 'rocauc(Pt)', 'rocauc(PE)', 'rocauc(PO)',
       'rocauc(F)', 'rocauc(SD)', 'ql_f1(macro)', 'ql_p(macro)', 'ql_r(macro)',
       'ql_f1(micro)', 'ql_p(micro)', 'ql_r(micro)', 'oracc'],
      dtype='object')

In [81]:
df_sorted = df_vm.sort_values(['f1(macro)'], ascending=False)
df_sorted

Unnamed: 0,method_name,chxlacc,f1(macro),p(macro),r(macro),f1(micro),p(micro),r(micro),f1(NF),f1(EC),...,rocauc(PO),rocauc(F),rocauc(SD),ql_f1(macro),ql_p(macro),ql_r(macro),ql_f1(micro),ql_p(micro),ql_r(micro),oracc
4,20220627_035559_mimiccxr+iuxray+chexpert_dense...,0.754057,0.469153,0.404611,0.603627,0.577708,0.494723,0.694143,0.337391,0.328049,...,0.824107,0.758299,0.884811,0.223384,0.188256,0.339953,0.477076,0.392007,0.609302,0.995688
2,20220613_180648_mimiccxr+iuxray+chexpert_dense...,0.724571,0.462447,0.381085,0.647382,0.56912,0.458331,0.750542,0.315447,0.431917,...,0.803908,0.71941,0.871197,0.203249,0.167657,0.397973,0.416963,0.316991,0.60904,0.996683
6,20220614_103906_mimiccxr+iuxray+chexpert_dense...,0.702834,0.45525,0.371016,0.689209,0.552323,0.434971,0.75639,0.3047,0.426162,...,0.807638,0.719522,0.872031,0.189783,0.153187,0.586394,0.325112,0.219123,0.629693,0.997015
1,20220601_073528_mimiccxr+iuxray_densenet121_mo...,0.802377,0.395667,0.450642,0.377714,0.57329,0.589286,0.55814,0.296296,0.3,...,0.738412,0.689397,0.873454,0.160353,0.268517,0.145219,0.574459,0.644961,0.517852,0.99602
7,20220531_213512_mimiccxr+iuxray_densenet121_mo...,0.789874,0.38341,0.409966,0.369966,0.549958,0.560523,0.539785,0.262069,0.294004,...,0.667616,0.694885,0.849394,0.157838,0.211068,0.143295,0.553642,0.606819,0.509033,0.99801
0,20220601_075947_mimiccxr+iuxray_densenet121_mo...,0.79776,0.381227,0.416248,0.378178,0.57592,0.574488,0.577359,0.222222,0.286301,...,0.755787,0.647075,0.866062,0.146962,0.229474,0.136325,0.568063,0.628182,0.518446,0.996683
3,20220601_082313_mimiccxr+iuxray_densenet121_mo...,0.804617,0.368311,0.430405,0.33865,0.556777,0.604617,0.515952,0.186047,0.228571,...,0.771059,0.709032,0.874094,0.155561,0.233481,0.136936,0.565268,0.647224,0.501734,0.996683
5,20220613_082339_mimiccxr+iuxray+chexpert_dense...,0.816274,0.338385,0.521501,0.299094,0.550045,0.676629,0.463359,0.165138,0.007059,...,0.797126,0.732497,0.889047,0.134225,0.264629,0.119441,0.562324,0.682262,0.478251,0.997347


In [82]:
df_sorted[['f1(NF)', 'f1(EC)', 'f1(Card)', 'f1(LL)', 'f1(LO)', 'f1(E)',
       'f1(Cons)', 'f1(Pn)', 'f1(A)', 'f1(Pt)', 'f1(PE)', 'f1(PO)', 'f1(F)',
       'f1(SD)']]

Unnamed: 0,f1(NF),f1(EC),f1(Card),f1(LL),f1(LO),f1(E),f1(Cons),f1(Pn),f1(A),f1(Pt),f1(PE),f1(PO),f1(F),f1(SD)
4,0.337391,0.328049,0.704377,0.272285,0.638964,0.635786,0.281828,0.44312,0.582547,0.242704,0.749139,0.254335,0.268537,0.829077
2,0.315447,0.431917,0.716625,0.218029,0.63993,0.635311,0.284422,0.438239,0.598996,0.209115,0.73941,0.23053,0.206186,0.810102
6,0.3047,0.426162,0.714601,0.215501,0.639835,0.638747,0.282888,0.433457,0.604342,0.204301,0.742057,0.184971,0.180022,0.80192
1,0.296296,0.3,0.645999,0.096491,0.583716,0.577244,0.165342,0.318352,0.552033,0.180258,0.710732,0.192308,0.102041,0.818523
7,0.262069,0.294004,0.659392,0.141176,0.566253,0.578504,0.134206,0.324282,0.491833,0.158621,0.679291,0.128655,0.15942,0.790029
0,0.222222,0.286301,0.674434,0.082949,0.59532,0.597566,0.153285,0.297506,0.551864,0.112554,0.727015,0.109091,0.133891,0.793184
3,0.186047,0.228571,0.658615,0.129032,0.555756,0.550976,0.146252,0.297535,0.495829,0.144,0.71701,0.075,0.160714,0.811019
5,0.165138,0.007059,0.504892,0.141667,0.616411,0.633663,0.095436,0.242841,0.434444,0.183206,0.730798,0.0625,0.099448,0.819888


In [84]:
df_sorted[['rocauc(NF)', 'rocauc(EC)',
       'rocauc(Card)', 'rocauc(LL)', 'rocauc(LO)', 'rocauc(E)', 'rocauc(Cons)',
       'rocauc(Pn)', 'rocauc(A)', 'rocauc(Pt)', 'rocauc(PE)', 'rocauc(PO)',
       'rocauc(F)', 'rocauc(SD)', 'rocauc(macro)', 'rocauc(micro)']]

Unnamed: 0,rocauc(NF),rocauc(EC),rocauc(Card),rocauc(LL),rocauc(LO),rocauc(E),rocauc(Cons),rocauc(Pn),rocauc(A),rocauc(Pt),rocauc(PE),rocauc(PO),rocauc(F),rocauc(SD),rocauc(macro),rocauc(micro)
4,0.793822,0.590963,0.798001,0.777315,0.672072,0.808137,0.649639,0.703102,0.710488,0.740862,0.850573,0.824107,0.758299,0.884811,0.754442,0.817648
2,0.786324,0.615614,0.799244,0.735185,0.680547,0.814557,0.664686,0.695594,0.717183,0.699618,0.845546,0.803908,0.71941,0.871197,0.746329,0.813498
6,0.780484,0.613659,0.798339,0.736969,0.673573,0.815019,0.663117,0.690793,0.720592,0.701428,0.846194,0.807638,0.719522,0.872031,0.745668,0.800462
1,0.853114,0.562984,0.745189,0.727991,0.631482,0.780446,0.624073,0.64945,0.701909,0.670715,0.82302,0.738412,0.689397,0.873454,0.719403,0.829729
7,0.841972,0.553308,0.735336,0.695531,0.621318,0.774836,0.574894,0.628988,0.671234,0.608714,0.805751,0.667616,0.694885,0.849394,0.694556,0.809321
0,0.862668,0.55878,0.761944,0.661813,0.63868,0.771498,0.622263,0.619778,0.689133,0.640035,0.824372,0.755787,0.647075,0.866062,0.708563,0.826854
3,0.851405,0.556778,0.751982,0.690446,0.644467,0.775955,0.598074,0.638782,0.682262,0.673938,0.823735,0.771059,0.709032,0.874094,0.717286,0.824832
5,0.799065,0.596243,0.788781,0.755524,0.670785,0.813054,0.625831,0.701697,0.716578,0.687786,0.846187,0.797126,0.732497,0.889047,0.7443,0.832714


In [54]:
for i, name in enumerate(df_sorted['method_name']):
    print(f'({i})    ',name)
    print()

(0)     20220627_035559_mimiccxr+iuxray+chexpert_densenet121_model-args=(densenet-121,1024)_cnn-pretr=0_dataset_weights=1.0,0.08,0.8_img-aug=random-color-and-spatial_tags=0_orien=1_chx=1_ql=1_use_amp - mimiccxr_metrics.pkl

(1)     20220613_180648_mimiccxr+iuxray+chexpert_densenet121_model-args=(densenet-121,1024)_cnn-pretr=0_dataset_weights=1.0,0.08,0.8_img-aug=random-spatial_tags=0_orien=1_chx=1_ql=1 - mimiccxr_metrics.pkl

(2)     20220614_103906_mimiccxr+iuxray+chexpert_densenet121_model-args=(densenet-121,1024,pretrained)_cnn-pretr=0_dataset_weights=1.0,0.08,0.8_img-aug=random-spatial_tags=0_orien=1_chx=1_ql=1_use_amp - mimiccxr_metrics.pkl

(3)     20220601_073528_mimiccxr+iuxray_densenet121_model-args=(densenet-121,1024)_cnn-pretr=0_mim-iux-freqs=200,10_img-aug=random-spatial_tags=0_orien=1_chx=1_ql=1_use_amp - mimiccxr_metrics.pkl

(4)     20220531_213512_mimiccxr+iuxray_densenet121_model-args=(densenet-121,1024)_cnn-pretr=0_mim-iux-freqs=200,10_tags=0_orien=1_chx=1_ql=1_use_am

In [46]:
from medvqa.utils.files import load_json_file

In [59]:
load_json_file(os.path.join(WORKSPACE_DIR, 'models', 'visual_module',
                           '20220627_035559_mimiccxr+iuxray+chexpert_densenet121_model-args=(densenet-121,1024)_cnn-pretr=0_dataset_weights=1.0,0.08,0.8_img-aug=random-color-and-spatial_tags=0_orien=1_chx=1_ql=1_use_amp',
                           'metadata.json'))

{'model_kwargs': {'image_local_feat_size': 1024,
  'pretrained': True,
  'densenet_pretrained_weights_path': None,
  'pretrained_checkpoint_folder_path': None,
  'classify_tags': False,
  'classify_orientation': True,
  'classify_chexpert': True,
  'classify_questions': True,
  'n_medical_tags': None,
  'n_questions': 97,
  'use_chexpert_forward': True},
 'optimizer_kwargs': {'lr': 0.001},
 'lr_scheduler_kwargs': {'factor': 0.76, 'patience': 3},
 'mimiccxr_vision_trainer_kwargs': {'preprocessed_data_filename': 'mimiccxr_preprocessed_train_data__(hash=292,4086290602689741595).pkl'},
 'iuxray_vision_trainer_kwargs': {'preprocessed_data_filename': 'iuxray_preprocessed_train_data__(hash=247,1338506350569278595).pkl'},
 'dataloading_kwargs': {'img_aug_mode': 'random-color-and-spatial',
  'mimiccxr_weight': 1.0,
  'iuxray_weight': 0.08,
  'chexpert_weight': 0.8},
 'training_kwargs': {'use_amp': True,
  'train_mimiccxr': True,
  'train_iuxray': True,
  'train_chexpert': True,
  'binary_loss_n