# CheXRay
* This website first gets your PA and/or Lateral X-Ray views. 
    * If only one is submitted, the image is copied and sent to the model as the other view (ex. if only a PA view is provided, it will be copied over as a lateral view image and both images will be sent to te model).
* It then generates a radiologist report based on those image(s) and corresponding visualizations for each word of where in the image(s) it decided was most important to generate that word.
* It then uses the image(s), the generated radiologist report, and the time and date to summarize its findings. 
* If your condition is known, it generates a list of diseases you may and/or may not have to be checked out for. In addition, it generates visualizations indicating where in the image(s) was the deciding factor.
    * If it isn't, the website will direct you towards a practicing radiologist. (Future update) 
* Note: Although images are saved as files in this website for the purpose of making this website work, they are only within your environment (meaning the website is running in its own world on your computer and there's no way we're getting/keeping it).
* If you have any questions (like about how this website works) or concerns, please contact the following email: ajhinh@gmail.com 

In [None]:
import gdown
url = "https://drive.google.com/u/0/uc?export=download&confirm=7rWd&id=1NkCdI2esy90GjZErHMmpokrbRcXI0GYd"
output = 'modules/Fastext_embedd_wordMap.pkl'
gdown.download(url, output, quiet=True)

url = "https://drive.google.com/u/0/uc?export=download&confirm=7rWd&id=1K7A9Ow89QwNrohiiCZMwrediAYllgtoJ"
output = 'modules/vocab.pkl'
gdown.download(url, output, quiet=True)

url = "https://drive.google.com/u/0/uc?export=download&confirm=7rWd&id=1axHx-EZSr18pwomrNzpsFAnjTEIZnW_z"
output = 'modules/images.zip'
gdown.download(url, output, quiet=True)
clear_output()

url = "https://drive.google.com/u/0/uc?export=download&confirm=7rWd&id=1Kazb6dY6hTU-NhA1Q4hWeGU2Y442VWbF"
output = 'models/all.0.1.pth'
gdown.download(url, output, quiet=True)

url = "https://drive.google.com/u/0/uc?export=download&confirm=7rWd&id=1Q9PPPeDpPaZiciQX2fGLfOETLov04WRa"
output = 'models/allimgcap.0.1.pth'
gdown.download(url, output, quiet=True)
clear_output()

In [4]:
from IPython.display import display, clear_output
import torch
device = torch.device("cpu")

from fastai.vision.all import *
from fastai.text.all import *
from fastai.tabular.all import *
imcap_path = Path('./modules/')

!cp str(imcap_path/"custom.py") .
from custom import *
import pickle
from functools import partial
from datetime import datetime
from PIL import Image as im
import ipywidgets as widgets

Enabling: voila
- Writing config: /Users/andrewhinh/anaconda3/etc/jupyter
    - Validating...
      voila 0.2.4 [32mOK[0m


In [2]:
trainval_sample = pd.read_csv(imcap_path/'final_trainval_sample_sample_sample.csv', low_memory=False)

trainval_sample = tokenize_df(trainval_sample, 'path')
with open(imcap_path/'vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)
add_text_to_num(trainval_sample[0], vocab) 
clear_output(wait=False)

original_size=32
trans = transforms.Compose([transforms.Resize((original_size,original_size)), transforms.ToTensor()])
denorm = transforms.Compose([transforms.functional.to_pil_image])
pa_dataset = TestImageCaptionDataset(trainval_sample[0],'images',trans) 
lat_dataset = TestImageCaptionDataset(trainval_sample[0],'images1',trans) 
imgcap_collate_func = partial(pad_collate_ImgCap, pad_idx=vocab.index('xxpad'), pad_first=False, transpose=False)
def fa_convert(t):
    "A replacement for PyTorch `default_convert` which maintains types and handles `Sequence`s"
    return (default_convert(t) if isinstance(t, _collate_types)
            else type(t)([fa_convert(s) for s in t]) if isinstance(t, Sequence)
            else default_convert(t))
def create_batch(b): return (imgcap_collate_func,fa_convert)[False](b)
bs=1
pa_dls = DataLoaders.from_dsets(pa_dataset, bs=bs, device=device, create_batch=create_batch, num_workers=0)
lat_dls = DataLoaders.from_dsets(lat_dataset, bs=bs, device=device, create_batch=create_batch, num_workers=0)
mixed_dls = MixedDL(pa_dls[0], lat_dls[0])

In [3]:
emb_dim = 300   # 300: pretrined words embedd GLove 
attention_dim = 512 # encoder_dim tranformed to attention_dim
decoder_dim = 512  #  word_emb_dim tranformed to decoder_dim
dropout = 0.5
encoder_dim = 512 #512 for resnet34 and 2048 for resnet 101 
vocab_size = len(vocab)
with open(imcap_path/'Fastext_embedd_wordMap.pkl','rb') as f:
    embedding = pickle.load(f)
###########   Layer Initializations ##########
# testing initiation
enc = Encoder(14, fine_tune=False)
dec = Decoder(attention_dim, 
              emb_dim, 
              decoder_dim, 
              vocab_size, 
              encoder_dim=encoder_dim, 
              dropout=0.5, 
              pretrained_embedding = embedding,
              teacher_forcing_ratio=1)
# Testing
enc = enc.to(device)
dec = dec.to(device)
arch = Ensemble(enc, dec).to(device)

global glb_pa_logits
global glb_lat_logits

In [4]:
# Create our Multi-Modal model
multi_model = MultViewCap(arch, arch)
# Set weights for each loss
pa_w = 0.1
lat_w = 0.1
pl_w = 0.8

# Initialise Loss
gb_loss = myGradientBlending(pa_weight=pa_w, lat_weight=lat_w, pa_lat_weight=pl_w, 
                             loss_scale=1.0, use_cel=True)

# Define accuracy weights
w_accuracy = partial(weighted_accuracy, w_pa=pa_w, w_lat=lat_w, w_pl=pl_w)
bleu1_w = partial(bleu1_weighted, w_pa=pa_w, w_lat=lat_w, w_pl=pl_w)
bleu2_w = partial(bleu2_weighted, w_pa=pa_w, w_lat=lat_w, w_pl=pl_w)
bleu3_w = partial(bleu3_weighted, w_pa=pa_w, w_lat=lat_w, w_pl=pl_w)
bleu4_w = partial(bleu4_weighted, w_pa=pa_w, w_lat=lat_w, w_pl=pl_w)
rouge_l_w = partial(rouge_l_weighted, w_pa=pa_w, w_lat=lat_w, w_pl=pl_w)

# Setting up Metrics
metrics = [topK_accuracy_pa,
           topK_accuracy_lat,
           topK_accuracy_pl, 
           w_accuracy,
           bleu1_pa,
           bleu1_lat,
           bleu1_pl,
           bleu1_w,
           bleu2_pa,
           bleu2_lat,
           bleu2_pl,
           bleu2_w,
           bleu3_pa,
           bleu3_lat,
           bleu3_pl,
           bleu3_w,
           bleu4_pa,
           bleu4_lat,
           bleu4_pl,
           bleu4_w,
           rouge_l_pa,
           rouge_l_lat,
           rouge_l_pl,
           rouge_l_w]

# Setting up Callbacks
cbs=[CutMixImgCapAll(alpha=1.)]

# Model Splitter
def split_model_all(arch):
    return L(arch.pa_model.encoder, arch.lat_model.encoder, arch.pa_model.decoder, arch.lat_model.decoder).map(params)

multi_learn = Learner(mixed_dls, multi_model, gb_loss, opt_func=partial(Adam, betas=(0.8, 0.99)), splitter=split_model_all, cbs=cbs, metrics=metrics)
multi_learn = multi_learn.load('allimgcap.0.1');

In [5]:
classes=["Atelectasis", 
        "Cardiomegaly", 
        "Consolidation", 
        "Edema",
        "Enlarged_Cardiomediastinum", 
        "Fracture", 
        "Lung_Lesion", 
        "Lung_Opacity", 
        "No_Finding", 
        "Pleural_Effusion",
        "Pleural_Other",
        "Pneumonia",
        "Pneumothorax",
        "Support_Devices",
        "Other"]

size=32
trainval_sample = pd.read_csv(imcap_path/'final_trainval_sample_sample.csv', low_memory=False)
trainval_sample = trainval_sample.iloc[:2]
df=trainval_sample
workers=0
seq_len=72
from zipfile import ZipFile
file_name = str(imgcap_path/'images.zip')
destination = './'
with ZipFile(file_name) as zf:
    zf.extractall(destination)
pa_dls_sum = v_dls_test(bs, size, df, 'images', workers)
lat_dls_sum = v_dls_test(bs, size, df, 'images1', workers)
text_class_dls = tc_dls_test(bs, df, seq_len, vocab, workers) 
tab_dls = t_dls_test(bs, df, workers)

mixed_dls_sum = MixedDLSum(pa_dls_sum[0], lat_dls_sum[0], text_class_dls[0], tab_dls[0])

In [6]:
# Get uni-modal models, ugly but quick way to grab a tabular and vision model
model=xresnet18 
pa_model = cnn_learner(pa_dls_sum, model).model
lat_model = cnn_learner(lat_dls_sum, model).model       
arch=AWD_QRNN
text_model = text_classifier_learner(text_class_dls, arch).model
layers=[500, 250]
tab_model = tabular_learner(tab_dls, layers=layers).model

# Create our Multi-Modal model
multi_model_sum = All(pa_model, lat_model, text_model, tab_model)
# Set weights for each loss
pa_w = 0.1
lat_w = 0.1
text_w = 0.1
tab_w = 0.1
all_w = 0.6

# Initialise Loss
gb_loss_sum = myGradientBlendingSum(pa_weight=pa_w, 
                                    lat_weight=lat_w, 
                                    text_weight=text_w, 
                                    tab_weight=tab_w, 
                                    all_weight=all_w, 
                                    loss_scale=1.0, 
                                    use_cel=False)

# Define accuracy weights
w_accuracy_sum = partial(weighted_accuracy_sum, pa_w=pa_w, lat_w=lat_w, text_w=text_w, tab_w=tab_w, all_w=all_w)
w_ap = partial(weighted_ap, pa_w=pa_w, lat_w=lat_w, text_w=text_w, tab_w=tab_w, all_w=all_w)
w_roc = partial(weighted_roc, pa_w=pa_w, lat_w=lat_w, text_w=text_w, tab_w=tab_w, all_w=all_w)

metrics_sum = [pa_accuracy, lat_accuracy, text_accuracy, tab_accuracy, all_accuracy, w_accuracy_sum,
               pa_ap, lat_ap, text_ap, tab_ap, all_ap, w_ap,
               pa_roc, lat_roc, text_roc, tab_roc, all_roc, w_roc]

# Model Splitter: Fix for text model

cbs_sum=[CutMixAll(alpha=1.)]

multi_learn_sum = Learner(mixed_dls_sum, multi_model_sum, gb_loss_sum, splitter=split_model_sum, cbs=cbs_sum, metrics=metrics_sum)
multi_learn_sum = multi_learn_sum.load('all.0.1')

In [7]:
prod_path = Path('./sample/')
btn_run = widgets.Button(description='Classify')
def on_click_classify(change):
    pa_img = PILImage.create(pa_btn_upload.data[-1])
    with out_pl: display(pa_img.to_thumb(128,128))
        
    lat_img = PILImage.create(lat_btn_upload.data[-1])
    with out_pl: display(lat_img.to_thumb(128,128))
                                                                                 
    if pa_btn_upload.data[-1]!=[]:
        with open(prod_path/'pa.jpg', 'wb') as f: 
            f.write(pa_btn_upload.value[list(pa_btn_upload.value.keys())[0]]['content'])
    else:
        with open(prod_path/'pa.jpg', 'wb') as f: 
            f.write(lat_btn_upload.value[list(lat_btn_upload.value.keys())[0]]['content'])

    if lat_btn_upload.data[-1]!=[]:
        with open(prod_path/'lat.jpg', 'wb') as f: 
            f.write(lat_btn_upload.value[list(lat_btn_upload.value.keys())[0]]['content'])
    else:
        with open(prod_path/'lat.jpg', 'wb') as f: 
            f.write(pa_btn_upload.value[list(pa_btn_upload.value.keys())[0]]['content']) 

    with open(prod_path/'report.txt', 'wb') as f: f.write(btn_upload.value[list(btn_upload.value.keys())[0]]['content'])
    with open(prod_path/'report.txt') as report: textreport = report.read()

    """
    pa_img = plt.imread(prod_path/'pa.jpg')
    lat_img = plt.imread(prod_path/'lat.jpg')
    pa_img = torch.stack((torch.as_tensor(pa_img,), torch.as_tensor(pa_img,), torch.as_tensor(pa_img,)), axis=0)
    lat_img = torch.stack((torch.as_tensor(lat_img,), torch.as_tensor(lat_img,), torch.as_tensor(lat_img,)), axis=0)
    pa_img = torch.from_numpy(np.array(pa_img)).float()
    lat_img = torch.from_numpy(np.array(lat_img)).float()

    caps, alphas = beam_search_all(multi_learn.model, pa_img, lat_img, vocab, 5)
    caps = [vocab[x] for x in caps[0]]
    textreport = ' '.join(caps)

    visualize_att_all('pa.jpg', 'lat.jpg', cap, alphas, prod_path)
    """

    from datetime import datetime
    tabcols = ['StudyTime', 'StudyDate', 'Seconds', 'Minutes', 'FracSec']
    tab = pd.DataFrame(columns=tabcols)
    tab.loc[0, 'StudyTime']=float(str(datetime.now().hour)+str(datetime.now().minute)+str(datetime.now().second)+'.'+str(datetime.now().microsecond))
    tab.iloc[0, 1]=float(str(datetime.now().year)+str(datetime.now().month)+str(datetime.now().day))
    tab.iloc[0, 2]=datetime.now().second
    tab.iloc[0, 3]=datetime.now().minute
    tab.iloc[0, 4]=datetime.now().microsecond

    size=64
    df = pd.read_csv(imcap_path/'final_trainval_sample_sample.csv', low_memory=False)
    pa = prod_path/'pa.jpg'
    lat = prod_path/'lat.jpg'
    text = textreport
    tab = tab.iloc[0, :]
    label = df.loc[0, classes]
    new, preds = predict_sum(multi_learn_sum, pa, lat, text, tab, label, df, tabcols, size)
    idx = (torch.tensor(preds>0.5)[0] == True).nonzero().flatten().tolist()
    
    b=[]
    if idx==[]:
        b.append(10)
        idx.append('Other')
    else:
        for i in range(len(idx)):
            b.append(preds[0][idx[i]])
            idx[i]=classes[idx[i]]

    temp = dict(zip(idx, b))
    temp = dict(sorted(temp.items(), key=operator.itemgetter(1),reverse=True))
    idx = list(temp.keys())
    b = list(temp.values())
    nograd = False
    thresh=0.5
    if b[0]<thresh or (len(idx)==1 and idx[0]=="Other"):
        lbl_pred.value += "Your condition cannot be determined. Please contact your radiologist for further consultation."
        nograd = True
    elif b[-1]>=thresh:
        lbl_pred.value += "You most likely need to get checked out for the following: "
        for finding in range(len(idx)):
            if finding!=len(idx)-1:
                lbl_pred.value += idx[finding] + f"({b[finding]*100}% confident), "
            else:
                lbl_pred.value += idx[finding] + f"({b[finding]*100}% confident)."
    else:
        lbl_pred.value += "You most likely need to get checked out for the following: "
        where_stop = len(idx)-1
        for finding in range(len(idx)):
            if finding<where_stop:
                if b[finding]>=thresh:
                    if b[finding+1]<thresh:
                        where_stop=finding+1
                        lbl_pred.value += idx[finding] + f" ({b[finding]*100}% confident);" 
                        break
                    else:
                        lbl_pred.value += idx[finding] + f" ({b[finding]*100}% confident), " 
                else:
                    where_stop=finding
                    break
        lbl_pred.value += "You most likely don't need to get checked out for the following: "
        for finding in range(where_stop, len(idx)):
            if finding!=len(idx)-1:
                lbl_pred.value += idx[finding] + f" ({b[finding]*100}% confident), "
            else:
                lbl_pred.value += idx[finding] + f" ({b[finding]*100}% confident)."
    
    if not nograd:
        bs=1
        df=new.iloc[:1]
        pa = v_dls_new(bs, size, df, 'images', workers)
        lat = v_dls_new(bs, size, df, 'images1', workers)
        pa_img = pa.one_batch()[0]
        pa_img = torch.tensor(torch.cat((pa_img,)*3, axis=1))
        lat_img = lat.one_batch()[0]
        lat_img = torch.tensor(torch.cat((lat_img,)*3, axis=1))      

        temp = (torch.tensor(preds>0.5)[0] == True).nonzero().flatten().tolist()
        return_grad(multi_learn_sum, multi_learn_sum.model.pa_model, pa_img, 0, prod_path, size, temp)
        return_grad(multi_learn_sum, multi_learn_sum.model.lat_model, lat_img, 1, prod_path, size, temp)

        for i in range(1, len(idx)+1):
            pa_grad = plt.imread(prod_path/str("pa_gradcam"+str(i)+".png"))
            pa_grad = im.fromarray((pa_grad * 255).astype(np.uint8))
            with out_pl: display(pa_grad.to_thumb(214,256)) 
            lat_grad = plt.imread(prod_path/str("lat_gradcam"+str(i)+".png"))
            lat_grad = im.fromarray((lat_grad * 255).astype(np.uint8))
            with out_pl: display(lat_grad.to_thumb(214,256))

btn_run.on_click(on_click_classify)

In [8]:
#hide_output
from ipywidgets import *
pa_btn_upload = widgets.FileUpload()
lat_btn_upload = widgets.FileUpload()
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()

VBox([widgets.Label("Upload your patient's PA, AP, or any closely related view."),
      pa_btn_upload, 
      widgets.Label("Then upload your patient's lateral, LL, or any closely related view."), 
      lat_btn_upload, btn_upload, btn_run, out_pl, lbl_pred])

VBox(children=(Label(value="Upload your patient's PA, AP, or any closely related view."), FileUpload(value={},…