<a href="https://colab.research.google.com/github/BastionPinnacle/ML2021-2022/blob/main/Untitled5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
rm -rf ./logs/

%load_ext tensorboard
%tensorboard --logdir logs


In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [None]:
##

In [21]:
class Discriminator(nn.Module):
  def __init__(self,in_features):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(in_features,128),
        nn.LeakyReLU(0.1),
        nn.Linear(128,1),
        nn.Sigmoid(),
    )
  def forward(self,x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self,z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim,256),
        nn.LeakyReLU(0.1),
        nn.Linear(256,img_dim),
        nn.Tanh(),
    )

  def forward(self,x):
    return self.gen(x)


In [None]:
#Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr=3e-4
z_dim = 64 #128,256
image_dim  = 28*28*1
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim,image_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)
transf = transforms.Compose(
    [transforms.ToTensor(),transforms.Normalize((0.1307),(0.3081))]
)

dataset = datasets.FashionMNIST(root="/dataset",transform=transf,download = True)
loader = DataLoader(dataset,batch_size = batch_size, shuffle = True)
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr= lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_FASHIONMNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_FASHIONMNIST/real")
step = 0 

#epoch loop
for epoch in range(num_epochs):
  #batch loop, we dont use 
  for batch_idx, (real, _) in enumerate(loader):
    ##real  will be having batch_size x 784 shape
    real = real.view(-1,784).to(device)
    batch_size = real.shape[0]

    ### Train Discriminator maxlog(D(real)) + log(1-D(G(fake)))
    ## generating random noise tensor batchsize x z_dim
    noise = torch.randn((batch_size, z_dim)).to(device)
    ## fake image generation
    fake = gen(noise)
    ## new tensor batch_size x 1
    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real,torch.ones_like(disc_real))
    
    disc_fake = disc(fake.detach()).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    
    lossD =(lossD_real+lossD_fake)/2

    disc.zero_grad()  
    lossD.backward()
    opt_disc.step()

    ### Training the generator min log(1-D(G(z)))

    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()


    if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1




Epoch [0/50] Batch 0/1875                       Loss D: 0.6635, loss G: 0.7238
Epoch [1/50] Batch 0/1875                       Loss D: 0.2156, loss G: 2.2438
Epoch [2/50] Batch 0/1875                       Loss D: 0.1515, loss G: 2.6734
Epoch [3/50] Batch 0/1875                       Loss D: 0.2092, loss G: 2.8646
Epoch [4/50] Batch 0/1875                       Loss D: 0.1872, loss G: 3.2189
Epoch [5/50] Batch 0/1875                       Loss D: 0.2222, loss G: 2.9852
Epoch [6/50] Batch 0/1875                       Loss D: 0.1313, loss G: 2.7722
Epoch [7/50] Batch 0/1875                       Loss D: 0.1615, loss G: 2.9400
Epoch [8/50] Batch 0/1875                       Loss D: 0.1613, loss G: 3.6975
Epoch [9/50] Batch 0/1875                       Loss D: 0.0737, loss G: 3.5662
Epoch [10/50] Batch 0/1875                       Loss D: 0.1015, loss G: 2.8649
Epoch [11/50] Batch 0/1875                       Loss D: 0.1004, loss G: 3.1372
Epoch [12/50] Batch 0/1875                       L