In [1]:
import argparse
import os
import shutil
import time
import yaml
import sys
import gdown
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from monai.config import KeysCollection
from monai.metrics import Cumulative, CumulativeAverage
from monai.networks.nets import milmodel, resnet, MILModel

from sklearn.metrics import cohen_kappa_score
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data.dataloader import default_collate
from torchvision.models.resnet import ResNet50_Weights
import shutil
from pathlib import Path
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from monai.utils import set_determinism
import matplotlib.pyplot as plt
import wandb
import math
import logging
from pathlib import Path


from src.model.MIL import MILModel_3D
from src.model.csPCa_model import csPCa_Model
from src.data.data_loader import get_dataloader
from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint
from src.train import train_cspca, train_pirads
import SimpleITK as sitk 

import nrrd

from tqdm import tqdm
import pandas as pd
from picai_prep.preprocessing import PreprocessingSettings, Sample
import multiprocessing
import sys
from src.preprocessing.register_and_crop import register_files
from src.preprocessing.prostate_mask import get_segmask
from src.preprocessing.histogram_match import histmatch
from src.preprocessing.generate_heatmap import get_heatmap
import logging
from pathlib import Path
from src.utils import setup_logging
from src.utils import validate_steps
import argparse
import yaml 
from src.data.data_loader import data_transform, list_data_collate
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset

If you have questions or suggestions, feel free to open an issue at https://github.com/DIAGNijmegen/picai_prep



In [6]:
import os
import shutil
import json
import random

In [10]:
with open('/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PICAI_cspca.json', 'r') as f:
    data = json.load(f)
samples = random.sample(data['test'],3)
samples

[{'image': '10270_1000274.nrrd',
  'mask': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/prostate_seg_mask/10270_1000274.nrrd',
  'dwi': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/DWI_hist_matched/10270_1000274.nrrd',
  'adc': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/ADC_hist_matched/10270_1000274.nrrd',
  'heatmap': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/heatmap/10270_1000274.nrrd',
  'label': 0},
 {'image': '11063_1001085.nrrd',
  'mask': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/prostate_seg_mask/11063_1001085.nrrd',
  'dwi': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/DWI_hist_matched/11063_1001085.nrrd',
  'adc': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/ADC_hist_matched/11063_1001085.nrrd',
  'heatmap': '/sc-projects/sc-proj-cc06-ag-ki

In [15]:
sam = samples[3]['image']
'''
shutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/t2_images/'+sam, 'dataset/samples/sample3/t2.nrrd')
shutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/DWI_images/'+sam, 'dataset/samples/sample3/dwi.nrrd')
shutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/ADC_images/'+sam, 'dataset/samples/sample3/adc.nrrd')
'''

IndexError: list index out of range

In [2]:

args = argparse.Namespace(

    margin = 0.2,
    t2_dir = '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/t2',
    dwi_dir = '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/dwi',
    adc_dir = '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/adc',
    output_dir = '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed',
    project_dir = '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate'
)

args = register_files(args)
args = get_segmask(args)
args = histmatch(args)
args = get_heatmap(args)


  0%|          | 0/1 [00:03<?, ?it/s]
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  0%|          | 0/1 [00:05<?, ?it/s]


In [4]:
args.num_classes = 4
args.mil_mode = "att_trans"
args.use_heatmap = True
args.tile_size = 64
args.tile_count = 24
args.depth = 3


In [5]:
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


pirads_model = MILModel_3D(
    num_classes=args.num_classes,  
    mil_mode=args.mil_mode 
)
pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location="cpu")
pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
pirads_model.to(args.device)

cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)
checkpt = torch.load(os.path.join(args.project_dir, 'models', 'cspca_model.pth'), map_location="cpu")
cspca_model.load_state_dict(checkpt['state_dict'])
cspca_model = cspca_model.to(args.device)

enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related t

In [6]:
transform = data_transform(args)
files = os.listdir(args.t2_dir)
data_list = []
for file in files:
    temp = {}
    temp['image'] = os.path.join(args.t2_dir, file)
    temp['dwi'] = os.path.join(args.dwi_dir, file)
    temp['adc'] = os.path.join(args.adc_dir, file)
    temp['heatmap'] = os.path.join(args.heatmapdir, file)
    temp['mask'] = os.path.join(args.seg_dir, file)
    temp['label'] = 0  # dummy label
    data_list.append(temp)

dataset = Dataset(data=data_list, transform=transform)
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    multiprocessing_context= None,
    sampler=None,
    collate_fn=list_data_collate,
)

In [7]:
pirads_list = []
pirads_model.eval()
cspca_risk_list = []
cspca_model.eval()
top5_patches = []
with torch.no_grad():
    for idx, batch_data in enumerate(loader):
        data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
        logits = pirads_model(data)
        pirads_score= torch.argmax(logits, dim=1)
        pirads_list.append(pirads_score.item())

        output = cspca_model(data)
        output = output.squeeze(1)
        cspca_risk_list.append(output.item())

        sh = data.shape
        x = data.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
        x = cspca_model.backbone.net(x)
        x = x.reshape(sh[0], sh[1], -1)
        x = x.permute(1, 0, 2)
        x = cspca_model.backbone.transformer(x)
        x = x.permute(1, 0, 2)
        a = cspca_model.backbone.attention(x)
        a = torch.softmax(a, dim=1)
        a = a.view(-1)
        top5_values, top5_indices = torch.topk(a, 5)

        patches_top_5 = []
        for i in range(5):
            patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()
            patches_top_5.append(patch_temp)

In [8]:
patches_top_5

[array([[[-0.43333182, -0.52243334, -1.4245864 , ..., -1.7698549 ,
          -1.6027895 , -1.257521  ],
         [-0.7786003 , -1.2797965 , -2.0817103 , ..., -2.0371597 ,
          -2.0817103 , -2.059435  ],
         [-1.4802749 , -1.658478  , -1.7809926 , ..., -1.4691372 ,
          -2.1596742 , -1.9591957 ],
         ...,
         [ 0.8697783 ,  0.34630668,  0.6358867 , ...,  0.624749  ,
           0.24606745,  0.03445129],
         [ 0.44654593,  0.23492976,  0.7806767 , ...,  1.2373221 ,
           0.2906182 , -1.0124918 ],
         [ 0.3017559 , -0.2551287 ,  0.27948052, ...,  1.1036698 ,
          -0.16602719, -1.2463834 ]],
 
        [[-0.84542644, -1.3800358 , -1.4357241 , ..., -2.5383558 ,
          -2.0817103 , -0.99021643],
         [-0.4444695 , -1.3466226 , -1.8255434 , ..., -1.335485  ,
          -1.6362027 , -1.4579996 ],
         [-1.2129703 , -1.7921304 , -2.0371597 , ...,  0.14582822,
          -0.04351256, -0.82315105],
         ...,
         [-0.4556072 ,  1.4378005

In [11]:
import argparse
import os
import shutil
import time
import yaml
import sys
import gdown
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from monai.config import KeysCollection
from monai.metrics import Cumulative, CumulativeAverage
from monai.networks.nets import milmodel, resnet, MILModel
from monai.transforms import (
    Compose,
    GridPatchd,
    LoadImaged,
    MapTransform,
    RandFlipd,
    RandGridPatchd,
    RandRotate90d,
    ScaleIntensityRanged,
    SplitDimd,
    ToTensord,
    ConcatItemsd, 
    SelectItemsd,
    EnsureChannelFirstd,
    RepeatChanneld,
    DeleteItemsd,
    EnsureTyped,
    ClipIntensityPercentilesd,
    MaskIntensityd,
    HistogramNormalized,
    RandBiasFieldd,
    RandCropByPosNegLabeld,
    NormalizeIntensityd,
    SqueezeDimd,
    CropForegroundd,
    ScaleIntensityd,
    SpatialPadd,
    CenterSpatialCropd,
    ScaleIntensityd,
    Transposed,
    RandWeightedCropd,
)
from sklearn.metrics import cohen_kappa_score
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data.dataloader import default_collate
from torchvision.models.resnet import ResNet50_Weights
from src.data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt

import wandb
import math
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset

from src.model.MIL import MILModel_3D
from src.model.csPCa_model import csPCa_Model

import logging
from pathlib import Path

In [13]:
transform_image = Compose(
    [
        LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
        ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
        NormalizeIntensity_customd(keys=["image"], mask_key="mask", channel_wise=True),
        EnsureTyped(keys=["label"], dtype=torch.float32),
        ToTensord(keys=["image", "label"]),
    ]
)
dataset_image = Dataset(data=data_list, transform=transform_image)


In [19]:
dataset_image[0]['image'][0].numpy().shape

(270, 270, 28)

In [20]:
data_list

[{'image': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/t2_histmatched/1009449_11049598.nrrd',
  'dwi': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/DWI_histmatched/1009449_11049598.nrrd',
  'adc': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/ADC_histmatched/1009449_11049598.nrrd',
  'heatmap': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/heatmaps/1009449_11049598.nrrd',
  'mask': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/prostate_mask/1009449_11049598.nrrd',
  'label': 0}]