# Libraries

In [None]:
import sys
sys.path.append('utility_box/')
from cpath import WSI, CPDataset
from load import save_geojson, load_pickle
from shapely_utils import Polygon, MultiPolygon
from shapely_utils import loads, make_valid, mapping
from torch_gpu_utils import get_device, get_gpu_memory_info
from ocv import process_contour_hierarchy,get_parent_daughter_idx_map
from image_utils import plot_image, extract_patches_with_coordinates, scale_mpp

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import cv2
import geojson
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from tifffile import imread

import segmentation_models_pytorch as smp

In [None]:
def wkt_to_geojson(wkt_string): 
    poly=loads(wkt_string)
    poly=make_valid(poly)
    geojson_feature = geojson.Feature(geometry=mapping(poly))
    return geojson_feature

In [None]:
def inference_logic(model, device, patch_size, patches, batch_size):
    
    patches = torch.tensor(patches).float()
    patches = patches.permute(0, 3, 1, 2)
    dataloader = DataLoader(patches, batch_size=batch_size, shuffle=False, num_workers=4)
    
    preds=torch.empty((len(patches), patch_size, patch_size), device=device)
    model.eval()
    
    with torch.inference_mode():
        index = 0
        for batch in tqdm(dataloader):
            batch = batch.to(device)
            pred = model(batch/255)
            
            preds[index:index + batch_size] = pred.squeeze(1)
            index += batch_size

    torch.cuda.empty_cache()
    return preds

In [None]:
from shapely.ops import unary_union
def compute_iou(multipolygon1, multipolygon2):
    # Compute intersection and union of two MultiPolygons
    intersection = multipolygon1.intersection(multipolygon2)
    union = unary_union([multipolygon1, multipolygon2])
    
    # Calculate areas
    intersection_area = intersection.area
    union_area = union.area
    
    # Compute IoU (Intersection over Union)
    iou = intersection_area / union_area if union_area != 0 else 0
    return iou

def compute_dice(multipolygon1, multipolygon2):

    # Compute intersection
    intersection = multipolygon1.intersection(multipolygon2)
    
    # Calculate areas
    intersection_area = intersection.area
    area1 = multipolygon1.area
    area2 = multipolygon2.area
    
    # Compute Dice score
    dice = (2 * intersection_area) / (area1 + area2) if (area1 + area2) != 0 else 0
    return dice

In [None]:
state_dict_path='/workspace/code/NodeSeg/model_logs/smp_unet_500epochs_run2/model_check_points/max_val/checkpoint_epoch305_0.9842888563871384.pth'

In [None]:
device = get_device(0)

In [None]:
n_classes=1
model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=n_classes,                      # model output channels (number of classes in your dataset)
)
model.to(device);
model=model.to(memory_format=torch.channels_last)

state_dict=torch.load(state_dict_path, weights_only=True)

model.load_state_dict(state_dict)

In [None]:
target_mpp=1

batch_size=4
target_mpp=1
patch_size=512 
overlap=128

scale, rescale = scale_mpp(0.25, target_mpp)

In [None]:
wsi_path=Path('/workspace/data/PublicDatasets/CAMELYON17/images/patient_012_node_0.tif')
true_geom_dicts=load_pickle("/workspace/data/PublicDatasets/CAMELYON17/tumor_geoms/patient_012_node_0.pkl")
#wsi_path=Path('/workspace/data/PublicDatasets/CAMELYON17/images/patient_009_node_1.tif')
#true_geom_dicts=load_pickle("/workspace/data/PublicDatasets/CAMELYON17/tumor_geoms/patient_009_node_1.pkl")

In [None]:
wsi= imread(wsi_path)
scaled_wsi=cv2.resize(wsi, (tuple(((np.array(wsi.shape[:2])*scale)[::-1]).astype(int))))

In [None]:
patches, coordinates = extract_patches_with_coordinates(scaled_wsi,(patch_size, patch_size), (overlap,overlap))

In [None]:
preds=inference_logic(model, device, patch_size, patches, batch_size)

In [None]:
true_geoms=[geom_dict['geom'] for geom_dict in true_geom_dicts]
true_mgeoms=make_valid(MultiPolygon(true_geoms))

In [None]:
prob_thresh=0.8
thresh_masks=[]
for pred_mask in tqdm(preds):
    thresh_mask = (F.sigmoid(pred_mask)> prob_thresh).to('cpu').numpy().astype(np.uint8)
    thresh_masks.append(thresh_mask)
        
full_mask=np.zeros((scaled_wsi.shape[:2]), dtype=np.uint8)
for coord, thresh_mask in zip(tqdm(coordinates), thresh_masks):
    x,y=coord
    delta=overlap//4
    full_mask[x+delta:x+patch_size-delta, y+delta:y+patch_size-delta]=thresh_mask[delta:-delta, delta:-delta]

In [None]:
#full_mask=cv2.erode(full_mask, (5,5), iterations=1)
full_mask=cv2.medianBlur(full_mask, 3)

In [None]:
contours, hierarchy = cv2.findContours(full_mask, cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
idx_map=get_parent_daughter_idx_map(contours, hierarchy )

In [None]:
geoms=[]
for contour_idx, holes_idx in tqdm(idx_map.items()):
    contour = contours[contour_idx].copy()
    if contour.shape[0]<4:
        continue
    contour = (contour*rescale).squeeze(1)
    geom=Polygon(contour).buffer(1)
    geoms.append(geom)

In [None]:
plot_image(cv2.resize(full_mask,  (tuple(((np.array(wsi.shape[:2])/100)[::-1]).astype(int)))))

In [None]:
compute_iou(true_mgeoms, MultiPolygon(geoms).buffer(0))

In [None]:
geojson_features=[]
for geom in tqdm(geoms):
    geojson_feature = geojson.Feature(geometry=mapping(geom))
    geojson_features.append(geojson_feature)

geojson_feature_collection = geojson.FeatureCollection(geojson_features)
save_geojson(f"{wsi_path.stem}.geojson", geojson_feature_collection)

# Fine Tune Threshold

In [None]:
fine_tune_thresh=[]
for prob_thresh in np.arange(0.5,1,0.05):
    thresh_masks=[]
    for pred_mask in tqdm(preds):
        thresh_mask = (F.sigmoid(pred_mask)> prob_thresh).to('cpu').numpy().astype(np.uint8)
        thresh_masks.append(thresh_mask)
            
    full_mask=np.zeros((scaled_wsi.shape[:2]), dtype=np.uint8)
    for coord, thresh_mask in zip(tqdm(coordinates), thresh_masks):
        x,y=coord
        delta=overlap//2
        full_mask[x+delta:x+patch_size-delta, y+delta:y+patch_size-delta]=thresh_mask[delta:-delta, delta:-delta]
    full_mask=cv2.medianBlur(full_mask, 5)

    contours, hierarchy = cv2.findContours(full_mask, cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    idx_map=get_parent_daughter_idx_map(contours, hierarchy )

    geoms=[]
    for contour_idx, holes_idx in tqdm(idx_map.items()):
        contour = contours[contour_idx].copy()
        if contour.shape[0]<4:
            continue
        contour = (contour*rescale).squeeze(1)
        geom=Polygon(contour).buffer(1)
        geoms.append(geom)

    temp_dict={}
    temp_dict['prob_thresh']=prob_thresh
    temp_dict['dice']= compute_dice(true_mgeoms, MultiPolygon(geoms).buffer(0))

    fine_tune_thresh.append(temp_dict)
    print(f"{prob_thresh}-->{temp_dict['dice']}")