<p align="center">
  <h1 align="center">CheXRay: Automatically Diagnosing Chest X-Rays using Generated Radiologist Reports and Patient Information </h1>
</p>

In [1]:
from IPython.display import display, HTML, clear_output

In [None]:
!mkdir models/
import gdown
urls = ["https://drive.google.com/u/0/uc?export=download&confirm=7rWd&id=12gqE8PZUn6aE0akNpE_dd-xI_bp3LZ20",
        "https://drive.google.com/u/0/uc?export=download&confirm=7rWd&id=1yS4XJzEI_lIGOraFMoMtqhcCAMxnQzT1",
        "https://drive.google.com/u/0/uc?export=download&confirm=7rWd&id=1Pzhd5qdXYWX7zNYBidO-WHKT0CJGJF1H"]
outputs = ['models/sum.0.0.pth',
           'models/repgen.0.0.pth',
           'models/txtcls.pkl']

gdown.download(urls[0], outputs[0], quiet=True)
gdown.download(urls[1], outputs[1], quiet=True)
gdown.download(urls[2], outputs[2], quiet=True)

clear_output()

In [2]:
#Modules for helper functions
from modules.utils.dicom import PILDicom2 #Because PILDicom from fastai doesn't work
    
#Fast.ai modules
from fastai.data.core import Datasets, DataLoaders
from fastai.data.block import MultiCategoryBlock
from fastai.torch_core import TensorImage, to_np
from fastai.text.core import BaseTokenizer
from fastai.tabular.core import make_date, cont_cat_split, Categorify, FillMissing, TabularPandas

from fastai.vision.augment import Resize, aug_transforms
from fastai.data.transforms import IntToFloatTensor, Normalize, EncodedMultiCategorize, ToTensor
from fastcore.foundation import L
from fastcore.transform import Transform
from fastai.text.data import Numericalize, pad_input, SortedDL

from fastai.vision.models.xresnet import xresnet18
from fastai.text.models.awdlstm import AWD_LSTM
from fastai.learner import Learner
from fastai.vision.learner import cnn_learner
from fastai.tabular.learner import tabular_learner
from fastai.text.learner import text_classifier_learner

#Modules for R2Gen/multimodal
from modules.repgen.dataset import RepGenDataset
from modules.repgen.dataloader import create_batch
from modules.repgen.model import R2GenModel
from modules.repgen.loss import compute_loss
from modules.repgen.fastai_utils import rep_gen, SelectPred
from modules.repgen.metrics import bleu4

#Modules for sum
from modules.sum.dataloader import SumDL  
import modules.sum.logits as log1
from modules.sum.model import SumModel
from modules.sum.loss import SumGradientBlending
from modules.sum.fastai_utils import sum_splitter
from modules.sum.metrics import ap_weighted

#Other libraries
import re
import gc
import html
import torch
import pickle
import multiprocessing
import numpy as np
import pandas as pd
import matplotlib.cm as cm
import matplotlib.pylab as plt
from pathlib import Path
from functools import partial
from datetime import datetime
from types import SimpleNamespace
from fastai_minima.optimizer import Adam
from ipywidgets import VBox,widgets,Button,Layout,Box,Output,Label,FileUpload
import warnings
warnings.filterwarnings('ignore')
import language_tool_python
tool = language_tool_python.LanguageToolPublicAPI('en-US')

In [3]:
ap_direct = widgets.Label()
ap_axial_direct = widgets.Label()
ap_lld_direct = widgets.Label()
ap_rld_direct = widgets.Label()
pa_direct = widgets.Label()
pa_lld_direct = widgets.Label()
pa_rld_direct = widgets.Label()
lat_direct = widgets.Label()
ll_direct = widgets.Label()
lao_direct = widgets.Label()
rao_direct = widgets.Label()
swim_direct = widgets.Label()
xtab_lat_direct = widgets.Label()
lpo_direct = widgets.Label()

ap_direct.value = "Upload any AP views here:"
ap_axial_direct.value = "Upload any AP axial views here:"
ap_lld_direct.value = "Upload any AP LLD views here:"
ap_rld_direct.value = "Upload any AP RLD views here:"
pa_direct.value = "Upload any PA views here:"
pa_lld_direct.value = "Upload any PA LLD views here:"
pa_rld_direct.value = "Upload any PA RLD views here:"
lat_direct.value = "Upload any lateral views here:"
ll_direct.value = "Upload any LL views here:"
lao_direct.value = "Upload any LAO views here:"
rao_direct.value = "Upload any RAO views here:"
swim_direct.value = "Upload any swimmers views here:"
xtab_lat_direct.value = "Upload any xtable lateral views here:"
lpo_direct.value = "Upload any LPO views here:"

ap_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm')
ap_axial_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
ap_lld_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm')
ap_rld_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
pa_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
pa_lld_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
pa_rld_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
lat_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
ll_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
lao_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
rao_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
swim_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
xtab_lat_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 
lpo_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

gen_report = widgets.Label()
id_cond = widgets.Label()
gen_sumvis = widgets.Label()
summary = widgets.HTML(value='<style>p{word-wrap: break-word}</style><p>')
out_pl = widgets.Output()
diagnose = widgets.Button(description='Diagnose')

In [4]:
#Making Path object which contains path to data
prep = Path('./data/')
prod_path = Path('./sample/')
classes=["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Enlarged_Cardiomediastinum", "Fracture", "Lung_Lesion", 
         "Lung_Opacity", "No_Finding", "Pleural_Effusion", "Pleural_Other", "Pneumonia", "Pneumothorax", "Support_Devices"]
views = ['AP','AP_AXIAL','AP_LLD','AP_RLD','PA','PA_LLD','PA_RLD','LATERAL','LL','LAO','RAO','SWIMMERS','XTABLE_LATERAL','LPO']
number_views = [144818, 2, 2, 2, 95145, 2058, 339, 81939, 42371, 5188, 3, 13, 2, 1]
workers = multiprocessing.cpu_count()
defaults = SimpleNamespace(cpus=workers, cmap='viridis', return_fig=False, silent=False)
defaults.device = torch.device('cpu') 
cpu = torch.device("cpu")
beta=2
buttons = [ap_direct, ap_btn_upload, ap_axial_direct, ap_axial_btn_upload, ap_lld_direct, ap_lld_btn_upload, ap_rld_direct,
           ap_rld_btn_upload, pa_direct, pa_btn_upload, pa_lld_direct, pa_lld_btn_upload, pa_rld_direct, pa_rld_btn_upload,
           lat_direct, lat_btn_upload, ll_direct, ll_btn_upload, lao_direct, lao_btn_upload, rao_direct, rao_btn_upload,
           swim_direct, swim_btn_upload, xtab_lat_direct, xtab_lat_btn_upload, lpo_direct, lpo_btn_upload, diagnose,
           gen_report, id_cond, gen_sumvis, summary, out_pl]

In [5]:
def on_click_classify(change): 
    for button in buttons[:-5]: button.close() # Except output buttons
    gen_report.value += "Writing Report..."
    
    # Collecting data from buttons
    input_views = []
    input_paths = []   
    idx = -6
    upload_buttons = [buttons[:idx][x] for x in list(range(1, len(buttons[:idx]), 2))]
    for idx in range(len(upload_buttons)): 
        button = upload_buttons[idx]
        value = {f["name"]: f.content.tobytes() for f in button.value}
        data = [f.content.tobytes() for f in button.value]
        if data:
            input_views.append(views[idx])
            for path in range(len(data)):
                temp_path = prod_path/str(views[idx]+"_"+str(path)+'.dcm')
                input_paths.append(temp_path)
                with open(temp_path, 'wb') as f: 
                    f.write(value[list(value.keys())[0]]) 
    
    single_repgen_trainval_sample_path = prep/'trainval_sample_repgen_nomiss.csv'
    vocab_path = Path('modules/repgen/vocab.pkl')
    trainval_sample_single = pd.read_csv(single_repgen_trainval_sample_path)
    trainval_sample_single['images']=prod_path/"AP_0.dcm"
    trainval_sample_single = trainval_sample_single.drop([10728, 10729])
    train_sample_single = trainval_sample_single[trainval_sample_single['split']==False]
    val_sample_single = trainval_sample_single[trainval_sample_single['split']==True]
    train_sample_single.reset_index(drop=True, inplace=True)
    val_sample_single.reset_index(drop=True, inplace=True)
    with open(vocab_path, 'rb') as f: vocab = pickle.load(f)    
        
    df = trainval_sample_single.iloc[:len(input_paths)].copy()
    df['images'] = input_paths

    isval=False
    viewtype='images' 
    ispred=False
    train_sample_dataset = RepGenDataset(train_sample_single,isval, viewtype, ispred, classes) 
    isval=True
    val_sample_dataset = RepGenDataset(val_sample_single,isval, viewtype, ispred, classes) 
    bs=16
    trainval_sample_dls = DataLoaders.from_dsets(train_sample_dataset, val_sample_dataset, bs=bs, create_batch=create_batch, device=cpu, num_workers=workers, shuffle=True)
    trainval_sample_dls.valid = trainval_sample_dls.valid.new(shuffle=False)

    # Model settings (for visual extractor)
    visual_extractor='resnet50' #'resnet101'
    pretrained=True
    # Model settings (for Transformer)  
    num_layers=3 #number of layers of Transformer
    d_model=512 #dimension of Transformer
    d_ff=512 #dimension of FFN
    num_heads=8 #number of heads in Transformer
    dropout=0.261 #dropout rate of Transformer
    use_bn = 0 #whether to use batch normalization
    drop_prob_lm = 0.5958
    max_seq_len = 100
    att_feat_size = 2048 #dimension of the patch features (d_vf in main.py)
    ## Not used in original/current, but included in main.py
    #parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') 
    # for Relational Memory    
    rm_num_slots=3
    rm_num_heads=8
    rm_d_model=512
    # for Sampling
    beam_size = 3 #beam size when beam searching
    group_size = 1
    sample_n = 1 #sample number per image
    sample_method = "beam_search" #sample methods to sample a report
    temperature = 1.0 #temperature when sampling
    output_logsoftmax = 1 #whether to output the probabilities
    decoding_constraint = 0
    block_trigrams = 1
    # More params (not in main.py, but used in original/current)
    diversity_lambda = 0.5       
    input_encoding_size = 512
    suppress_UNK = 0 
    length_penalty = ''
    mode='forward'
    model = R2GenModel(visual_extractor, pretrained, num_layers, d_model, d_ff, num_heads, dropout, rm_num_slots, rm_num_heads,
                       rm_d_model, vocab, input_encoding_size, drop_prob_lm, max_seq_len, att_feat_size, use_bn, beam_size,
                       group_size, sample_n, sample_method, temperature, output_logsoftmax, decoding_constraint, block_trigrams,
                       diversity_lambda, suppress_UNK, length_penalty, mode)
    model = model.to(cpu)

    wd=3.734e-0

    learn = Learner(trainval_sample_dls, model, loss_func=compute_loss, wd=wd, splitter=rep_gen, metrics=[bleu4], cbs=SelectPred)
    ve_params = list(map(id, model.visual_extractor.parameters()))                                                  
    ed_params = filter(lambda x: id(x) not in ve_params, model.parameters())
    opt_func = Adam([{'params': model.visual_extractor.parameters()},
                   {'params': ed_params}],
                   lr = learn.lr,
                   betas=(0.9, 0.98),
                   weight_decay=wd, 
                   amsgrad=True)
    learn.opt = opt_func

    learn.load("repgen.0.0", device=cpu)
    learn.model.mode='sample'
    def passfunc(arg): return arg #Make last arg for learn.predict to not decode anything
    def decode(pred): #Convert idx_report to report
        words = [] #For every word in report (size rep_len)
        for report in pred:
            for word in report: #For each word in report
                txtword = vocab[word] #word = index for vocab
                if txtword not in [word for word in vocab if word[:2]=="xx"]: words.append(txtword) 
        return " ".join(words)
    learn.dls.decode = passfunc
    learn.dls.decode_batch = passfunc

    rep_input_view = ""
    if len(input_views)>1:
        for i in input_views:
            if rep_input_view=="": rep_input_view = i
            else: 
                if number_views[views.index(rep_input_view)] < number_views[views.index(i)]: 
                    rep_input_view = i
    else: rep_input_view = input_views[0]
        
    image = ""
    for img in input_paths:
        temp = str(img).split(".")[0].split("/")[1].split("_")[:-1]
        if len(temp)==1: 
            if temp[0]==rep_input_view: 
                image = img
                break
        else: 
            if "_".join(temp)==rep_input_view: 
                image = img
                break
    
    idx = df[df['images']==image].index[0]
    ispred=True
    pred_dataset = RepGenDataset(df.iloc[idx:idx+1], isval, 'images', ispred, classes) #
    gts, rep, _ = learn.predict(pred_dataset[idx])
    if decode(rep)[-2:] != " ." or decode(rep)[-2:] != ". ": report = decode(rep) + ' . '
    else: report = decode(rep)
    df['reports'] = report

    del trainval_sample_dls
    del model
    del learn
    del passfunc
    gc.collect()

    month = str(datetime.now().month)
    if int(month) < 10: month = "0"+str(month)
    day = str(datetime.now().day)
    if int(day) < 10: day = "0"+str(day)
    df['StudyElapsed'] = str(datetime.now().year)+'-'+month+'-'+day
    make_date(df, 'StudyElapsed')
    df['StudyElapsed'].values.astype(np.int64) // 10 ** 9
    df['Minutes'] = datetime.now().minute
    df['Hour'] = datetime.now().hour
    df['Seconds'] = datetime.now().second
    df['StudyWeek'] = datetime.now().isocalendar()[1]
    df['StudyDay'] = datetime.now().day
    df['StudyDayofweek'] = datetime.now().isocalendar()[2]
    df['StudyDayofyear'] = datetime.now().timetuple().tm_yday
    df['StudyElapsed'] = df['StudyElapsed'].values.astype(np.int64) // 10 ** 9
    
    gen_report.close()
    id_cond.value = "Identifying Conditions..."

    size=224
    seq_len=72
    bs=16
    val_bs=len(val_sample_single)
    test_bs=len(df)
    train_dls = []
    val_dls = []
    test_dls = []

    with open(Path('./modules/txtcls/vocab.pkl'), 'rb') as f: vocab = pickle.load(f) 
    cont_nn,cat_nn = cont_cat_split(trainval_sample_single, max_card=365, dep_var=classes)
    for frame in [trainval_sample_single, df]:
        frame[['Minutes', 
            'Hour', 
            'Seconds', 
            'StudyWeek', 
            'StudyDay', 
            'StudyDayofweek', 
            'StudyDayofyear',
            'StudyElapsed']] = frame[['Minutes', 
                                     'Hour', 
                                     'Seconds', 
                                     'StudyWeek', 
                                     'StudyDay', 
                                     'StudyDayofweek', 
                                     'StudyDayofyear',
                                     'StudyElapsed']].astype('int32')
        
    trainval_imgs = list(trainval_sample_single['images'])
    test_imgs = list(df['images'])
    
    def get_names(pand):
        fnames = list(pand['reports'])
        fnames = [[text] for text in fnames]
        return L(fnames)
    trainval_txts = get_names(trainval_sample_single)
    test_txts = get_names(df)
    
    def formatting(tokens): return list(tokens)[0]
    tfm = Transform(formatting)
    
    def get_labels(fname): return [0]*13 + [1]
        
    def vis_dls(bs, name_list):
        imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        dsets = Datasets(name_list, 
                         [[PILDicom2.create], 
                          [get_labels, EncodedMultiCategorize(vocab=classes)]],
                         splits=None)
        item_tfms=[Resize(460), ToTensor]
        batch_tfms=[IntToFloatTensor(div=2**16-1), Normalize.from_stats(*imagenet_stats), *aug_transforms(size=size)]
        return dsets.dataloaders(after_item=item_tfms, after_batch=batch_tfms, bs=bs, num_workers=workers)
        
    def txt_dls(bs, seq_len, name_list):
        dsets = Datasets(name_list, 
                         [[BaseTokenizer(), tfm, Numericalize(vocab=vocab)], 
                          [get_labels, EncodedMultiCategorize(vocab=classes)]],
                         splits=None)
        return dsets.dataloaders(bs=bs, seq_len=seq_len, num_workers=workers, before_batch=pad_input, dl_type=SortedDL)

    def tab_dls(bs, path):
        procs_nn = [Categorify, FillMissing, Normalize]
        return TabularPandas(path, procs_nn, None, cont_nn, splits=None, y_block=MultiCategoryBlock(encoded=True, vocab=classes), 
                              y_names=classes).dataloaders(bs, num_workers=workers)
        
    def get_dls(istest):
        if istest:
            test_dls.append(vis_dls(test_bs, test_imgs)[0].to("cpu"))
            test_dls.append(txt_dls(test_bs, seq_len, test_txts)[0].to("cpu"))
            test_dls.append(tab_dls(test_bs, df)[0].to("cpu"))
            return SumDL(*test_dls, device=cpu)
        else:
            train_dls.append(vis_dls(bs, trainval_imgs)[0])
            val_dls.append(vis_dls(val_bs, trainval_imgs)[1])
            train_dls.append(txt_dls(bs, seq_len, trainval_txts)[0])
            val_dls.append(txt_dls(val_bs, seq_len, trainval_txts)[1])
            train_dls.append(tab_dls(bs, train_sample_single)[0])
            val_dls.append(tab_dls(val_bs, val_sample_single)[1])
            return DataLoaders(SumDL(*train_dls, device=cpu), SumDL(*val_dls, device=cpu))   
    mixed_dls = get_dls(False)

    def calcHiddenLayer(data, alpha, numHiddenLayers):
        i, o = len(list(trainval_sample_single.columns)[2:10]), len(classes)
        io = i+o
        return [(len(data)//(alpha*(io)))//numHiddenLayers]*numHiddenLayers
    
    drop_mult=0.3263
    model=xresnet18
    alpha=2
    numHiddenLayers=2
    layers=calcHiddenLayer(train_dls[-1], alpha, numHiddenLayers)
    txtcls_learn = text_classifier_learner(txt_dls(bs, seq_len, trainval_txts), AWD_LSTM, drop_mult=drop_mult)

    sum_model = SumModel(cnn_learner(vis_dls(bs, trainval_imgs), model).model,
                     txtcls_learn.model, 
                     tabular_learner(tab_dls(bs, trainval_sample_single), layers=layers).model, 
                     len(classes))

    # Set loss_scale for each loss
    weights = [3/17, 9/17, 1/17, 4/17]
    loss_scale = 1.07
    loss = SumGradientBlending(loss_scale, *weights)
    
    thresh=0.43
    
    ap_w = partial(ap_weighted, weights=weights)
    metrics = [ap_w]

    sum_learn = Learner(mixed_dls.to("cpu"), sum_model.to("cpu"), loss, splitter=sum_splitter, metrics=metrics)
    sum_learn.freeze_to(-4)
    name = 'sum.0.0'
    sum_learn.load(name, device=cpu)
    sum_learn.dls = sum_learn.dls.to(cpu)
    sum_learn.model = sum_learn.model.to(cpu)

    pred_mixed_dls = get_dls(True)
    preds,_ = sum_learn.get_preds(dl=pred_mixed_dls)
    
    def decode_prob(preds):
        all_inp=0
        preds = torch.stack(preds)
        for weight in range(len(weights)): all_inp += preds[weight] * weights[weight]
        preds = all_inp/len(weights)
        preds = preds.sigmoid()
        return preds
    def decode_rep(preds, thresh=0.5):
        preds = decode_prob(preds)
        preds[preds>=thresh] = 1
        preds[preds<thresh] = 0
        return preds

    confs = decode_prob(preds)
    class_preds = decode_rep(preds, thresh)
    
    sum_input_views = []
    avg_num_view = 2
    if len(input_views)>avg_num_view: #Because most studies have two views
        for i in input_views:
            if len(sum_input_views)<avg_num_view: sum_input_views.append(i)
            else: 
                compare = [number_views[views.index(sum_input_views[j])] for j in range(avg_num_view)]
                if number_views[views.index(i)] < min(compare): 
                    sum_input_views[compare.index(min(compare))] = i
    else: sum_input_views.extend(input_views)
    num_view = [number_views[views.index(view)] for view in sum_input_views]
    sum_input_views = [x for _, x in sorted(zip(num_view, sum_input_views))]
    
    sum_input_imgs = []
    sum_input_idxs = [] #Images that are of view with more examples in front, know with less_view_count
    less_view_count = 0
    for img in input_paths:
        temp = str(img).split(".")[0].split("/")[1].split("_")[:-1]
        if len(temp)==1: compare = temp[0]
        else: compare = "_".join(temp)
        if compare == sum_input_views[0]: 
            sum_input_imgs.append(img)
            sum_input_idxs.append(df[df['images']==img].index[0])
        if len(sum_input_views)>1:
            if compare == sum_input_views[1]: 
                sum_input_imgs.insert(less_view_count, img)
                sum_input_idxs.insert(less_view_count, df[df['images']==img].index[0])
                less_view_count+=1
                
    dl_list_idxs = []
    for dl in test_dls: 
        try: dl_list_idxs.append(dl.get_idxs())
        except: 
            temp = []
            for idx in dl.get_idxs(): temp.append(idx)
            dl_list_idxs.append(temp)
    dl_list_idxs = dl_list_idxs[0]   
    
    pred_list_idxs = []
    for i in sum_input_idxs: pred_list_idxs.append(dl_list_idxs.index(i))

    def get_results(is_pos):
        results = []
        for i in pred_list_idxs:
            temp = []
            temp1 = []
            for j in range(len(class_preds[i])):
                if class_preds[i][j]==is_pos: 
                    temp.append(classes[j])
                    if is_pos: temp1.append(confs[i][j].item())
                    else: temp1.append(1 - confs[i][j].item())
            results.append({temp[i]: temp1[i] for i in range(len(temp))})
        return results
    
    def get_summary():
        neg_results = get_results(0)
        pos_results = get_results(1)
        confs_select_neg, class_names_neg, confs_select_pos, class_names_pos = [], [], [], []
        is_single_img = True if len(pred_list_idxs)<2 else False
         
        if is_single_img:
            def fill_lists(results, select, names):
                for condition in classes:
                    for dic in results:
                        if condition in dic.keys() and condition not in names: 
                            names.append(condition)
                            if len(names) > len(select): select.append(dic[condition])
                            else: 
                                if dic[condition] > select[-1]: select[-1]=dic[condition]
                return select, names
            confs_select_neg, class_names_neg = fill_lists(neg_results, confs_select_neg, class_names_neg)
            confs_select_pos, class_names_pos = fill_lists(pos_results, confs_select_pos, class_names_pos)
        else: #Majority vote -> Max Conf -> Pos just to be safe
            for condition in classes:
                pos_count = np.sum(np.array([condition in dic.keys() for dic in pos_results]))
                neg_count = np.sum(np.array([condition in dic.keys() for dic in neg_results]))
                def add_to_lists(results, select, names):
                    for dic in results:
                        if len(names) > len(select): select.append(dic[condition])
                        else: 
                            if dic[condition] > select[-1]: select[-1]=dic[condition]
                    return select, names
                if pos_count > neg_count:
                    class_names_pos.append(condition)
                    confs_select_pos, class_names_pos = add_to_lists(pos_results, confs_select_pos, class_names_pos)
                elif pos_count < neg_count:
                    class_names_neg.append(condition)
                    confs_select_neg, class_names_neg = add_to_lists(neg_results, confs_select_neg, class_names_neg)
                else:
                    max_pos_conf = 0
                    max_neg_conf = 0
                    for dic in pos_results:
                        if dic[condition] > max_pos_conf: max_pos_conf = dic[condition]
                    for dic in neg_results:
                        if dic[condition] > max_neg_conf: max_neg_conf = dic[condition]
                    if max_pos_conf >= max_neg_conf:
                        class_names_pos.append(condition)
                        confs_select_pos, class_names_pos = add_to_lists(pos_results, confs_select_pos, class_names_pos)
                    elif max_pos_conf > max_neg_conf:
                        class_names_neg.append(condition)
                        confs_select_neg, class_names_neg = add_to_lists(neg_results, confs_select_neg, class_names_neg)
        return confs_select_pos, class_names_pos, confs_select_neg, class_names_neg  
    
    confs_select, class_names, confs_select_neg, class_names_neg = get_summary()
    
    class_names = [class_names for _, class_names in sorted(zip(confs_select, class_names))]
    confs_select = sorted(confs_select, reverse=True)
    class_names_neg = [class_names_neg for _, class_names_neg in sorted(zip(confs_select_neg, class_names_neg), reverse=True)]
    confs_select_neg = sorted(confs_select_neg, reverse=True)

    id_cond.close()
    gen_sumvis.value += "Summarizing Diagnosis..."

    max_memory_num_imgs = 2
    idxs = []
    for condition in class_names:
        pos_results = get_results(1)
        for dic in range(len(pos_results)):
            if condition in pos_results[dic].keys(): 
                idxs.append(dic)
                if len(idxs)==max_memory_num_imgs: break
    sum_input_imgs = [sum_input_imgs[idx] for idx in idxs]
    class_idxes = [classes.index(i) for i in class_names]
    
    max_mem_num_cond = 3
    if len(class_idxes)>max_mem_num_cond: 
        class_idxes = class_idxes[:max_mem_num_cond]
        class_names_slice = class_names[:max_mem_num_cond]
    
    def show_gradcam(learn, x, thresh):
        class Hook():
            def __init__(self, m): self.hook = m.register_forward_hook(self.hook_func)   
            def hook_func(self, m, i, o): self.stored = o.detach().clone()
            def __enter__(self, *args): return self
            def __exit__(self, *args): self.hook.remove()
        class HookBwd():
            def __init__(self, m):
                self.hook = m.register_backward_hook(self.hook_func)   
            def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
            def __enter__(self, *args): return self
            def __exit__(self, *args): self.hook.remove()
        for img in sum_input_imgs:
            def cmap(class_idx):
                with HookBwd(learn.model.models[0][0]) as hookg: 
                    with Hook(learn.model.models[0][0]) as hook:
                        output = learn.model.eval()(*x[:-1])
                        act = hook.stored
                    output[0][0][class_idx].backward()
                    grad = hookg.stored
                return act, grad
            for idx in class_idxes:
                act, grad = cmap(idx)
                w = grad[0].mean(dim=[1,2], keepdim=True)
                cam_map = (w * act[0]).sum(0)
                x_dec = TensorImage(PILDicom2.create(img))
                _,ax = plt.subplots()
                x_dec.show(ctx=ax, cmap='gray')
                ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,x_dec.shape[0],x_dec.shape[1],0), interpolation='bilinear', cmap='magma');
                plt.savefig(Path(str(img)+","+classes[idx]+'.png'), bbox_inches='tight')
                
    a = pred_mixed_dls.one_batch()
    show_gradcam(sum_learn, a, thresh)
    
    def display_both(learn, x, thresh):
        data = pd.DataFrame(columns = ['Conditions'])
        for idx in class_idxes:
            for i in sum_input_imgs:
                data.loc[class_idxes.index(idx), str(i).split("/")[1].split(".")[0]] = str(i)+","+classes[idx]+'.png'
            data.loc[class_idxes.index(idx), 'Conditions'] = classes[idx]
        data.set_index('Conditions', inplace=True)
        
        # Converting links to html tags
        def path_to_image_html(path): 
            img = str(path).split(",")[0]
            x_dec = TensorImage(PILDicom2.create(img))
            return '<img src="'+ path + '" width="'+ str(int(x_dec.shape[0])) + '" height="'+ str(int(x_dec.shape[1])) + '">'

        # Rendering the dataframe as HTML table
        data.to_html(escape=False, formatters={str(img).split("/")[1].split(".")[0]:path_to_image_html for img in sum_input_imgs})
        out_pl.clear_output()
        gen_sumvis.close()
        
        summary.value += "Generated Radiologist Report:<br/>"
        b = df.loc[0, 'reports'].split(" . ")
        views1 = [view.lower() for view in views]
        rep = dict(zip(views1, views))
        def replace_all(text, dic):
            for i, j in dic.items(): text = text.replace(" "+i, " "+j).replace(i+" ", j+" ").replace(" "+i+" ", " "+j+" ")
            return text
        c = [replace_all(x, rep) for x in b]
        d = ". ".join(c)
        e = d[:-1]
        text = re.sub(r'(\s)xx\w+', "", e, flags=re.IGNORECASE)   
        if text[-1]!=".": text = text + "."
        #text = tool.correct(text)
        
        summary.value += text + "<br/><br/>"
        
        summary.value += "Condition Summary:<br/>"
        summary.value += "Given a confidence threshold of "+str(thresh)+",<br/> which is the minimum confidence the model must have in order to give a positive diagnosis for a disease,<br/> and is the ideal confidence for maximizing the F" + str(beta) + " score,<br/>"
        if len(class_names)<1:
            summary.value += "this patient's condition cannot be determined. Please contact them to collect another set of x-rays.<br/>"
        else:
            summary.value += "this patient most likely needs to get checked out for the following conditions:<br/>"
            if len(class_names)<2:
                summary.value += class_names[0] + " " + f"({confs_select[0]*100:.2f}% confident).<br/>"
            else:
                for idx in range(len(class_names)-1):
                    summary.value += class_names[idx] + " " + f"({confs_select[idx]*100:.2f}% confident),<br/>"
                temp_idx = len(class_names)-1
                summary.value += "and " + class_names[temp_idx] + " " + f"({confs_select[temp_idx]*100:.2f}% confident).<br/>"
            if len(class_names)<len(classes):
                summary.value += "<br/>This patient most likely doesn't need to get checked out for the following conditions:<br/>"
                if len(class_names_neg)<2:
                    summary.value += class_names_neg[0] + " " + f"({confs_select_neg[0]*100:.2f}% confident).<br/>"
                else:    
                    for idx in range(len(class_names_neg)-1):
                        summary.value += class_names_neg[idx] + " " + f"({confs_select_neg[idx]*100:.2f}% confident),<br/>"
                    temp_idx = len(class_names_neg)-1
                    summary.value += "and " + class_names_neg[temp_idx] + " " + f"({confs_select_neg[temp_idx]*100:.2f}% confident).<br/>"
        summary.value += ' </p>'
        with out_pl: display(HTML(data.to_html(escape=False,formatters={str(img).split("/")[1].split(".")[0]:path_to_image_html for img in sum_input_imgs}))) 
    display_both(sum_learn, a, thresh)
    #"""            
diagnose.on_click(on_click_classify)

In [6]:
VBox([ap_direct, ap_btn_upload, ap_axial_direct, ap_axial_btn_upload, ap_lld_direct, ap_lld_btn_upload, ap_rld_direct, ap_rld_btn_upload, pa_direct, pa_btn_upload, pa_lld_direct, pa_lld_btn_upload, pa_rld_direct, pa_rld_btn_upload, lat_direct, lat_btn_upload, ll_direct, ll_btn_upload, lao_direct, lao_btn_upload, rao_direct, rao_btn_upload, swim_direct, swim_btn_upload, xtab_lat_direct, xtab_lat_btn_upload, lpo_direct, lpo_btn_upload, diagnose, gen_report, id_cond, gen_sumvis, summary, out_pl], layout=Layout(width='100%', display='flex', align_items='center'))

VBox(children=(Label(value='Upload any AP views here:'), FileUpload(value=(), accept='.dcm', description='Uplo…