In [3]:
!pip install torch
!pip install torchvision

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m53.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m59.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m70.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [9]:
import  torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter


In [18]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc=nn.Sequential(
        nn.Linear(img_dim,128),
        nn.LeakyReLU(0.1),
        nn.Linear(128,1),
        nn.Sigmoid(),
    )
  def forward (self,x):
       return self.disc(x)

In [21]:
class Generator(nn.Module):
   def __init__(self,z_dim,img_dim):
      super().__init__()
      self.gen=nn.Sequential(
          nn.Linear(z_dim,256),
          nn.LeakyReLU(0.1),
          nn.Linear(256,img_dim),
          nn.Tanh(),
      )
   def forward(self,x):
      return self.gen(x)

In [16]:
device= 'cuda' if torch.cuda.is_available()  else 'cpu'
learning_rate=3e-4
z_dim=64
image_dim =28*28*1
batch_size=64
num_epochs=100

In [23]:
disc=Discriminator(image_dim).to(device)
gen=Generator(z_dim,image_dim).to(device)
fixed_noise=torch.randn(batch_size,z_dim).to(device)
transforms=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5, ))])

In [26]:
dataset=datasets.MNIST(root="dataset/",transform=transforms,download=True)
loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
opt_disc=optim.Adam(disc.parameters(),lr=learning_rate )
opt_gen=optim.Adam(gen.parameters(),lr=learning_rate)
criterion=nn.BCELoss()
writer_fake=SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real=SummaryWriter(f"runs/GAN_MNIST/real")
step=0

In [None]:
for epoch in range(num_epochs):
   for batch_idx,(real,_) in enumerate(loader):
     real=real.view(-1,784).to(device)
     batch_size=real.shape[0]

     #train_discriminator
     noise=torch.randn(batch_size,z_dim).to(device)
     fake=gen(noise)
     disc_real=disc(real).view(-1)
     lossD_real=criterion(disc_real,torch.ones_like(disc_real))
     disc_fake=disc(fake).view(-1)
     lossD_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
     lossD=(lossD_fake+lossD_real)/2
     disc.zero_grad()
     lossD.backward(retain_graph=True)
     opt_disc.step()


     #train_generator
     output=disc(fake).view(-1)
     lossG=criterion(output,torch.ones_like(output))
     gen.zero_grad()
     lossG.backward()
     opt_gen.step()


     if batch_idx==0:
       print(f"Epochs[{epoch}\ {num_epochs}]/"
             f"Loss D:{lossD:.4f},Loss G:{lossG:.4f}")
       with torch.no_grad():
        fake=gen(fixed_noise).reshape(-1,1,28,28)
        data=real.reshape(-1,1,28,28)
        img_grid_fake=torchvision.utils.make_grid(fake,normalize=True)
        img_grid_real=torchvision.utils.make_grid(real,normalize=True)


        writer_fake.add_image(
            "MNIST fake images",img_grid_fake,global_step=step
        )
        writer_real.add_image(
            "MNIST real images",img_grid_real,global_step=step
        )
        step+=1



Epochs[0\ 100]/Loss D:0.4537,Loss G:0.7256
Epochs[1\ 100]/Loss D:0.4372,Loss G:0.9986
Epochs[2\ 100]/Loss D:0.4139,Loss G:1.4473
Epochs[3\ 100]/Loss D:0.8527,Loss G:0.6210
Epochs[4\ 100]/Loss D:0.4868,Loss G:1.2043
Epochs[5\ 100]/Loss D:0.6813,Loss G:0.8012
Epochs[6\ 100]/Loss D:0.5111,Loss G:0.9842
Epochs[7\ 100]/Loss D:0.7851,Loss G:0.8952
Epochs[8\ 100]/Loss D:0.4573,Loss G:1.2860
Epochs[9\ 100]/Loss D:0.4980,Loss G:1.1946
Epochs[10\ 100]/Loss D:0.6303,Loss G:1.1188
Epochs[11\ 100]/Loss D:0.4758,Loss G:1.1446
Epochs[12\ 100]/Loss D:0.6851,Loss G:0.7497
Epochs[13\ 100]/Loss D:0.4660,Loss G:1.2919
Epochs[14\ 100]/Loss D:0.4302,Loss G:1.1383
Epochs[15\ 100]/Loss D:0.8037,Loss G:0.9514
Epochs[16\ 100]/Loss D:0.7696,Loss G:0.7004
Epochs[17\ 100]/Loss D:0.6603,Loss G:0.9915
Epochs[18\ 100]/Loss D:0.7721,Loss G:1.1201
Epochs[19\ 100]/Loss D:0.6476,Loss G:0.9947
Epochs[20\ 100]/Loss D:0.7904,Loss G:0.9689
Epochs[21\ 100]/Loss D:0.6429,Loss G:0.9827
Epochs[22\ 100]/Loss D:0.4833,Loss G:1.157