In [None]:
import pandas as pd

In [None]:
df = pd.read_csv("/WAVE/users/unix/smalladi/varian_ml/patient_data_resampled_training.csv")
df.head()

In [None]:
print(df.CT_Size.value_counts())
print(df.PT_Size.value_counts())
print(df.Label_Size.value_counts())

In [None]:
print(df.CT_PixDim.value_counts())
print(df.PT_PixDim.value_counts())
print(df.Label_PixDim.value_counts())

In [None]:
import glob
import logging
import os
from pathlib import Path
import shutil
import sys
import tempfile

import nibabel as nib
import numpy as np
from monai.config import print_config
from monai.data import (
ArrayDataset,
create_test_image_3d,
decollate_batch,
DataLoader,
CacheDataset
)
from monai.handlers import (
    MeanDice,
    MLFlowHandler,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
)
from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss, DeepSupervisionLoss, DiceLoss
from monai.metrics import compute_dice, DiceMetric
from monai.networks.nets import UNet, SegResNet, SegResNetDS, SwinUNETR
from monai.transforms import (
    Activations,
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    Orientationd,
    Spacingd,
    CropForegroundd,
    RandCropByPosNegLabeld,
    Resized,
    ScaleIntensityRanged,
    RandShiftIntensityd,
    RandAffined,
    RandFlipd,
    ToTensord,

)
from monai.utils import first
import ignite
import torch
import matplotlib.pyplot as plt


In [None]:
root_dir = 'Hecktor22/model_data'
data_dir = 'hecktor2022_training/hecktor2022'
resampled_ct_path = 'hecktor2022_training/hecktor2022/resampled_largerCt'
resampled_pt_path = 'hecktor2022_training/hecktor2022/resampled_largerPt'
resampled_label_path = 'hecktor2022_training/hecktor2022/resampled_largerlabel'

train_images = sorted(
    glob.glob(os.path.join(resampled_ct_path, "*_CT*")))
train_images2 = sorted(
    glob.glob(os.path.join(resampled_pt_path, "*_PT*")))
train_labels = sorted(
    glob.glob(os.path.join(resampled_label_path, "*.nii.gz")))
data_dicts = [{"image": image_name, "image2": pet_image, 'label': label_name}
    for image_name, pet_image, label_name in zip(train_images, train_images2, train_labels)
]

In [None]:
test_files = data_dicts[:10]

In [None]:
test_files[0]

In [None]:
ct_a_min = -200
ct_a_max = 400
pt_a_min = 0
pt_a_max = 25
crop_samples = 2
input_size = [96, 96, 96]
modes_2d = ['bilinear', 'bilinear', 'nearest']
p = 0.5
strength = 1
image_keys = ["image", "image2", "label"]
val_transforms = Compose([
    LoadImaged(keys=["image", "image2", "label"]),
    # EnsureChannelFirstd(keys = ["image", "image2"]),
    EnsureChannelFirstd(keys = ["image", "image2", "label"]),
    # EnsureTyped(keys=["image", "image2", "label"]),
    # ConvertToMultiChannelBasedOnClassesd(keys='label'),
    Orientationd(keys=["image", "image2", "label"], axcodes="RAS"),
    Spacingd(
        keys=image_keys,
        pixdim=(1, 1, 1),
        mode=modes_2d,
    ),
    ScaleIntensityRanged(keys=['image'], a_min=ct_a_min, a_max=ct_a_max, b_min=0.0, b_max=1.0, clip=True),
    ScaleIntensityRanged(keys=['image2'], a_min=pt_a_min, a_max=pt_a_max, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=image_keys, source_key='image'),
    ToTensord(keys=["image", "image2", "label"])
])

In [None]:
test_ds = CacheDataset(data=test_files, transform=val_transforms, cache_rate=0.0)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=0)

In [None]:
VAL_AMP = True
device = torch.device("cuda:0")
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    init_filters=16,
    blocks_up=[1, 1, 1],
    in_channels = 2,
    out_channels= 3,
    dropout_prob = 0.2
).to(device)

# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=input_size,
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)
    
post_label = AsDiscrete(to_onehot=3)
post_pred = AsDiscrete(argmax=True, to_onehot=3)

In [None]:
slice_no = 90
images = 10
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
for image_no in range(images):
    with torch.no_grad():
        # select one image to evaluate and visualize the model output
        val_inputct = test_ds[image_no]["image"].unsqueeze(0).to(device)
        val_inputpt = test_ds[image_no]["image2"].unsqueeze(0).to(device)
        val_input = torch.concat([val_inputct, val_inputpt], axis=1)
        roi_size = (192, 192, 192)
        sw_batch_size = 4
        val_output = inference(val_input)
        val_output = post_pred(val_output[0])
    # plt.figure("image", (6, 6))
    # for i in range(1):
    #     plt.subplot(1, 1, i + 1)
    #     plt.title(f"image channel {i}")
    #     plt.imshow(test_ds[image_no]["image"][i, :, :, slice_no].detach().cpu(), cmap="gray")
    # plt.show()
    # plt.figure("image2", (6, 6))
    # for i in range(1):
    #     plt.subplot(1, 1, i + 1)
    #     plt.title(f"image2 channel {i}")
    #     plt.imshow(test_ds[image_no]["image2"][i, :, :, slice_no].detach().cpu(), cmap="gray")
    # plt.show()
    plt.figure("label", (6, 6))
    for i in range(1):
        plt.subplot(1, 1, i + 1)
        plt.title(f"label channel {i}")
        plt.imshow(test_ds[2]["label"][i, :, :, slice_no].detach().cpu())
    plt.show()
    # visualize the 3 channels model output corresponding to this image
    plt.figure("output", (18, 6))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[i, :, :, slice_no].detach().cpu())
    plt.show()
    