In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session



def post_process(exam_pred, image_pred):
    
    '''
    Inputs:
    
    - exam_pred: 9 predictions for each study/exam
    - image_pred: 1 prediction for each image
    
    '''
    
    rv_lv_ratio_lt_1_ix = CFG['exam_target_cols'].index('rv_lv_ratio_lt_1')
    rv_lv_ratio_gte_1_ix = CFG['exam_target_cols'].index('rv_lv_ratio_gte_1')
    central_pe_ix = CFG['exam_target_cols'].index('central_pe')
    rightsided_pe_ix = CFG['exam_target_cols'].index('rightsided_pe')
    leftsided_pe_ix = CFG['exam_target_cols'].index('leftsided_pe')
    acute_and_chronic_pe_ix = CFG['exam_target_cols'].index('acute_and_chronic_pe')
    chronic_pe_ix = CFG['exam_target_cols'].index('chronic_pe')
    negative_exam_for_pe_ix = CFG['exam_target_cols'].index('negative_exam_for_pe')
    indeterminate_ix = CFG['exam_target_cols'].index('indeterminate')
    
    # rule 1 or rule 2 judgement: if any pe image exist
    has_pe_image = torch.max(image_pred, 1)[0][0] > 0
    #print(has_pe_image)
    
    # rule 1-a: only one >= 0.5, the other < 0.5
    rv_lv_ratios = exam_pred[:, [rv_lv_ratio_lt_1_ix, rv_lv_ratio_gte_1_ix]]
    rv_lv_ratios_1_a = nn.functional.softmax(rv_lv_ratios, dim=1) # to make one at least > 0.5
    rv_lv_ratios_1_a = torch.log(rv_lv_ratios_1_a/(1-rv_lv_ratios_1_a)) # turn back into logits
    exam_pred[:, [rv_lv_ratio_lt_1_ix, rv_lv_ratio_gte_1_ix]] = torch.where(has_pe_image, rv_lv_ratios_1_a, rv_lv_ratios)
    
    # rule 1-b-1 or 1-b-2 judgement: at least one > 0.5
    crl_pe = exam_pred[:, [central_pe_ix, rightsided_pe_ix, leftsided_pe_ix]]
    has_no_pe = torch.max(crl_pe ,1)[0] <= 0 # all <= 0.5
    #print(has_no_pe)
    #assert False
        
    # rule 1-b
    max_val = torch.max(crl_pe, 1)[0]
    crl_pe_1_b = torch.where(crl_pe==max_val, 0.0001-crl_pe+crl_pe, crl_pe)
    exam_pred[:, [central_pe_ix, rightsided_pe_ix, leftsided_pe_ix]] = torch.where(has_pe_image*has_no_pe, crl_pe_1_b, crl_pe)
    
    # rule 1-c-1 or 1-c-2 judgement: at most one > 0.5
    ac_pe = exam_pred[:, [acute_and_chronic_pe_ix, chronic_pe_ix]]
    both_ac_ch = torch.min(ac_pe ,1)[0] > 0 # all > 0.5
    
    # rule 1-c
    ac_pe_1_c = nn.functional.softmax(ac_pe, dim=1) # to make only one > 0.5
    ac_pe_1_c = torch.log(ac_pe_1_c/(1-ac_pe_1_c)) # turn back into logits
    exam_pred[:, [acute_and_chronic_pe_ix, chronic_pe_ix]] = torch.where(has_pe_image*both_ac_ch, ac_pe_1_c, ac_pe)
    
    # rule 1-d
    neg_ind = exam_pred[:, [negative_exam_for_pe_ix, indeterminate_ix]]
    neg_ind_1d = torch.clamp(neg_ind, max=0)
    exam_pred[:, [negative_exam_for_pe_ix, indeterminate_ix]] = torch.where(has_pe_image, neg_ind_1d, neg_ind)
    
    # rule 2-a
    ne_inde = exam_pred[:, [negative_exam_for_pe_ix, indeterminate_ix]]
    ne_inde_2_a = nn.functional.softmax(ne_inde, dim=1) # to make one at least > 0.5
    ne_inde_2_a = torch.log(ne_inde_2_a/(1-ne_inde_2_a)) # turn back into logits
    exam_pred[:, [negative_exam_for_pe_ix, indeterminate_ix]] = torch.where(~has_pe_image, ne_inde_2_a, ne_inde)
    
    # rule 2-b
    all_other_exam_labels = exam_pred[:, [rv_lv_ratio_lt_1_ix, rv_lv_ratio_gte_1_ix,
                                          central_pe_ix, rightsided_pe_ix, leftsided_pe_ix,
                                          acute_and_chronic_pe_ix, chronic_pe_ix]]
    all_other_exam_labels_2_b = torch.clamp(all_other_exam_labels, max=0)
    exam_pred[:, [rv_lv_ratio_lt_1_ix, rv_lv_ratio_gte_1_ix,
                  central_pe_ix, rightsided_pe_ix, leftsided_pe_ix,
                  acute_and_chronic_pe_ix, chronic_pe_ix]] = torch.where(~has_pe_image, all_other_exam_labels_2_b, all_other_exam_labels)
    
    return exam_pred, image_pred