## Imports


In [16]:
import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter


In [17]:
!tensorboard --version


2.12.3


## Dataset

In [18]:
IN_CHANNELS = 1
IMG_SIZE = 64
BATCH_SIZE = 128

In [19]:
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5 for _ in range(IN_CHANNELS)], [0.5 for _ in range(IN_CHANNELS)])
])

In [20]:
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

## Model

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [22]:
def initialize_wieghts(model):

    for m in model.modules():

        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            torch.nn.init.normal_(m.weight.data ,0.0, 0.02)


In [23]:
class Generator(nn.Module):

    def __init__(self, img_channels):
        super().__init__()
        self.layers = nn.Sequential(
            self.gen_block(100, 1024, 4, 1, 0),
            self.gen_block(1024, 512, 4, 2, 1),
            self.gen_block(512, 256, 4, 2, 1),
            self.gen_block(256, 128, 4, 2, 1),
            nn.ConvTranspose2d(128, img_channels, 4, 2, 1),
            nn.Tanh()
        )

    def gen_block(self, in_channels, out_channels, kernel_size, stride, padding):
        block_layers = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),  
            nn.ReLU()
        )

        return block_layers
        
    def forward(self, x):
        x = self.layers(x)

        return x


In [24]:
class Discriminator(nn.Module):
    def __init__(self, img_channels):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(64, 128, 4, 2, 1),
            self._block(128, 256, 4, 2, 1),
            self._block(256, 512, 4, 2, 1),
            nn.Conv2d(512, img_channels, kernel_size=4, stride=2, padding=0)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,#
            ),

            nn.InstanceNorm2d(out_channels, affine=True),#
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)


In [25]:
generator = Generator(IN_CHANNELS).to(device)
initialize_wieghts(generator)

critic = Discriminator(IN_CHANNELS).to(device)
initialize_wieghts(critic)


## Train

In [26]:
fixed_noise = torch.randn(32, 100, 1, 1).to(device)

In [27]:
EPOCHS = 15
LR = 0.00005
C = 0.01
n_critic = 5

optimizer_G = optim.RMSprop(generator.parameters(), lr=LR)

optimizer_C = optim.RMSprop(critic.parameters(), lr=LR)

writer_fake = SummaryWriter(f"logs_WGAN/fake")
writer_real = SummaryWriter(f"logs_WGAN/real")
step = 0



In [28]:
generator.train()

Generator(
  (layers): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): ConvTranspose2d(128, 

In [29]:
critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

In [30]:
for epoch in range(EPOCHS):

    for i, (real_images, _) in enumerate(loader):
        #Training Discriminator

        critic.train()

        for _ in range(0, n_critic):

            real_images = real_images.to(device)

            noise = torch.randn(real_images.shape[0], 100, 1, 1).to(device)
            fake_images = generator(noise)

            critic_real = critic(real_images).reshape(-1)
        
            critic_fake = critic(fake_images.detach()).reshape(-1)
        
            loss_C = -(torch.mean(critic_real) - torch.mean(critic_fake))

            optimizer_C.zero_grad()
            loss_C.backward()
            optimizer_C.step()

            for p in critic.parameters():
                p.data.clamp_(-C, C)

        #Training Discriminator

        critic.eval()
        
        gen_fake = critic(fake_images).reshape(-1)

        loss_G = -torch.mean(gen_fake)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()


        if i== 0:
            print(
                f"Epoch [{epoch}/{EPOCHS}] Batch {i}/{len(loader)} \
                      Loss C: {loss_C:.4f}, loss G: {loss_G:.4f}"
            )

            with torch.no_grad():
                fake_images = generator(fixed_noise)
                data = real_images
                img_grid_fake = torchvision.utils.make_grid(fake_images[:32], normalize=True)
                img_grid_real = torchvision.utils.make_grid(data[:32], normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

        

Epoch [0/15] Batch 0/469                       Loss C: -0.0364, loss G: 0.0055
Epoch [1/15] Batch 0/469                       Loss C: -1.5407, loss G: 0.7414
Epoch [2/15] Batch 0/469                       Loss C: -1.4443, loss G: 0.6983
Epoch [3/15] Batch 0/469                       Loss C: -1.2712, loss G: 0.6767
Epoch [4/15] Batch 0/469                       Loss C: -1.1026, loss G: 0.6636
Epoch [5/15] Batch 0/469                       Loss C: -0.9863, loss G: 0.6210
Epoch [6/15] Batch 0/469                       Loss C: -1.0458, loss G: 0.6030
Epoch [7/15] Batch 0/469                       Loss C: -1.0433, loss G: 0.4326
Epoch [8/15] Batch 0/469                       Loss C: -0.9137, loss G: 0.5986
Epoch [9/15] Batch 0/469                       Loss C: -0.8556, loss G: 0.2700
Epoch [10/15] Batch 0/469                       Loss C: -0.9545, loss G: 0.4237
Epoch [11/15] Batch 0/469                       Loss C: -0.9479, loss G: 0.5766
Epoch [12/15] Batch 0/469                       Lo