<a href="https://colab.research.google.com/github/AzeemWaqarRao/Pytorch_Implementations/blob/main/GANs_DenseNN_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import os
import shutil

In [None]:
noise_size = 64
hidden_size = 256
image_size = 784
epochs = 50
batch_size = 100
step_size = 200

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [None]:
mnist = MNIST(root='./data', train=True, transform=transform, download=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 113081646.53it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 51922714.88it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 28015666.48it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4555363.17it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
data = DataLoader(mnist, batch_size=batch_size, shuffle=True)

In [None]:
for i in data:
  batch_info = (i[0].shape)
  print(i[0].min())
  print(i[0].max())
  break
iterations = len(data)
batch_info

tensor(-1.)
tensor(1.)


torch.Size([100, 1, 28, 28])

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.sequential = nn.Sequential(
        nn.Linear(noise_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, image_size),
        nn.Tanh()
    )

  def forward(self, x):
    x = self.sequential(x)
    return x

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.sequential = nn.Sequential(
        nn.Linear(image_size, hidden_size),
        nn.LeakyReLU(0.2),
        nn.Linear(hidden_size, hidden_size),
        nn.LeakyReLU(0.2),
        nn.Linear(hidden_size, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    x = self.sequential(x)
    return x

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
g_optim = torch.optim.Adam(generator.parameters(), lr=0.0002)
d_optim = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

In [None]:
criterion = nn.BCELoss()

In [None]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [None]:
real_labels = torch.ones([batch_size,1]).to(device)
fake_labels = torch.zeros([batch_size,1]).to(device)

In [None]:
os.makedirs('samples', exist_ok=True)

for i in range(epochs):
    for k, (batch,_) in enumerate(data):
        input = batch.reshape(batch_size,-1).to(device)

        #========= Discriminator Training =============
        ## feeding real images
        d_out = discriminator(input)
        real_loss = criterion(d_out, real_labels)

        ## feeding fake images
        rand_noise = torch.randn([batch_size, noise_size]).to(device)
        g_out = generator(rand_noise)
        d_out = discriminator(g_out)
        fake_loss = criterion(d_out, fake_labels)

        d_loss = fake_loss + real_loss

        # Zero your gradients for every batch!
        d_optim.zero_grad()
        d_loss.backward()

        # Adjust learning weights
        d_optim.step()

      #========= Generator Training =============
        rand_noise = torch.randn([batch_size, noise_size]).to(device)
        g_out = generator(rand_noise)
        d_out = discriminator(g_out)
        g_loss = criterion(d_out,real_labels)

        # train generator
        g_optim.zero_grad()
        g_loss.backward()
        g_optim.step()

        if (k+1) % step_size == 0:
            print(f"Epoch[{i}/{epochs}]: Step[{k+1}/{iterations}] -- Discriminator Loss: {d_loss}\t -- Generator Loss: {g_loss}")

    save_image(denorm(g_out.reshape(100,1,28,28)), f'samples/image{i}.png')



Epoch[0/200]: Step[200/600] -- Discriminator Loss: 0.20676669478416443	 -- Generator Loss: 3.2886695861816406
Epoch[0/200]: Step[400/600] -- Discriminator Loss: 0.005862819962203503	 -- Generator Loss: 6.115520477294922
Epoch[0/200]: Step[600/600] -- Discriminator Loss: 0.05896639823913574	 -- Generator Loss: 4.980039119720459
Epoch[1/200]: Step[200/600] -- Discriminator Loss: 0.09939474612474442	 -- Generator Loss: 3.4271953105926514
Epoch[1/200]: Step[400/600] -- Discriminator Loss: 0.16903650760650635	 -- Generator Loss: 5.085259437561035
Epoch[1/200]: Step[600/600] -- Discriminator Loss: 0.10278700292110443	 -- Generator Loss: 4.32558012008667
Epoch[2/200]: Step[200/600] -- Discriminator Loss: 0.1399911344051361	 -- Generator Loss: 4.0613884925842285
Epoch[2/200]: Step[400/600] -- Discriminator Loss: 0.5461602210998535	 -- Generator Loss: 3.0934550762176514
Epoch[2/200]: Step[600/600] -- Discriminator Loss: 0.13857394456863403	 -- Generator Loss: 3.939342498779297
Epoch[3/200]: Ste

In [None]:
shutil.rmtree('./samples')
