In [None]:
%cd /gpfs/space/home/danylobo/bm-ai-pipelines/common/ocs/lightning_pipeline  

In [None]:
import cv2
import torch
import numpy as np 
from tqdm import tqdm
import torchmetrics
from inference_utils import predict_frame, get_model, draw_contours_from_mask, imshow
from pathlib import Path
from dataset.dataset import OneclickDataset

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 150
plt.show()

%load_ext autoreload
%autoreload 2

In [None]:
ndim = 3

version = 2
ckpt_dir = Path(f"lightning_logs/3D-segm/version_{version}/checkpoints")
ckpt_path = list(ckpt_dir.iterdir())[0]

device = torch.device('cuda:0')

In [None]:
# val set:
img_dir = Path("/gpfs/space/projects/BetterMedicine/danylo/lung/training/raw/3D-LIDC_thin_CropPad_consensus1_centered/Ts/images")
# test set:
img_dir = Path("/gpfs/space/projects/BetterMedicine/danylo/lung/training/raw/3D_201_redbrick/images")


dst_dir = Path("predictions/test_set_201_119")
dst_dir.mkdir(parents=True, exist_ok=True)

In [None]:
model = get_model(ckpt_path)

In [None]:
val_files = img_dir.rglob("*npy")
val_files = [str(filepath) for filepath in val_files]

add_input = {'point_prompt': 2,
             'prob_random_pair': 0.,
             'prob_jitter': 0.}

In [None]:
dataset = OneclickDataset(files=val_files,
                          ndim=ndim,
                          mode="test",
                          add_input=add_input)

val_dice = torchmetrics.Dice()

for img_tensor, mask_tensor in tqdm(dataset):
    img_tensor = img_tensor.unsqueeze(0).to(device=device)

    pred = predict_frame(model=model, img_tensor=img_tensor, ndim=ndim)

    
    img = img_tensor.squeeze().detach().cpu().numpy()
    gt = mask_tensor.squeeze().detach().cpu().numpy()

    for img_slice, pred_slice, gt_slice in zip(img, pred, gt):
        if (np.sum(pred_slice > 0.5) + np.sum(gt_slice > 0.5)) == 0:
            continue

        val_dice.update(torch.tensor(pred_slice), torch.tensor(gt_slice).to(torch.int8))

        
        # # Uncomment to visualize
        # img_slice = 255 * cv2.cvtColor(0.226 * img_slice + 0.449, cv2.COLOR_GRAY2RGB)
        # img_slice = np.clip(img_slice, 0, 255).astype(np.uint8)

        # img_pred = draw_contours_from_mask(img_slice, pred_slice, [255, 20, 147])
        # img_gt   = draw_contours_from_mask(img_slice, gt_slice, [0, 255, 102])

        # # separators
        # img_pred[:, 0] = 255
        # img_gt[:, 0] = 255

        # img_stacked = np.hstack([img_slice, img_pred, img_gt])
        # imshow(img_stacked)

val_dice_mean = val_dice.compute()
val_dice.reset()
print("Dice =", val_dice_mean)