In [1]:
import torch  # IMPORTING TORCH LIBRARY
from torch import nn  # IMPORTING NEURAL NETWORK MODULE
from tqdm.auto import tqdm  # IMPORTING TQDM FOR PROGRESS BARS
from torchvision import transforms  # IMPORTING TRANSFORMS FROM TORCHVISION
from torchvision.datasets import MNIST  # IMPORTING MNIST DATASET
from torchvision.utils import make_grid  # IMPORTING MAKE_GRID TO CREATE IMAGE GRIDS
from torch.utils.data import DataLoader  # IMPORTING DATALOADER
import matplotlib.pyplot as plt  # IMPORTING MATPLOTLIB FOR PLOTTING
torch.manual_seed(0)  # SETTING SEED FOR REPRODUCIBILITY

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    # FUNCTION TO DISPLAY TENSOR IMAGES
    image_tensor = (image_tensor + 1) / 2  # NORMALIZE IMAGE TENSORS TO [0,1]
    image_unflat = image_tensor.detach().cpu()  # DETACH AND MOVE TO CPU
    image_grid = make_grid(image_unflat[:num_images], nrow=5)  # CREATE IMAGE GRID
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())  # PLOT IMAGES
    plt.show()  # SHOW PLOT

def make_grad_hook():
    # FUNCTION TO CREATE GRADIENT HOOK
    grads = []  # LIST TO STORE GRADIENTS
    def grad_hook(m):
        # INNER FUNCTION TO CHECK LAYER TYPE AND APPEND GRADIENT
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            grads.append(m.weight.grad)  # APPEND GRADIENT OF LAYER WEIGHTS
    return grads, grad_hook  # RETURN GRADIENT LIST AND HOOK FUNCTION


In [5]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim  # SET DIMENSION OF INPUT NOISE
        # BUILD THE NEURAL NETWORK
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),  # FIRST LAYER
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),  # SECOND LAYER
            self.make_gen_block(hidden_dim * 2, hidden_dim),  # THIRD LAYER
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),  # FINAL LAYER
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            # NOT FINAL LAYER, ADD CONVOLUTION, BATCHNORM, AND RELU
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),  # UPSAMPLING
                nn.BatchNorm2d(output_channels),  # NORMALIZATION
                nn.ReLU(inplace=True),  # ACTIVATION FUNCTION
            )
        else:
            # FINAL LAYER, ADD CONVOLUTION AND TANH
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),  # UPSAMPLING
                nn.Tanh(),  # OUTPUT ACTIVATION FUNCTION
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)  # RESHAPE NOISE TO IMAGE FORMAT
        return self.gen(x)  # PASS THROUGH NETWORK

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)  # RETURN RANDOM NOISE


### TRAINING INITIALIZATIONS

In [10]:
n_epochs = 100
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cpu'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [12]:
# SIMPLE CRITIC CLASS
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [13]:
gen = Generator(z_dim).to(device)  # INITIALIZE GENERATOR MODEL AND MOVE TO DEVICE (GPU/CPU) FOR COMPUTATION
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))  # CREATE OPTIMIZER FOR GENERATOR USING ADAM WITH SPECIFIED LEARNING RATE AND BETAS
crit = Critic().to(device)  # INITIALIZE CRITIC MODEL AND MOVE TO DEVICE (GPU/CPU) FOR COMPUTATION
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))  # CREATE OPTIMIZER FOR CRITIC USING ADAM WITH SPECIFIED LEARNING RATE AND BETAS

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)  # INITIALIZE WEIGHTS OF CONVOLUTIONAL LAYERS WITH NORMAL DISTRIBUTION
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)  # INITIALIZE WEIGHTS OF BATCHNORM LAYERS WITH NORMAL DISTRIBUTION
        torch.nn.init.constant_(m.bias, 0)  # INITIALIZE BIASES OF BATCHNORM LAYERS TO 0
gen = gen.apply(weights_init)  # APPLY WEIGHT INITIALIZATION TO GENERATOR
crit = crit.apply(weights_init)  # APPLY WEIGHT INITIALIZATION TO CRITIC


### GRADIENT PENALTY


In [15]:
def get_gradient(crit, real, fake, epsilon):
    mixed_images = real * epsilon + fake * (1 - epsilon)
    mixed_scores = crit(mixed_images)

    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient


In [17]:
# UNIT TEST
def test_get_gradient(image_shape):
    real = torch.randn(*image_shape, device=device) + 1
    fake = torch.randn(*image_shape, device=device) - 1
    epsilon_shape = [1 for _ in image_shape]
    epsilon_shape[0] = image_shape[0]
    epsilon = torch.rand(epsilon_shape, device=device).requires_grad_()
    gradient = get_gradient(crit, real, fake, epsilon)
    assert tuple(gradient.shape) == image_shape
    assert gradient.max() > 0
    assert gradient.min() < 0
    return gradient

# gradient = test_get_gradient((256, 1, 28, 28))
print("Success!")

Success!


In [19]:
def gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)

    penalty = (gradient_norm - 1) ** 2

    return penalty

In [22]:
def gradient_penalty(gradient):
    
    gradient = gradient.view(len(gradient), -1)  # FLATTEN GRADIENTS TO TREAT EACH ROW AS ONE IMAGE

    gradient_norm = gradient.norm(2, dim=1)  # CALCULATE MAGNITUDE (NORM) OF EACH ROW
    
    penalty = torch.mean((gradient_norm - 1) ** 2)  # CALCULATE MEAN SQUARED DISTANCE FROM 1 AND TAKE MEAN

    return penalty  # RETURN THE MEAN PENALTY


def test_gradient_penalty(image_shape):
    bad_gradient = torch.zeros(*image_shape)
    bad_gradient_penalty = gradient_penalty(bad_gradient)
    assert torch.isclose(bad_gradient_penalty, torch.tensor(1.))

    image_size = torch.prod(torch.Tensor(image_shape[1:]))
    good_gradient = torch.ones(*image_shape) / torch.sqrt(image_size)
    good_gradient_penalty = gradient_penalty(good_gradient)
    assert torch.isclose(good_gradient_penalty, torch.tensor(0.))

    random_gradient = test_get_gradient(image_shape)
    random_gradient_penalty = gradient_penalty(random_gradient)
    assert torch.abs(random_gradient_penalty - 1) < 0.1

# test_gradient_penalty((256, 1, 28, 28))
print("Success!")


Success!


In [23]:
def get_gen_loss(crit_fake_pred):
    
    gen_loss = -1.0 * torch.mean(crit_fake_pred)  # CALCULATE GENERATOR LOSS AS NEGATIVE MEAN OF CRITIC PREDICTIONS

    return gen_loss

In [24]:
# UNIT TEST
assert torch.isclose(
    get_gen_loss(torch.tensor(1.)), torch.tensor(-1.0)
)

assert torch.isclose(
    get_gen_loss(torch.rand(10000)), torch.tensor(-0.5), 0.05
)

print("Success!")

Success!


In [27]:
def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    
    crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp  # CALCULATE CRITIC LOSS    

    return crit_loss

In [28]:
assert torch.isclose(
    get_crit_loss(torch.tensor(1.), torch.tensor(2.), torch.tensor(3.), 0.1),
    torch.tensor(-0.7)
)
assert torch.isclose(
    get_crit_loss(torch.tensor(20.), torch.tensor(-20.), torch.tensor(2.), 10),
    torch.tensor(60.)
)

print("Success!")

Success!


In [30]:
import matplotlib.pyplot as plt

cur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic ###
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = crit(fake.detach())
            crit_real_pred = crit(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(crit, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            crit_opt.step()
        critic_losses += [mean_iteration_critic_loss]

        ### Update generator ###
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = crit(fake_2)
        
        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        gen_opt.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1


  0%|          | 0/469 [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [64, 3, 4, 4], expected input[128, 1, 28, 28] to have 3 channels, but got 1 channels instead