<a href="https://colab.research.google.com/github/151ali/lr-pytorch/blob/main/10_DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Define **Discriminator** and **generator**

In [18]:
import torch
import torch.nn as nn

In [19]:
from torch.nn.modules.batchnorm import BatchNorm2d
class Discriminator(nn.Module):
  def __init__(self,img_channels, features_d):
    super().__init__()

    self.discriminator = nn.Sequential(
        # input => N * img_channels * 64 * 64
        nn.Conv2d(img_channels, features_d, kernel_size=4, stride=2, padding=1),    # 32*32
        nn.LeakyReLU(0.2),                                                          # ..*..
        self._block(features_d, features_d*2, kernel_size=4, stride=2, padding=1),  # 16*16
        self._block(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1),# 08*08
        self._block(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1),# 04*04
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),             # 01*01
        nn.LeakyReLU(0.2),
        nn.Sigmoid(),
    )

    self.initialize_weights()


  def _block(self, in_channels,out_channels, **kwargs):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, bias=False, **kwargs),
        # nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )
    
  def initialize_weights(self):
    for m in self.modules():
      if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

  def forward(self, x):
    return self.discriminator(x)

In [20]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_channels, features_g):
    super().__init__()

    self.generator = nn.Sequential(
        # input => N * z_dim * 01 * 01
        self._block(z_dim, features_g*16, kernel_size=4, stride=1, padding=0),
        self._block(features_g*16, features_g*8, kernel_size=4, stride=2, padding=1),
        self._block(features_g*8, features_g*4, kernel_size=4, stride=2, padding=1),
        self._block(features_g*4, features_g*2, kernel_size=4, stride=2, padding=1),
        nn.ConvTranspose2d(features_g*2, img_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh() # [-1, 1]
    )
    self.initialize_weights()

  def _block(self, in_channels,out_channels, **kwargs):
    return nn.Sequential(
      nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, bias=False, **kwargs),
      # nn.BatchNorm2d(out_channels),
      nn.ReLU(), # From DCGAN paper :{
    )

  def initialize_weights(self):
    for m in self.modules():
      if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

  def forward(self, x):
    return self.generator(x)


In [21]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100

  x= torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, 8)
  assert disc(x).shape == (N, 1, 1, 1)

  gen = Generator(z_dim, in_channels, 8)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W)

  print("Success !")


In [22]:
test()

Success !


# Trainnig

In [23]:
!pip install wandb -qq

In [24]:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from   torch.utils.data import DataLoader

In [25]:
import wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33m151ali[0m (use `wandb login --relogin` to force relogin)


In [26]:
# Initialize a new run
wandb.init(project="DCGAN_impl");

VBox(children=(Label(value=' 2.75MB of 2.75MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Discriminator loss,0.00082
_runtime,352.0
_timestamp,1624573023.0
_step,35.0
Generator loss,0.00129


0,1
Discriminator loss,█▂▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇████
_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
Generator loss,█▂▁▁▁▁▁▁▁


In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 2e-4
batch_size    = 128
image_size    = 64
img_channels  = 1
z_dim = 100
num_epochs = 50   # try to increase it
features_disc = 64
features_gen = 64

In [28]:
print(device)

cuda


In [29]:
transforms = transforms.Compose([
  transforms.Resize(image_size),
  transforms.ToTensor(),
  transforms.Normalize(
      [0.5 for _ in range(img_channels)],[0.5 for _ in range(img_channels)]),          
])

In [30]:
dataset = datasets.MNIST(root="datasets/",train=True,
                          transform=transforms,
                          download=True
                          );

In [31]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [32]:
gen  = Generator(z_dim,img_channels,features_gen).to(device)
disc = Discriminator(img_channels,features_gen).to(device)

opt_gen  = optim.SGD(gen.parameters(),  lr=learning_rate, momentum=0.9) # try to use Adam
opt_disc = optim.SGD(disc.parameters(), lr=learning_rate, momentum=0.9)

In [33]:
criterion = nn.BCELoss()
fixed_noise = torch.randn(32, z_dim, 1, 1).to(device)

In [None]:
step = 0
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real  = real.to(device)
    noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
    fake = gen(noise)


    # train Discriminator
    disc_real = disc(real).reshape(-1)
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).reshape(-1)
    loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    loss_disc = (loss_disc_real + loss_disc_fake) /2
    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_disc.step()

    # train Generator
    output = disc(fake).reshape(-1)
    loss_gen = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()



    # Print losses occasionally and print to tensorboard
    if batch_idx % 100 == 0:
      wandb.log({"Discriminator loss": loss_disc})
      wandb.log({"Generator loss": loss_gen})

      with torch.no_grad():
        fake = gen(fixed_noise)
        # take out (up to) 32 examples
        img_grid_real = torchvision.utils.make_grid(
            real[:32], normalize=True
        )
        img_grid_fake = torchvision.utils.make_grid(
            fake[:32], normalize=True
        )

        wandb.log({"real": [wandb.Image(img_grid_real, caption=f"real-{step}")]})
        wandb.log({"fake": [wandb.Image(img_grid_fake, caption=f"fake-{step}")]})
      print(step)
      step += 1

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
