In [1]:
import os
import cv2
import math
import json
import glob
import random
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import torch.nn.functional as F

from utils.sdf import SDF
from utils.util import load_mask, load_png, load_sdm, logits_to_lbl
from utils.dataset import DatasetTools
from configs.cfgparser import Config

from model.MobaNet import MobaNet
    

cfg  = Config('configs/config.yaml', inference = True, cli = False)
cfg.RANK = 0

[INFO] Configuration file passed all validation tests.


In [2]:
loader = DatasetTools.predict_dataloader(cfg)

In [9]:
for batch in loader:
    ids = batch['id']
    input = batch['image']
    break

In [12]:
input.shape

torch.Size([10, 1, 512, 512])

In [2]:
imIDs = [id.stem for id in cfg.MSK_DIR.glob('*.png')]

label_to_id = {}
id_to_label = {}

for id in tqdm(imIDs, desc="[PREP] Generating class labels"):
    maskpath = cfg.MSK_DIR / f"{id}.png"
    mask = load_mask(maskpath)

    counts = np.bincount(mask.ravel(), minlength=cfg.SEG_CLASSES)
    max_label = int(counts.argmax())
    max_fraction = counts[max_label] / mask.size

    if max_fraction >= 0.99:
        lbl = max_label
    else:
        lbl = cfg.SEG_CLASSES

    label_to_id.setdefault(str(lbl), []).append(id)
    id_to_label[id] = lbl

[PREP] Generating class labels: 100%|██████████| 4933/4933 [00:06<00:00, 732.99it/s]


In [8]:
sorted(label_to_id.keys())

['0', '1', '2']

In [70]:
# MASK STATES
# 1. UNet input:         1hot encoded; (B, C, H, W)
# 2. SDM generation:     1hot encoded; (B, C, H, W)
# 3. Losses DICE-based:  1hot encoded; (B, C, H, W)
# 4. Loss sdm-based:     1hot encoded; (B, C, H, W)
# 5. Losses CE-based:    Argmax();     (B, H, W)
# 6. Metrics:            1hot encoded; (B, C, H, W)

# SEG LOGIT STATES
# 1. Losses DICE-based:  Probabilistic [0; 1];     (B, C, H, W)
# 2. Loss sdm-based:     Tanh() sdm-union [-1; 1]; (B, 1, H, W)
# 3. Losses CE-based:    Direct logits;            (B, C, H, W)
# 4. Metrics:            1hot encoded;             (B, C, H, W)
# 5. Visualization:      Argmax()                  (B, 1, H, W)

# CLS LOGIT STATES
# 1. Loss CE:            Direct logits; (B, C)
# 2. Metrics:            Argmax()       (B, ) 
# 3. UNet forward():     Argmax()       (B, )

In [None]:
class Predictor:
    def __init__(self,
                 device: str,
                 checkpoint: Path):

        self.device = device
        self.checkpoint = checkpoint

        weights = torch.load(self.checkpoint, 
                             map_location=self.device,
                             mmap=True,
                             weights_only=True)['weights']
        self.model = MobaNet(cfg)
        self.model.load_state_dict(weights)
        self.model = self.model.to(self.device)
        self.model.eval()

    def predict(self, input: str | Path | np.ndarray | torch.Tensor) -> torch.Tensor:
        """
        Predicts the output for the given input image.
        
        Args
        ----
            input : (str | Path | np.ndarray | torch.Tensor)
                Input image path, numpy array, or tensor.
                - np.ndarray: should be of shape (H, W) or (H, W, C) or (B, H, W, C).
                - torch.Tensor: should be of shape (H, W) or (H, W, C) or (B, H, W, C).

        Returns
        -------
            output : torch.Tensor (B, C, H, W)
                The model's output logits tensor.

        Raises
        ------
            ValueError
                If the input array or tensor has an unsupported shape.

            TypeError
                If the input type is unsupported.
        """
        
        if isinstance(input, (str, Path)):
            # Load image from file; (H, W, C)
            input = load_png(input)

            # add batch dimension; (H, W, C) → (1, H, W, C)
            input = input[None, ...] 
            
            # make pytorch compatible; (1, H, W, C) → (1, C, H, W)
            tensor = torch.from_numpy(input).permute(0, 3, 1, 2).float() 

        elif isinstance(input, (np.ndarray, torch.Tensor)):
            # Expand dimensions batch and channel dims (if necessary) 
            if input.ndim == 2:
                input = input[None, None, ...]  # (H, W) → (1, 1, H, W)
            elif input.ndim == 3:
                input = input[None, ...]  # (H, W, C) → (1, H, W, C)
            elif input.ndim > 4:
                raise ValueError(f"Unsupported input shape: {input.shape}. Expected an 2D, 3D or 4D (batched) input.")
                      
            # Convert input to tensor if it's a numpy array and permute dimensions for PyTorch
            tensor = torch.from_numpy(input).float() if isinstance(input, np.ndarray) else input
            tensor = tensor.permute(0, 3, 1, 2)

        else:
            raise TypeError(f"Unsupported input type: {type(input)}. Expected str, Path, np.ndarray, or torch.Tensor.")
        
        tensor = tensor.to(self.device)
        with torch.no_grad():
            return self.model(tensor)
        
    def save_predictions(self,
                         output: torch.Tensor | np.ndarray,
                         image: torch.Tensor | np.ndarray,
                         imID: str,
                         overlay: bool = True):
        
        
        
    # def save_predictions(mask: torch.Tensor, image: torch.Tensor, overlay = True) -> None:
    #     # Save mask as 8-bit image (0, 255)
    #     mask_8bit = (mask * 255).astype(np.uint8)
    #     cv2.imwrite("mask_8bit.png", mask_8bit)

    #     if  overlay:
    #         # Apply colormap to the mask
    #         mask_color = cv2.applyColorMap(mask_8bit, cv2.COLORMAP_JET)
    #         image_rgb = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
    #         overlay = cv2.addWeighted(image_rgb, 0.7, mask_color, 0.3, 0)
    #         cv2.imwrite("overlay.png", overlay)

In [None]:
def save_predictions(logits: dict[str, torch.Tensor], 
                     image: torch.Tensor, 
                     imID: str,
                     output_dir: Path, 
                     overlay: bool = True) -> None:

    pd_mask = logits['seg']
    pd_mask = F.softmax(pd_mask, dim=1).argmax(dim=1)[0].cpu().numpy()

    # Save mask as 8-bit image (0, 255)
    mask_8bit = (pd_mask * 255).astype(np.uint8)
    mask_path = output_dir / f"{imID}.png"
    cv2.imwrite(str(mask_path), mask_8bit)

    if overlay:
        # Apply colormap to the mask
        mask_color = cv2.applyColorMap(mask_8bit, cv2.COLORMAP_JET)


        image_rgb = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
        overlay = cv2.addWeighted(image_rgb, 0.7, mask_color, 0.3, 0)
        overlay_path = output_dir / f"{imID}_overlay.png"
        cv2.imwrite(str(overlay_path), overlay)




In [4]:
impath = cfg.IMG_DIR / 'ESP_018244_2655_RED-3456_0.png'
maskpath = cfg.MSK_DIR / 'ESP_018244_2655_RED-3456_0.png'
mask = load_mask(maskpath, C = 2)

counts = np.bincount(mask.ravel(), minlength=2)



In [5]:
counts

array([262144, 262144], dtype=int64)

In [None]:
# impath = cfg.IMG_DIR / 'ESP_018244_2655_RED-3456_0.png'
# img = load_png(impath)
# # img = img[None, ...]  # add batch dimension
# # tensor = torch.from_numpy(img).permute(0, 3, 1, 2).float() # (B, H, W, C) -> (B, C, H, W)
# tensor = torch.from_numpy(img).permute(2, 0, 1).float()  # (H, W, C) -> (C, H, W)


In [None]:
# Post-processing for a clean line: After prediction, compute the boundary as the set of pixels where a 4- or 8-connected 
# neighborhood changes class, or subtract eroded from dilated masks to get a 1-pixel ridge.

(2, 512, 512)

In [None]:
# def save_predictions(mask: torch.Tensor, image: torch.Tensor, overlay = True) -> None:
#     # Save mask as 8-bit image (0, 255)
#     mask_8bit = (mask * 255).astype(np.uint8)
#     cv2.imwrite("mask_8bit.png", mask_8bit)

#     if  overlay:
#         # Apply colormap to the mask
#         mask_color = cv2.applyColorMap(mask_8bit, cv2.COLORMAP_JET)
#         image_rgb = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
#         overlay = cv2.addWeighted(image_rgb, 0.7, mask_color, 0.3, 0)
#         cv2.imwrite("overlay.png", overlay)

True

In [2]:
# High Level Description of prediction stragtegy:
# 1. INPUT: 
#    - A direct path to an image file
#    - A directory containing multiple image files
#    - An image as a numpy array or tensor
#
# def predict(input: str | np.ndarray | torch.Tensor, cfg: Config) -> None: