In [None]:
import yaml
from torch import Tensor
from torchvision import transforms
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from PIL.Image import NEAREST
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from code.voc_segmentation.datamodules.dataset import VOC2012SegmentationDataset
import matplotlib.pyplot as plt

with open("/home/haim/code/voc_segmentation/utils/parameters.yaml", "r") as yaml_file:
    parameters = yaml.safe_load(yaml_file)

In [None]:
image_dir = "/home/haim/hdd/data/voc/VOCdevkit/VOC2012/JPEGImages"
mask_dir = "/home/haim/hdd/data/voc/VOCdevkit/VOC2012/SegmentationObject"
train_file = "/home/haim/hdd/data/voc/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt"

In [None]:
image_transforms = Compose(
    [
        ToTensor(),
        Resize((374, 500), antialias=True),
    ]
)
mask_transforms = Compose(
    [
        ToTensor(),
        Resize((374, 500), interpolation=NEAREST, antialias=True), 
    ]
)
train_dataset = VOC2012SegmentationDataset(
    image_dir=image_dir,
    mask_dir=mask_dir,
    split="train",
    image_transforms=image_transforms,
    mask_transforms=mask_transforms,
    train_file=train_file,
)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=False, num_workers=3)

In [None]:
def plot_image_with_mask(img, mask):
    try:
        img_to_display = img # transforms.functional.to_pil_image(img)
        mask_to_display = transforms.functional.to_pil_image(mask)
        plt.imshow(img_to_display.transpose(2, 0))
        plt.imshow(mask_to_display.convert("RGB"), alpha=0.6)
        plt.show()
    except TypeError:
        img_to_display = img # transforms.functional.to_pil_image(img)
        mask_to_display = mask
        plt.imshow(img_to_display)  # .transpose(2, 0))
        plt.imshow(mask_to_display.convert("RGB"), alpha=0.6)
        plt.show()

In [None]:
img, mask = next(iter(train_dataset))

In [None]:
img_batch, mask_batch = next(iter(train_loader))

In [None]:
plot_image_with_mask(img, mask)