In [26]:
import os
import numpy as np

from medvqa.evaluation.results import collect_multimodal_question_probs
from medvqa.models.ensemble import QuestionClassificationEnsembleSearcher
from medvqa.utils.files import load_json_file, save_to_pickle
from medvqa.utils.common import RESULTS_DIR, get_timestamp
from medvqa.datasets.mimiccxr import MIMICCXR_CACHE_DIR

In [2]:
qprobs_paths = collect_multimodal_question_probs('mimiccxr')
qprobs_paths

['/home/pamessina/medvqa-workspace/results/multimodal/20220907_043140_mim+iu+chexp+cxr14+vinbig_imgtxtenc(dense121+chx-emb+txtenc=bilstm+txtdec=lstm)_visenc-pretr=0_dws=1.0,0.2,0.3,0.25,0.25_orien_chx_ql_amp/mimiccxr_question_probs.pkl',
 '/home/pamessina/medvqa-workspace/results/multimodal/20220913_080931_mim+iu+chexp_imgtxtenc(dense121(niqc)+chx-emb+txtenc=bilstm+txtdec=lstm)_visenc-pretr=0_dws=1.0,0.2,0.3_orien_chx_ql_amp/mimiccxr_question_probs.pkl',
 '/home/pamessina/medvqa-workspace/results/multimodal/20220907_084600_mim+iu+chexp+cxr14+vinbig_imgtxtenc(dense121+chx-emb+txtenc=bilstm+txtdec=lstm)_visenc-pretr=0_dws=1.0,0.2,0.3,0.25,0.25_orien_chx_ql_amp/mimiccxr_question_probs.pkl',
 '/home/pamessina/medvqa-workspace/results/multimodal/20220913_062801_mim+iu+chexp_imgtxtenc(dense121+chx-emb+txtenc=bilstm+txtdec=lstm)_visenc-pretr=0_dws=1.0,0.2,0.3_orien_chx_ql_amp/mimiccxr_question_probs.pkl']

In [3]:
qces = QuestionClassificationEnsembleSearcher(
    probs_paths=qprobs_paths,
    qa_adapted_reports_path=os.path.join(MIMICCXR_CACHE_DIR, 'qa_adapted_reports__20220904_095810.json'),
)

In [4]:
qces.sample_weights(300)

100%|██████████| 300/300 [00:36<00:00,  8.22it/s]


In [5]:
qces.evaluate_best_predictions()

[34mf1(macro)=0.299201708494498, f1(micro)=0.5955972925178366, score=0.8947990010123346[0m


0.8947990010123346

In [12]:
[(x.score, x.threshold, x.weights) for x in qces.minheaps[-5]]

[(0.38461538461538464,
  0.8040123106521324,
  array([0.53496192, 0.10465651, 0.29535309, 0.06502848])),
 (0.38461538461538464,
  0.8073076238056862,
  array([0.51983015, 0.09818563, 0.30610275, 0.07588147])),
 (0.38461538461538464,
  0.8050583495051841,
  array([0.52830783, 0.09907984, 0.29418287, 0.07842946])),
 (0.38461538461538464,
  0.8062306027190893,
  array([0.53439294, 0.09873397, 0.29036742, 0.07650567])),
 (0.38575667655786355,
  0.8088364700182914,
  array([0.523166  , 0.10756509, 0.29305861, 0.07621029])),
 (0.38575667655786355,
  0.8074481627172204,
  array([0.52393398, 0.11163986, 0.29977492, 0.06465124]))]

In [10]:
qces.sample_weights_from_previous_ones(400, 0.1)

100%|██████████| 400/400 [00:49<00:00,  8.10it/s]


In [11]:
qces.evaluate_best_predictions()

[34mf1(macro)=0.3106654827840452, f1(micro)=0.6030184294772752, score=0.9136839122613205[0m


0.9136839122613205

In [13]:
output = qces.compute_best_merged_probs_and_thresholds()

[34mf1(macro)=0.3106654827840452, f1(micro)=0.6030184294772752, score=0.9136839122613205[0m


In [14]:
output.keys()

dict_keys(['merged_probs', 'thresholds', 'score'])

In [17]:
next(iter(output['merged_probs'].values()))

array([1.61876834e-04, 3.87144737e-01, 9.65881010e-02, 2.20055367e-02,
       7.79192984e-01, 1.57798554e-02, 2.36720579e-01, 1.24690644e-01,
       9.53960401e-01, 5.33372968e-04, 1.03566289e-01, 6.99366772e-01,
       2.20276143e-04, 8.34071380e-05, 1.07155428e-02, 1.53473486e-02,
       1.45633232e-01, 7.29708180e-05, 5.46124458e-03, 1.56843920e-02,
       1.75741986e-02, 3.72888526e-04, 1.64513427e-01, 6.59176931e-01,
       2.93133571e-01, 2.40525915e-02, 4.03386703e-01, 1.07309104e-01,
       5.81464687e-03, 1.65350716e-04, 9.91622431e-03, 2.37468525e-01,
       3.89375525e-01, 3.89826048e-01, 5.81195661e-01, 2.29868388e-01,
       1.05608077e-01, 9.34787037e-01, 1.05396855e-01, 2.83167437e-02,
       4.09914864e-01, 8.73935720e-02, 8.95769726e-01, 2.41823181e-02,
       1.40539986e-01, 1.95856220e-03, 1.80294619e-01, 3.78161908e-01,
       1.58542057e-01, 1.36982951e-02, 2.71887577e-01, 2.12285147e-01,
       1.62528389e-01, 5.04716157e-05, 8.93750703e-02, 1.75765395e-03,
      

In [22]:
strings = [
    f'n={len(qprobs_paths)}',
    f'score={output["score"]:.4f}',
    f't={get_timestamp()}',
]
merged_probs_save_path = os.path.join(RESULTS_DIR, 'multimodal', f'mimiccxr_ensemble({",".join(strings)})_probs.pkl')
thresholds_save_path = os.path.join(RESULTS_DIR, 'multimodal', f'mimiccxr_ensemble({",".join(strings)})_thresholds.pkl')

In [24]:
merged_probs_save_path, thresholds_save_path

('/home/pamessina/medvqa-workspace/results/multimodal/mimiccxr_ensemble(n=4,score=0.9137,t=20220914_063812)_probs.pkl',
 '/home/pamessina/medvqa-workspace/results/multimodal/mimiccxr_ensemble(n=4,score=0.9137,t=20220914_063812)_thresholds.pkl')

In [27]:
save_to_pickle(output['merged_probs'], merged_probs_save_path)

In [28]:
save_to_pickle(output['thresholds'], thresholds_save_path)