In [None]:
from hubmap.data import DATA_DIR
from hubmap.dataset import TrainDataset, ValDataset
import hubmap.dataset.transforms as T
from torch.utils.data import DataLoader
import torch
import torch
import matplotlib.pyplot as plt

import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import torch

import scienceplots as _

from hubmap.visualization.visualize_mask import mask_to_rgba
from hubmap.dataset import label2id, label2title

In [None]:
def plot_example(image, target, other_image, other_target, alpha=0.6):
    plt.style.use(["science", "nature"])
    
    fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(7, 5))

    colors = {
        "blood_vessel": "tomato",
        "glomerulus": "dodgerblue",
        "unsure": "palegreen",
        "background": "black",
    }
    colors = colors
    cmap = {label2id[l]: colors[l] for l in colors.keys()}

    target_rgba = mask_to_rgba(target, color_map=cmap, bg_channel=3, alpha=alpha)
    image_np = image.permute(1, 2, 0).numpy()
    
    other_target_rgba = mask_to_rgba(other_target, color_map=cmap, bg_channel=3, alpha=alpha)
    other_image_np = other_image.permute(1, 2, 0).numpy()

    axs[0].imshow(image_np)
    axs[0].set_title("First Example")
    axs[1].imshow(image_np)
    axs[1].imshow(target_rgba.permute(1, 2, 0))
    axs[1].set_title("Ground Truth Mask")
    
    axs[2].imshow(other_image_np)
    axs[2].set_title("Second Example")
    axs[3].imshow(other_image_np)
    axs[3].imshow(other_target_rgba.permute(1, 2, 0))
    axs[3].set_title("Ground Truth Mask")
    
    axs[0].set_xticks([], [])
    axs[0].set_yticks([], [])
    axs[1].set_xticks([], [])
    axs[1].set_yticks([], [])
    axs[2].set_xticks([], [])
    axs[2].set_yticks([], [])
    axs[3].set_xticks([], [])
    axs[3].set_yticks([], [])

    blood_vessel_patch = mpatches.Patch(
        facecolor=colors["blood_vessel"],
        label=f"{label2title['blood_vessel']}",
        edgecolor="black",
    )
    glomerulus_patch = mpatches.Patch(
        facecolor=colors["glomerulus"],
        label=f"{label2title['glomerulus']}",
        edgecolor="black",
    )
    unsure_patch = mpatches.Patch(
        facecolor=colors["unsure"],
        label=f"{label2title['unsure']}",
        edgecolor="black",
    )
    handles = [blood_vessel_patch, glomerulus_patch, unsure_patch]
    fig.legend(handles=handles, loc="upper center", bbox_to_anchor=(0.5, 0.325), ncol=4, frameon=False)
    fig.tight_layout()
    return fig

In [None]:
transforms_augment = T.Compose([
    T.ToTensor(mask_as_integer=False),
    T.Resize((512, 512)),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
])

transforms_val = T.Compose([
    T.ToTensor(mask_as_integer=False),
    T.Resize((512, 512)),
])
train = TrainDataset(DATA_DIR, transform=transforms_augment, with_background=True, as_id_mask=False)
val = ValDataset(DATA_DIR, transform=transforms_val, with_background=True, as_id_mask=False)

batch_size = 6
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)

In [None]:
selected_images = []
selected_masks = []

for images, masks in train_loader:
    if len(selected_images) >= 2:
        break

    for i in range(images.shape[0]):
        mask = torch.argmax(masks[i], dim=0).cpu().numpy()
        class_counts = Counter(mask.flatten())

        if all(class_counts.get(c, 0) > 0 for c in range(3)):
            selected_images.append(images[i])
            selected_masks.append(masks[i])
            if len(selected_images) >= 2:
                break

In [None]:
image1 = selected_images[0]
image2 = selected_images[1]

target1 = selected_masks[0]
target2 = selected_masks[1]

In [None]:
ex = plot_example(image1, target1, image2, target2, alpha=0.6)
ex.savefig("examples_with_labels_alpha.svg", format="svg")