In [None]:
from pathlib import Path
import os

import numpy as np
from PIL import Image
import torch
from tqdm import tqdm
from rtnls_inference import (
    SegmentationEnsemble,
)
from rtnls_inference.utils import decollate_batch

In [None]:
ds_path = Path("../samples/fundus")

# these are the output folders for:
av_path = ds_path / "av"                # artery-vein segmentations
vessels_path = ds_path / "vessels"          # optic disc segmentations

device = torch.device("cuda:1")         # device to use for inference

In [None]:
# Load models
ensemble_av = SegmentationEnsemble.from_huggingface('Eyened/vascx:artery_vein/av_july24.pt').to(device).eval()
ensemble_vessels = (
    SegmentationEnsemble.from_huggingface('Eyened/vascx:vessels/vessels_july24.pt').to(device).eval()
)


In [None]:
rgb_paths = list((ds_path / 'original').glob('*'))

In [None]:
rgb_paths

In [None]:

# Create dataloader
dataloader = ensemble_av._make_inference_dataloader(
    rgb_paths,
    num_workers=8,
    preprocess=True,
    batch_size=8,
)

In [None]:
# Run inference
av_masks = []
vessel_masks = []
with torch.no_grad():
    for batch in tqdm(dataloader):
        # AV segmentation
        with torch.autocast(device_type=device.type):
            proba = ensemble_av.forward(batch["image"].to(device))
        proba = torch.mean(proba, dim=1)  # average over models
        proba = torch.permute(proba, (0, 2, 3, 1))  # NCHW -> NHWC
        proba = torch.nn.functional.softmax(proba, dim=-1)

        items = {
            "id": batch["id"],
            "image": proba,
        }

        items = decollate_batch(items)
        for i, item in enumerate(items):
            fpath = os.path.join(av_path, f"{item['id']}.png")
            mask = np.argmax(item["image"], -1)
            av_masks.append(mask.squeeze().astype(np.uint8))

        # Vessel segmentation
        with torch.autocast(device_type=device.type):
            proba = ensemble_vessels.forward(batch["image"].to(device))
        proba = torch.mean(proba, dim=1)  # average over models
        proba = torch.permute(proba, (0, 2, 3, 1))  # NCHW -> NHWC
        proba = torch.nn.functional.softmax(proba, dim=-1)

        items = {
            "id": batch["id"],
            "image": proba,
        }

        items = decollate_batch(items)
        for i, item in enumerate(items):
            fpath = os.path.join(vessels_path, f"{item['id']}.png")
            mask = np.argmax(item["image"], -1)
            vessel_masks.append(mask.squeeze().astype(np.uint8))


In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.imshow(av_masks[0])

In [None]:
plt.imshow(av_masks[1])