In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator,Generator,initialize_weights

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
lr= 2e-4
batch_size=128
img_size=64
channels_img= 1
z_dim = 100
num_epochs = 5
features_disc = 64
features_gen= 64

In [4]:
transforms = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(channels_img)],[0.5 for _ in range(channels_img)],) # For the number of channels

    ]
)

In [5]:
dataset  = datasets.MNIST(root="dataset/",train=True, transform=transforms,download=True)

In [7]:
loader= DataLoader(dataset,batch_size=batch_size,shuffle= True)
gen = Generator(z_dim,channels_img,features_gen).to(device)
disc= Discriminator(channels_img,features_disc).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5,0.999))
opt_disc= optim.Adam(disc.parameters(),lr=lr, betas=(0.5,0.999))

criterion = nn.BCELoss()

In [8]:
fixed_noise = torch.randn(32,z_dim,1,1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step=0

gen.train()
disc.train()


Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [12]:
for epoch in range(num_epochs):
    for batch_idx, (real,_) in enumerate(loader):
        real= real.to(device)
        noise = torch.randn((batch_size,z_dim,1,1)).to(device)
        fake= gen(noise)
        ### Train Discriminator max log(D(x)) + log(1-D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real,torch.ones_like(disc_real))

        disc_fake= disc(fake).reshape(-1)
        loss_disc_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake)/2
        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator min log(1-D(G(z))) <-> max log(D(G(z)))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output,torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx%100 == 0: 
            print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \ Loss D:{loss_disc:.4f}, Loss G: {loss_gen: .4f}")

            with torch.no_grad():

                fake = gen(fixed_noise)

                img_grid_real = torchvision.utils.make_grid(real[:32],normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32],normalize=True)

                writer_real.add_image("Real",img_grid_real,global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            
            step+=1

Epoch [0/5] Batch 0/469 \ Loss D:0.6960, Loss G:  0.7756
Epoch [0/5] Batch 100/469 \ Loss D:0.0142, Loss G:  4.1699
Epoch [0/5] Batch 200/469 \ Loss D:0.7127, Loss G:  0.8759
Epoch [0/5] Batch 300/469 \ Loss D:0.5763, Loss G:  1.7172
Epoch [0/5] Batch 400/469 \ Loss D:0.9892, Loss G:  1.9411
Epoch [1/5] Batch 0/469 \ Loss D:0.6179, Loss G:  1.6491
Epoch [1/5] Batch 100/469 \ Loss D:0.6273, Loss G:  0.7653
Epoch [1/5] Batch 200/469 \ Loss D:0.6518, Loss G:  1.2517
Epoch [1/5] Batch 300/469 \ Loss D:0.6316, Loss G:  0.5411
Epoch [1/5] Batch 400/469 \ Loss D:0.6523, Loss G:  0.6210
Epoch [2/5] Batch 0/469 \ Loss D:0.6374, Loss G:  1.2007
Epoch [2/5] Batch 100/469 \ Loss D:0.5933, Loss G:  0.5722
Epoch [2/5] Batch 200/469 \ Loss D:0.6418, Loss G:  1.2464
Epoch [2/5] Batch 300/469 \ Loss D:0.6365, Loss G:  0.7885
Epoch [2/5] Batch 400/469 \ Loss D:0.5585, Loss G:  1.1718
Epoch [3/5] Batch 0/469 \ Loss D:0.5417, Loss G:  1.3982
Epoch [3/5] Batch 100/469 \ Loss D:0.4983, Loss G:  2.2339
Epoch

In [13]:
writer_fake.close()

In [14]:
writer_real.close()