In [1]:
import hydra
from omegaconf import DictConfig
from pytorch_lightning import (
    Callback,
    LightningDataModule,
    LightningModule,
    Trainer,
    seed_everything,
)
import matplotlib.pyplot as plt
import numpy as np

import torch
from pytorch_lightning.loggers import LightningLoggerBase
from torchmetrics.regression import CosineSimilarity
import torchxrayvision as xrv 
import os, sys
sys.path.append("..")

from src import utils
import pandas as pd
import wandb

hydra.initialize(config_path="../configs")
config=hydra.compose(config_name="config.yaml",overrides=["experiment=example"])

os.environ['WANDB_NOTEBOOK_NAME'] = "gifsplanation_playground.ipynb"
#wandb.init(project="concept_vector_stability")


In [6]:
config.datamodule.only_bbox = True
config.datamodule.bucket_name = "nih_bbox"
config.datamodule.test_split = 0
config.datamodule.batch_size = 1

In [7]:
NIH_datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)

b
b


In [10]:
NIH_datamodule.data_train

<webdataset.composable.Processor at 0x7f2300009a60>

In [14]:
NIH_datamodule.setup(stage="all")
i = 0
for a in NIH_datamodule.train_dataloader():
    i+=1
print(i)

880


In [9]:
a = pd.DataFrame(NIH_datamodule.associator).T["pathology_masks"]

In [12]:
a[~a.isna()]

00000032_037    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
00000072_000    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
00000147_001    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
00000149_006    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
00000150_002    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
                                      ...                        
00030606_006    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
00030634_000    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
00030635_001    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
00030636_004    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
00030674_000    [{'mask': [[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0....
Name: pathology_masks, Length: 880, dtype: object

In [2]:
bbox = pd.read_csv("~/chexpert/BBox_List_2017.csv",
                                              names=["Image Index", "Finding Label", "x", "y", "w", "h", "_1", "_2", "_3"],
                                              skiprows=1)
bbox.loc[bbox["Finding Label"] == "Infiltrate", "Finding Label"] = "Infiltration"
import numpy as np

def get_mask_dict(image_name, this_size):
        base_size = 1024
        scale = this_size / base_size

        images_with_masks = bbox[bbox["Image Index"] == image_name]
        path_mask = {}

        for i in range(len(images_with_masks)):
            row = images_with_masks.iloc[i]

            # Don't add masks for labels we don't have
          #  if row["Finding Label"] in self.pathologies:
            mask = np.zeros([this_size, this_size])
            xywh = np.asarray([row.x, row.y, row.w, row.h])
            xywh = xywh * scale
            xywh = xywh.astype(int)
            mask[xywh[1]:xywh[1] + xywh[3], xywh[0]:xywh[0] + xywh[2]] = 1

            # Resize so image resizing works
            mask = mask[None, :, :]

            path_mask[bbox.index(row["Finding Label"])] = mask
        return path_mask

In [6]:
bbox

Unnamed: 0,Image Index,Finding Label,x,y,w,h,_1,_2,_3
0,00013118_008.png,Atelectasis,225.084746,547.019217,86.779661,79.186441,,,
1,00014716_007.png,Atelectasis,686.101695,131.543498,185.491525,313.491525,,,
2,00029817_009.png,Atelectasis,221.830508,317.053115,155.118644,216.949153,,,
3,00014687_001.png,Atelectasis,726.237288,494.951420,141.016949,55.322034,,,
4,00017877_001.png,Atelectasis,660.067797,569.780787,200.677966,78.101695,,,
...,...,...,...,...,...,...,...,...,...
979,00029464_015.png,Atelectasis,198.940451,352.900747,615.537778,323.128889,,,
980,00025769_001.png,Atelectasis,701.838229,572.491858,103.537778,63.715556,,,
981,00016837_002.png,Atelectasis,140.913785,658.962969,271.928889,94.435556,,,
982,00020124_003.png,Atelectasis,175.047118,580.456302,244.622222,103.537778,,,


In [25]:
ds = ds.select(lambda x: images.str.contains(x["__key__"]).any())

In [28]:
i = 0
for a in ds:
    i+=1
print(i)

880


In [90]:

def get_bbox(row,this_size=1025) -> dict:
    base_size = 1024
    scale = this_size / base_size
    key = row.str.replace(".png", "").iloc[0]
    mask = np.zeros([this_size, this_size])
    xywh = np.asarray([row.x, row.y, row.w, row.h])
    xywh = xywh * scale
    xywh = xywh.astype(int)
    mask[xywh[1]:xywh[1] + xywh[3], xywh[0]:xywh[0] + xywh[2]] = 1

    # Resize so image resizing works
    mask = mask[None, :, :]
    return {key:{ "pathology_masks":[{"mask":mask,"mask_label":row["Finding Label"]}]}}


In [93]:
masks = list(map(lambda x: get_bbox(x[1]),bbox.iterrows()))
d = dict()
for i_id in masks:
    for items in i_id.items():
        if d.get(items[0]):
          d[items[0]]["pathology_masks"]+=items[1]["pathology_masks"]
        else:
            d[items[0]] = items[1]
labels = {'00013118_008':{"label":1},'00014716_007':{"label":3}}


  key = row.str.replace(".png", "").iloc[0]


[{'00013118_008': {'label': 1,
   'pathology_masks': [{'mask': array([[[0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             ...,
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.]]]),
     'mask_label': 'Atelectasis'}]}},
 {'00014716_007': {'label': 3,
   'pathology_masks': [{'mask': array([[[0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             ...,
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.],
             [0., 0., 0., ..., 0., 0., 0.]]]),
     'mask_label': 'Atelectasis'}]}}]

In [94]:
{k: {**v, **d.get(k,{})} for k,v in labels.items()}

{'00013118_008': {'label': 1,
  'pathology_masks': [{'mask': array([[[0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            ...,
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.]]]),
    'mask_label': 'Atelectasis'}]},
 '00014716_007': {'label': 3,
  'pathology_masks': [{'mask': array([[[0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            ...,
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.]]]),
    'mask_label': 'Atelectasis'}]}}

In [86]:
masks = list(map(lambda x: get_bbox(x[1]),bbox.iterrows()))


  key = row.str.replace(".png", "").iloc[0]


In [87]:
masks

[{'00013118_008': {'pathology_masks': {'mask': array([[[0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            ...,
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.]]]),
    'mask_label': 'Atelectasis'}}},
 {'00014716_007': {'pathology_masks': {'mask': array([[[0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            ...,
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.]]]),
    'mask_label': 'Atelectasis'}}},
 {'00029817_009': {'pathology_masks': {'mask': array([[[0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
            ...,
            [0., 0., 0., ..., 0., 0., 0.],
            [0., 0., 0., ..., 0., 0., 0.],
           