In [None]:
import os
import json
import glob
import torch
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import ujson
import ray
import math
from dataclasses import dataclass
from typing import Union

# Import your CellViT modules
from cellvit.models.cell_segmentation.cellvit_virchow import CellViTVirchow
from cellvit.utils.tools import unflatten_dict
from cellvit.inference.inference_disk import CellViTInference
from cellvit.inference.postprocessing_cupy import (
    BatchPoolingActor,
    DetectionCellPostProcessorCupy,
)
from cellvit.config.config import TYPE_NUCLEI_DICT_PANNUKE

class region_dataset(Dataset):
    def __init__(self, image_path, transform=None, patch_size=256, overlap = 32):
        img_bgr = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        self.img_np = np.array(img_rgb)
        self.patch_size = patch_size
        self.region_coords = []
        self.transform = transform
        h, w, _ = self.img_np.shape
        h = h-256
        w = w-256
        self.overlap = overlap
        self.shift = self.patch_size - self.overlap
        for i in range(0, h, self.shift):
            # if i + self.patch_size > h:
            #     i = max(0, h - self.patch_size)
            for j in range(0, w, self.shift):
                # if j + self.patch_size > w:
                    # j = max(0, w - self.patch_size)
                self.region_coords.append((i, j))
                # if j + self.shift >= w:
                    # break
            # if i + self.shift >= h:
                # break
    def __len__(self):
        return len(self.region_coords)
    
    def __getitem__(self, idx):
        x, y = self.region_coords[idx]
        patch = self.img_np[x:min(x + self.patch_size, self.img_np.shape[0]),
                            y:min(y + self.patch_size, self.img_np.shape[1])]
        
        # Pad if necessary
        # if patch.shape[0] < self.patch_size or patch.shape[1] < self.patch_size:
        #     patch = np.pad(patch, 
        #                   ((0, self.patch_size - patch.shape[0]), 
        #                    (0, self.patch_size - patch.shape[1]), 
        #                    (0, 0)), 
        #                   mode='constant')
        
        if self.transform:
            patch = self.transform(patch)
        
        metadata = {
            "col": y // self.shift,
            "row": x // self.shift,
            "position_x": x,
            "position_y": y,
        }
    
        return patch, metadata

@dataclass
class WSIMetadata:
    slide_path: Union[str, Path]
    metadata: dict

def get_wsi_metadata(image_path, patch_size=256, downsample=1, patch_overlap=32, label_map=None, normalize_stains=False):
    # Get image dimensions
    img = cv2.imread(image_path)
    h, w = img.shape[:2]

    shift_size = patch_size - patch_overlap
    
    # n_cols = math.ceil((w - patch_size) / shift_size) + 1 if w > patch_size else 1
    # n_rows = math.ceil((h - patch_size) / shift_size) + 1 if h > patch_size else 1
    
    n_cols = 4 #math.ceil(w / patch_size)
    n_rows = 4 #math.ceil(h / patch_size)
    
    slide_mag = 40  
    slide_mpp = 0.25  
    resulting_mpp = slide_mpp * downsample

    wsi_metadata = {
        "orig_n_tiles_cols": n_cols,
        "orig_n_tiles_rows": n_rows,
        "base_magnification": slide_mag,
        "downsampling": downsample,
        "label_map": label_map if label_map is not None else {},
        "patch_overlap": patch_overlap,
        "patch_size": patch_size,
        "base_mpp": slide_mpp,
        "target_patch_mpp": resulting_mpp,
        "stain_normalization": normalize_stains,
        "magnification": slide_mag / (downsample * 1.0),
        "level": 8,
    }

    return wsi_metadata

def default_collate_fn(batch):
    patches = []
    metadatas = []
    for patch, metadata in batch:
        patches.append(patch)
        metadatas.append(metadata)
    return torch.stack(patches), metadatas

def process_single_region(image_path, model, device, run_conf, infer_cell, batch_size=8, overlap=32):  # Reduced further
    """Process a single region image and save GeoJSON results"""
    
    image_path_obj = Path(image_path)
    output_dir = image_path_obj.parent
    region_name = image_path_obj.stem
    
    geojson_path = output_dir / f"{region_name}_cells.geojson"
    # Uncomment to skip existing files
    # if geojson_path.exists():
    #     print(f"Skipping {region_name}, GeoJSON already exists")
    #     return
    
    print(f"Processing region: {region_name}")
    
    # Setup transforms and dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    dataset = region_dataset(image_path=image_path, transform=transform, patch_size=256,overlap=overlap )
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=0, 
                           collate_fn=default_collate_fn, pin_memory=True)

    wsi_metadata = get_wsi_metadata(
        image_path=image_path,
        patch_size=256,
        downsample=1,
        patch_overlap= overlap,  
        label_map={"tumor": 1, "normal": 0},
        normalize_stains=False,
    )
    wsi = WSIMetadata(
        slide_path=image_path,
        metadata=wsi_metadata,
    )

    postprocessor = DetectionCellPostProcessorCupy(
        wsi=wsi,
        nr_types=run_conf["data"]["num_nuclei_classes"],
        resolution=0.25,
        classifier=None, 
        binary=False,
    )

    # Use fewer ray actors to avoid resource contention
    ray_actors = min(2, batch_size)  # Reduced from 4-8 to max 2
    
    try:
        # Initialize ray once with appropriate resources
        if not ray.is_initialized():
            ray.init(num_cpus=ray_actors, num_gpus=0.1, ignore_reinit_error=True)
        
        batch_pooling_actors = [
            BatchPoolingActor.remote(postprocessor, run_conf)
            for _ in range(ray_actors)
        ]
        
        call_ids = []
        with torch.inference_mode():
            for batch_num, (patches, metadata) in enumerate(dataloader):
                batch_actor = batch_pooling_actors[batch_num % ray_actors]
                patches = patches.to(device, non_blocking=True)
                
                predictions = model(patches, retrieve_tokens=True)
                predictions = infer_cell.apply_softmax_reorder(predictions=predictions)
                predictions = {k: v.cpu() if isinstance(v, torch.Tensor) else v 
                              for k, v in predictions.items()}
                
                call_id = batch_actor.convert_batch_to_graph_nodes.remote(predictions, metadata)
                call_ids.append(call_id)

            # Process results in smaller chunks to avoid memory issues
            inference_results = []
            for i in range(0, len(call_ids), 4):  # Process 4 at a time
                chunk = call_ids[i:i+4]
                inference_results.extend(ray.get(chunk))

        # Result processing
        cell_dict_wsi = []
        cell_dict_detection = []
        
        label_map = TYPE_NUCLEI_DICT_PANNUKE 
        # label_map = {int(k): v for k, v in label_map.items()}
        
        graph_data = {
            "cell_tokens": [],
            "positions": [],
            "metadata": {
                "wsi_metadata": wsi.metadata,
                "nuclei_types": label_map,
            },
        }
        
        for batch_results in inference_results:
            (
                batch_complete_dict,
                batch_detection,
                batch_cell_tokens,
                batch_cell_positions,
            ) = batch_results
            
            cell_dict_wsi.extend(batch_complete_dict)
            cell_dict_detection.extend(batch_detection)
            graph_data["cell_tokens"].extend(batch_cell_tokens)
            graph_data["positions"].extend(batch_cell_positions)
        
        # ADD ROBUST ERROR HANDLING HERE
        try:
            keep_idx = infer_cell._post_process_edge_cells(cell_list=cell_dict_wsi)
            
            # Validate indices before using them
            if not keep_idx:
                print(f"Warning: Empty keep_idx for {region_name}, using all cells")
                keep_idx = list(range(len(cell_dict_wsi)))
            
            # Filter out invalid indices
            valid_keep_idx = [idx for idx in keep_idx if 0 <= idx < len(cell_dict_wsi)]
            
            if len(valid_keep_idx) != len(keep_idx):
                print(f"Warning: Filtered {len(keep_idx) - len(valid_keep_idx)} invalid indices in {region_name}")
            
            cell_dict_wsi = [cell_dict_wsi[idx_c] for idx_c in valid_keep_idx]
            cell_dict_detection = [cell_dict_detection[idx_c] for idx_c in valid_keep_idx]
            graph_data["cell_tokens"] = [graph_data["cell_tokens"][idx_c] for idx_c in valid_keep_idx]
            graph_data["positions"] = [graph_data["positions"][idx_c] for idx_c in valid_keep_idx]
            
        except Exception as e:
            print(f"Error in post-processing for {region_name}: {str(e)}")
            # Fallback: use all detected cells without filtering
            keep_idx = list(range(len(cell_dict_wsi)))
            print(f"Using fallback: all {len(keep_idx)} cells")
        
        final_cell_dict = {
            "wsi_metadata": wsi.metadata,
            "type_map": label_map,
            "cells": cell_dict_wsi,  
        }
        
        # Save results
        geojson_list = infer_cell._convert_json_geojson(final_cell_dict["cells"], True)
        
        with open(geojson_path, "w") as outfile:
            ujson.dump(geojson_list, outfile)
        
    except Exception as e:
        print(f"Error processing {region_name}: {str(e)}")
        # Create empty file on error to avoid reprocessing
        with open(geojson_path, "w") as outfile:
            ujson.dump([], outfile)
    finally:
        # Don't shutdown ray here - let the main function handle it
        torch.cuda.empty_cache()

def main():
    # Load model
    model_path = "/home/shivam/nuc_seg/CellViT-plus-plus-main/checkpoints/CellViT-Virchow-x40-AMP-001.pth"
    model_checkpoint = torch.load(model_path, map_location="cpu")
    run_conf = unflatten_dict(model_checkpoint["config"], ".")
    
    model = CellViTVirchow(
        model_virchow_path=None,
        num_nuclei_classes=run_conf["data"]["num_nuclei_classes"],
        num_tissue_classes=run_conf["data"]["num_tissue_classes"]
    )
    model.load_state_dict(model_checkpoint["model_state_dict"])
    
    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    run_conf["model"]["token_patch_size"] = model.patch_size
    
    infer_cell = CellViTInference(model_path=model_path, gpu=1)
    
    # Find all region folders
    base_directory = "/home/shivam/nuc_seg/CellViT-plus-plus-main/inference_data"
    region_folders = [f for f in os.listdir(base_directory) 
                  if os.path.isdir(os.path.join(base_directory,f)) and f.endswith("_10_regions")]
    
    print(f"Found {len(region_folders)} region folders")
    
    try:
        # Process all region images
        for folder in tqdm(region_folders, desc="Processing folders"):
            folder_path = os.path.join(base_directory, folder)
            image_files = glob.glob(os.path.join(folder_path, "*.png"))
            
            print(f"Processing {len(image_files)} images in {folder}")
            
            for image_path in image_files:
                process_single_region(
                    image_path=image_path,
                    model=model,
                    device=device,
                    run_conf=run_conf,
                    infer_cell=infer_cell,
                    batch_size=8 
                )
    
    finally:
        # Clean up ray at the end
        if ray.is_initialized():
            ray.shutdown()
        torch.cuda.empty_cache()
    
    print("All regions processed successfully!")

if __name__ == "__main__":
    main()

No checkpoint provided!
2025-09-12 17:28:50,623 [INFO] - Loading model: /home/shivam/nuc_seg/CellViT-plus-plus-main/checkpoints/CellViT-Virchow-x40-AMP-001.pth
No checkpoint provided!
2025-09-12 17:28:56,692 [INFO] - <All keys matched successfully>
2025-09-12 17:29:01,760 [INFO] - Based on the hardware we limit the batch size to a maximum of:
2025-09-12 17:29:01,762 [INFO] - 8
2025-09-12 17:29:01,763 [INFO] - Loading inference transformations


2025-09-12 17:29:06,086	INFO worker.py:1724 -- Started a local Ray instance.


2025-09-12 17:29:10,916 [INFO] - Using 4 ray-workers
Found 11 region folders


Processing folders:   0%|                                | 0/11 [00:00<?, ?it/s]

Processing 10 images in CAIB-T00001508PC01R01P0101HE_10_regions
Processing region: CAIB-T00001508PC01R01P0101HE_x15233_y9145
2025-09-12 17:30:02,877 [INFO] - Initializing Cell-Postprocessor
2025-09-12 17:30:02,886 [INFO] - Finding edge-cells for merging
2025-09-12 17:30:02,903 [INFO] - Removal of cells detected multiple times
2025-09-12 17:30:02,924 [INFO] - Iteration 0: Found overlap of # cells: 9
2025-09-12 17:30:02,942 [INFO] - Iteration 1: Found overlap of # cells: 0
2025-09-12 17:30:02,943 [INFO] - Found all overlapping cells
Processing region: CAIB-T00001508PC01R01P0101HE_x40623_y12193
2025-09-12 17:30:23,861 [INFO] - Initializing Cell-Postprocessor
2025-09-12 17:30:23,868 [INFO] - Finding edge-cells for merging
2025-09-12 17:30:23,884 [INFO] - Removal of cells detected multiple times
2025-09-12 17:30:23,901 [INFO] - Iteration 0: Found overlap of # cells: 4
2025-09-12 17:30:23,916 [INFO] - Iteration 1: Found overlap of # cells: 0
2025-09-12 17:30:23,917 [INFO] - Found all overlap

Processing folders:   9%|█▉                   | 1/11 [06:24<1:04:04, 384.44s/it]

Processing 10 images in CAIB-T00000691OC01R01P0404HE_10_regions
Processing region: CAIB-T00000691OC01R01P0404HE_x8880_y15788
2025-09-12 17:36:11,658 [INFO] - Initializing Cell-Postprocessor
2025-09-12 17:36:11,664 [INFO] - Finding edge-cells for merging
2025-09-12 17:36:11,681 [INFO] - Removal of cells detected multiple times
2025-09-12 17:36:11,710 [INFO] - Iteration 0: Found overlap of # cells: 5
2025-09-12 17:36:11,736 [INFO] - Iteration 1: Found overlap of # cells: 0
2025-09-12 17:36:11,737 [INFO] - Found all overlapping cells
Processing region: CAIB-T00000691OC01R01P0404HE_x32562_y20721
2025-09-12 17:36:52,324 [INFO] - Initializing Cell-Postprocessor
2025-09-12 17:36:52,332 [INFO] - Finding edge-cells for merging
2025-09-12 17:36:52,347 [INFO] - Removal of cells detected multiple times
2025-09-12 17:36:52,375 [INFO] - Iteration 0: Found overlap of # cells: 10
2025-09-12 17:36:52,403 [INFO] - Iteration 1: Found overlap of # cells: 0
2025-09-12 17:36:52,404 [INFO] - Found all overla

In [None]:
import os
import glob
from pathlib import Path
import json

def geojson_shift_handling(geojson_path):
    with open(geojson_path, 'r') as f:
        data = json.load(f)
    # min_x = float('inf')
    # min_y = float('inf')
    
    # for feature in data:
    #     for polygon_set in feature['geometry']['coordinates']:
    #         for polygon in polygon_set:
    #             for coord in polygon:
    #                 min_x = min(min_x, coord[0]) # print(ring)
    #                 min_y = min(min_y, coord[1])
    min_x = -32//2
    min_y = -32//2
    x_shift = -min_x
    # print(x_shift)
    y_shift = -min_y

    for feature in data:
        transformed_coords = []
        for polygon_set in feature['geometry']['coordinates']:
            transformed_polygons = []
            for polygon in polygon_set:
                transformed_poly = [[x + x_shift, y + y_shift] for x,y in polygon]
                transformed_polygons.append(transformed_poly)
            transformed_coords.append(transformed_polygons)
        feature['geometry']['coordinates'] = transformed_coords

    geojson_path_obj = Path(geojson_path)
    output_dir = geojson_path_obj.parent
    region_name = geojson_path_obj.stem
    geojson_fixed_path = output_dir / f"{region_name}_FIXED.geojson"

    with open(geojson_fixed_path, 'w') as f:
        json.dump(data, f, indent=2)

base_directory = "/home/shivam/nuc_seg/CellViT-plus-plus-main/inference_data"
region_folders = [f for f in os.listdir(base_directory) 
                  if os.path.isdir(os.path.join(base_directory,f)) and f.endswith("_10_regions")]
for folder in region_folders:
    folder_path = os.path.join(base_directory, folder)
    geojson_files = glob.glob(os.path.join(folder_path, "*_cells.geojson"))
    for geojson_path in geojson_files:
        geojson_shift_handling(geojson_path)

print("done..!!")

In [None]:
x_s