In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(                  ## 64x64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),   ## 32x32
            nn.LeakyReLU(0.02),
            self._block(features_d, features_d*2, kernel_size=4, stride=2, padding=1),     ## 16x16
            self._block(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1),   ## 8x8
            self._block(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1),   ## 4x4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),                ## 1x1
            # nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, \
                kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.02)
        )
        
    def forward(self, x):
        return self.disc(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, feature_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(                            ## 1x1
            self._block(z_dim, feature_g*16, 4, 1, 0),       ## 4x4
            self._block(feature_g*16, feature_g*8, 4, 2, 1), ## 8x8
            self._block(feature_g*8, feature_g*4, 4, 2, 1),  ## 16x16
            self._block(feature_g*4, feature_g*2, 4, 2, 1),  ## 32x32
            nn.ConvTranspose2d(in_channels=feature_g*2, out_channels=channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )


    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, \
                kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.gen(x)

In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)
    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print('SUCCESS')

test()

In [None]:
## Hyperparameters

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 128
NUM_EPOCHS = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

print(DEVICE)

In [None]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)])
    ]
)

In [None]:
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)
loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(DEVICE)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(DEVICE)
initialize_weights(gen)
initialize_weights(critic)

# initializate optimizer
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(DEVICE)

In [None]:
for epoch in range(NUM_EPOCHS):
    gen.train()
    critic.train()
    
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(DEVICE)
        ## Training Critic
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(DEVICE)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()
                
        noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(DEVICE)
        fake = gen(noise)
            
        ## Train Generator
        output = critic(fake).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx%100 == 0:
            with torch.no_grad():
                print(f'epoch - {epoch}/{NUM_EPOCHS} || {batch_idx}/{len(loader)} || loss D: {loss_critic} || loss G: {loss_gen}')
                gen_fake_img = gen(fixed_noise)
                img_grid_fake = torchvision.utils.make_grid(
                    gen_fake_img[:32], normalize=True
                ).cpu()
                plt.imshow(img_grid_fake.permute(1, 2, 0))
                plt.show()