In [1]:
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import glob
import random
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as T
import segmentation_models_pytorch as smp
from tqdm import tqdm
import cv2
from skimage import feature
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
val = glob.glob('/workspace/jay/DDP/Ocelot/yolo_binary/datasets/cell_detect_33-1/valid/images/*.jpg')
test = glob.glob('/workspace/jay/DDP/Ocelot/yolo_binary/datasets/cell_detect_33-1/test/images/*.jpg')
train = sorted(glob.glob('/workspace/jay/DDP/Ocelot/ocelot2023/images/train/cell/*.jpg'))
val_files = np.unique(np.array([x.split('/')[-1][:3] for x in val]))
test_files = np.unique(np.array([x.split('/')[-1][:3] for x in test]))
temp = [x.split('/')[-1][:3] for x in train]
val_set = list(val_files)+list(test_files)
train_files = [x for x in temp if x not in val_set]

In [3]:
device = 'cuda:2'
model = torch.load('/workspace/jay/DDP/Ocelot/cell_seg/deeplab_dice_ckpts/142_0.29378.pt',map_location=device)
model = model.to(device)
softmax = torch.nn.Softmax(dim=1)

In [4]:
pred_json = {
    "type": "Multiple points",
    "num_images": len(train_files),
    "points": [],
    "version": {
        "major": 1,
        "minor": 0,
    }
}

In [5]:
def find_cells(heatmap,min_dist=10):
    """This function detects the cells in the output heatmap
    Parameters
    ----------
    heatmap: torch.tensor
        output heatmap of the model,  shape: [1, 3, 1024, 1024]
    Returns
    -------
        List[tuple]: for each predicted cell we provide the tuple (x, y, cls, score)
    """
    arr = heatmap[0,:,:,:].cpu().detach().numpy()
    # arr = np.transpose(arr, (1, 2, 0)) # CHW -> HWC

    pred_wo_bg,bg = np.split(arr, (2,), axis=0) # Background and non-background channels
    bg = np.squeeze(bg, axis=0)
    obj = 1.0 - bg

    arr = cv2.GaussianBlur(obj, (5,5), sigmaX=3)
    peaks = feature.peak_local_max(
        arr, min_distance=min_dist, exclude_border=0, threshold_abs=0.0
    ) # List[y, x]

    maxval = np.max(pred_wo_bg, axis=0)
    maxcls_0 = np.argmax(pred_wo_bg, axis=0)

    # Filter out peaks if background score dominates
    peaks = np.array([peak for peak in peaks if bg[peak[0], peak[1]] < maxval[peak[0], peak[1]]])
    if len(peaks) == 0:
        return []

    # Get score and class of the peaks
    scores = maxval[peaks[:, 0], peaks[:, 1]]
    peak_class = maxcls_0[peaks[:, 0], peaks[:, 1]]

    predicted_cells = [(x, y, c + 1, float(s)) for [y, x], c, s in zip(peaks, peak_class, scores)]

    return predicted_cells

In [6]:
for j,file in enumerate(tqdm(train_files)):
    idx = int(file) -1 
    cell_path = f'/workspace/jay/DDP/Ocelot/ocelot2023/images/train/cell/{file}.jpg'
    tissue_path = f'/workspace/jay/DDP/Ocelot/ocelot2023/images/train/tissue/{file}.jpg'
    cell = np.array(Image.open(cell_path))
    cell = cell / 255
    cell = cell - 0.5
    cell = torch.Tensor(np.moveaxis(cell, -1, 0))
    cell = cell[None,:]
    cell = cell.to(device)
    with torch.no_grad():
        out_mask = softmax(model(cell))
    predicted_cells = find_cells(out_mask,min_dist=10)
    for i in range(len(predicted_cells)):
        x,y,clas,prob = predicted_cells[i]
#         if clas==1:
#             clas=2
#         else:
#             clas=1
        point = {
                "name": f"image_{idx}",
                "point": [int(x), int(y), int(clas)],
                "probability": prob,  # dummy value, since it is a GT, not a prediction
                }
        pred_json["points"].append(point)

                                

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 320/320 [44:44<00:00,  8.39s/it]


In [9]:
# with open("/workspace/jay/DDP/Ocelot/jsons/pred1.json", "w") as g:
with open("/workspace/jay/DDP/Ocelot/ocelot23algo/evaluation/cellonlyseg_train.json", "w") as g:
    json.dump(pred_json, g)
    print("JSON file saved")

JSON file saved
