In [None]:
import numpy as np
import math
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
import os

def weights_init_normal(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    torch.nn.init.normal_(m.weight.data,0.0,0.02)
  elif classname.find('BatchNorm2d') != -1:
    torch.nn.init.normal_(m.weight.data,1.0,0.02)
    torch.nn.init.constant_(m.bias.data,0.0)

class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.init_size = 8
    self.l = nn.Sequential(nn.Linear(100,128*(self.init_size**2)))
    self.conv = nn.Sequential(
        nn.BatchNorm2d(128),
        nn.Upsample(scale_factor = 2),
        nn.Conv2d(128,128,3,stride = 1, padding = 1),
        nn.BatchNorm2d(128,0.8),
        nn.LeakyReLU(0.2 , inplace = True),
        nn.Upsample(scale_factor = 2),
        nn.Conv2d(128,64,3,stride = 1, padding = 1),
        nn.BatchNorm2d(64,0.8),
        nn.LeakyReLU(0.2,inplace = True),
        nn.Conv2d(64,3,3,stride = 1, padding = 1),
        nn.Tanh())

       
  def forward(self,r):
    r2 = self.l(r)
    r3 = r2.view(r2.shape[0],128,self.init_size,self.init_size)
    output = self.conv(r3)
    return output

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()

    def discriminator_conv(in_feature_num, out_feature_num,bn = True):
      conv = [nn.Conv2d(in_feature_num, out_feature_num,3,2,1), nn.LeakyReLU(0.2, inplace = True), nn.Dropout2d(0.25)]
      if bn:
        conv.append(nn.BatchNorm2d(out_feature_num,0.8))
      return conv
    self.network = nn.Sequential(
        *discriminator_conv(3,16,bn = False),
        *discriminator_conv(16,32),
        *discriminator_conv(32,64),
        *discriminator_conv(64,128)
    )
    self.validation_layer = nn.Sequential(nn.Linear(128*(32**2),1), nn.Sigmoid())

    def forward(self,img):
      img1 = self.network(img)
      img2 = img1.view(img1.shape[0],-1)
      validity = self.validation_layer(img2)

      return validity
    

adversarial_loss = torch.nn.BCELoss()

generator = Generator()
discriminator = Discriminator()

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


data = torch.utils.data.DataLoader(
    datasets.MNIST(
        
        train = True,
       transform = transforms.Compose(
           [transforms.Resize(32), transforms.ToTensor(), transforms.Normalize([0.5],[0.5])]
       )),
        batch_size = 64,
        shuffle = True
    )

g_optimizer = torch.optim.Adam(generator.parameters(),lr = 0.0002, betas = (0.5,0.999))
d_optimizer = torch.optim(discriminator.parameters(), lr = 0.0002, betas = (0.5,0.999))
Tensor = torch.FloatTensor
epochs = 3

for epoch in range(epochs):
  for i , (imgs,_) in enumerate(data):
    valid = Variable(Tensor(imgs.shape[0],1).fill_(1.0), requires_grad = False)
    fake = Variable(Tensor(imgs.shape[0],1).fill_(0.0), requires_grad = False)
    original_imgs = Variable(imgs.type(Tensor))
g_optimizer.zero_grad()

noise = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],100))))

gen_imgs = generator(noise)
g_loss = adversarial_loss(discriminator(gen_imgs),valid)

g_loss.backward()
g_optimizer.step()

d_optimizer.zero_grad()

real_loss = adversarial_loss(discriminator(original_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss)/2

d_loss.backward()
d_optimizer.step()



    

    
