<a href="https://colab.research.google.com/github/Sakin101/GANS/blob/main/SIMPLE_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch as th
import numpy as np
import pandas as pd
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
from torch.utils.tensorboard import SummaryWriter
device='cuda' if th.cuda.is_available() else 'cpu'
print(device)

In [None]:
class Discriminator(nn.Module):
  def __init__(self,in_features):
    super().__init__()
    self.layer=nn.Sequential(
        nn.Linear(in_features,128),
        nn.ReLU(),
        nn.Linear(128,1),
        nn.Sigmoid()
    )
  def forward(self,x):
    x=self.layer(x)
    return x

class Generator(nn.Module):
  def __init__(self,z_features,img_dimension):
    super().__init__()
    self.layer=nn.Sequential(
        nn.Linear(z_features,256),
        nn.LeakyReLU(0.1),
        nn.Linear(256,img_dimension),
        nn.Tanh()
    )
  def forward(self,x):
    return self.layer(x)


In [None]:
from IPython.core.interactiveshell import dis
lr=3e-4
z_dim=32
img_dim=28*28
discriminator_loss=nn.BCELoss()
generator_loss=nn.L1Loss()
batch_size=64
epoch=60

disc=Discriminator(img_dim).to(device)
gen=Generator(z_dim,img_dim).to(device)

transforms=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
,torchvision.transforms.Normalize((0.5,),(0.5,))])

noise=th.randn(batch_size,z_dim).to(device)
dis_optim=th.optim.Adam(disc.parameters(),lr)
gen_optim=th.optim.Adam(gen.parameters(),lr)
datasets=torchvision.datasets.MNIST(root='datasets/',transform=transforms,download=True)
loader=DataLoader(datasets,batch_size=batch_size)
epoch=60
step=0
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/fashion_mnist_experiment_1')
for num_epochs in range(epoch):
  for ix,(real,label) in enumerate(loader):
    real=real.view(-1,784).to(device)
    batch_size=real.shape[0]
    noise=th.randn(batch_size,z_dim).to(device)
    fake_img=gen(noise)
    disc_fake=disc(fake_img)
    gen_loss=discriminator_loss(disc_fake,th.zeros_like(disc_fake))
    real_disc=disc(real)
    real_loss=discriminator_loss(real_disc,th.ones_like(real_disc))

    total_loss=(real_loss+gen_loss)/2
    disc.zero_grad()
    total_loss.backward(retain_graph=True)
    dis_optim.step()

    output=disc(fake_img)
    gen_loss=discriminator_loss(output,th.ones_like(output))
    gen.zero_grad()
    gen_loss.backward()
    gen_optim.step()

    if ix == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {id}/{len(loader)} \
                      Loss D: {total_loss:.4f}, loss G: {gen_loss:.4f}"
            )

            with th.no_grad():
                fake = gen(noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1



In [None]:
%load_ext tensorboard

%tensorboard --logdir /content/logs/fake
