## Visualizing Augmentations

In order to see if augmentations actually make sense for the dataset at hand, it is often helpful to visualize the augmentations. For that, we supply two visualization functions:

1. Single Image Comparison: This functions plots the original image and the augmented image side by side. 
2. Augmentation Grid: This function takes a dictionary containing augmentations and plots them in a grid alongside the original image.
3. Single Image Augmentation: This function augments the original image with the corruptions provided by the MedMNIST-C API (di Salvo, Doerrich & Ledig (2024)) for the PathMNIST dataset.

**Note:** *It is important to note, that the last step in each augmentation is the transformation into a `torch.Tensor`.*

In [None]:
from domgen.augment import plot_single_augmented, plot_augmented_grid, get_examples
from medmnistc.corruptions.registry import CORRUPTIONS_DS
from pathlib import Path
from PIL import Image
import numpy as np

image_path = '../imgs/cat.jpg'
augmentations = get_examples() # defines and returns a dictionary of example augmentations. 

## Plot a Single Image

In [None]:
plot_single_augmented(image_path, augmentations['Solarize'])

## Plot a Grid of Images

In [None]:
plot_augmented_grid(image_path, augmentations, grid_cols=4)

In [None]:
from domgen.data import PACS
from domgen.augment import pacs_aug, denormalize
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

augment = pacs_aug
data = PACS(root='../datasets/', test_domain=1, augment=augment)

train, val, test = data.generate_loaders(batch_size=6)

idx_to_class = data.idx_to_class
print(idx_to_class)

def show(img, label):
    fig = plt.figure()
    plt.axis('off')
    plt.imshow(img.permute(1,2,0))
    plt.title(label)

In [None]:
for i in range(3):
    images, labels = next(iter(train))
    images = [denormalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) for (k, img) in images.items()]
    grid = make_grid(images[0], scale_each=True)
    labels = [idx_to_class[label.item()] for label in labels]
    show(grid, labels)

## Plot a Single Image with Corruptions from MedMNIST-C

In [None]:
original_image_path = "../imgs/original/camelyon17/patch_patient_004.png"

pathmnist_corruptions = CORRUPTIONS_DS.get("pathmnist", {})

output_dir = Path("../imgs/augmented/camelyon17/diSalvo")
output_dir.mkdir(parents=True, exist_ok=True)

severity = 3
for corruption_name, corruption_instance in pathmnist_corruptions.items():
    print(f"Applying corruption: {corruption_name} with severity {severity}")

    original_image_pil = Image.open(original_image_path)

    if original_image_pil.mode == 'RGBA':
        original_image_pil = original_image_pil.convert('RGB')

    corrupted_image = corruption_instance.apply(original_image_pil, severity=severity)

    if isinstance(corrupted_image, np.ndarray):
        corrupted_image_pil = Image.fromarray(corrupted_image)
    else:
        corrupted_image_pil = corrupted_image

    output_path = output_dir / f"{corruption_name.lower().replace(' ', '_')}_sev{severity}.png"
    corrupted_image_pil.save(output_path, format='JPEG')
    print(f"Saved PNG image at: {output_path}")