In [2]:
import os.path
import random
import numpy as np
from PIL import Image

import torch
import torchvision.transforms as transforms

from tqdm import tqdm

In [10]:
def get_transform(method=Image.BICUBIC, normalize=True, flip_prob=0.5, resize=True):
    transform_list = []

    osize = [128, 128]
    if resize:
        transform_list.append(transforms.Resize(osize, method))

    transform_list.append(transforms.RandomHorizontalFlip(p=0.5))
    transform_list += [transforms.ToTensor()]

    if normalize:
        transform_list += [
            transforms.Normalize(
                (0.5, 0.5, 0.5),
                (0.5, 0.5, 0.5)
            )]

    return transforms.Compose(transform_list)

In [41]:
label_path = "/scratch/as3ek/github/HistoMask/data/segvae/lizard/classes/consep_1__0_0.png"

In [42]:
label = Image.open(label_path)
label = Image.fromarray(np.array(label).astype(np.uint8))

In [43]:
np.unique(label)

array([0, 2], dtype=uint8)

In [44]:
transform_label = get_transform(method=Image.NEAREST, normalize=False, flip_prob=0)
label_tensor = transform_label(label) * 255.0

In [45]:
torch.unique(label_tensor)

tensor([0., 2.])

In [46]:
label_tensor.size()

torch.Size([1, 128, 128])

In [47]:
image_path = "/scratch/as3ek/github/HistoMask/data/segvae/lizard/images/consep_1__0_0.png"

image = Image.open(image_path)
image = image.convert('RGB')

# NOTE: if label flip, image should flip as well...
transform_image = get_transform(method=Image.BICUBIC, normalize=True, flip_prob=0)
image_tensor = transform_image(image)



In [48]:
image_tensor.size()

torch.Size([3, 128, 128])

In [52]:
# label
label = label_tensor.long()
# normalize
image = 0

# create one-hot label map
label_map = label.unsqueeze(0)
bs, _, h, w = label_map.size()
nc = 7 + 1

input_label = torch.FloatTensor(bs, nc, h, w).zero_()
input_semantics = input_label.scatter_(1, label_map, 1.0)

# NOTE: set bg as 0.
input_semantics[:, 0] = 0.

label_set = (input_semantics.view(bs, nc, -1).sum(-1) > 0).float()
label_set[:, 0] = 0. # set bg to 0

In [53]:
np.unique(label)

array([0, 2])

In [54]:
label_set

tensor([[0., 0., 1., 0., 0., 0., 0., 0.]])

In [62]:
label_path_lizard = "/scratch/as3ek/github/HistoMask/data/segvae/lizard/classes/consep_1__0_372.png"

In [63]:
label_lizard = Image.fromarray(np.array(Image.open(label_path_lizard)).astype(np.uint8))

In [64]:
np.unique(label_lizard)

array([0, 2, 3, 6], dtype=uint8)

In [65]:
label_lizard_tensor = transform_label(label_lizard) * 255.0

In [66]:
torch.unique(label_lizard_tensor)

tensor([0., 2., 3., 6.])