In [5]:
import torch
from torch import Tensor
from torch.nn import Module, BCELoss
from torch.optim import Adam, Optimizer

from torchvision.datasets import VOCSegmentation
import torchvision.transforms as transforms
import torchvision.transforms.functional as func
import torchvision.utils

from torch.utils.data import Dataset, DataLoader

from typing import Tuple

def convert_target_pil_to_tensor(pil_img) -> Tensor:
    """Convert the target mask from pillow image to tensor.

    Parameters
    ----------
    pil_img : PIL.Image
        The segmentation mask as a pillow image.

    Returns
    -------
    target : Tensor
        The segmentation mask as a (C, H, W) Tensor.

    Notes
    -----
    If:
        grey[i, j] = 0, target[:, i, j] = [1, 0, 0, ...]
        grey[i, j] = 1, target[:, i, j] = [0, 1, 0, ...]
        grey[i, j] = 2, target[:, i, j] = [0, 0, 1, ...],

        etc.

    """
    grey = func.pil_to_tensor(pil_img).squeeze()
    grey[grey == 255] = 21
    num_classes = 22
    target = torch.eye(num_classes)[grey.long()].permute(2, 0, 1).float()
    return target

'''
randcrop_images = transforms.Compose(
    [
        transforms.RandomResizedCrop(size=572,scale=(0.25,1.0)),
        transforms.ToTensor(),
    ]
)

randcrop_targets = transforms.Compose(
    [
        transforms.RandomResizedCrop(size=572,scale=(0.25,1.0)),
        convert_target_pil_to_tensor, 
        # Resizing because images all different sizing. Could instead pad or make custom collate.
        # Set to 572 x 572 to match original UNet paper
        #transforms.Resize([572, 572]),
    ]
)
'''

img_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        # Resizing because images all different sizing. Could instead pad or make custom collate.
        # Set to 572 x 572 to match original UNet paper
        transforms.Resize([572, 572]),
    ]
)

target_transforms = transforms.Compose(
    # Set to 572 x 572 to match original UNet paper
    [convert_target_pil_to_tensor, transforms.Resize([572, 572])]
)


def get_data_set_and_loader(img_set) -> Tuple[Dataset, DataLoader]:
    """Return a dataset and dataloader to use in training/validation.

    Parameters
    ----------
    args : Namespace
        Command-line arguments.
    img_set : Image Set to use. 
        "train" or "val"

    Returns
    -------
    data_set : Dataset
        The requested dataset.
    data_loader : DataLoader
        The requested dataloader.

    """
    if img_set == 'train':
        shuffle_img = True
    elif img_set == 'val':
        shuffle_img = False
    else:
        raise ValueError(f"Image set option {img_set} is not acceptable.")


    data_set = VOCSegmentation(
        "data",
        #image_set="train",
        image_set=img_set,
        download=False,
        transform=img_transforms,
        target_transform=target_transforms,
    )

    #data_set = data_subset(args, data_set)

    data_loader = DataLoader(
        data_set, 
        batch_size=4, 
        shuffle=False,
        num_workers=0,
    )

    return data_set, data_loader

dataset, loader = get_data_set_and_loader("train")

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display

def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))


augment_transforms = transforms.Compose(
    # Set to 572 x 572 to match original UNet paper
    [   
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(572, scale=(0.25, 1.0)
    ]
)

def data_augmenter(images, targets):
    """Transform a batch of images and targets using a random resized crop"""
    img_concat = torch.cat((images, targets), dim=1)
    augmented = augment_transforms(img_concat)
    images_new = augmented[:, :3, :, :]
    targets_new = augmented[:, 3:, :, :]

    return images_new, targets_new


dataiter = iter(loader)
images, labels = dataiter.next()

images, labels = data_augmenter(images, labels)

print(images.size())
print(labels.size())

#img_concat = torch.cat((images, labels), dim=1)

#augmented = transforms.RandomResizedCrop(572, scale=(0.25, 1.0))(img_concat)

#images_new = augmented[:, :3, :, :]
#labels_new = augmented[:, 3:, :, :]

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=False)

labels = labels.argmax(dim=1).unsqueeze(dim=1).float()
labels.size()

img_grid = torchvision.utils.make_grid(labels)
matplotlib_imshow(img_grid, one_channel=True)

SyntaxError: closing parenthesis ']' does not match opening parenthesis '(' on line 146 (1829623210.py, line 147)