## Visualizing the Dataset

First, we create the handler-class `PACS` (`Camelyon17`). Afterwards, we get the training and validation dataloaders. These loaders contain the concatenated datasets except for the `test_domain`. We set a `batch_size` of 6 for better visualization.
The label mappings are stored directly in the handler class. This assumes that there is no label shift across domains.

Lastly, we define a custom `show()` function that helps us visualize the batches with their corresponding labels and call the function on three batches from the train set.

In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from domgen.data import PACS, Camelyon17
from domgen.augment import denormalize
from domgen.eval import plot_domain_images
import albumentations as A

augment = {"noop": A.NoOp(p=1)}
name = "Camelyon"
if name == "PACS":
    data = PACS('../datasets/', test_domain=0, augment=augment)
else:
    data = Camelyon17('../datasets/', test_domain=0, augment=augment)

train, val, test = data.generate_loaders(batch_size=6)

idx_to_class = data.idx_to_class
print(idx_to_class)

def show(img, label):
    fig = plt.figure()
    plt.axis('off')
    plt.imshow(img.permute(1,2,0))
    plt.title(label)

### Visualize Grids with Labels

Visualizing batches from the training and test set. Since the images are normalized, we need to denormalize them before visualizing.

First, we visualize from the train set. There should be no members of the test domain.


In [None]:
for i in range(3):
    images, labels = next(iter(train))
    images = [denormalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) for (k, img) in images.items()]
    grid = make_grid(images[0], scale_each=True)
    labels = [idx_to_class[label.item()] for label in labels]
    show(grid, labels)

Next, we visualize from the test set. There should only be members from the test domain:

In [None]:
for i in range(1):
    images, labels = next(iter(test))
    images = [denormalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) for img in images]
    grid = make_grid(images, scale_each=True)
    labels = [idx_to_class[label.item()] for label in labels]
    show(grid, labels)

### Inspecting Class Images from the Different Domains

To get greater insight into the different domains, we inspect images from the different classes from the domains.

In [None]:
dataset_path = data.dir
domains = data.domains
classes = data.classes
num_images = 5

for domain in domains:
    for class_name in classes:
        plot_domain_images(dataset_path, domain, class_name, num_images=num_images)