In [None]:
import cv2
import albumentations as A
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from sds_playground.datasets import CaDISv2_Dataset
from sds_playground.utils import denormalize, convert_mask_to_RGB, convert_to_binary_mask, convert_to_integer_mask

In [None]:
train_ds = CaDISv2_Dataset(
    root='/local/scratch/CaDISv2/',
    spatial_transform=A.Compose([
        A.Resize(128, 128, interpolation=cv2.INTER_LINEAR)
    ]),
    img_normalization=A.Normalize(.5, .5),
    exp=2,
    mode='train',
    filter_mislabeled=True,
    sample_mask=True,
    sample_img=True
)
train_dl = DataLoader(train_ds, shuffle=True, batch_size=1)

In [None]:
len(train_ds)

In [None]:
for img, mask, _, _ in train_dl:
    print(torch.unique(mask))

In [None]:
val_ds = CaDISv2_Dataset(
    root='/local/scratch/CaDISv2/',
    spatial_transform=A.Compose([
        A.Resize(128, 128, interpolation=cv2.INTER_LINEAR)
    ]),
    img_normalization=A.Normalize(.5, .5),
    exp=2,
    mode='val',
    filter_mislabeled=True,
    sample_mask=True,
    sample_img=True
)
val_dl = DataLoader(val_ds, shuffle=True, batch_size=16)

In [None]:
test_ds = CaDISv2_Dataset(
    root='/local/scratch/CaDISv2/',
    spatial_transform=A.Compose([
        A.Resize(128, 128, interpolation=cv2.INTER_LINEAR)
    ]),
    img_normalization=A.Normalize(.5, .5),
    exp=2,
    mode='test',
    filter_mislabeled=True,
    sample_mask=True,
    sample_img=True
)
test_dl = DataLoader(test_ds, shuffle=True, batch_size=16)

In [None]:
print("Total: ", len(train_ds) + len(test_ds) + len(val_ds))
print("Train: ", len(train_ds))
print("Val: ", len(val_ds))
print("Test: ", len(test_ds))

In [None]:
img, int_mask, name, label = next(iter(train_dl))

In [None]:
print(f"{int_mask.max()=}")
print(f"{int_mask.min()=}")
print(f"{train_ds.num_classes=}")

In [None]:
fig, ax = plt.subplots(2, img.shape[0], figsize=(img.shape[0] * 3, 2 * 3))
_img = F.interpolate(img, train_ds.original_shape[1:], mode='bilinear')
_mask = F.interpolate(int_mask.unsqueeze(1).float(), train_ds.original_shape[1:], mode='nearest')
_mask = _mask.round().squeeze(1).long()
_mask_rgb = convert_mask_to_RGB(_mask, train_ds.get_cmap(), ignore_index=train_ds.ignore_index)
for n in range(img.shape[0]):
    ax[0, n].imshow(denormalize(_img[n], .5, .5).permute(1, 2, 0).cpu().numpy())
    ax[0, n].axis('off')
    ax[1, n].imshow(_mask_rgb[n].permute(1, 2, 0).cpu().numpy())
    ax[1, n].axis('off')
plt.tight_layout()
plt.autoscale()
plt.show()

In [None]:
print(int_mask.shape)
binary_mask = convert_to_binary_mask(int_mask, num_classes=train_ds.num_classes, ignore_index=train_ds.ignore_index)
print(binary_mask.shape)

In [None]:
int_from_binary_mask = convert_to_integer_mask(binary_mask, ignore_index=train_ds.ignore_index)
print(int_from_binary_mask.shape)

In [None]:
fig, ax = plt.subplots(2, img.shape[0], figsize=(img.shape[0] * 3, 2 * 3))

plot_mask = F.interpolate(int_mask.unsqueeze(1).float(), train_ds.original_shape[1:], mode='nearest')
plot_mask = plot_mask.round().squeeze(1).long()
plot_mask_rgb = convert_mask_to_RGB(plot_mask, train_ds.get_cmap(), ignore_index=train_ds.ignore_index)

___plot_mask = F.interpolate(int_from_binary_mask.unsqueeze(1).float(), train_ds.original_shape[1:], mode='nearest')
___plot_mask = ___plot_mask.round().squeeze(1).long()
___plot_mask_rgb = convert_mask_to_RGB(___plot_mask, train_ds.get_cmap(), ignore_index=train_ds.ignore_index)
for n in range(img.shape[0]):
    ax[0, n].imshow(plot_mask_rgb[n].permute(1, 2, 0).cpu().numpy())
    ax[0, n].axis('off')
    ax[1, n].imshow(___plot_mask_rgb[n].permute(1, 2, 0).cpu().numpy())
    ax[1, n].axis('off')
plt.tight_layout()
plt.autoscale()
plt.show()

In [None]:
torch.equal(int_mask, int_from_binary_mask)

In [None]:
print(int_mask.min())
print(int_mask.max())
print(int_from_binary_mask.min())
print(int_from_binary_mask.max())

In [None]:
print(torch.unique(int_mask))
print(torch.unique(int_from_binary_mask))