In [None]:
!wget https://www.dropbox.com/s/rbajpdlh7efkdo1/male_female_face_images.zip
!unzip /content/male_female_face_images.zip

In [None]:
!pip install -q torch_snippets

In [None]:
from torch_snippets import *
import torchvision
import torch
from torchvision import transforms
import torchvision.utils as vutils 
import cv2, numpy as np, pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'
face_cascade=cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
!mkdir cropped_faces
images = Glob('/content/females/*.jpg') + Glob('/content/males/*.jpg')

for i in range(len(images)):
  img = read(images[i],1)
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  faces = face_cascade.detectMultiScale(gray, 1.3, 5)
  for (x,y,w,h) in faces:
    imgs = img[y:(y+h),x:(x+w),:]
  cv2.imwrite('/cropped_faces/' + str(i) + '.jpg', cv2.cvtColor(imgs, cv2.COLOR_RGB2BGR))

transform = transforms.Compose([transforms.Resize(64), transforms.CenterCrop(64),transforms.ToTensor(), transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])

In [None]:
class FacesData(Dataset):
  def __init__(self,folder):
    super().__init__()
    self.folder = folder
    self.images = sorted(Glob(folder))
  
  def __len__(self):
    return len(self.images)
  
  def __getitem__(self, index):
    image_path = self.images[index]
    image = Image.open(image_path)
    image = transform(image)
    return image

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
                    nn.Conv2d(3,64,4,2,1,bias=False),
                    nn.LeakyReLU(0.2,inplace=True),
                    nn.Conv2d(64,64*2,4,2,1,bias=False),
                    nn.BatchNorm2d(64*2),
                    nn.LeakyReLU(0.2,inplace=True),
                    nn.Conv2d(64*2,64*4,4,2,1,bias=False),
                    nn.BatchNorm2d(64*4),
                    nn.LeakyReLU(0.2,inplace=True),
                    nn.Conv2d(64*4,64*8,4,2,1,bias=False),
                    nn.BatchNorm2d(64*8),
                    nn.LeakyReLU(0.2,inplace=True),
                    nn.Conv2d(64*8,1,4,1,0,bias=False),
                    nn.Sigmoid()
                )
  def forward(self, input):
    return self.model(input)

class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.model = nn.Sequential(
        nn.ConvTranspose2d(100,64*8,4,1,0,bias=False,),
        nn.BatchNorm2d(64*8),
        nn.ReLU(True),
        nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),
        nn.BatchNorm2d(64*4),
        nn.ReLU(True),
        nn.ConvTranspose2d( 64*4,64*2,4,2,1,bias=False),
        nn.BatchNorm2d(64*2),
        nn.ReLU(True),
        nn.ConvTranspose2d( 64*2,64,4,2,1,bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),
        nn.ConvTranspose2d( 64,3,4,2,1,bias=False),
        nn.Tanh())
  
  def forward(self,input): 
    return self.model(input)

In [None]:
def discriminator_train_step(real_data, fake_data):
  d_optimizer.zero_grad()
  prediction_real = discriminator(real_data)
  error_real = loss(prediction_real.squeeze(), \
                    torch.ones(len(real_data)).to(device))
  error_real.backward()
  prediction_fake = discriminator(fake_data)
  error_fake = loss(prediction_fake.squeeze(), \
                    torch.zeros(len(fake_data)).to(device))
  error_fake.backward()
  d_optimizer.step()
  return error_real + error_fake

def generator_train_step(fake_data):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction.squeeze(), \
                torch.ones(len(real_data)).to(device))
    error.backward()
    g_optimizer.step()
    return error

discriminator = Discriminator().to(device)
generator = Generator().to(device)
loss = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))


log = Report(25)
for epoch in range(25):
  N = len(dataloader)
  for i, images in enumerate(tqdm(dataloader)):
    real_data = images.to(device)
    fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
    fake_data = fake_data.detach()
    d_loss=discriminator_train_step(real_data, fake_data)
    fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
    g_loss = generator_train_step(fake_data)
    log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
    log.report_avgs(epoch+1)
log.plot_epochs(['d_loss','g_loss'])

generator.eval()
noise = torch.randn(64, 100, 1, 1, device=device)
sample_images = generator(noise).detach().cpu()
grid = vutils.make_grid(sample_images,nrow=8,normalize=True)
show(grid.cpu().detach().permute(1,2,0), sz=10, title='Generated images')