# CellViT-Hibou Model Usage Example

This notebook showcases the basic usage of the CellViT-Hibou model for segmentation of cell nuclei.

In [None]:
import torch
import numpy as np
import cv2
from torchvision import transforms
from hibou import CellViTHibou
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
ckpt_path = "/home/azureuser/cellvit-hibou-l.pth"
model = CellViTHibou(
    hibou_path=None,  # we don't need to load hibou model separately as it is already included in the checkpoint
    num_nuclei_classes=6,
    num_tissue_classes=19,
    embed_dim=1024,
    input_channels=3,
    depth=24,
    num_heads=16,
    extract_layers=[6,12,18,24],
)
model.load_state_dict(torch.load(ckpt_path))
model.eval()
model = model.to(device)
print("Model loaded successfully.")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.7068, 0.5755, 0.7220], std=[0.1950, 0.2316, 0.1816]),
    transforms.Resize((256, 256))
])

#### Load the test image.

In [None]:
orig_image = cv2.imread("images/sample.png", cv2.IMREAD_COLOR)
orig_image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
plt.imshow(orig_image)
plt.axis('off')
plt.show()

#### Segment the image using the CellViT-Hibou model. Draw the segmentation mask on the image.

In [None]:
global_contour_visualization = orig_image.copy()

for i in range(0, orig_image.shape[0], 256):
    for j in range(0, orig_image.shape[1], 256):
        image = orig_image[i:i+256, j:j+256]
        image = transform(image).unsqueeze(0)
        with torch.no_grad():
            output = model(image.to(device))

        output["nuclei_binary_map"] = output["nuclei_binary_map"].softmax(dim=1)
        output["nuclei_type_map"] = output["nuclei_type_map"].softmax(dim=1)

        for key in output.keys():
            if isinstance(output[key], torch.Tensor):
                output[key] = output[key].cpu()

        (_, instance_types) = model.calculate_instance_map(output, magnification=20)
        cells = instance_types[0]
        for cell in cells.values():
            contour = cell["contour"]
            contour = contour + np.array([j, i])
            cv2.drawContours(global_contour_visualization, [contour], -1, (255, 0, 0), 3)

Visualize the results.

In [None]:

orig_im_vis = cv2.addWeighted(orig_image, 0.5, global_contour_visualization, 0.5, 0)

# Displaying the images
plt.figure(figsize=(21, 11))
plt.subplot(1, 2, 1)
plt.title('Original Image with Contours')
plt.imshow(orig_im_vis)
plt.axis('off')

plt.show()

#### In this simple example we used the model only to find the cell nuclei in the image. The model is also trained to predict nuclei types and tissue types. 

We can check all the outputs for the single cell and tissue type for the patch. For mapping between cell/tissue types and their numerical labels check `configs/dataset_config.yaml` file.

In [None]:
cell = next(iter(cells.values()))

print("Cell information:")
for key, value in cell.items():
    print(f"{key}: {value}")

# there is also a tissue type map in the output dictionary
print("\nTissue information:")
print(f"Tissue type: {output['tissue_types'].argmax(dim=1).item()}")