In [None]:
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torch import optim
import numpy as np
import seaborn as sns
import pandas as pd
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
class Discriminator(nn.Module):
  def __init__(self,channel_img,feature_d):
    super(Discriminator,self).__init__()
    self.disc = nn.Sequential(
        nn.Conv2d(channel_img,feature_d,4,2,1),
        self._block(feature_d,feature_d*2,4,2,1),
        self._block(feature_d*2,feature_d*4,4,2,1),
        self._block(feature_d*4,feature_d*8,4,2,1),
        nn.Conv2d(feature_d*8,1,4,2,0)
    )
  
  def _block(self,in_channel,out_channel,kernel_size,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding),
        nn.InstanceNorm2d(out_channel,affine=True),
        nn.LeakyReLU(0.2)
    )
  
  def forward(self,x):
    return self.disc(x)
  



class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    z = torch.randn((N, noise_dim, 1, 1))

    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"

test()

In [None]:
LEARNING_RATE = 5e-5
BATCH_SIZE= 64
NUM_EPOCHS = 3
Z_DIM = 100
IMAGE_SIZE = 64
FEATURES_CRITIC = 64
FEATURES_GEN =64
CRITIC_ITERATION = 5
WEIGHT_CLIP = 0.01
CHANNELS_IMG = 1

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5 for _ in range(CHANNEL_IMG)],[0.5 for _ in range(CHANNEL_IMG)])])

dataset = datasets.MNIST(root='data/',download=True,transform = transform)
loader = DataLoader(dataset,batch_size =BATCH_SIZE,shuffle=True)

gen = Generator(Z_DIM,CHANNELS_IMG,FEATURES_GEN).to(device)
critic = Discriminator(CHANNEL_IMG, FEATURE_CRITIC).to(device)

initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.RMSprop(gen.parameters(),lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(),lr=LEARNING_RATE)

fixed_noise = torch.randn(32,Z_DIM,1,1).to(device)
writer_real = SummaryWriter('logs/real')
writer_fake = SummaryWriter('logs/fake')
step = 0

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
  for batch_idx,(data,_) in enumerate(loader):
    data = data.to(device)
    current_batch = data.shape[0]
    for i in range(CRITIC_ITERATION):
      noise = torch.randn(current_batch,Z_DIM,1,1).to(device)
      fake = gen(noise)
      critic_real = critic(data).reshape(-1)
      critic_fake = critic(fake).reshape(-1)
      loss_critic = -(torch.mean(critic_real)) - (torch.mean(critic_fake))
      criterion = zero_grad()
      loss_critic.backward(retrain_gaph=True)
      opt_critic.step()

      for p in critic.paramters():
        p.data.clamp_(-WEIGHT_CLIP,WEIGHT_CLIP)

    gen_fake = critic(fake).reshape(-1)
    loss_gen = -torch.mean(gen_fake)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()


    # Print losses occasionally and print to tensorboard
    if batch_idx % 100 == 0 and batch_idx > 0:
        gen.eval()
        critic.eval()
        print(
            f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
              Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
        )

        with torch.no_grad():
            fake = gen(noise)
            # take out (up to) 32 examples
            img_grid_real = torchvision.utils.make_grid(
                data[:32], normalize=True
            )
            img_grid_fake = torchvision.utils.make_grid(
                fake[:32], normalize=True
            )

            writer_real.add_image("Real", img_grid_real, global_step=step)
            writer_fake.add_image("Fake", img_grid_fake, global_step=step)

        step += 1
        gen.train()
        critic.train()

ValueError: ignored