In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
!jupyter nbextension enable --py widgetsnbextension


  warn(


Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [39]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [45]:
class _Discriminator(nn.Module):
    def __init__(self, in_features):
        super(_Discriminator,self).__init__()
        self.l1=nn.Linear(in_features, 128)
        self.b=nn.BatchNorm1d(128)
        self.leak=nn.LeakyReLU(0.01)
        self.l2=nn.Linear(128, 1)
        self.sig=nn.Sigmoid()
    def forward(self, x):
        out=self.l1(x)
        # print(out.shape) # torch.Size([4, 128])
        out=self.b(out)
        # print(out.shape) # torch.Size([4, 128])
        out=self.leak(out)
        # print(out.shape) # torch.Size([4, 128])
        out=self.l2(out)
        # print(out.shape) # torch.Size([4, 1])
        out=self.sig(out)
        # print(out.shape) # torch.Size([4, 1])
        return out

class _Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super(_Generator,self).__init__()
        self.l1=nn.Linear(z_dim, 256)
        self.ba=nn.BatchNorm1d(256)
        self.leak=nn.LeakyReLU(0.01)
        self.l2=nn.Linear(256, img_dim)
        self.tan=nn.Tanh() # normalize inputs to [-1, 1] so make outputs [-1, 1]

    def forward(self, x):
        # print(x.shape) # torch.Size([4, 100])
        out=self.l1(x)
        # print(out.shape) # torch.Size([4, 256])
        out=self.ba(out)
        # print(out.shape) # torch.Size([4, 256])
        out=self.leak(out)
        # print(out.shape) # torch.Size([4, 256])
        out=self.l2(out)
        # print(out.shape) # torch.Size([4, 784])
        out=self.tan(out)
        # print(out.shape) # torch.Size([2, 784])
        return out

In [54]:
b,z=4,100
image_dim = 28 * 28 * 1  # 784
disc=_Discriminator(image_dim)
gen = _Generator(z, image_dim)
real=torch.randn(b,784) 
# print(real.shape)# torch.Size([4, 784])
noise = torch.randn(b,z)
fake=gen(noise)
# print(fake.shape) # torch.Size([4, 784])
out= disc(fake)
# print(out.shape) # torch.Size([4, 1])
disc= disc(real)
# print(disc.shape) # torch.Size([4, 1])

torch.Size([4, 784])
torch.Size([4, 784])
torch.Size([4, 1])
torch.Size([4, 1])


In [31]:
# Hyperparameters etc.
device = "cuda:7" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 100

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

In [5]:
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

In [6]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
gen_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt_gen, factor=0.1, patience=10, verbose=True
)
disc_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt_disc, factor=0.1, patience=10, verbose=True
)
criterion = nn.BCELoss()
# writer_fake = SummaryWriter(f"logs_lr_sched/fake")
# writer_real = SummaryWriter(f"logs_lr_sched/real")
step = 0



In [8]:
real,y=next(iter(loader))
real.shape


torch.Size([32, 1, 28, 28])

In [16]:
real = real.view(-1, 784).to(device)
batch_size = real.shape[0]
noise = torch.randn(batch_size, z_dim).to(device)
print(noise.shape) # torch.Size([32, 64])
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
fake = gen(noise)
print(fake.shape) # torch.Size([32, 784])
disc_real = disc(real).view(-1)
print(disc_real.shape) # torch.Size([32])
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake)
print(disc_fake.shape) # torch.Size([32, 1])
disc_fake = disc(fake).view(-1)
print(disc_fake.shape) # torch.Size([32])
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
lossD = (lossD_real + lossD_fake) / 2
disc.zero_grad()
lossD.backward(retain_graph=True)
opt_disc.step()

torch.Size([32, 64])
torch.Size([32, 784])
torch.Size([32])
torch.Size([32, 1])
torch.Size([32])


In [15]:
torch.ones_like(disc_real)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       device='cuda:7')

In [None]:
# from tqdm.notebook import tqdm
# for epoch in tqdm(range(num_epochs),total=num_epochs):
#     for batch_idx, (real, _) in tqdm(enumerate(loader)):
#         real = real.view(-1, 784).to(device)
#         batch_size = real.shape[0]
#         noise = torch.randn(batch_size, z_dim).to(device)

#         ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
#         fake = gen(noise)
#         disc_real = disc(real).view(-1)
#         lossD_real = criterion(disc_real, torch.ones_like(disc_real))
#         disc_fake = disc(fake).view(-1)
#         lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
#         lossD = (lossD_real + lossD_fake) / 2
#         disc.zero_grad()
#         lossD.backward(retain_graph=True)
#         opt_disc.step()
#         ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
#         # where the second option of maximizing doesn't suffer from
#         # saturating gradients
#         output = disc(fake).view(-1)
#         lossG = criterion(output, torch.ones_like(output))
#         gen.zero_grad()
#         lossG.backward()
#         opt_gen.step()
#         if batch_idx == 0:
#             print(
#                 f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
#                       Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
#             )

#             with torch.no_grad():
#                 fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
#                 data = real.reshape(-1, 1, 28, 28)
#                 img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
#                 img_grid_real = torchvision.utils.make_grid(data, normalize=True)

#                 writer_fake.add_image(
#                     "Mnist Fake Images", img_grid_fake, global_step=step
#                 )
#                 writer_real.add_image(
#                     "Mnist Real Images", img_grid_real, global_step=step
#                 )
#                 step += 1

  0%|          | 0/100 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Epoch [0/100] Batch 0/1875                       Loss D: 0.7191, loss G: 0.9167


0it [00:00, ?it/s]

Epoch [1/100] Batch 0/1875                       Loss D: 0.2075, loss G: 1.9385


0it [00:00, ?it/s]

Epoch [2/100] Batch 0/1875                       Loss D: 0.2341, loss G: 1.8252


0it [00:00, ?it/s]

Epoch [3/100] Batch 0/1875                       Loss D: 0.3949, loss G: 1.6696


0it [00:00, ?it/s]

Epoch [4/100] Batch 0/1875                       Loss D: 0.4019, loss G: 1.4599


0it [00:00, ?it/s]

Epoch [5/100] Batch 0/1875                       Loss D: 0.2873, loss G: 1.6271


0it [00:00, ?it/s]

Epoch [6/100] Batch 0/1875                       Loss D: 0.4805, loss G: 1.5769


0it [00:00, ?it/s]

Epoch [7/100] Batch 0/1875                       Loss D: 0.3463, loss G: 1.7884


0it [00:00, ?it/s]

Epoch [8/100] Batch 0/1875                       Loss D: 0.3258, loss G: 1.5874


0it [00:00, ?it/s]

Epoch [9/100] Batch 0/1875                       Loss D: 0.5422, loss G: 1.3193


0it [00:00, ?it/s]

Epoch [10/100] Batch 0/1875                       Loss D: 0.4737, loss G: 1.6815


0it [00:00, ?it/s]

Epoch [11/100] Batch 0/1875                       Loss D: 0.3074, loss G: 1.4530


0it [00:00, ?it/s]

Epoch [12/100] Batch 0/1875                       Loss D: 0.4843, loss G: 1.5996


0it [00:00, ?it/s]

Epoch [13/100] Batch 0/1875                       Loss D: 0.5987, loss G: 1.1515


0it [00:00, ?it/s]

Epoch [14/100] Batch 0/1875                       Loss D: 0.4840, loss G: 1.3840


0it [00:00, ?it/s]

Epoch [15/100] Batch 0/1875                       Loss D: 0.6850, loss G: 1.5379


0it [00:00, ?it/s]

Epoch [16/100] Batch 0/1875                       Loss D: 0.4115, loss G: 1.2886


0it [00:00, ?it/s]

Epoch [17/100] Batch 0/1875                       Loss D: 0.2764, loss G: 1.5407


0it [00:00, ?it/s]

Epoch [18/100] Batch 0/1875                       Loss D: 0.5242, loss G: 1.1992


0it [00:00, ?it/s]

Epoch [19/100] Batch 0/1875                       Loss D: 0.4434, loss G: 1.4789


0it [00:00, ?it/s]

Epoch [20/100] Batch 0/1875                       Loss D: 0.5273, loss G: 1.2573


0it [00:00, ?it/s]

Epoch [21/100] Batch 0/1875                       Loss D: 0.5343, loss G: 1.4029


0it [00:00, ?it/s]

Epoch [22/100] Batch 0/1875                       Loss D: 0.4526, loss G: 1.4636


0it [00:00, ?it/s]

Epoch [23/100] Batch 0/1875                       Loss D: 0.4046, loss G: 1.5842


0it [00:00, ?it/s]

Epoch [24/100] Batch 0/1875                       Loss D: 0.6040, loss G: 1.2856


0it [00:00, ?it/s]

Epoch [25/100] Batch 0/1875                       Loss D: 0.3619, loss G: 1.6470


0it [00:00, ?it/s]

Epoch [26/100] Batch 0/1875                       Loss D: 0.4446, loss G: 1.4240


0it [00:00, ?it/s]

Epoch [27/100] Batch 0/1875                       Loss D: 0.5075, loss G: 1.1009


0it [00:00, ?it/s]

Epoch [28/100] Batch 0/1875                       Loss D: 0.5937, loss G: 1.4447


0it [00:00, ?it/s]

Epoch [29/100] Batch 0/1875                       Loss D: 0.4629, loss G: 1.8117


0it [00:00, ?it/s]

Epoch [30/100] Batch 0/1875                       Loss D: 0.3431, loss G: 1.3781


0it [00:00, ?it/s]

Epoch [31/100] Batch 0/1875                       Loss D: 0.5528, loss G: 1.3578


0it [00:00, ?it/s]

Epoch [32/100] Batch 0/1875                       Loss D: 0.3588, loss G: 1.8879


0it [00:00, ?it/s]

Epoch [33/100] Batch 0/1875                       Loss D: 0.5016, loss G: 1.3134


0it [00:00, ?it/s]

Epoch [34/100] Batch 0/1875                       Loss D: 0.4256, loss G: 1.8818


0it [00:00, ?it/s]

Epoch [35/100] Batch 0/1875                       Loss D: 0.3696, loss G: 1.5223


0it [00:00, ?it/s]

Epoch [36/100] Batch 0/1875                       Loss D: 0.3701, loss G: 1.6001


0it [00:00, ?it/s]

Epoch [37/100] Batch 0/1875                       Loss D: 0.3853, loss G: 1.3674


0it [00:00, ?it/s]

Epoch [38/100] Batch 0/1875                       Loss D: 0.3821, loss G: 1.5939


0it [00:00, ?it/s]

Epoch [39/100] Batch 0/1875                       Loss D: 0.5827, loss G: 1.7029


0it [00:00, ?it/s]

Epoch [40/100] Batch 0/1875                       Loss D: 0.4370, loss G: 1.7836


0it [00:00, ?it/s]

Epoch [41/100] Batch 0/1875                       Loss D: 0.3190, loss G: 2.0480


0it [00:00, ?it/s]

Epoch [42/100] Batch 0/1875                       Loss D: 0.3917, loss G: 1.4234


0it [00:00, ?it/s]

Epoch [43/100] Batch 0/1875                       Loss D: 0.3152, loss G: 1.8979


0it [00:00, ?it/s]

Epoch [44/100] Batch 0/1875                       Loss D: 0.3542, loss G: 1.3366


0it [00:00, ?it/s]

Epoch [45/100] Batch 0/1875                       Loss D: 0.5587, loss G: 1.0278


0it [00:00, ?it/s]

Epoch [46/100] Batch 0/1875                       Loss D: 0.2345, loss G: 1.8065


0it [00:00, ?it/s]

Epoch [47/100] Batch 0/1875                       Loss D: 0.5339, loss G: 1.2817


0it [00:00, ?it/s]

Epoch [48/100] Batch 0/1875                       Loss D: 0.3422, loss G: 1.8461


0it [00:00, ?it/s]

Epoch [49/100] Batch 0/1875                       Loss D: 0.3803, loss G: 2.0306


0it [00:00, ?it/s]

Epoch [50/100] Batch 0/1875                       Loss D: 0.2276, loss G: 2.2279


0it [00:00, ?it/s]

Epoch [51/100] Batch 0/1875                       Loss D: 0.2378, loss G: 1.9560


0it [00:00, ?it/s]

Epoch [52/100] Batch 0/1875                       Loss D: 0.1591, loss G: 2.1274


0it [00:00, ?it/s]

Epoch [53/100] Batch 0/1875                       Loss D: 0.1622, loss G: 1.8927


0it [00:00, ?it/s]

Epoch [54/100] Batch 0/1875                       Loss D: 0.1944, loss G: 2.8167


0it [00:00, ?it/s]

Epoch [55/100] Batch 0/1875                       Loss D: 0.5778, loss G: 1.2881


0it [00:00, ?it/s]

Epoch [56/100] Batch 0/1875                       Loss D: 0.6354, loss G: 1.3076


0it [00:00, ?it/s]

Epoch [57/100] Batch 0/1875                       Loss D: 0.5713, loss G: 1.4558


0it [00:00, ?it/s]

Epoch [58/100] Batch 0/1875                       Loss D: 0.6410, loss G: 2.1039


0it [00:00, ?it/s]

Epoch [59/100] Batch 0/1875                       Loss D: 0.4574, loss G: 1.5008


0it [00:00, ?it/s]

Epoch [60/100] Batch 0/1875                       Loss D: 0.6229, loss G: 1.5969


0it [00:00, ?it/s]

Epoch [61/100] Batch 0/1875                       Loss D: 0.4500, loss G: 1.5875


0it [00:00, ?it/s]

Epoch [62/100] Batch 0/1875                       Loss D: 0.3453, loss G: 1.6530


0it [00:00, ?it/s]

Epoch [63/100] Batch 0/1875                       Loss D: 0.2601, loss G: 1.6996


0it [00:00, ?it/s]

Epoch [64/100] Batch 0/1875                       Loss D: 0.5304, loss G: 1.5986


0it [00:00, ?it/s]

Epoch [65/100] Batch 0/1875                       Loss D: 0.2269, loss G: 2.2512


0it [00:00, ?it/s]

Epoch [66/100] Batch 0/1875                       Loss D: 0.3470, loss G: 1.5785


0it [00:00, ?it/s]

Epoch [67/100] Batch 0/1875                       Loss D: 0.4946, loss G: 1.4096


0it [00:00, ?it/s]

Epoch [68/100] Batch 0/1875                       Loss D: 0.4044, loss G: 1.4506


0it [00:00, ?it/s]

Epoch [69/100] Batch 0/1875                       Loss D: 0.5076, loss G: 1.3433


0it [00:00, ?it/s]

Epoch [70/100] Batch 0/1875                       Loss D: 0.5309, loss G: 1.1659


0it [00:00, ?it/s]

Epoch [71/100] Batch 0/1875                       Loss D: 0.4832, loss G: 1.5747


0it [00:00, ?it/s]

Epoch [72/100] Batch 0/1875                       Loss D: 0.3538, loss G: 1.7768


0it [00:00, ?it/s]

Epoch [73/100] Batch 0/1875                       Loss D: 0.4093, loss G: 1.5079


0it [00:00, ?it/s]

Epoch [74/100] Batch 0/1875                       Loss D: 0.3816, loss G: 1.7921


0it [00:00, ?it/s]

Epoch [75/100] Batch 0/1875                       Loss D: 0.5036, loss G: 1.5874


0it [00:00, ?it/s]

Epoch [76/100] Batch 0/1875                       Loss D: 0.3278, loss G: 1.5544


0it [00:00, ?it/s]

Epoch [77/100] Batch 0/1875                       Loss D: 0.4747, loss G: 1.7890


0it [00:00, ?it/s]

Epoch [78/100] Batch 0/1875                       Loss D: 0.4831, loss G: 1.3618


0it [00:00, ?it/s]

Epoch [79/100] Batch 0/1875                       Loss D: 0.5703, loss G: 1.4556


0it [00:00, ?it/s]

Epoch [80/100] Batch 0/1875                       Loss D: 0.6888, loss G: 1.6737


0it [00:00, ?it/s]

Epoch [81/100] Batch 0/1875                       Loss D: 0.3568, loss G: 1.7616


0it [00:00, ?it/s]

Epoch [82/100] Batch 0/1875                       Loss D: 0.5999, loss G: 1.5358


0it [00:00, ?it/s]

Epoch [83/100] Batch 0/1875                       Loss D: 0.4825, loss G: 1.4432


0it [00:00, ?it/s]

Epoch [84/100] Batch 0/1875                       Loss D: 0.3970, loss G: 1.3514


0it [00:00, ?it/s]

Epoch [85/100] Batch 0/1875                       Loss D: 0.3940, loss G: 1.3291


0it [00:00, ?it/s]

Epoch [86/100] Batch 0/1875                       Loss D: 0.2692, loss G: 1.4983


0it [00:00, ?it/s]

Epoch [87/100] Batch 0/1875                       Loss D: 0.4245, loss G: 1.5838


0it [00:00, ?it/s]

Epoch [88/100] Batch 0/1875                       Loss D: 0.5854, loss G: 1.3880


0it [00:00, ?it/s]

Epoch [89/100] Batch 0/1875                       Loss D: 0.2741, loss G: 1.6420


0it [00:00, ?it/s]

Epoch [90/100] Batch 0/1875                       Loss D: 0.5417, loss G: 1.8384


0it [00:00, ?it/s]

Epoch [91/100] Batch 0/1875                       Loss D: 0.5131, loss G: 1.3013


0it [00:00, ?it/s]

Epoch [92/100] Batch 0/1875                       Loss D: 0.4365, loss G: 1.8215


0it [00:00, ?it/s]

Epoch [93/100] Batch 0/1875                       Loss D: 0.4468, loss G: 1.4770


0it [00:00, ?it/s]

Epoch [94/100] Batch 0/1875                       Loss D: 0.6031, loss G: 1.3852


0it [00:00, ?it/s]

Epoch [95/100] Batch 0/1875                       Loss D: 0.3660, loss G: 1.3650


0it [00:00, ?it/s]

Epoch [96/100] Batch 0/1875                       Loss D: 0.5188, loss G: 1.3163


0it [00:00, ?it/s]

Epoch [97/100] Batch 0/1875                       Loss D: 0.5538, loss G: 1.2274


0it [00:00, ?it/s]

Epoch [98/100] Batch 0/1875                       Loss D: 0.2596, loss G: 1.5137


0it [00:00, ?it/s]

Epoch [99/100] Batch 0/1875                       Loss D: 0.3965, loss G: 1.5539
