In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
cd "./drive/My Drive/data"

In [0]:
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

In [0]:
import glob
import numpy as np

In [0]:
torch.set_default_tensor_type('torch.DoubleTensor')

In [0]:
def sample_from_dataset(batch_size,img_shape,data_dir=None):
  sample=[]
  global_data_dir=list(glob.glob(data_dir))
  sample_img_path=np.random.choice(global_data_dir,batch_size)
  for index,img_filename in enumerate(sample_img_path):
    img=Image.open(img_filename)
    img=img.resize(img_shape[:-1])
    #img=img.convert('RGB')
    img=np.asarray(img)
    img=(img/127.5)-1
    sample.append(img)
  return torch.from_numpy(np.asarray(sample))

In [0]:
def weights_init(m):
  classname=m.__class__.__name__
  if classname.find('Conv')!=-1:
    m.weight.data.normal_(0.0,0.02)
  elif classname.find('BatchNorm')!=-1:
    m.weight.data.normal_(1.0,0.02)
    m.bias.data.fill_(0)

In [0]:
class G(nn.Module):
  def __init__(self):
    super(G,self).__init__()
    self.main=nn.Sequential(
        nn.ConvTranspose2d(100,512,4,1,0,bias=False),
        nn.BatchNorm2d(512),
        nn.ReLU(True),
        nn.ConvTranspose2d(512,256,4,2,1,bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(True),
        nn.ConvTranspose2d(256,128,4,2,1,bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(True),
        nn.ConvTranspose2d(128,64,4,2,1,bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),
        nn.ConvTranspose2d(64,3,4,2,1,bias=False),
        nn.Tanh()
    )
  def forward(self,input):
    output=self.main(input)
    return output


In [0]:
netG=G()
netG.apply(weights_init)

In [0]:
class D(nn.Module):
  def __init__(self):
    super(D,self).__init__()
    self.main=nn.Sequential(
        nn.Conv2d(3,64,4,2,1,bias=False),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(64,128,4,2,1,bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(128,256,4,2,1,bias=False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(256,512,4,2,1,bias=False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Conv2d(512,1,4,1,0,bias=False),
        nn.Sigmoid()
    )

  def forward(self,input):
    output=self.main(input)
    return output.view(-1)

In [0]:
netD=D()  
netD.apply(weights_init)

In [0]:
criterion=nn.BCELoss()
optimizerD=optim.Adam(netD.parameters(),lr=0.00015,betas=(0.5,0.999))
optimizerG=optim.Adam(netG.parameters(),lr=0.00015,betas=(0.5,0.999))

In [0]:
for epoch in range(10000):
  #updating the weights of the neural network of discriminator
  netD.zero_grad()
  #training discriminator with real image dataset
  data=sample_from_dataset(64,(64,64,64),data_dir="content/drive/My Drive/data/*.png")
  data=data.permute(0,3,1,2)
  input=Variable(data)
  target=Variable(torch.ones(input.size()[0]))
  output=netD(input)
  errD_real=criterion(output,target)
  #training discriminator with fake image generated by generator
  noise=Variable(torch.randn(input.size()[0],100,1,1))
  fake=netG(noise)
  target=Variable(torch.zeros(input.size()[0]))
  output=netD(fake.detach())
  errD_fake=criterion(output,target)
  #Backpropagating the total error
  errD=errD_real + errD_fake
  errD.backward()
  optimizerD.step()
  #updating the weights of the neural network of generator
  netG.zero_grad()
  target=Variable(torch.ones(input.size()[0]))
  output=netD(fake)
  errG=criterion(output,target)
  errG.backward()
  optimizerG.step()
  if epoch%100==0:
    print('%d epochs done' % epoch)
  if epoch%200==0:
    print('[%d/%d]Loss_D: %.4f Loss_G: %.4f' % (epoch,8000,errD.data,errG.data))
  if epoch%1000==0:
    torch.save(netG.state_dict(),"./content/drive/My Drive/results/GENERATOR_%03d.pth" % epoch)
    torch.save(netD.state_dict(),"./content/drive/My Drive/results/DISCRIMINATOR_%03d.pth" % epoch)
  if epoch%500==0:
    vutils.save_image(data,'%s/real_samples.png' % "./content/drive/My Drive/results",normalize=True)
    fake=netG(noise)
    vutils.save_image(fake.data,'%s/reFake_samples_epoch_%03d.png' % ("./content/drive/My Drive/results",epoch),normalize=True)

In [0]:
model=G()
model.apply(weights_init)
model.load_state_dict(torch.load("./drive/My Drive/results/reGENERATOR_3000.pth"))
model.eval()

In [0]:
noise=Variable(torch.randn(64,100,1,1))
fake=model(noise)
fake.data
vutils.save_image(fake.data,'%s/Check.png' % ("./drive/My Drive/results"),normalize=True)