# Inference Test

In [None]:
import argparse
import os
import sys

import torch
import torch.optim as optim

import neptune
from neptune_pytorch import NeptuneLogger
from neptune.utils import stringify_unsupported

import albumentations as A
from PIL import Image, ImageDraw
import numpy as np
import matplotlib.pyplot as plt

sys.path.append(os.path.abspath(".."))
from dataloader.puma_dataset import PumaDataset
from hover_net.dataloader.dataset import get_dataloader
from hover_net.models import HoVerNetExt
from hover_net.process import proc_valid_step_output, train_step, valid_step
from hover_net.tools.utils import (dump_yaml, read_yaml,
                                   update_accumulated_output)

In [None]:
config = read_yaml('../configs/consep_config.yaml')

model = HoVerNetExt(
    backbone_name=config["MODEL"]["BACKBONE"],
    pretrained_backbone=config["MODEL"]["PRETRAINED"],
    num_types=config["MODEL"]["NUM_TYPES"]
)
model.load_state_dict(torch.load('../epoch_7.pth', weights_only=True))
model.eval()

In [None]:
IMAGE_DIR = '../data/01_training_dataset_tif_ROIs'

transform = A.Compose([
    A.RandomCrop(width=256, height=256)
])

images = os.listdir(IMAGE_DIR)

for image in images:
    img = Image.open(os.path.join(IMAGE_DIR, image))
    img = transform(image=np.array(img))['image']
    img_tensor = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float()
    img_tensor = img_tensor[:, 1:, :, :]
    print(img_tensor.shape)
    out = model(img_tensor)
    break

In [None]:
# Plot the output tensors
fig, axes = plt.subplots(1, 4, figsize=(15, 5))

# Plot tp tensor
axes[0].imshow(out['tp'].detach().numpy()[0, 0, :, :], cmap='viridis')
axes[0].set_title('TP Output')

# Plot np tensor
axes[1].imshow(out['np'].detach().numpy()[0, 0, :, :], cmap='viridis')
axes[1].set_title('NP Output')

# Plot hv tensor
axes[2].imshow(out['hv'].detach().numpy()[0, 0, :, :], cmap='viridis')
axes[2].set_title('HV Output')

axes[3].imshow(img, cmap='viridis')
axes[3].set_title('Image')

plt.show()

In [None]:
from hover_net.dataloader import get_dataloader
from hover_net.postprocess import process
from hover_net.process import infer_step, visualize_instances_dict
from hover_net.tools.api import parse_single_instance

def infer_one_image(
    image_path,
    model,
    nr_types=3,
    input_size=(512, 512),
    device="cuda",
    show=False
):
    inference_dataloader = get_dataloader(
        data_path=[image_path],
        input_shape=input_size,
        run_mode="inference_single"
    )

    detection_list = []
    segmentation_list = []
    for step_idx, data in enumerate(inference_dataloader):
        assert data.shape[0] == 1

        test_result_output = infer_step(
            batch_data=data, model=model, device=device
        )
        image_id = 0
        for curr_image_idx in range(len(test_result_output)):
            pred_inst, inst_info_dict = process(
                test_result_output[curr_image_idx],
                nr_types=nr_types,
                return_centroids=True
            )

            for single_inst_info in inst_info_dict.values():
                detection_dict, segmentation_dict = parse_single_instance(
                    image_id, single_inst_info
                )
                detection_list.append(detection_dict)
                segmentation_list.append(segmentation_dict)

            if show:
                src_image = data[0].numpy()
                type_info_dict = {
                    "0": ["nolabe", [0, 0, 0]],
                    "1": ["neopla", [255, 0, 0]],
                    "2": ["inflam", [0, 255, 0]],
                    "3": ["connec", [0, 0, 255]],
                    "4": ["necros", [255, 255, 0]],
                    "5": ["no-neo", [255, 165, 0]]
                }
                type_info_dict = {
                    int(k): (
                        v[0], tuple(v[1])
                    ) for k, v in type_info_dict.items()
                }
                overlay_kwargs = {
                    "draw_dot": True,
                    "type_colour": type_info_dict,
                    "line_thickness": 2,
                }
                overlaid_img = visualize_instances_dict(
                    src_image.copy(), inst_info_dict, **overlay_kwargs
                )
                plt.imshow(overlaid_img)
                plt.axis("off")
                plt.show()

    return inst_info_dict, detection_list, segmentation_list


In [None]:
IMAGE_DIR = '../data/01_training_dataset_tif_ROIs'

transform = A.Compose([
    A.RandomCrop(width=256, height=256)
])

images = os.listdir(IMAGE_DIR)

for image in images:
    img = Image.open(os.path.join(IMAGE_DIR, image))
    img = transform(image=np.array(img))['image']
    img_tensor = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float()
    img_tensor = img_tensor[:, 1:, :, :]
    print(img_tensor.shape)
    # out = model(img_tensor)
    break

"""
    image_path,
    model,
    nr_types=3,
    input_size=(512, 512),
    device="cuda",
    show=False
"""

model.to('cuda')
infer_one_image(os.path.join(IMAGE_DIR, images[0]), model, nr_types=4, input_size=(256, 256), show=True)