In [1]:
from pytorch_lightning import (
    Callback,
    LightningDataModule,
    LightningModule,
    Trainer,
    seed_everything,
)
import matplotlib.pyplot as plt
import numpy as np

import torch
from pytorch_lightning.loggers import LightningLoggerBase
from torchmetrics.regression import CosineSimilarity
import torchxrayvision as xrv 
import os, sys
sys.path.append("..")

from src import utils
from src.models.vqvae_simple import VQVAE_simple
import pandas as pd
import wandb
import skimage, skimage.filters
import sklearn, sklearn.metrics


from utils import generate_explanation,generate_vector,generate_vector_cav,calc_iou

os.environ['WANDB_NOTEBOOK_NAME'] = "gifsplanation_playground.ipynb"


In [2]:
ae = xrv.autoencoders.ResNetAE(weights="101-elastic")
clf = xrv.models.DenseNet(weights="all")
ae = ae.cuda()
clf = clf.cuda()

In [3]:
from torchsampler import ImbalancedDatasetSampler
from torchvision.transforms import Compose
import torchvision.transforms as transforms
from torchxrayvision.datasets import XRayCenterCrop, XRayResizer, normalize, apply_transforms, relabel_dataset, SubsetDataset,ToPILImage,normalize
from datawrappers import NIH_wrapper, CheX_wrapper, VINBig_wrapper
from torchvision.transforms import RandomHorizontalFlip,ToTensor,RandomVerticalFlip,RandomAffine,Compose
from torchsampler import ImbalancedDatasetSampler

class XRayStack(object):
    def __call__(self,  img): 
        img = np.squeeze(img,0)
        return np.stack([img,img,img])
        

transforms = Compose([XRayCenterCrop(),XRayResizer(224)])

nih_ds = NIH_wrapper(r"/mnt/mp1/@ais/nih_raw/%ob/",\
    csvpath=r"/mnt/mp1/@ais/nih_raw/%ob/Data_Entry_2017_v2020.csv",\
    bbox_list_path=r"/mnt/mp1/@ais/nih_raw/%ob/BBox_List_2017.csv",\
    unique_patients=True,
    transform=transforms
    )
relabel_dataset(clf.pathologies,nih_ds,silent=False)

Lung Lesion doesn't exist. Adding nans instead.
Fracture doesn't exist. Adding nans instead.
Lung Opacity doesn't exist. Adding nans instead.
Enlarged Cardiomediastinum doesn't exist. Adding nans instead.


In [4]:
NIH_ds = NIH_wrapper(r"/mnt/mp1/@ais/nih_bbox/%ob/",\
    csvpath=r"/mnt/mp1/@ais/nih_bbox/%ob/Data_Entry_2017_v2020.csv",\
    bbox_list_path=r"/mnt/mp1/@ais/nih_bbox/%ob/BBox_List_2017.csv",\
    unique_patients=True,
    transform=transforms,
    pathology_masks=True,
    views=["PA","AP"])
NIH_ds = SubsetDataset(NIH_ds, NIH_ds.csv[NIH_ds.csv["has_masks"]].index)
relabel_dataset(clf.pathologies,NIH_ds,silent=False)
NIH_datamodule = torch.utils.data.DataLoader(NIH_ds,batch_size=500,pin_memory=False)

Lung Lesion doesn't exist. Adding nans instead.
Fracture doesn't exist. Adding nans instead.
Lung Opacity doesn't exist. Adding nans instead.
Enlarged Cardiomediastinum doesn't exist. Adding nans instead.


In [5]:
def generate_vector_copygrad(image,target,p=0, ae=None,clf=None):
    ae.zero_grad()
    clf.zero_grad()
    image.requires_grad = True
    pred = torch.nn.functional.sigmoid(clf((image)))[:,clf.pathologies.index(target)]
    i_grad = torch.autograd.grad((pred), image)[0]
    z = ae.encode(image).detach()
    z.requires_grad = True
    image_shape = image.shape
    xp = ae.decode(z, image_shape)
    pred = torch.nn.functional.sigmoid(clf((image*p + xp*(1-p))))[:,clf.pathologies.index(target)]
    xp.grad = i_grad.clone()
    dzdxp = torch.autograd.grad((pred), z)[0]
    return dzdxp


In [20]:
def eval(ae, generate_vector, target, data,method=None):
    if method:   
        wandb.init(project="copygrad",name=str(method)+" "+target)
    else:
        wandb.init(project="copygrad")
    result = []
    per_sample_table = None
    for sample in data:
        if sample["label"][data.pathologies.index(target)] == 1: 
            if target not in sample["pathology_masks"]:   
                #print("no mask found")
                continue
            image = torch.from_numpy(sample["img"]).unsqueeze(0).cuda()
            vector = generate_vector(image,target)
            dimage = generate_explanation(sample, vector, target, ae=ae, clf=clf)
            metrics = calc_iou(dimage, sample["pathology_masks"][target]['mask'])
            recon = ae(image)["out"]
            metrics["mse"] = float(((image-recon)**2).mean().detach().cpu().numpy())
            metrics["mae"] = float(torch.abs(image-recon).mean().detach().cpu().numpy())
            metrics["idx"] = sample["idx"]
            metrics["method"] = method
            #metrics["p"] = float(p)
            metrics["target"] = target
            if per_sample_table: 
                per_sample_table.add_data(*list(metrics.values()))
            else: 
                per_sample_table = wandb.Table(dataframe=pd.DataFrame(metrics,index=[metrics['idx']]))
            result.append(metrics)
            image = wandb.Image(image, caption="original image")
            r_image = wandb.Image(recon,caption="reconstruction")
            mask = wandb.Image(sample["pathology_masks"][target]["mask"][0],caption="gt mask")
            fig = plt.imshow(dimage)
            wandb.log({"mask":mask, "salincy": fig,"original":image,"reconstruction":r_image,"target":target,"vector":vector})
    print(len(result))
    wandb.log({"res_table":per_sample_table,"total_table":pd.DataFrame(result).groupby("method").agg("mean")})
    return pd.DataFrame(result)

In [21]:
for_eval = [
           "Cardiomegaly",
           'Mass',
            'Nodule', 
            "Atelectasis",
            "Effusion",
            "Lung Opacity",
            ]

In [None]:
for target in for_eval:
    print(f"Starting exp for {target}")
    experiments = [
        {"method": "latentshift max", "function": lambda x,y : generate_vector(x,y,ae=ae,clf=clf)},
        {"method": "copygrad", "function": lambda x,y : generate_vector_copygrad(x,y,ae=ae,clf=clf)},

    ]
    for cfg in experiments:
        res = eval(ae,cfg['function'],target,NIH_ds,method=cfg['method'])

Starting exp for Cardiomegaly


VBox(children=(Label(value=' 0.20MB of 0.20MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
target,Cardiomegaly


[34m[1mwandb[0m: wandb version 0.13.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [None]:
a = []
b = []
NIH_ds = NIH_wrapper("/data/nih_raw/",\
    csvpath=r"/data/nih_raw/Data_Entry_2017_v2020.csv",\
    bbox_list_path=r"/data/nih_raw/BBox_List_2017.csv",\
    unique_patients=False,
    transform=transforms,
    pathology_masks=True,
        views=["PA","AP"])
NIH_ds.csv = NIH_ds.csv[NIH_ds.csv['Image Index'].isin(NIH_ds.pathology_maskscsv["Image Index"])].reset_index()
target = "Cardiomegaly"
for sample in NIH_ds:
    if sample["label"][NIH_ds.pathologies.index(target)] == 1: 
            if target not in sample["pathology_masks"]:
                a.append(sample)
            else: 
                b.append(sample)

In [13]:
NIH_ds[0]

{'idx': 0,
 'img': array([[[-744.7915  , -828.45276 , -831.0125  , ..., -563.04083 ,
          -596.3253  , -561.7012  ],
         [-826.0441  , -919.52985 , -925.32904 , ..., -534.8762  ,
          -577.55914 , -566.72375 ],
         [-824.03125 , -919.3312  , -927.5306  , ..., -524.5742  ,
          -514.73016 , -502.46982 ],
         ...,
         [  40.125572,  139.4112  ,  169.64018 , ..., -947.3776  ,
          -936.6322  , -834.3423  ],
         [  55.411507,  155.81358 ,  190.21979 , ..., -947.4722  ,
          -935.85016 , -832.87805 ],
         [  11.82943 ,   91.64818 ,  121.65605 , ..., -851.81384 ,
          -841.9902  , -749.05457 ]]], dtype=float32),
 'label': array([0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       dtype=float32),
 'pathology_masks': {'Infiltration': {'mask': array([[[0., 0., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 0., 0., 0.],
           ...,
           [0., 0., 0., ..., 0., 0., 0.],
  

In [69]:
csv = pd.read_csv("/home/luab/experiments/Data_Entry_2017_v2020.csv")
csv[csv['Image Index'] == '00009608_024.png']

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y]
36428,00009608_024.png,Cardiomegaly,24,9608,30,M,AP,2500,2048,0.168,0.168


In [71]:
NIH_ds.pathology_maskscsv

Unnamed: 0,Image Index,Finding Label,x,y,w,h,_1,_2,_3
0,00013118_008.png,Atelectasis,225.084746,547.019217,86.779661,79.186441,,,
1,00014716_007.png,Atelectasis,686.101695,131.543498,185.491525,313.491525,,,
2,00029817_009.png,Atelectasis,221.830508,317.053115,155.118644,216.949153,,,
3,00014687_001.png,Atelectasis,726.237288,494.951420,141.016949,55.322034,,,
4,00017877_001.png,Atelectasis,660.067797,569.780787,200.677966,78.101695,,,
...,...,...,...,...,...,...,...,...,...
979,00029464_015.png,Atelectasis,198.940451,352.900747,615.537778,323.128889,,,
980,00025769_001.png,Atelectasis,701.838229,572.491858,103.537778,63.715556,,,
981,00016837_002.png,Atelectasis,140.913785,658.962969,271.928889,94.435556,,,
982,00020124_003.png,Atelectasis,175.047118,580.456302,244.622222,103.537778,,,


In [72]:
NIH_ds.pathology_maskscsv[NIH_ds.pathology_maskscsv['Image Index'].isin(NIH_ds.csv["Image Index"])].reset_index()

Unnamed: 0,index,Image Index,Finding Label,x,y,w,h,_1,_2,_3
0,0,00013118_008.png,Atelectasis,225.084746,547.019217,86.779661,79.186441,,,
1,1,00014716_007.png,Atelectasis,686.101695,131.543498,185.491525,313.491525,,,
2,2,00029817_009.png,Atelectasis,221.830508,317.053115,155.118644,216.949153,,,
3,3,00014687_001.png,Atelectasis,726.237288,494.951420,141.016949,55.322034,,,
4,4,00017877_001.png,Atelectasis,660.067797,569.780787,200.677966,78.101695,,,
...,...,...,...,...,...,...,...,...,...,...
979,979,00029464_015.png,Atelectasis,198.940451,352.900747,615.537778,323.128889,,,
980,980,00025769_001.png,Atelectasis,701.838229,572.491858,103.537778,63.715556,,,
981,981,00016837_002.png,Atelectasis,140.913785,658.962969,271.928889,94.435556,,,
982,982,00020124_003.png,Atelectasis,175.047118,580.456302,244.622222,103.537778,,,
