In [53]:
import torch
import torch.nn as nn

import torchvision

import os
import gdown                             
import zipfile
import matplotlib.pyplot as plt
%matplotlib inline

# import shutil
# shutil.rmtree('./data')

In [54]:
def device():
  if torch.cuda.is_available():
    return torch.device('cuda')
  else:
    return torch.device('cpu')

device = device()

def todevice(data_model, device):
  if isinstance(data_model, (list, tuple)):
    return [todevice(i, device) for i in data_model]
  return data_model.to(device, non_blocking=True)

class DataDeviceLoader():
  def __init__(self, dataloader, device):
    self.dataloader = dataloader   
    self.device = device

  def __iter__(self):
    for batch in self.dataloader:
      yield todevice(batch, self.device)
  
  def __len__(self):
    return len(self.dataloader)

In [55]:
dataset_root = 'data/celeba'
dataset_folder = f'{dataset_root}/img_align_celeba'
download_path = f'{dataset_root}/img_align_celeba.zip'

if not os.path.exists(dataset_root):
  os.makedirs(dataset_root)
  os.makedirs(dataset_folder)

url = 'https://drive.google.com/uc?id=1cNIac61PSA_LqDFYFUeyaQYekYPc75NH'
gdown.download(url, download_path, quiet=False)

with zipfile.ZipFile(download_path, 'r') as ziphandler:
  ziphandler.extractall(dataset_folder)

Downloading...
From: https://drive.google.com/uc?id=1cNIac61PSA_LqDFYFUeyaQYekYPc75NH
To: /content/data/celeba/img_align_celeba.zip
100%|██████████| 1.44G/1.44G [00:08<00:00, 178MB/s]


In [None]:
class Discriminator(nn.Module):
  def __init__(self, inputchannels=3, imagesize=64, classes=10):
    super(Discriminator, self).__init__()
    self.model = nn.Sequential(
        self.block(inputchannels, imagesize),
        self.block(imagesize, imagesize*2),
        self.block(imagesize*2, imagesize*4, stride=3),        
        self.block(imagesize*4, classes, lastlayer=True),
    )
  def block(self, inputchannels, outputchannels, kernelsize=4, stride=2, lastlayer=False):
    if lastlayer:
      return nn.Sequential(
          nn.Conv2d(inputchannels, outputchannels, kernelsize, stride),
          )
    else:
      return nn.Sequential(
          nn.Conv2d(inputchannels, outputchannels, kernelsize, stride),
          nn.BatchNorm2d(outputchannels),
          nn.leakyReLU(0.2, inplace=True),
          )

  def forward(self, images):
    predictions = self.model(images)
    return predictions.view(predictions.size(0), -1)                       #the same predictions.shape[0] or len(predictions)



class Generator(nn.Module):
  def __init__(self, latentsize=100, imagesize=64, RGBchannels=3):
    super(Generator, self).__init__()
    self.model = nn.Sequential(
        self.block(latentsize, imagesize*8),
        self.block(imagesize*8, imagesize*4),
        self.block(imagesize*4, imagesize*2), 
        self.block(imagesize*2, imagesize, kernelsize=4),         
        self.block(imagesize*2, RGBchannels, lastlayer=True),
    )
  def block(self, inputchannels, outputchannels, kernelsize=3, stride=2, lastlayer=False):
    if lastlayer:
      return nn.Sequential(
          nn.ConvTranspose2d(inputchannels, outputchannels, kernelsize, stride),
          nn.Tanh(),
          )
    else:
      return nn.Sequential(
          nn.ConvTranspose2d(inputchannels, outputchannels, kernelsize, stride),
          nn.BatchNorm2d(outputchannels),
          nn.ReLU(True),
          )

  def forward(self, latentinput):
    noisevector = latentinput.view(latentinput.size(0), self.latentsize, 1, 1)
    return self.model(noisevector)                    

In [None]:
generator = Generator().to(device)
generator.load_state_dict(torch.load('Gen_CelebA.pth', map_location=torch.device(device))['generator'])

discriminator = Discriminator(classes=40).to(device)
discriminator.load_state_dict(torch.load('Dis_CelebA.pth', map_location=torch.device(device))['discriminator'])

optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))

batchsize = 128
latentsize = 100

In [None]:
# 1. optimizing for particular nois vectors to achieve special featured images without trying to disentangle different features that may be entangled with target feature

features = ["5oClockShadow", "ArchedEyebrows", "Attractive", "BagsUnderEyes", "Bald", "Bangs",
"BigLips", "BigNose", "BlackHair", "BlondHair", "Blurry", "BrownHair", "BushyEyebrows", "Chubby",
"DoubleChin", "Eyeglasses", "Goatee", "GrayHair", "HeavyMakeup", "HighCheekbones", "Male", 
"MouthSlightlyOpen", "Mustache", "NarrowEyes", "NoBeard", "OvalFace", "PaleSkin", "PointyNose", 
"RecedingHairline", "RosyCheeks", "Sideburn", "Smiling", "StraightHair", "WavyHair", "WearingEarrings", 
"WearingHat", "WearingLipstick", "WearingNecklace", "WearingNecktie", "Young"]

targetfeatureindex = features.index['Young']

history = []
latentinputnoise = torch.randn(batchsize, latentsize, device=device).requires_grad_()
for i in range(10):
  optimizer.zero_grad()
  fake = generator(latentinputnoise)
  history += [fake]
  predictions = discriminator(fake)[:, targetfeatureindex].mean()
  predictions.backward()
  latentinputnoise.data = latentinputnoise + 0.1*latentinputnoise.grad      # latentinputnoise.data updates latentinputnoise and preserves for next iteration!!!


In [None]:
# 2. optimizing for particular nois vectors to achieve special featured images by trying to disentangle different features that may be entangled with target feature

features = ["5oClockShadow", "ArchedEyebrows", "Attractive", "BagsUnderEyes", "Bald", "Bangs",
"BigLips", "BigNose", "BlackHair", "BlondHair", "Blurry", "BrownHair", "BushyEyebrows", "Chubby",
"DoubleChin", "Eyeglasses", "Goatee", "GrayHair", "HeavyMakeup", "HighCheekbones", "Male", 
"MouthSlightlyOpen", "Mustache", "NarrowEyes", "NoBeard", "OvalFace", "PaleSkin", "PointyNose", 
"RecedingHairline", "RosyCheeks", "Sideburn", "Smiling", "StraightHair", "WavyHair", "WearingEarrings", 
"WearingHat", "WearingLipstick", "WearingNecklace", "WearingNecktie", "Young"]

targetfeatureindex = features.index['Young']
otherfeatureindexes = [index != targetfeatureindex for index, _ in enumerate(features)]

history = []
latentinputnoise = torch.randn(batchsize, latentsize, device=device).requires_grad_()
startingpredictions = discriminator(generator(latentinputnoise)).detach()

for i in range(10):
  optimizer.zero_grad()
  fake = generator(latentinputnoise)
  history += [fake]

  targetfeature_score = discriminator(fake)[:, targetfeatureindex].mean()
  nontargets_losses = startingpredictions[:, otherfeatureindexes] - discriminator(fake)[:, otherfeatureindexes]
  nontargets_loss = torch.norm(nontargets_losses, dim=1).mean()*0.1             # 0.1 because we totally have 10 iterations here
  final_score = targetfeature_score - nontargets_loss

  final_score.backward()
  latentinputnoise.data = latentinputnoise + 0.1*latentinputnoise.grad          # 0.1 because we totally have 10 iterations here   

check papers:

https://arxiv.org/abs/1411.1784

https://arxiv.org/abs/1907.10786