In [None]:
import os.path
import numpy as np
import matplotlib.pyplot as plt
import torch
from typing import Tuple

from ipynb.fs.full.definitions import latent_vector, OUTPUT_PATH, DEVICE

In [None]:
def plot_real_images(dataloader, _show: bool = False) -> None:
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    _dataset = dataloader.dataset

    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(_dataset), size=(1,)).item()
        img, label = _dataset[sample_idx]
        
        figure.add_subplot(rows, cols, i)
        plt.title("Malignent Melanoma")
        plt.axis("off")
        plt.imshow(img.permute(1, 2, 0))
    plt.tight_layout(pad=1.02)
    plt.savefig(os.path.join(OUTPUT_PATH, 'original_samples.png'))
    if _show:
        plt.show()
    plt.close()

In [None]:
def plot_fake_images(
        generator, n_images: int = 9, _show: bool = False) -> None:
    cols, rows = 3, 3
    fig, axs = plt.subplots(rows, cols, sharex='all')
    axs = axs.flatten()

    gen_z, label, _label_names = __generate_random_inputs(n_images)
    gen_images = generator(gen_z, label)
    images = gen_images.to("cpu").clone().detach()
    images = images.numpy().transpose(0, 2, 3, 1)

    for i in range(9):
        axs[i].set_title(_label_names[i])
        axs[i].set_axis_off()
        axs[i].imshow(images[i])
    plt.tight_layout(pad=1.04)
    plt.savefig(os.path.join(OUTPUT_PATH, 'synthetic_samples.png'))
    if _show:
        plt.show()
    plt.close()

In [None]:
def __generate_random_inputs(n_images: int) \
        -> Tuple[torch.Tensor, torch.Tensor, list]:
    gen_z = torch.randn(n_images, latent_vector, device=DEVICE)
    label = torch.zeros(n_images, 1, device=DEVICE)
    _label_names = []
    for i in range(n_images):
        x = np.random.randint(0, 1)
        label[i][x] = 1
        _label_names.append("MM")
    return gen_z, label, _label_names