In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


 The Generator Network (G) 🎨<br>
The Generator takes a random noise vector (latent vector) as input and tries to transform it into something that resembles the real data (e.g., an image).

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim , img_shape):
        super(Generator , self).__init__()
        self.img_shape = img_shape #(channels , height , width)

        self.model = nn.functional(
            nn.Linear(latent_dim , 128),
            nn.LeakyReLU(0.2 , inplace=True),

            nn.Linear(128 , 256),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2 , inplace=True),

            nn.Linear(256 , 512),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512 , int(np.prod(img_shape))), #np.prod is for multiplication
            nn.Tanh() #values between -1 and 1
        )

    def forward(self, z):
        # considering z as the input noise vector (batch_size , latent_dim)
        img = self.model(z)
        #reshaping the output to img shape
        img = img.view(img.size(0) , *self.img_shape) # * symbol unpacks the tuple
        return img