# Papillae Detection

## Description

This is a notebook that detects the papillae from M0 images.\
It uses a UResNet CNN model, that has been pretrained accordingly.\
**To use it:**
- Change "MODEL_PATH" to the path that points to your pretrained model (.pth extension)
- Execute every cell until the "Main" section
- Modify "IMAGE_PATH"
- If needed, modify "SAVE_PATH"
- Execute all cells of the "Main" section

## Imports and Constants

In [None]:
MODEL_PATH = "states/optic_disk_10.pth" # /!\ Change this to your model path
IMAGE_SIZE = 1024 # /!\ DO NOT CHANGE

⚠️⚠️⚠️ Image size\
If the image size of M0 pictures changes, do not modify this number.\
Right now, the M0 pictures have the size 1023 x 1023. So, one pixel is padded\
on the bottom and on the right of the picture in the "eval_image" function.\
**IF THE M0 SIZE CHANGES**, getting to 1024 x 1024 for example, you can remove the padding in "eval_image".

In [None]:
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from skimage.measure import label, regionprops
import numpy as np

print(torch.__version__)          # should show 2.8.0+cu126
print(torch.version.cuda)         # should show '12.6'
print(torch.cuda.is_available())  # should be True
print(torch.cuda.get_device_name(0))  # NVIDIA RTX A6000

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
num_workers = 1

if device == "cuda":
    print("Using GPU")
    num_workers = torch.cuda.device_count() * 4
else:
    print("Using CPU")

generator = torch.Generator(device=device)

In [None]:
def free_cache():
    """
    Releases unused CUDA memory
    """
    torch.cuda.empty_cache()

## UResNet

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        # To match dimensions for the skip connection if in/out channels differ
        self.shortcut = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
            if in_channels != out_channels else nn.Identity()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = self.shortcut(x)
        out = self.conv(x)
        out += residual
        return self.relu(out)

In [None]:
class DownResSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownResSample, self).__init__()
        self.conv = ResBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        y = self.conv(x)
        p = self.pool(y)
        return y, p

In [None]:
class UpResSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpResSample, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ResBlock(in_channels, out_channels)

    def forward(self, x1, x2):
        x2 = self.up(x2)
        x = torch.cat([x1, x2], dim=1)
        return self.conv(x)

In [None]:
class UResNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UResNet, self).__init__()
        self.down_conv1 = DownResSample(in_channels, 64)
        self.down_conv2 = DownResSample(64, 128)
        self.down_conv3 = DownResSample(128, 256)
        self.down_conv4 = DownResSample(256, 512)

        self.bottle_neck = ResBlock(512, 1024)

        self.up_conv1 = UpResSample(1024, 512)
        self.up_conv2 = UpResSample(512, 256)
        self.up_conv3 = UpResSample(256, 128)
        self.up_conv4 = UpResSample(128, 64)

        self.out = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        s1, down1 = self.down_conv1(x)
        s2, down2 = self.down_conv2(down1)
        s3, down3 = self.down_conv3(down2)
        s4, down4 = self.down_conv4(down3)

        b = self.bottle_neck(down4)

        up1 = self.up_conv1(s4, b)
        up2 = self.up_conv2(s3, up1)
        up3 = self.up_conv3(s2, up2)
        up4 = self.up_conv4(s1, up3)

        return self.out(up4)

In [None]:
def load_model(path: str = MODEL_PATH):
    """
    Loads pretrained model
    :param path: Trained model path (.pth extension) (str)
    :return: PyTorch model in UResNet format
    """
    m = UResNet(in_channels=1, num_classes=1).to(device)
    m.load_state_dict(torch.load(path, map_location=torch.device(device)))
    return m

## Connected

In [None]:
def keep_largest_connected_object(pred_mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
    """
    Keep only the largest connected object from a grayscale prediction tensor
    :param pred_mask: torch.Tensor of shape [1,H,W] or [H,W], values in [0,1]
    :param threshold: Threshold to binarize prediction
    :return: torch.Tensor of shape [1,H,W] with only the largest connected object
    """
    # Convert to numpy
    real_mask = pred_mask.squeeze().detach().cpu().numpy()

    # Threshold to binary
    binary = (real_mask > threshold).astype(np.uint8)

    if binary.sum() == 0:
        return torch.zeros_like(pred_mask)  # no object detected

    # Label connected components
    labeled = label(binary)
    regions = regionprops(labeled)

    if not regions:
        return torch.zeros_like(pred_mask)

    # Keep only the largest region
    largest_region = max(regions, key=lambda r: r.area)
    largest_mask = (labeled == largest_region.label)

    # Convert back to torch tensor
    return torch.from_numpy(largest_mask).unsqueeze(0).to(pred_mask.device).float()

## Save

In [None]:
def save_mask(m: torch.Tensor, path: str):
    """
    Save a mask tensor as a PNG image
    :param m: torch.Tensor of shape [1,H,W] or [H,W], values in [0,1] or [0,255]
    :param path: Where to save the PNG (str)
    """
    if path is None:
        return

    # Remove batch/channel dims if needed
    m = m.squeeze()

    # Scale to 0–255 if not already
    if m.max() <= 1.0:
        m = (m * 255).byte()
    else:
        m = m.byte()

    # Convert to PIL image and save
    img = Image.fromarray(m.cpu().numpy())
    img.save(path)

## Evaluation and Visualization

In [None]:
def visualize_prediction(img: torch.Tensor, mask: torch.Tensor, title: str = None):
    """
    Visualize an image tensor and its mask side by side with the overlay of both
    :param img: torch.Tensor of shape [1,H,W] or [H,W] (grayscale image)
    :param mask: torch.Tensor of shape [1,H,W] or [H,W] (binary/float mask)
    :param title: Optional string to display as figure title
    """
    # Squeeze batch/channel dims if necessary
    img_np = img.squeeze().cpu().numpy()
    mask_np = mask.squeeze().cpu().numpy()

    plt.figure(figsize=(8, 4))

    # Show image
    plt.subplot(1, 3, 1)
    plt.imshow(img_np, cmap="gray")
    plt.title("Image")
    plt.axis("off")

    # Show mask
    plt.subplot(1, 3, 2)
    plt.imshow(mask_np, cmap="gray")
    plt.title("Mask")
    plt.axis("off")

    # Show Overlay
    plt.subplot(1, 3, 3)
    plt.imshow(img_np, cmap="gray")
    plt.imshow(mask_np, cmap="Reds", alpha=0.1)
    plt.title("Overlay")
    plt.axis("off")

    if title is not None:
        plt.suptitle(title)

    plt.tight_layout()
    plt.show()

In [None]:
def eval_image(model, image_path: str, save_path: str = None):
    """
    Forward pass on a single image and keep only the largest connected object
    :param model: Trained PyTorch model
    :param image_path: Path to the input M0 image (str)
    :param save_path: Path to save mask (str | None)
    :return: torch.Tensor of shape [1,H,W] -> the largest connected object mask
    """
    # --- Predict mask ---
    model.eval()

    # Preprocessing
    transform = transforms.Compose([
        transforms.Pad(padding=(0, 0, 1, 1)),  # /!\ because base image is 1023 x 1023
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor()
    ])

    # Load image
    img = Image.open(image_path).convert("L")
    img_tensor = transform(img).unsqueeze(0).to(device)  # add batch dim

    # Forward pass
    with torch.no_grad():
        pred = model(img_tensor)

    # Remove batch dimension -> shape [1,H,W]
    pred_mask = pred.squeeze(0)

    # Apply the largest object filter
    largest_mask = keep_largest_connected_object(pred_mask, threshold=0.5)

    # --- Draw ---
    visualize_prediction(img_tensor, largest_mask)

    # --- Save ---
    if save_path is not None:
        save_mask(largest_mask, save_path)

    return largest_mask

## Main

In [None]:
trained_model = load_model(MODEL_PATH)

In [None]:
# /!\ The M0 image to scan TODO
IMAGE_PATH = "data/optic_disk/train/0676.png"

# /!\ The path to save the mask to or None (if you don't want to save it) TODO
SAVE_PATH = None

In [None]:
# Calculate and save mask
eval_image(trained_model, IMAGE_PATH, SAVE_PATH)
free_cache()