In [9]:
import torch
import numpy as np
from pathlib import Path
import os
import fastremap
import xml.etree.ElementTree as ET
from tqdm import tqdm

import matplotlib.pyplot as plt

from PIL import Image, ImageDraw

In [12]:
from instanseg.utils.data_download import create_raw_datasets_dir, create_processed_datasets_dir, download_and_extract


from instanseg.utils.utils import show_images, _move_channel_axis
#aws s3 cp --no-sign-request s3://monkey-training/ ./ --recursive
monkey_dir = Path("../data")

files = sorted(os.listdir(os.path.join(monkey_dir ,"annotations","xml")))

label_ids = []
means_list = []
annotations_dict = {}


np.random.seed(0)

for file in tqdm(files):

    split = np.random.choice(["train", "val"], p=[0.8, 0.2])
 
    img_pascpg_path = Path(monkey_dir) / ("images/pas-cpg/" + file.split(".")[0] + "_PAS_CPG.tif")
    img_pasdiagnostic_path = Path(monkey_dir) / ("images/pas-diagnostic/" + file.split(".")[0] + "_PAS_Diagnostic.tif")
   # img_pasoriginal_path = Path(monkey_dir) / ("images/pas-original/" + file.split(".")[0] + "_PAS_Original.tif")
    ihc_path = Path(monkey_dir) / ("images/ihc/" + file.split(".")[0] + "_IHC_CPG.tif")
    
    from tiffslide import TiffSlide
    slidepascpg = TiffSlide(img_pascpg_path)
    slideihc = TiffSlide(ihc_path)


    tree = ET.parse(monkey_dir/("annotations/xml/"+file))
    root = tree.getroot()  # Get the root of the XML

    # if split == "val":
    #     destination_img = "/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/validation_set/images/kidney-transplant-biopsy-wsi-pas/"
    #     destination_mask = "/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/validation_set/images/tissue-mask/"
        
    #     #move images to inference folder
    #     import shutil
    #     shutil.copy(monkey_dir / ("images/pas-cpg/" + file.split(".")[0] + "_PAS_CPG.tif"), destination_img)
    #     shutil.copy(monkey_dir / ("images/tissue-masks/" + file.split(".")[0] + "_mask.tif"), destination_mask)
        
    #     shutil.copy(monkey_dir / ("annotations/json/" + file.split(".")[0] + "_inflammatory-cells.json"), 
    #     '/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/ground_truth')

    #     shutil.copy(monkey_dir / ("annotations/json/" + file.split(".")[0] + "_lymphocytes.json"), 
    #     '/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/ground_truth')

    #     shutil.copy(monkey_dir / ("annotations/json/" + file.split(".")[0] + "_monocytes.json"), 
    #     '/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/ground_truth')

    coords = []

    annotations_dict[file] = []

    # Iterate over each annotation and extract relevant information
    for annotation in root.findall('.//Annotation'):
        name = annotation.get('Name')
        part_of_group = annotation.get('PartOfGroup')
        _type = annotation.get('Type')
      
        if _type == "Polygon":
            coords_ROI = []
            for coordinate in annotation.findall('.//Coordinate'):
                x = float(coordinate.get('X'))
                y = float(coordinate.get('Y'))
                coords_ROI.append([x, y])

            coords_ROI = np.array(coords_ROI)

            x_min, y_min = coords_ROI.min(axis=0)
            x_max, y_max = coords_ROI.max(axis=0)
            bbox_width = int(x_max - x_min)
            bbox_height = int(y_max - y_min)

            # Read the bounding box from the slide
            rgb_data = slidepascpg.read_region(
                (int(x_min), int(y_min)),
                0,
                (bbox_width, bbox_height),
                as_array=True,
            )

            ihc_data = slideihc.read_region(
                (int(x_min), int(y_min)),
                0,
                (bbox_width, bbox_height),
                as_array=True,
            )


            mask = Image.new("L", (bbox_width, bbox_height), 0)
            polygon = coords_ROI - [x_min, y_min]  # Translate polygon to local bbox coordinates
            ImageDraw.Draw(mask).polygon(polygon.flatten().tolist(), outline=1, fill=1)
            # Convert the mask to a NumPy array
            binary_mask = np.array(mask)

            annotations_dict[file].append({ "split": split,
                                            "pas-cpg":rgb_data,
                                            "ihc":ihc_data,
                                            "polygon": coords_ROI, 
                                            "mask": binary_mask, 
                                            "bbox" : [x_min, y_min, x_max, y_max], 
                                            "dots" : []})

            #show_images(rgb_data)

    for annotation in root.findall('.//Annotation'):
        name = annotation.get('Name')
        part_of_group = annotation.get('PartOfGroup')
        _type = annotation.get('Type')
        
        if _type == "Dot":
            # Find the coordinates
            coordinates = annotation.find('.//Coordinate')
            x = int(float(coordinates.get('X')))
            y = int(float(coordinates.get('Y')))
            c = 0 if part_of_group == "lymphocytes" else 1

            for i,annotation in enumerate(annotations_dict[file]):
                if annotation["bbox"][0] < x < annotation["bbox"][2] and annotation["bbox"][1] < y < annotation["bbox"][3]:
                    annotations_dict[file][i]["dots"].append([y - annotation["bbox"][1] ,x - annotation["bbox"][0],c])

    
                        
                  


100%|██████████| 81/81 [00:41<00:00,  1.95it/s]


In [21]:
leukocytes_dots = 0
detected_leukocytes = 0

import os
import pdb
from instanseg.utils.pytorch_utils import get_masked_patches, _to_tensor_float32
from instanseg.inference_class import _rescale_to_pixel_size
import torchstain
from instanseg import InstanSeg

os.environ["INSTANSEG_BIOIMAGEIO_PATH"] = '/home/cdt/Documents/Projects/InstanSeg/instanseg_thibaut/instanseg/bioimageio_models/'
os.environ['INSTANSEG_DATASET_PATH'] = "../datasets/"

instanseg_script = torch.jit.load("../models/instanseg_brightfield_monkey.pt") #download for github release
brightfield_nuclei = InstanSeg(instanseg_script, verbosity = 0)

patch_size = 128
destination_pixel_size = 0.5 # 2420
rescale_output = False if destination_pixel_size == 0.5 else True

# image_types  = ["cpg", "ihc"]
image_types  = ["ihc"]

for image_type in image_types:

  if image_type == "cpg":
    image_key  = "pas-cpg"
  else:
    image_key = "ihc"

  device = "cpu"

  np.random.seed(0)
  import h5py
  with h5py.File(Path(os.environ['INSTANSEG_DATASET_PATH']) / f"monkey_{image_type}_gold.h5", "w") as f:

      f.attrs['class_names'] = str({"0": "lymphocytes", "1": "monocytes", "2" : "other"})  # Convert to string since HDF5 attributes must be simple types
      f.attrs['pixel_size'] = destination_pixel_size

      for split in ['train', 'val']:
          f.create_dataset(f"{split}/data", shape=(0, 4, patch_size, patch_size),
          dtype=np.uint8, maxshape=(None, 4, patch_size, patch_size),
          chunks=(1, 4, patch_size, patch_size),)
          f.create_dataset(f"{split}/labels", shape=(0, 1), dtype=np.uint8, maxshape=(None, 1))


      for file in tqdm(annotations_dict.keys()):
          split = annotations_dict[file][0]["split"]

          for annotation in annotations_dict[file]:

              array = _to_tensor_float32(annotation["pas-cpg"])

              labels , input_tensor = brightfield_nuclei.eval_medium_image(array,
              pixel_size = 0.2420, rescale_output = rescale_output, seed_threshold = 0.05, tile_size= 1024)

              dots = torch.tensor(annotation["dots"]).to(device)
              dots[:,:2] = dots[:,:2] * 0.2420 / destination_pixel_size

              mask = _rescale_to_pixel_size(_to_tensor_float32(annotation["mask"]), 0.2420, destination_pixel_size).to(device)
              
              labels = labels.to(device) * torch.tensor(mask).bool()
              canvas = torch.zeros_like(labels)
              dots = torch.tensor(dots, dtype=torch.long)
              canvas[:,:,dots[:,0],dots[:,1]] = dots[:,2].float() + 1
              monocytes = labels * torch.isin(labels,labels * (canvas == 2).float()).float()
              lymphocytes = labels * torch.isin(labels,labels * (canvas == 1).float()).float()
              other_cells = (labels * ~torch.isin(labels,labels * (canvas > 0).float())).float()

              img_tensor = _rescale_to_pixel_size(_to_tensor_float32(annotation[image_key]), 0.2420, destination_pixel_size).byte().to(device)

              assert img_tensor.shape[-2:] == labels.shape[-2:]
              detected_leukocytes += len(torch.unique(monocytes + lymphocytes)) - 1
              leukocytes_dots += len(dots)


              if len(torch.unique(monocytes)) > 1:
                crops,masks = get_masked_patches(monocytes,img_tensor, patch_size=patch_size)
                crops = (crops).to(torch.uint8)
                masks = (masks).to(torch.uint8)
                x_monocytes =(torch.cat((crops,masks),dim= 1))
                y_monocytes = torch.zeros(len(x_monocytes),dtype = torch.long) + 1
              else:
                x_monocytes = torch.zeros(0,4,patch_size,patch_size).to(device)
                y_monocytes = torch.zeros(0,dtype = torch.long) + 1


              if len(torch.unique(lymphocytes)) > 1:
                crops,masks = get_masked_patches(lymphocytes,img_tensor, patch_size=patch_size)
                crops = (crops).to(torch.uint8)
                masks = (masks).to(torch.uint8)
                x_lymphocytes =(torch.cat((crops,masks),dim= 1))
                y_lymphocytes = torch.zeros(len(x_lymphocytes),dtype = torch.long) + 0
              else:
                x_lymphocytes = torch.zeros(0,4,patch_size,patch_size).to(device)
                y_lymphocytes = torch.zeros(0,dtype = torch.long) + 0

              if len(torch.unique(other_cells)) > 1:
                crops,masks = get_masked_patches(other_cells,img_tensor, patch_size=patch_size)
                crops = (crops).to(torch.uint8)
                masks = (masks).to(torch.uint8)
                x_other =(torch.cat((crops,masks),dim= 1))
                y_other = torch.zeros(len(x_other),dtype = torch.long) + 2
              else:
                x_other = torch.zeros(0,4,patch_size,patch_size).to(device)
                y_other = torch.zeros(0,dtype = torch.long) + 2

              x = torch.cat((x_monocytes,x_lymphocytes,x_other),dim = 0)
              y = torch.cat((y_monocytes,y_lymphocytes,y_other),dim = 0).numpy()[:,None]

              if len(x) != len(y):
                    pdb.set_trace()

              data_ds = f[f"{split}/data"]
              labels_ds = f[f"{split}/labels"]

              data_ds.resize((data_ds.shape[0] + x.shape[0],) + x.shape[1:])
              data_ds[-x.shape[0]:, ...] = (x).cpu().numpy().astype(np.uint8)
              labels_ds.resize((labels_ds.shape[0] + y.shape[0],) + y.shape[1:])
              labels_ds[-y.shape[0]:, ...] = y.astype(np.uint8)

          

  undetected_percent = ( leukocytes_dots - detected_leukocytes) / leukocytes_dots
  print(f"Detected {detected_leukocytes} out of {leukocytes_dots} dots. { 100 - undetected_percent * 100:.2f}% detected")
      

  labels = labels.to(device) * torch.tensor(mask).bool()
  dots = torch.tensor(dots, dtype=torch.long)
100%|██████████| 81/81 [27:58<00:00, 20.72s/it]


Detected 77574 out of 90343 dots. 85.87% detected


TO CONTINUE, YOU NEED TO TRAIN A MODEL ON THE IHC IMAGES (see Readme)

In [None]:
from pathlib import Path
import numpy as np
import fastremap
import os
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from instanseg.utils.utils import show_images, _move_channel_axis
from instanseg.utils.pytorch_utils import get_masked_patches, _to_tensor_float32
from instanseg.inference_class import _rescale_to_pixel_size
from tiling import get_random_non_empty_tiles
from tiffslide import TiffSlide


os.environ['INSTANSEG_DATASET_PATH'] = "../datasets/"

instanseg_script = torch.jit.load("instanseg/instanseg_brightfield_monkey.pt")
brightfield_nuclei = InstanSeg(instanseg_script, verbosity = 0)

import os
os.environ["INSTANSEG_OUTPUT_PATH"] = "../outputs/"
from utils import get_classifier
classifier = get_classifier("1922985").to("cuda").eval() #THIS IS THE IHC CLASSIFIER. THERE IS A COPY ON GITHUB RELEASES

import ttach as tta
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),  
    ]
)
tta_classifier = tta.ClassificationTTAWrapper(classifier, transforms, merge_mode='mean').eval()


patch_size = 128
destination_pixel_size = 0.5
normalise = True

device = "cpu"

monkey_dir = Path("../Raw_Datasets/Monkey")
files = os.listdir(os.path.join(monkey_dir ,"annotations","xml"))


np.random.seed(0)
import h5py
with h5py.File(Path(os.environ['INSTANSEG_DATASET_PATH']) / "monkey_cpg_silver.h5", "w") as f:

    f.attrs['class_names'] = str({"0": "lymphocytes", "1": "monocytes", "2" : "other"})  # Convert to string since HDF5 attributes must be simple types
    f.attrs['pixel_size'] = destination_pixel_size

    for split in ['train', 'val']:
        f.create_dataset(f"{split}/data", shape=(0, 4, patch_size, patch_size), 
                        dtype=np.uint8, maxshape=(None, 4, patch_size, patch_size), 
                        chunks=(1, 4, patch_size, patch_size),
                      #  compression = "lzf",
        )
        f.create_dataset(f"{split}/labels", shape=(0, 1), dtype=np.uint8, maxshape=(None, 1))


    for file in tqdm(files):

        split = annotations_dict[file][0]["split"]

        img_pascpg_path = Path(monkey_dir) / ("images/pas-cpg/" + file.split(".")[0] + "_PAS_CPG.tif")
        ihc_path = Path(monkey_dir) / ("images/ihc/" + file.split(".")[0] + "_IHC_CPG.tif")
        
        slidepascpg = TiffSlide(img_pascpg_path)
        slideihc = TiffSlide(ihc_path)

        tiles_he,tiles_ihc = get_random_non_empty_tiles(slidepascpg,slideihc, num_images=1000, tile_size=1024) #400


        for tile_he,tile_ihc in zip(tiles_he,tiles_ihc):

           # show_images(tile_he,tile_ihc,labels)
       
            labels , input_tensor = brightfield_nuclei.eval_small_image(tile_he,
            pixel_size = 0.2420, rescale_output = False, seed_threshold = 0.05)

            ihc_tensor = _rescale_to_pixel_size(_to_tensor_float32(tile_ihc), 0.2420, destination_pixel_size).byte().to(device)

            he_tensor = _rescale_to_pixel_size(_to_tensor_float32(tile_he), 0.2420, destination_pixel_size).byte().to(device)

            if labels.sum() == 0:
                continue

            assert ihc_tensor.shape[-2:] == he_tensor.shape[-2:]
            assert ihc_tensor.shape[-2:] == labels.shape[-2:]

            crops,masks = get_masked_patches(labels.to(device),ihc_tensor, patch_size=patch_size)
            crops = (crops) / 255
            masks = (masks)
            x_ihc =(torch.cat((crops,masks),dim= 1))

            crops,masks = get_masked_patches(labels.to(device),he_tensor, patch_size=patch_size)
            crops = (crops).to(torch.uint8)
            masks = (masks).to(torch.uint8)
            x =(torch.cat((crops,masks),dim= 1)).cpu().numpy().astype(np.uint8)

            with torch.no_grad():
                batch_size = 128
               # y_hat_he = torch.cat([classifier_he.forward(x[i:i+batch_size].float().to("cuda")) for i in range(0,len(x_ihc),batch_size)],dim = 0)
               # y_hat_he = y_hat_he.argmax(dim = 1).cpu()

                y_hat= torch.cat([tta_classifier.forward(x_ihc[i:i+batch_size].float().to("cuda")) for i in range(0,len(x_ihc),batch_size)],dim = 0)
                y_hat = y_hat.argmax(dim = 1).cpu()

          
            # show_images(*x_ihc[y_hat == 1][:8,:3],n_cols = 8)
            # show_images(*x_ihc[y_hat == 0][:8,:3],n_cols = 8)

            y = y_hat.numpy()[:,None]

            unique, counts = np.unique(y, return_counts=True)
            min_count = counts.min()
            y_subset = np.concatenate([y[y == i][:min_count + 10] for i in range(3)])
            x_subset = np.concatenate([x[(y == i).squeeze()][:min_count + 10] for i in range(3)])


            if x_subset.ndim == 5:
                x_subset = x_subset[0]
            x = x_subset
            y = y_subset[:,None]

            data_ds = f[f"{split}/data"]
            labels_ds = f[f"{split}/labels"]

            data_ds.resize((data_ds.shape[0] + x.shape[0],) + x.shape[1:])
            data_ds[-x.shape[0]:, ...] = x
            labels_ds.resize((labels_ds.shape[0] + y.shape[0],) + y.shape[1:])
            labels_ds[-y.shape[0]:, ...] = y.astype(np.uint8)





In [None]:

import os
import torch
os.environ["INSTANSEG_OUTPUT_PATH"] = "../outputs"
from utils import get_classifier

model = "1937330" 

classifier = get_classifier(model).to("cpu").eval()
torch.jit.save(torch.jit.script(classifier.path_classifier.eval()), f"{model}.pt")