<p align="center">
  <h1 align="center">CheXRay: Automatically Diagnosing Chest X-Rays using Generated Radiologist Reports and Patient Information </h1>
</p>

In [None]:
import os
google_drive_url = "https://drive.google.com/file/d/1f3GKUpjPKra4NZk5Pjj9OmNcvNkPudzA/view?usp=sharing"
os.environ['GOOGLE_FILE_ID'] = google_drive_url.split('/')[5]
os.environ['GDRIVE_URL'] = f'https://docs.google.com/uc?export=download&id={os.environ["GOOGLE_FILE_ID"]}'
!wget -q --no-check-certificate $GDRIVE_URL -r -A 'uc*' -e robots=off -nd
!mv $(ls -S uc* | head -1) models/sum.0.0.pth

!rm -f uc*
google_drive_url = "https://drive.google.com/file/d/13bdY9r8vzpzw_V086ujSwO5WQGz4yo59/view?usp=sharing"
os.environ['GOOGLE_FILE_ID'] = google_drive_url.split('/')[5]
os.environ['GDRIVE_URL'] = f'https://docs.google.com/uc?export=download&id={os.environ["GOOGLE_FILE_ID"]}'
!wget -q --no-check-certificate $GDRIVE_URL -r -A 'uc*' -e robots=off -nd
!mv $(ls -S uc* | head -1) models/repgen.0.0.pth

!rm -f uc*
google_drive_url = "https://drive.google.com/file/d/1Pzhd5qdXYWX7zNYBidO-WHKT0CJGJF1H/view?usp=sharing"
os.environ['GOOGLE_FILE_ID'] = google_drive_url.split('/')[5]
os.environ['GDRIVE_URL'] = f'https://docs.google.com/uc?export=download&id={os.environ["GOOGLE_FILE_ID"]}'
!wget -q --no-check-certificate $GDRIVE_URL -r -A 'uc*' -e robots=off -nd
!mv $(ls -S uc* | head -1) models/txtcls.pkl

!rm -f uc*

In [None]:
#%load_ext memory_profiler
#from production import run
#%mprun -f run run()

In [None]:
#Importing libraries and setup
#Modules for helper functions
from modules.utils.dicom import * #Because PILDicom from fastai doesn't work
from modules.utils.tokenizers import *

#Modules for fastai.vis
#!pip install -q pydicom pyarrow kornia opencv-python scikit-image nbdev
from fastai.basics import *
from fastai.callback.all import *
from fastai.medical.imaging import *
from fastai.vision.widgets import *

#Modules for fastai.text
from fastai.text.all import *

#Modules for fastai.tab
from fastai.tabular.all import *

#Modules for R2Gen/multimodal
from modules.repgen.dataset import RepGenDataset
from modules.repgen.dataloader import *
import modules.repgen.logits as log
from modules.repgen.model import *
from modules.repgen.loss import *
from modules.repgen.fastai_utils import *
from modules.repgen.metrics import bleu4

#Modules for sum
from modules.sum.dataloader import SumDL  
import modules.sum.logits as log1
from modules.sum.model import *
from modules.sum.loss import *
from modules.sum.fastai_utils import *
from modules.sum.metrics import *

#Other libraries
import html
import os
import re
import gc
import matplotlib.cm as cm
import copy as cp
import matplotlib.pylab as plt
from IPython.display import Image, display, HTML, clear_output
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
import language_tool_python
tool = language_tool_python.LanguageToolPublicAPI('en-US')

In [None]:
#Making Path object which contains path to data
prep = Path('./data/')
prod_path = Path('./sample/')
classes=["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Enlarged_Cardiomediastinum", "Fracture", 
         "Lung_Lesion", "Lung_Opacity", "No_Finding", "Pleural_Effusion", "Pleural_Other", "Pneumonia", 
         "Pneumothorax", "Support_Devices"]
views = ['AP','AP_AXIAL','AP_LLD','AP_RLD','PA','PA_LLD','PA_RLD','LATERAL','LL','LAO','RAO','SWIMMERS','XTABLE_LATERAL','LPO']
workers = multiprocessing.cpu_count()
defaults.device = torch.device('cpu')
device = torch.device("cpu")
cpu = torch.device("cpu")

In [None]:
heading = widgets.HTML(value='<style>p{word-wrap: break-word}</style><p>')
heading.value += "This program takes the patient's chest x-ray(s), formatted as .dcm files, as input and<br/>"
heading.value += "1) generates a radiologist report using the chest x-ray(s),<br/>"
heading.value += "2) generates tabular data using the time and date,<br/>"
heading.value += "3) and generates heatmap and intrinsic attention visualizations which represent the diagnosis for the patient using the above-mentioned data.</p>"
heading.value += "Notes:<br/>" 
heading.value += "- To upload multiple instances of a view, select all of the instances and upload them all at once.<br/>"
heading.value += "- Although images are saved as files in this website, they are only within your environment.<br/>"
heading.value += "- Depending on the number of views and images you input, you can expect the program to complete within 10-140 minutes.<br/>"
heading.value += "If you have any questions or concerns, please contact the author at the the following email: ajhinh@gmail.com."

In [None]:
ap_direct = widgets.Label()
ap_direct.value = "If (an) AP view instance(s) is/are available, upload it/them here:"

In [None]:
ap_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') #, 

In [None]:
ap_axial_direct = widgets.Label()
ap_axial_direct.value = "If (an) AP axial view instance(s) is/are available, upload it/them here:"

In [None]:
ap_axial_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
ap_lld_direct = widgets.Label()
ap_lld_direct.value = "If (an) AP LLD view instance(s) is/are available, upload it/them here:"

In [None]:
ap_lld_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') #, 

In [None]:
ap_rld_direct = widgets.Label()
ap_rld_direct.value = "If (an) AP RLD view instance(s) is/are available, upload it/them here:"

In [None]:
ap_rld_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
pa_direct = widgets.Label()
pa_direct.value = "If (a) PA view instance(s) is/are available, upload it/them here:"

In [None]:
pa_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') #, 

In [None]:
pa_lld_direct = widgets.Label()
pa_lld_direct.value = "If (a) PA LLD view instance(s) is/are available, upload it/them here:"

In [None]:
pa_lld_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
pa_rld_direct = widgets.Label()
pa_rld_direct.value = "If (a) PA RLD view instance(s) is/are available, upload it/them here:"

In [None]:
pa_rld_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
lat_direct = widgets.Label()
lat_direct.value = "If (a) lateral view instance(s) is/are available, upload it/them here:"

In [None]:
lat_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
ll_direct = widgets.Label()
ll_direct.value = "If (a) LL view instance(s) is/are available, upload it/them here:"

In [None]:
ll_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
lao_direct = widgets.Label()
lao_direct.value = "If (a) LAO view instance(s) is/are available, upload it/them here:"

In [None]:
lao_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
rao_direct = widgets.Label()
rao_direct.value = "If (a) RAO view instance(s) is/are available, upload it/them here:"

In [None]:
rao_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
swim_direct = widgets.Label()
swim_direct.value = "If (a) swimmers view instance(s) is/are available, upload it/them here:"

In [None]:
swim_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
xtab_lat_direct = widgets.Label()
xtab_lat_direct.value = "If (a) xtable lateral view instance(s) is/are available, upload it/them here:"

In [None]:
xtab_lat_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
lpo_direct = widgets.Label()
lpo_direct.value = "If (a) LPO view instance(s) is/are available, upload it/them here:"

In [None]:
lpo_btn_upload = widgets.FileUpload(multiple=True, accept='.dcm') 

In [None]:
summary = widgets.HTML(value='<style>p{word-wrap: break-word}</style><p>')

In [None]:
out_pl = widgets.Output()

In [None]:
diagnose = widgets.Button(description='Diagnose')

In [None]:
def on_click_classify(change):
    input_views = []
    input_paths = []
    
    if ap_btn_upload.data!=[]:
        input_views.append("AP")
        for path in range(len(ap_btn_upload.data)):
            temp_path = prod_path/str('AP_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(ap_btn_upload.value[list(ap_btn_upload.value.keys())[0]]['content'])
    if ap_axial_btn_upload.data!=[]:
        input_views.append("AP_AXIAL")
        for path in range(len(ap_axial_btn_upload.data)):
            temp_path = prod_path/str('AP_AXIAL_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(ap_axial_btn_upload.value[list(ap_axial_btn_upload.value.keys())[0]]['content'])
    if ap_lld_btn_upload.data!=[]:
        input_views.append("AP_LLD")
        for path in range(len(ap_lld_btn_upload.data)):
            temp_path = prod_path/str('AP_LLD_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(ap_lld_btn_upload.value[list(ap_lld_btn_upload.value.keys())[0]]['content'])
    if ap_rld_btn_upload.data!=[]:
        input_views.append("AP_RLD")
        for path in range(len(ap_rld_btn_upload.data)):
            temp_path = prod_path/str('AP_RLD_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(ap_rld_btn_upload.value[list(ap_rld_btn_upload.value.keys())[0]]['content'])
    if pa_btn_upload.data!=[]:
        input_views.append("PA")
        for path in range(len(pa_btn_upload.data)):
            temp_path = prod_path/str('PA_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(pa_btn_upload.value[list(pa_btn_upload.value.keys())[0]]['content'])
    if pa_lld_btn_upload.data!=[]:
        input_views.append("PA_LLD")
        for path in range(len(pa_lld_btn_upload.data)):
            temp_path = prod_path/str('PA_LLD_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(pa_lld_btn_upload.value[list(pa_lld_btn_upload.value.keys())[0]]['content'])
    if pa_rld_btn_upload.data!=[]:
        input_views.append("PA_RLD")
        for path in range(len(pa_rld_btn_upload.data)):
            temp_path = prod_path/str('PA_RLD_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(pa_rld_btn_upload.value[list(pa_rld_btn_upload.value.keys())[0]]['content'])
    if lat_btn_upload.data!=[]:
        input_views.append("LATERAL")
        for path in range(len(lat_btn_upload.data)):
            temp_path = prod_path/str('LATERAL_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(lat_btn_upload.value[list(lat_btn_upload.value.keys())[0]]['content'])
    if ll_btn_upload.data!=[]:
        input_views.append("LL")
        for path in range(len(ll_btn_upload.data)):
            temp_path = prod_path/str('LL_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(ll_btn_upload.value[list(ll_btn_upload.value.keys())[0]]['content'])
    if lao_btn_upload.data!=[]:
        input_views.append("LAO")
        for path in range(len(lao_btn_upload.data)):
            temp_path = prod_path/str('LAO_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(lao_btn_upload.value[list(lao_btn_upload.value.keys())[0]]['content'])
    if rao_btn_upload.data!=[]:
        input_views.append("RAO")
        for path in range(len(rao_btn_upload.data)):
            temp_path = prod_path/str('RAO_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(rao_btn_upload.value[list(rao_btn_upload.value.keys())[0]]['content'])
    if swim_btn_upload.data!=[]:
        input_views.append("SWIMMERS")
        for path in range(len(swim_btn_upload.data)):
            temp_path = prod_path/str('SWIMMERS_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(swim_btn_upload.value[list(swim_btn_upload.value.keys())[0]]['content'])
    if xtab_lat_btn_upload.data!=[]:
        input_views.append("XTABLE_LATERAL")
        for path in range(len(xtab_lat_btn_upload.data)):
            temp_path = prod_path/str('XTABLE_LATERAL_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(xtab_lat_btn_upload.value[list(xtab_lat_btn_upload.value.keys())[0]]['content'])
    if lpo_btn_upload.data!=[]:
        input_views.append("LPO")
        for path in range(len(lpo_btn_upload.data)):
            temp_path = prod_path/str('LPO_'+str(path)+'.dcm')
            input_paths.append(temp_path)
            with open(temp_path, 'wb') as f: 
                f.write(lpo_btn_upload.value[list(lpo_btn_upload.value.keys())[0]]['content'])
        
    nomiss_repgen_trainval_sample_path = prep/'trainval_sample_repgen_nomiss.csv'
    trainval_sample = pd.read_csv(nomiss_repgen_trainval_sample_path)
    df = pd.DataFrame(columns=trainval_sample.columns)
    for i in trainval_sample.columns[14:]: df.loc[0, i]=trainval_sample.loc[0, i]
    df.drop(['split'], axis=1, inplace=True)

    def sums(length, total_sum):
        if length == 1:
            yield (total_sum,)
        else:
            for value in range(total_sum + 1):
                for permutation in sums(length - 1, total_sum - value):
                    yield (value,) + permutation
    #Function to get difference between two lists
    def mse(input, target): 
        difference = []
        for inp, targ in zip(input, target): difference.append((inp-targ)**2)
        return sum(difference)/len(difference)
    #Function to count number of images in df
    def count_img(lstlsts):
        sum=0
        for lst in lstlsts: sum+=len(lst)
        return sum
    #Function to get best combination for #time each image copied
    def get_int_copy(sum, time_copy, error, new_range):
        new_time_copy=[]
        for com in sums(len(time_copy), sum):
            if use_range:
                if mse(time_copy, com)<error and max(time_copy)-min(time_copy)>new_range:
                    new_time_copy=com
                    error = mse(time_copy, new_time_copy)  
                    new_range = max(time_copy)-min(time_copy)
            else:
                if mse(time_copy, com)<error:
                    new_time_copy=com
                    error = mse(time_copy, new_time_copy)
        return new_time_copy
    #Function to fill new_df
    def filldf(new_time_copy, temp):
        #Filling in new_df
        temp_idx=0
        for time in new_time_copy:
            #find empty column and fill it with single df row's images column, repeat #time times
            for column in range(time):
                a = cp.copy(column)
                while pd.notna(df.iloc[0, a]): a+=1
                df.iloc[0, a] = temp[temp_idx]
            temp_idx+=1
            
    #For combinations with one view
    if len(input_views)<2:
        time_copy = [] #Getting #time each row is multiplied combination
        #If only one row present
        if len(input_paths)<2: time_copy=[len(views)]
        else:
            #Minimize std of new combination
            std = 7 #greater than max possible since std([1, 13])==6, and any longer list has smaller std
            for com in sums(len(input_paths), len(views)):
                if std>np.std(com):
                    time_copy=com
                    std=np.std(com)
        #Sorting time_copy for indexing
        time_copy.sort()
        for time in range(1, len(time_copy)): time_copy[time] += time_copy[time-1]
        time_copy.insert(0,0)
        for time in range(1, len(time_copy)): df.loc[0, time_copy[time-1]:time_copy[time]] = input_paths[time-1]
    else:  
        #Checking if there are frontal/lateral views and main/other views
        havelat = False
        havefront = False
        havemain = False
        haveother = False
        use_range=False
        main_views = ['AP','PA','LATERAL','LL'] 
        for view in input_views: 
            temp = []
            for i in input_paths: 
                check = str(i).split('/')[1].split('_')[:-1]
                new_check = str()
                if len(check)>1: new_check = "_".join(check)
                else: new_check=check[0]
                if new_check == view: temp.append(i)
            df.loc[0, view] = temp[0] #Adding one example for each present column
            if view in views[:7]:  havefront = True
            else: havelat = True
            if view in main_views: havemain = True
            else: haveother = True  
        havemainother = havemain and haveother
        havefrontlat = havefront and havelat
        #if only frontal/lateral views
        if not havefrontlat:
            #Get estimate for #times each image added to df_sample, but float estimates so next section fixes it
            time_copy = []
            #if only main/other views
            if not havemainother:
                for view in input_views:
                    temp = []
                    for i in input_paths: 
                        check = str(i).split('/')[1].split('_')[:-1]
                        new_check = str()
                        if len(check)>1: new_check = "_".join(check)
                        else: new_check=check[0]
                        if new_check == view: temp.append(i)
                    for image in range(len(temp)):
                        time_copy.append(len(views)/len(input_views)/len(temp))
                        if not image: time_copy[-1]-=1
            #if both main and other views
            else:
                use_range=True
                for view in input_views:
                    temp = []
                    for i in input_paths: 
                        check = str(i).split('/')[1].split('_')[:-1]
                        new_check = str()
                        if len(check)>1: new_check = "_".join(check)
                        else: new_check=check[0]
                        if new_check == view: temp.append(i)
                    for image in range(len(temp)):
                        if view in main_views: 
                            if image: time_copy.append(1) #For images not in new_df
                            else: time_copy.append(0) #For images already in new_df
                        else: time_copy.append(len(views)/len(input_views)/len(temp))                                 
            #Getting integer list of #times each image added to new_df, minimizing Euclidean distance
            error=1000 #unsure of max MSE, so guessing
            new_range=0
            new_time_copy = get_int_copy(len(views)-len(input_views), time_copy, error, new_range)
            filldf(new_time_copy, input_paths)
        else: 
            #same as above except split between frontal/lateral categories instead of all len(views) columns
            frontal = []
            lateral = []
            for view in input_views:
                temp = []
                for i in input_paths: 
                    check = str(i).split('/')[1].split('_')[:-1]
                    new_check = str()
                    if len(check)>1: new_check = "_".join(check)
                    else: new_check=check[0]
                    if new_check == view: temp.append(i)
                if len(temp)>0:
                    if view in views[:7]: frontal.append(temp)
                    else: lateral.append(temp)
            #Check if more images than #columns (len(views)/2) in frontal category
            front_track = count_img(frontal)
            if front_track>len(views)/2:
                drop=1
                lateral.append(frontal[-1][-drop:])
                frontal[-1] = frontal[-1][:-drop] 
                front_track = count_img(frontal)
                while front_track>len(views)/2:
                    lateral[-1] = lateral[-1].append(frontal[-1][-drop:]) #pd.DataFrame.append
                    frontal[-1] = frontal[-1][:-drop] 
                    front_track = count_img(frontal)
                for view in views[7:]:
                    if pd.notna(df.loc[0, view]): continue
                    else: df.loc[0, view] = lateral[-1][0]
                    break
            frontal_time_copy = []
            lateral_time_copy = []
            for input_paths_list in frontal:       
                for i in range(len(input_paths_list)):
                    temp=0
                    frontal_time_copy.append(len(views)/2/len(frontal)/len(input_paths_list))
                    try: temp=frontal_time_copy[-1]
                    except: frontal_time_copy.append(1) #For combinations with len=len(views)/2
                    if not i: frontal_time_copy[-1]-=1
            for input_paths_list in lateral:
                check = str(input_paths_list[0]).split('/')[1].split('_')[:-1]
                new_check = str()
                if len(check)>1: new_check = "_".join(check)
                else: new_check=check[0]
                if new_check in views[:7]: #If frontal in lateral b/c len(frontal) was > len(views)/2
                    for image in range(len(input_paths_list)): 
                        if image: lateral_time_copy.append(1) #For images not in new_df
                        else: lateral_time_copy.append(0) #For images already in new_df
                else:
                    for i in range(len(input_paths_list)):
                        lateral_time_copy.append(len(views)/2/len(lateral)/len(input_paths_list))
                        if not i: lateral_time_copy[-1]-=1
            #Getting integer list of #times each image added to df_sample, account for len(frontal/lateral)=1
            error=1000 #unsure of max MSE, so guessing
            if havemainother: use_range=True
            new_range=0
            notin=[1, 7]
            if len(frontal_time_copy) in notin: new_frontal_time_copy = [int(item) for item in frontal_time_copy]
            else: new_frontal_time_copy = get_int_copy(int(len(views)/2-len(frontal)), frontal_time_copy, error, new_range)
            error = 1000
            if len(lateral_time_copy) in notin: new_lateral_time_copy = [int(item) for item in lateral_time_copy]
            else: new_lateral_time_copy = get_int_copy(int(len(views)/2-len(lateral)), lateral_time_copy, error, new_range)
            #Combine list of dfs into single df
            frontal = list(itertools.chain.from_iterable(frontal))
            lateral = list(itertools.chain.from_iterable(lateral))
            filldf(new_frontal_time_copy, frontal)
            filldf(new_lateral_time_copy, lateral)
    
    single_repgen_trainval_sample_path = prep/'trainval_sample_repgen_single.csv'
    vocab_path = Path('modules/repgen/vocab.pkl')
    trainval_sample_single = pd.read_csv(single_repgen_trainval_sample_path)
    trainval_sample_single['images']=prod_path/"AP_0.dcm"
    train_sample_single = trainval_sample_single[trainval_sample_single['split']==False]
    val_sample_single = trainval_sample_single[trainval_sample_single['split']==True]
    train_sample_single.reset_index(drop=True, inplace=True)
    val_sample_single.reset_index(drop=True, inplace=True)
    with open(vocab_path, 'rb') as f: vocab = pickle.load(f)    
        
    isval=False
    viewtype='images' 
    ispred=False
    train_sample_dataset = RepGenDataset(train_sample_single,isval, viewtype, ispred, classes) 
    isval=True
    val_sample_dataset = RepGenDataset(val_sample_single,isval, viewtype, ispred, classes) 
    bs=16
    trainval_sample_dls = DataLoaders.from_dsets(train_sample_dataset, val_sample_dataset, bs=bs, device=cpu, create_batch=create_batch, num_workers=workers, shuffle=True)
    trainval_sample_dls.valid = trainval_sample_dls.valid.new(shuffle=False)
    
    # Model settings (for visual extractor)
    visual_extractor='resnet50' #'resnet101'
    pretrained=True
    # Model settings (for Transformer)  
    num_layers=3 #number of layers of Transformer
    d_model=512 #dimension of Transformer
    d_ff=512 #dimension of FFN
    num_heads=8 #number of heads in Transformer
    dropout=0.1 #dropout rate of Transformer
    use_bn = 0 #whether to use batch normalization
    drop_prob_lm = 0.5
    max_seq_len = 100
    att_feat_size = 2048 #dimension of the patch features (d_vf in main.py)
    ## Not used in original/current, but included in main.py
    #parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.') 
    # for Relational Memory    
    rm_num_slots=3
    rm_num_heads=8
    rm_d_model=512
    # for Sampling
    beam_size = 3 #beam size when beam searching
    group_size = 1
    sample_n = 1 #sample number per image
    sample_method = "beam_search" #sample methods to sample a report
    temperature = 1.0 #temperature when sampling
    output_logsoftmax = 1 #whether to output the probabilities
    decoding_constraint = 0
    block_trigrams = 1
    # More params (not in main.py, but used in original/current)
    diversity_lambda = 0.5       
    input_encoding_size = 512
    suppress_UNK = 0 
    length_penalty = ''
    mode='forward'
    model = R2GenModel(visual_extractor,
                    pretrained,
                    num_layers,
                    d_model,
                    d_ff,
                    num_heads,
                    dropout,
                    rm_num_slots,
                    rm_num_heads,
                    rm_d_model,
                    vocab,
                    input_encoding_size,
                    drop_prob_lm,
                    max_seq_len,
                    att_feat_size,
                    use_bn,
                    beam_size,
                    group_size,
                    sample_n,
                    sample_method,
                    temperature,
                    output_logsoftmax,
                    decoding_constraint,
                    block_trigrams,
                    diversity_lambda,
                    suppress_UNK,
                    length_penalty,
                    mode)
    model = model.to(cpu)
    
    criterion = compute_loss
    metrics = [bleu4] # bleu1, bleu2, bleu3, meteor, rouge, partial(precision, thresh=0.5), partial(recall, thresh=0.5), partial(f1, thresh=0.5)    
    wd=5e-5
    
    learn = Learner(trainval_sample_dls, model, loss_func=criterion, wd=wd, 
                    splitter=rep_gen, metrics=metrics, cbs=SelectPred)
    learn.load("repgen.0.0", device=cpu)
    learn.model.mode='sample'
    def passfunc(arg): return arg #Make last arg for learn.predict to not decode anything
    def decode(pred): #Convert idx_report to report
        words = [] #For every word in report (size rep_len)
        for report in pred:
            for word in report: #For each word in report
                txtword = vocab[word] #word = index for vocab
                if txtword not in [word for word in vocab if word[:2]=="xx"]: words.append(txtword) 
        return " ".join(words)
    learn.dls.decode = passfunc
    learn.dls.decode_batch = passfunc

    reports = []
    ispred=True
    for i in input_views:
        pred_dataset = RepGenDataset(df, isval, i, ispred, classes)
        gts, rep, _ = learn.predict(pred_dataset[0])
        if decode(rep)[-2:] != " ." or decode(rep)[-2:] != ". ": reports.append(decode(rep) + ' . ')
        else: reports.append(decode(rep))
    if len(input_views)<2:
        report = reports[0]
    else:
        clean_first_view_report = [x.strip() for x in reports[0].split(".")]
        more_views_report_append=[]
        for i in range(1, len(input_views)):
            more_views_report_append.append([x for x in reports[i].split(".") if x.strip() not in clean_first_view_report])
        reports1 = []
        reports1.extend(reports[0].split("."))
        for i in more_views_report_append: reports1.extend(i)
        report = ".".join(reports1)
    df.loc[0, 'reports'] = report
    
    del trainval_sample_single 
    del train_sample_single
    del val_sample_single
    del train_sample_dataset
    del val_sample_dataset
    del trainval_sample_dls
    del model
    del learn
    del pred_dataset
    gc.collect()

    month = str(datetime.now().month)
    if int(month) < 10: month = "0"+str(month)
    day = str(datetime.now().day)
    if int(day) < 10: day = "0"+str(day)
    df.loc[0, 'StudyElapsed'] = str(datetime.now().year)+'-'+month+'-'+day
    make_date(df, 'StudyElapsed')
    df['StudyElapsed'].values.astype(np.int64) // 10 ** 9
    df.loc[0, 'Minutes'] = datetime.now().minute
    df.loc[0, 'Hour'] = datetime.now().hour
    df.loc[0, 'Seconds'] = datetime.now().second
    df.loc[0, 'StudyWeek'] = datetime.now().isocalendar()[1]
    df.loc[0, 'StudyDay'] = datetime.now().day
    df.loc[0, 'StudyDayofweek'] = datetime.now().isocalendar()[2]
    df.loc[0, 'StudyDayofyear'] = datetime.now().timetuple().tm_yday
    df.loc[0, 'StudyElapsed'] = df['StudyElapsed'].values.astype(np.int64) // 10 ** 9
    #"""

    train_sample = trainval_sample[trainval_sample['split']==False]
    val_sample = trainval_sample[trainval_sample['split']==True]
    train_sample.reset_index(drop=True, inplace=True)
    val_sample.reset_index(drop=True, inplace=True)

    size=224
    seq_len=72
    bs=1
    temp = pd.concat([train_sample.iloc[:bs], val_sample.iloc[:bs]], ignore_index=True)
    temp.loc[:, :14]=prod_path/"AP_0.dcm"
    train_dls = []
    val_dls = []
    test_bs=1
    test_workers = 0
    pred_dls = []
    
    def vis_dls(bs, size, path, view, istest=None):
        dblock = DataBlock(
            blocks=(ImageBlock(cls=PILDicom2), MultiCategoryBlock(encoded=True, vocab=classes)),
            get_x=ColReader(view),
            get_y=ColReader(classes),
            splitter=istest, 
            item_tfms=Resize(460),
            batch_tfms=[IntToFloatTensor(div=2**16-1),
                        Normalize.from_stats(*imagenet_stats),
                        *aug_transforms(size=size)]) #, min_scale=0.75: RandomResizedCropGPU not working
        if istest is None: return dblock.dataloaders(path, bs=bs, num_workers=test_workers) 
        else: return dblock.dataloaders(path, bs=bs, num_workers=workers) 
    for view in views: 
        train_dls.append(vis_dls(bs, size, temp, view, ColSplitter('split'))[0].to("cpu"))
        val_dls.append(vis_dls(bs, size, temp, view, ColSplitter('split'))[1].to("cpu"))
        pred_dls.append(vis_dls(test_bs, size, df, view)[0].to("cpu"))
        
    with open(Path('./modules/txtcls/vocab.pkl'), 'rb') as f: vocab = pickle.load(f) 
    nomiss_repgen_test_path = prep/'test_repgen_nomiss.csv'
    test = pd.read_csv(nomiss_repgen_test_path)
    def txt_dls(bs, path, seq_len, istest=None):
        dblock = DataBlock(
            blocks=(TextBlock(tok_tfm=BaseTokenizer).from_df(text_cols='reports', vocab=vocab), 
                    MultiCategoryBlock(encoded=True, vocab=classes)),
            get_x=ColReader(cols='text'),
            get_y=ColReader(classes),
            splitter=istest)                                         
        if istest is None: return dblock.dataloaders(path, bs=bs, seq_len=seq_len, num_workers=test_workers) 
        else: return dblock.dataloaders(path, bs=bs, seq_len=seq_len, num_workers=workers)
    train_dls.append(txt_dls(bs, temp, seq_len, ColSplitter('split'))[0].to("cpu"))
    val_dls.append(txt_dls(bs, temp, seq_len, ColSplitter('split'))[1].to("cpu"))
    pred_dls.append(txt_dls(test_bs, df.append(test.iloc[:1], ignore_index=True), seq_len)[0].to("cpu"))
    
    cont_nn,cat_nn = cont_cat_split(temp, max_card=365, dep_var=classes)
    for frame in [temp, df]:
        frame[['Minutes', 
            'Hour', 
            'Seconds', 
            'StudyWeek', 
            'StudyDay', 
            'StudyDayofweek', 
            'StudyDayofyear',
            'StudyElapsed']] = frame[['Minutes', 
                                     'Hour', 
                                     'Seconds', 
                                     'StudyWeek', 
                                     'StudyDay', 
                                     'StudyDayofweek', 
                                     'StudyDayofyear',
                                     'StudyElapsed']].astype('int32')
    def tab_dls(bs, path, is_test):
        procs_nn = [Categorify, FillMissing, Normalize]
        if is_test: 
            splits=None
            works = test_workers
        else:  
            cond = (path.split==False)
            train_idx = np.where( cond)[0]
            valid_idx = np.where(~cond)[0]
            splits = (list(train_idx),list(valid_idx))
            works = workers
        return TabularPandas(path, procs_nn, None, cont_nn, splits=splits, y_block=MultiCategoryBlock(encoded=True, vocab=classes), 
                              y_names=classes).dataloaders(bs, num_workers=works) #cat_nn[16:23] where None is
    train_dls.append(tab_dls(bs, temp, False)[0].to("cpu"))
    val_dls.append(tab_dls(bs, temp, False)[1].to("cpu"))
    pred_dls.append(tab_dls(test_bs, df, True)[0].to("cpu"))
    
    train_mixed_dl = SumDL(device, *train_dls)
    valid_mixed_dl = SumDL(device, *val_dls)
    mixed_dls = DataLoaders(train_mixed_dl, valid_mixed_dl)
    pred_mixed_dls = SumDL(cpu, *pred_dls)
    
    a = pred_mixed_dls.one_batch()
    
    del train_dls
    del val_dls
    del pred_dls
    gc.collect()
    
    drop_mult=0.5
    model=xresnet18
    txtcls_learn = text_classifier_learner(txt_dls(bs, temp, seq_len, ColSplitter('split')), AWD_LSTM, drop_mult=drop_mult)
    unfreeze_name='lang.0.1'
    txtcls_learn = txtcls_learn.load_encoder(unfreeze_name)

    # Create our Multi-Modal model
    sum_model = SumModel(cnn_learner(vis_dls(bs, size, temp, input_views[0], ColSplitter('split')), model).model,
                         txtcls_learn.model, 
                         tabular_learner(tab_dls(bs, temp, False), layers=[500, 250]).model, 
                         len(classes))

    # Set loss_scale for each loss
    weights = [14/17/14, 14/17/14, 14/17/14, 14/17/14, 14/17/14, 14/17/14, 14/17/14, 14/17/14, 
               14/17/14, 14/17/14, 14/17/14, 14/17/14, 14/17/14, 14/17/14, 2/17, 0.25/17, 0.75/17]
    thresh=0.5
    loss_scale = 1.0
    beta=1

    loss = SumGradientBlending(1.0, *weights)
    ap_w = partial(ap_weighted, weights=weights)

    metrics = [ap_w]
    
    sum_learn = Learner(mixed_dls.to("cpu"), sum_model.to("cpu"), loss, splitter=sum_splitter, metrics=metrics)
    name = 'sum.0.0'
    sum_learn.load(name, device=cpu)
    sum_learn.dls = sum_learn.dls.to(cpu)
    sum_learn.model = sum_learn.model.to(cpu)
    
    del model
    del sum_model
    del txtcls_learn
    del mixed_dls
    gc.collect()
    
    preds,targs = sum_learn.get_preds(dl=[a]) #18 seconds
    
    def decode_prob(preds):
        all_inp=0
        preds = torch.stack(preds)
        for weight in range(len(weights)): all_inp += preds[weight] * weights[weight]
        preds = all_inp/len(weights)
        preds = preds.sigmoid()
        return preds
    def decode_rep(preds, thresh=0.5):
        all_inp=0
        preds = torch.stack(preds)
        for weight in range(len(weights)): all_inp += preds[weight] * weights[weight]
        preds = all_inp/len(weights)
        preds = preds.sigmoid()
        preds[preds>=thresh] = 1
        preds[preds<thresh] = 0
        return preds
    thresh = 0.5
    confs = decode_prob(preds)
    class_preds = decode_rep(preds, thresh)
    confs_select = []
    class_names = []
    for i in range(len(class_preds[0])):
        if class_preds[0][i]==1: 
            confs_select.append(confs[0][i].item())
            class_names.append(classes[i])
    confs_select_neg = []
    class_names_neg = []
    for i in range(len(class_preds[0])):
        if class_preds[0][i]==0: 
            confs_select_neg.append(confs[0][i].item())
            class_names_neg.append(classes[i])
    class_names = [class_names for _, class_names in sorted(zip(confs_select, class_names))]
    confs_select = sorted(confs_select, reverse=True)
    class_names_neg = [class_names_neg for _, class_names_neg in sorted(zip(confs_select_neg, class_names_neg), reverse=True)]
    confs_select_neg = sorted(confs_select_neg, reverse=True)
    
    summary.value += "Given a confidence threshold of "+str(thresh)+",<br/> which is the minimum confidence the model must have in order to give a positive diagnosis for a disease,<br/> and is the ideal confidence for maximizing accuracy as determined by the validation set,<br/>"
    if len(class_names)<1:
        summary.value += "this patient's condition cannot be determined. Please contact them to collect another set of x-rays.<br/>"
    else:
        summary.value += "this patient most likely needs to get checked out for the following conditions:<br/>"
        for idx in range(len(class_names)-1):
            summary.value += class_names[idx] + f"({confs_select[idx]*100:.2f}% confident),<br/>"
        temp_idx = len(class_names)-1
        summary.value += "and " + class_names[temp_idx] + f"({confs_select[temp_idx]*100:.2f}% confident).<br/>"

        summary.value += "<br/>This patient most likely doesn't need to get checked out for the following conditions:<br/>"
        for idx in range(len(class_names_neg)-1):
            summary.value += class_names_neg[idx] + f"({confs_select_neg[idx]*100:.2f}% confident),<br/>"
        temp_idx = len(class_names_neg)-1
        summary.value += "and " + class_names_neg[temp_idx] + f"({confs_select_neg[temp_idx]*100:.2f}% confident).<br/>"
    summary.value += ' </p>'
    
    def show_gradcam(learn, x, thresh):
        class Hook():
            def __init__(self, m): self.hook = m.register_forward_hook(self.hook_func)   
            def hook_func(self, m, i, o): self.stored = o.detach().clone()
            def __enter__(self, *args): return self
            def __exit__(self, *args): self.hook.remove()
        with torch.no_grad(): 
            output = learn.model.eval()(*x[:-1])
            output = decode_rep(output, thresh)
            class_idxes = [i for i, val in enumerate(output[0]) if val]
        class HookBwd():
            def __init__(self, m):
                self.hook = m.register_backward_hook(self.hook_func)   
            def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
            def __enter__(self, *args): return self
            def __exit__(self, *args): self.hook.remove()
        for img in input_views:
            def cmap(class_idx):
                with HookBwd(learn.model.models[views.index(img)][0]) as hookg: 
                    with Hook(learn.model.models[views.index(img)][0]) as hook:
                        output = learn.model.eval()(*x[:-1])
                        act = hook.stored
                    output[views.index(img)][0][class_idx].backward()
                    grad = hookg.stored
                return act, grad
            for idx in class_idxes:
                act, grad = cmap(idx)
                w = grad[0].mean(dim=[1,2], keepdim=True)
                cam_map = (w * act[0]).sum(0)
                x_dec = TensorImage(PILDicom.create(df.loc[0, img]))
                _,ax = plt.subplots()
                x_dec.show(ctx=ax, cmap='gray')
                ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,x_dec.shape[0],x_dec.shape[1],0), interpolation='bilinear', cmap='magma');
                plt.savefig(prod_path/Path(img+","+classes[idx]+'.png'), bbox_inches='tight')
    show_gradcam(sum_learn, a, thresh)
       
    def _eval_dropouts(mod):
        module_name =  mod.__class__.__name__
        if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False
        for module in mod.children(): _eval_dropouts(module)
    def intrinsic_attention(learn, batch, class_id=None):
        "Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`."
        learn.model.models[len(views)].train()
        _eval_dropouts(learn.model)
        learn.model.models[len(views)].zero_grad()
        learn.model.models[len(views)].reset()
        batch = batch[len(views)]
        emb = learn.model.models[len(views)][0].module.encoder(batch).detach().requires_grad_()
        emb.retain_grad()
        lstm = learn.model.models[len(views)][0].module(emb, True)
        learn.model.models[len(views)].eval()
        cl = learn.model.models[len(views)][1]((lstm, torch.zeros_like(batch).bool(),))[0].softmax(dim=-1)
        if class_id is None: class_id = cl.argmax()
        cl[0][class_id].backward()
        attn = emb.grad.squeeze().abs().sum(dim=-1)
        attn /= attn.max()
        tok, _ = learn.dls.dls[len(views)].decode_batch((*tuplify(batch), *tuplify(cl)))[0]
        
        b = tok.split(" . ")
        for i in range(len(b)): b[i] = re.sub(r'(\s)xx\w+', "", b[i], flags=re.IGNORECASE)        
        views1 = [view.lower() for view in views]
        rep = dict(zip(views1, views))
        def replace_all(text, dic):
            for i, j in dic.items(): text = text.replace(" "+i, " "+j).replace(i+" ", j+" ").replace(" "+i+" ", " "+j+" ")
            return text
        c = [replace_all(x, rep) for x in b]
        d = ". ".join(c)
        e = d[:-1]
        text = tool.correct(e)
        text = re.sub(r'(\s)xx\w+', "", text, flags=re.IGNORECASE)   
        if text[-1]!=".": text = text + "."
        return text, attn
    def value2rgba(x, cmap=cm.Purples, alpha_mult=1.0):
        "Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
        c = cmap(x)
        rgb = (np.array(c[:-1]) * 255).astype(int)
        a = c[-1] * alpha_mult
        return tuple(rgb.tolist() + [a])
    def piece_attn_html(pieces, attns, sep=' ', **kwargs):
        html_code,spans = ['<span style="font-family: monospace;">'], []
        for p, a in zip(pieces, attns):
            p = html.escape(p)
            c = str(value2rgba(a, alpha_mult=0.5, **kwargs))
            spans.append(f'<span title="{a:.3f}" style="background-color: rgba{c};">{p}</span>')
        html_code.append(sep.join(spans))
        html_code.append('</span>')
        return ''.join(html_code)
    def show_piece_attn(*args, **kwargs):
        from IPython.display import display, HTML
        display(HTML(piece_attn_html(*args, **kwargs)))
    def html_intrinsic_attention(learn, x:tuple, class_id:int=None, **kwargs)->str:
        text, attn = intrinsic_attention(learn, x, class_id)
        return piece_attn_html(text.split(), to_np(attn), **kwargs)
    def show_intrinsic_attention(learn, x:tuple, class_id:int=None, **kwargs)->None:
        text, attn = intrinsic_attention(learn, x, class_id)
        show_piece_attn(text.split(), to_np(attn), **kwargs)        
    for i in range(len(class_preds[0])):
        if class_preds[0][i].item()>0:
            with open(prod_path/Path(classes[i]+'.txt'), "wt") as txt:
                txt.write(html_intrinsic_attention(sum_learn, a, i))
                txt.close()
    #"""
    def display_both(learn, x, thresh):
        class Hook():
            def __init__(self, m): self.hook = m.register_forward_hook(self.hook_func)   
            def hook_func(self, m, i, o): self.stored = o.detach().clone()
            def __enter__(self, *args): return self
            def __exit__(self, *args): self.hook.remove()
        with torch.no_grad(): 
            output = learn.model.eval()(*x[:-1])
            output = decode_rep(output, thresh)
            class_idxes = [i for i, val in enumerate(output[0]) if val]
        class HookBwd():
            def __init__(self, m):
                self.hook = m.register_backward_hook(self.hook_func)   
            def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
            def __enter__(self, *args): return self
            def __exit__(self, *args): self.hook.remove()

        txts = []
        for i in range(len(class_preds[0])):
            if class_preds[0][i].item()>0:
                with open(prod_path/Path(classes[i]+'.txt')) as txt:
                    lines = txt.readlines()
                    txts.append([lines[0]])
        # Create a dataframe using pandas library
        data = pd.DataFrame([[0, 0], [0, 0]], columns = ['Conditions', 'Report_Interpretation'])
        for idx in class_idxes:
            for i in input_views:
                data.loc[class_idxes.index(idx), i] = str(prod_path)+"/"+str(i+","+classes[idx]+'.png')
            data.loc[class_idxes.index(idx), 'Conditions'] = classes[idx]
            data.loc[class_idxes.index(idx), 'Report_Interpretation'] = txts[class_idxes.index(idx)] if type(txts[class_idxes.index(idx)])!=list else txts[class_idxes.index(idx)][0]
        data.set_index('Conditions', inplace=True)

        # Converting links to html tags
        def path_to_image_html(path): 
            x_dec = TensorImage(PILDicom.create(df.loc[0, path.split("/")[1].split(",")[0]]))
            return '<img src="'+ path + '" width="'+ str(int(x_dec.shape[0])) + '" height="'+ str(int(x_dec.shape[1])) + '">'

        # Rendering the dataframe as HTML table
        data.to_html(escape=False, formatters={col:path_to_image_html for col in input_views})
        out_pl.clear_output()
        with out_pl: display(HTML(data.to_html(escape=False,formatters={col:path_to_image_html for col in input_views}))) 
    display_both(sum_learn, a, thresh)
    #"""            
diagnose.on_click(on_click_classify)

In [None]:
VBox([heading, 
      ap_direct, 
      ap_btn_upload,
      ap_axial_direct,
      ap_axial_btn_upload,
      ap_lld_direct,
      ap_lld_btn_upload,
      ap_rld_direct,
      ap_rld_btn_upload,
      pa_direct,
      pa_btn_upload,
      pa_lld_direct,
      pa_lld_btn_upload,
      pa_rld_direct,
      pa_rld_btn_upload,
      lat_direct, 
      lat_btn_upload, 
      ll_direct,
      ll_btn_upload,
      lao_direct,
      lao_btn_upload,
      rao_direct,
      rao_btn_upload,
      swim_direct,
      swim_btn_upload,
      xtab_lat_direct,
      xtab_lat_btn_upload,
      lpo_direct,
      lpo_btn_upload,
      diagnose,
      summary,
      out_pl],
     layout=Layout(width='100%', display='flex', align_items='center'))