In [1]:
import torch
import torch.nn as nn
import sys
import numpy as np
import torch.nn.functional as F

sys.path.append('/gpfs/space/home/joonas97/MIL/AttentionDeepMIL')
from dataloader_TUH import TUH_full_scan_dataset
from model import Attention, GatedAttention, ResNet18Attention
import matplotlib.pyplot as plt
from tqdm import tqdm

sys.path.append('/gpfs/space/home/joonas97/KITSCAM/')
from utils.visualize_utils import *

from IPython.display import HTML
import matplotlib
import glob
import nibabel as nib
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import re
matplotlib.rcParams['animation.embed_limit'] = 2 ** 128


In [29]:
import glob
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from monai.transforms import CropForeground, CenterSpatialCrop
import torch.nn.functional as F
import os


def threshold_at_one(x):
    # threshold at 1
    return x > -1


background_cropper = CropForeground(select_fn=threshold_at_one)
center_cropper = CenterSpatialCrop(roi_size=(512, 512, 1000))  #500


class TUH_full_scan_dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_type, only_every_nth_slice=1, interpolation=False, augmentations=None, as_rgb=False,
                 min_max_normalization=False):
        super(TUH_full_scan_dataset, self).__init__()
        self.as_rgb = as_rgb
        self.min_max_norm = min_max_normalization
        self.augmentations = augmentations
        self.nth_slice = only_every_nth_slice
        self.interpolation = interpolation
        data_path = '/gpfs/space/home/joonas97/nnUNet/nn_pipeline/nnUNet_preprocessed/Task602_tuh_final/nnUNetData_plans_v2.1_stage1/'
        data_path = '/gpfs/space/projects/BetterMedicine/joonas/tuh_kidney_study/new_setup/MIL_EXP/'

        if dataset_type == "train":
            data_path = os.path.join(data_path, "train")
        elif dataset_type == "test":
            data_path = os.path.join(data_path, "test")
        else:
            raise ValueError("Dataset type should be either train or test")
        control_path = data_path + '/controls2/*nii.gz'
        tumor_path = data_path + '/tumors2/*nii.gz'
        print("PATHS: ")
        print(control_path)
        print(tumor_path)
        control = glob.glob(control_path)
        tumor = glob.glob(tumor_path)

        control_labels = [[False]] * len(control)
        tumor_labels = [[True]] * len(tumor)

        self.img_paths = control + tumor
        self.labels = control_labels + tumor_labels

        print("Data length: ", len(self.img_paths), "Label length: ", len(self.labels))
        print(
            f"control: {len(control)}, tumor: {len(tumor)}")

        self.classes = ["control", "tumor"]

    def __len__(self):
        # a DataSet must know its size
        return len(self.img_paths)

    def __getitem__(self, index):
        path = self.img_paths[index]

        x = nib.load(path).get_fdata()
        x = x[:, :, ::self.nth_slice]
        clipped_x = np.clip(x, np.percentile(x, q=0.05), np.percentile(x, q=99.5))
        norm_x = (clipped_x - np.mean(clipped_x, axis=(0, 1))) / np.std(clipped_x, axis=(0, 1))  # mean 0, std 1 norm
        #norm_x = (clipped_x - np.min(clipped_x)) / (np.max(clipped_x) - np.min(clipped_x)) ## 0-1 norm

        norm_x = torch.unsqueeze(torch.from_numpy(norm_x), 0)
        norm_x = center_cropper(norm_x)

        _, h, w, d = norm_x.shape
        if self.interpolation:
            norm_x = F.interpolate(torch.unsqueeze(norm_x, 0), size=(int(h / 2), int(w / 2), d),
                                   mode='trilinear', align_corners=False)

        norm_x = torch.squeeze(norm_x)

        x = norm_x.to(torch.float16)
        y = torch.tensor(self.labels[index])

        if self.as_rgb:
            x = torch.stack([x, x, x], dim=0)
            x = torch.squeeze(x)

        return x, y, self.img_paths[index]



In [27]:
nth_slice = 4
test_dataset = TUH_full_scan_dataset(dataset_type="test", only_every_nth_slice=nth_slice,
                                     interpolation=False, as_rgb=True)
data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=2,
                                          pin_memory=True)

PATHS: 
/gpfs/space/projects/BetterMedicine/joonas/tuh_kidney_study/new_setup/MIL_EXP/test/controls2/*nii.gz
/gpfs/space/projects/BetterMedicine/joonas/tuh_kidney_study/new_setup/MIL_EXP/test/tumors2/*nii.gz
Data length:  78 Label length:  78
control: 39, tumor: 39


In [22]:
model = ResNet18Attention()
model.load_state_dict(torch.load(
    '/gpfs/space/home/joonas97/MIL/AttentionDeepMIL/results/main_joonas/2023-06-26/19-24-49/checkpoints/best_model.pth'))
model.cuda()
print("model loaded!")

# get ground truth segmentations
ctrls = glob.glob('/gpfs/space/projects/BetterMedicine/joonas/tuh_kidney_study/new_setup/controls2/labelsTr/*nii.gz')
tmrs = glob.glob('/gpfs/space/projects/BetterMedicine/joonas/tuh_kidney_study/new_setup/tumors2/labelsTr/*nii.gz')
all_labels = ctrls + tmrs
all_labels[:3]

model loaded!


In [None]:


start_string = 'case_'
end_string = '_0000'
step = 0
for data, bag_label, path in data_loader:
    step +=1
    case_id = re.search(start_string + '(.*)' + end_string, path[0]).group(1)
    matching_gt = [s for s in all_labels if case_id in s][0]
    seg = nib.load(matching_gt).get_fdata()[:, :, ::nth_slice]
    slice_data = [np.unique(seg[:, :, i], return_counts=True) for i in range(seg.shape[2])]

    df = pd.DataFrame(slice_data, columns=["label", "count"])
    
    df = df.explode(["label", "count"])
    df = pd.concat([df, pd.DataFrame({"label": [2.0], "count": [0]})])
    
    df = pd.pivot_table(df, index=df.index, columns="label", values="count", fill_value=0)
    df = df.rename(columns={0.0: "none", 1.0: "kidney", 2.0: "tumor"})

    _, c, x, y, h = data.shape
    data = torch.reshape(data, (1, h, c, x, y))
    model.train()
    data, bag_label = data.cuda(), bag_label.cuda()
    with torch.no_grad(), torch.cuda.amp.autocast():
        Y_prob, predicted_label, attention_weights = model.forward(data)
    print("predicted label: ",predicted_label)
    print("attenttion length: ", len(attention_weights.cpu().as_tensor()[0]))
    
    print("dataframe length:", len(df))
    print("data shape: ", data.shape)
    print("seg shape: ", seg.shape)
    df["attention"] = pd.Series(attention_weights.cpu().as_tensor()[0]) * 30000
    fig = px.bar(df, x=df.index, y=["kidney", "attention", "tumor"], barmode="overlay",
                 title=str(path[0].split("/")[9:]))
    fig.show()
    print(step)
    if step > 15:
        break