In [None]:
#!pip install -r requirements_short.txt

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import make_grid

from torch_tools.visualization import to_image
from utils import is_conditional
from visualization import interpolate
from loading import load_trained_from_dir, load_generator

from classifier_networks import VGG, vgg_layers

import numpy as np
from matplotlib import pyplot as plt
import json
import os
import collections

In [2]:
from utils import make_noise, one_hot

In [3]:
device = torch.device('cuda:0')

In [15]:
G_weights='./models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt'
gan_resolution = 1024
gan_output_channel = 3
shift_in_w = True

In [5]:
root_dir = './models/pretrained/deformators/StyleGAN2/'
args = json.load(open(os.path.join(root_dir, 'args.json')))
args['w_shift'] = shift_in_w
args['gan_resolution'] = gan_resolution

In [6]:
G = load_generator(args, G_weights)

In [7]:
result_dir = './models/pretrained/deformators/StyleGAN2/'
deformator, _ = load_trained_from_dir(result_dir,G.dim_shift,shift_in_w=shift_in_w)
deformator.eval()

LatentDeformator()

In [None]:
training_name = 'FACE_Attractive'
shift_predictor_lr = 1e-4
n_steps = 100000
batch_size = 16
noise_scale = 0.1
gamma = 0.5

In [9]:
resize_transform = transforms.Resize((classifier_input_size,classifier_input_size))

In [10]:
classifier_weight_file = 'models/classifiers/celebA_Attractive_vgg11_classifier.pt'
class_count = 2
classifier_input_size = 256
classifier_input_channel = 3

classifier_weights = torch.load(classifier_weight_file)
if isinstance(classifier_weights, collections.OrderedDict):
    classifier = VGG(vgg_layers,class_count)
    classifier.load_state_dict(classifier_weights)
else:
    classifier = classifier_weights
classifier = classifier.cuda()
classifier = classifier.eval()

In [11]:
class FCShiftPredictor(nn.Module):
    def __init__(self,input_dim,class_dim, inner_dim, output_dim):
        super(FCShiftPredictor, self).__init__()
        self.fc_direction = nn.Sequential(
            nn.Linear(input_dim+class_dim,inner_dim),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(inner_dim,inner_dim),
            nn.ReLU(),
            nn.Dropout(p=0.1),
            nn.Linear(inner_dim,output_dim),
        )
        
    def forward(self, x,c):
        x_c = torch.cat((x,c),1)
        dir_ = self.fc_direction(x_c)
        return dir_

In [12]:
shift_model = FCShiftPredictor(deformator.out_dim,class_count,1024,deformator.input_dim)
shift_model.to(device).train()

FCShiftPredictor(
  (fc_direction): Sequential(
    (0): Linear(in_features=514, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=1024, out_features=512, bias=True)
  )
)

In [13]:
shift_model_opt = torch.optim.Adam(shift_model.parameters(), lr=shift_predictor_lr)

In [18]:
criterion = nn.BCELoss()

In [24]:
for step in range(n_steps):
    shift_model.zero_grad()
    z = torch.randn([1,G.dim_z]).repeat(batch_size,1).to(device)
    z_noise = noise_scale * torch.randn([batch_size,G.dim_z]).to(device)
    z_perturb = z + z_noise

    target_random = torch.randn([batch_size, class_count]).to(device)
    target_classes = torch.argmax(target_random, 1, keepdim=True)
    y_target = torch.FloatTensor(batch_size, class_count).to(device)
    y_target.zero_()
    y_target.scatter_(1, target_classes, 1)

    dir_pred = shift_model(z_perturb,y_target)
    
    img_shift = G(z_perturb + deformator(dir_pred) )

    if gan_resolution != classifier_input_size:
        img_shift = resize_transform(img_shift)
    
    if gan_output_channel == 1 and args.classifier_input_channel > 1:
        img_shift = img_shift.repeat([1,args.classifier_input_channel,1,1])
    
    y_shift = classifier(img_shift)
    
    if isinstance(y_shift,tuple):
        y_shift = y_shift[0]
    
    y_out = torch.softmax(y_shift,1)

    dir_loss = criterion(y_out,y_target)
    scale_loss = torch.mean(torch.abs(dir_pred))
    loss =  dir_loss + gamma * scale_loss
    
    print("STEP {:08d} CLASS LOSS: {:1.8f}  SHIFT SIZE LOSS: {:1.8f}".format(step, dir_loss, scale_loss))
    
    loss.backward()
    shift_model_opt.step()
    
    if step % 10000 == 0:
        torch.save(shift_model, 'trained_scale_predictors/shift_model_{:1.3f}_{:08d}.pt'.format(training_name,scale_loss_ratio, step))


STEP 00000000 CLASS LOSS: 0.63509446  SHIFT SIZE LOSS: 0.09505896
STEP 00000001 CLASS LOSS: 0.42623204  SHIFT SIZE LOSS: 0.08562627
STEP 00000002 CLASS LOSS: 0.21854746  SHIFT SIZE LOSS: 0.07724383
STEP 00000003 CLASS LOSS: 4.03591299  SHIFT SIZE LOSS: 0.09064588
STEP 00000004 CLASS LOSS: 0.30183387  SHIFT SIZE LOSS: 0.09089753
STEP 00000005 CLASS LOSS: 2.31770849  SHIFT SIZE LOSS: 0.09474817
STEP 00000006 CLASS LOSS: 0.75165021  SHIFT SIZE LOSS: 0.08726131
STEP 00000007 CLASS LOSS: 2.06401825  SHIFT SIZE LOSS: 0.08369934
STEP 00000008 CLASS LOSS: 0.40603667  SHIFT SIZE LOSS: 0.08846195
STEP 00000009 CLASS LOSS: 0.47795206  SHIFT SIZE LOSS: 0.09654497
STEP 00000010 CLASS LOSS: 0.15636313  SHIFT SIZE LOSS: 0.09083900
STEP 00000011 CLASS LOSS: 0.57725739  SHIFT SIZE LOSS: 0.10022274
STEP 00000012 CLASS LOSS: 0.74153620  SHIFT SIZE LOSS: 0.08771026
STEP 00000013 CLASS LOSS: 0.12373655  SHIFT SIZE LOSS: 0.10007721
STEP 00000014 CLASS LOSS: 0.33946943  SHIFT SIZE LOSS: 0.09556222
STEP 00000

KeyboardInterrupt: 