### Imports

In [6]:
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch.nn.functional as F

### Defining the model

In [7]:
class Generator(nn.Module):
    def __init__(self, input_dim=10, image_channel=1, hidden_dim=64):
        super(Generator, self).init__()

        self.input_dim = input_dim
        self.gen = nn.Sequential(
            self._generator_block(input_dim, hidden_dim * 4),
            self._generator_block(
                hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1
            ),
            self._generator_block(hidden_dim * 2, hidden_dim),
            self._generator_block(
                hidden_dim, image_channel, kernel_size=4, stride=2, final_layer=True
            ),
        )

    def _generator_block(
        self,
        input_channels,
        output_channels,
        kernel_size=3,
        stride=2,
        final_layer=False,
    ):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(
                    input_channels, output_channels, kernel_size, stride
                ),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(
                    input_channels, output_channels, kernel_size, stride
                ),
                nn.Tanh(),
            )
            
    def forward(self, noise):
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)

In [8]:
def create_noise_vector(n_samples, input_dim, device='gpu'):
    return torch.randn(n_samples, input_dim, device=device)

In [9]:
class Discriminator(nn.Module):
    def __init__(self, image_channel=1, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self._discriminator_block(
                input_channels=image_channel, output_channels=hidden_dim
            ),
            self._discriminator_block(
                input_channels=hidden_dim, output_channels=hidden_dim * 2
            ),
            self._discriminator_block(
                input_channels=hidden_dim * 2, output_channels=1, final_layer=True
            ),
        )

    def _discriminator_block(
        self,
        input_channels,
        output_channels,
        kernel_size=4,
        stride=2,
        final_layer=False,
    ):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

### Define utility methods

In [None]:
torch.manual_seed(0)  # set seed for reproducibility


def plot_images_from_tensor(
    image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True
):
    """
    Reason for doing "image_grid.permute(1, 2, 0)":
    - PyTorch modules process image in the format (C, H, W)
    - Matplotlib and Pillow expects image in the format (H, W, C)
    - could also do "np.transpose(image_grid, (1, 2, 0))"
    Tensor.detach() is used to detach a tensor from the current computation graph.
    It returns a new tensor that doesn't require gradients.
    When we don't need a tensor to be traced for the gradient computation we use detach().
    We also need to detach a tensor when we need to move the tensor to a different device.
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()


def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 0.0, 0.02)
        nn.init.constant_(m.bias, 0)

def ohe_vector_from_labels(labels, n_classes):
    return F.one_hot(labels, n_classes=n_classes)