In [8]:
from torchvision.transforms import Resize
from utils import *
from models.gan_load import make_style_gan2
from models.gan_load import make_proggan
from models.latent_deformator import LatentDeformator


from torchvision.utils import save_image
import numpy as np
import random
import torch.nn.functional as F
from torchvision.models import resnet18
import torch.nn as nn

import cv2

In [9]:
def set_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)

In [10]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Configurations

In [11]:
gan_type = 'prog-gan-sefa'
num_directions = 512
shifts_r= 10
shifts_count=5
scale = Resize(64)
set_seed(12)

## Model Selection

In [12]:
if gan_type == 'StyleGAN2':
    config_gan = {"latent": 512, "n_mlp": 3,
                  "channel_multiplier": 8}
    G = Generator(size=64, style_dim=config_gan["latent"],n_mlp=config_gan["n_mlp"],
        small=True, channel_multiplier=config_gan["channel_multiplier"])
    G.load_state_dict(torch.load(opt.pretrained_gen_path))
    G.eval().to(device)
    for p in G.parameters():
        p.requires_grad_(False)
elif gan_type == 'StyleGAN2-Natural':
    G = make_style_gan2(1024,'models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt' , True)
elif gan_type == 'prog-gan':
    G = make_proggan('models/pretrained/generators/ProgGAN/100_celeb_hq_network-snapshot-010403.pth').cuda()
elif gan_type == 'prog-gan-sefa':
    from models.proggan_sefa import PGGANGenerator
    G = PGGANGenerator(resolution=1024)
    checkpoint = torch.load('models/pretrained/ProgGAN/pggan_celebahq1024.pth', map_location='cpu')
    if 'generator_smooth' in checkpoint:
        G.load_state_dict(checkpoint['generator_smooth'])
    else:
        G.load_state_dict(checkpoint['generator'])
    G = G.cuda()
    G.eval()

## Ours

In [13]:
visualisation_data_path = '/media/adarsh/DATA/CelebA-Analysis/'

In [14]:
pretrained_model = torch.load(visualisation_data_path + 'models/100007_model.pkl', map_location='cpu')
deformator_ours = LatentDeformator(shift_dim=G.z_space_dim, input_dim=num_directions,
                              out_dim=G.z_space_dim, type='ortho',random_init=True)
deformator_ours.load_state_dict(pretrained_model['deformator'])
deformator_ours.cuda()

LatentDeformator()

# CF

In [20]:
pretrained_model = torch.load(visualisation_data_path + 'models/cf_model.pkl', map_location='cpu')
deformator_cf = LatentDeformator(shift_dim=G.z_space_dim, input_dim=num_directions,
                              out_dim=G.z_space_dim, type='linear',random_init=True, bias=False)

deformator_cf.load_state_dict(pretrained_model['deformator'])
deformator_cf.cuda()

deformator_cf.eval()

LatentDeformator(
  (linear): Linear(in_features=512, out_features=512, bias=False)
)

In [15]:
def save_images(z, shifts_r, shifts_count, direction, G, deformator, path):
    for i,shift in enumerate(np.arange(-shifts_r, shifts_r, shifts_r / shifts_count)):
        latent_shift = deformator(one_hot(deformator.input_dim, shift,direction).cuda())
        shifted_image = G(z+latent_shift.cuda())
        images_row = postprocess(shifted_image)
        images_row = images_row.reshape(1024,1024,3)
        cv2.imwrite(path +'image_' + str(i) + '.jpg', images_row[:, :, ::-1])

In [16]:
def postprocess(images):
        """Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
        images = images.detach().cpu().numpy()
        images = (images + 1) * 255 / 2
        images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
        images = images.transpose(0, 2, 3, 1)
        return images

In [None]:
# path = visualisation_data_path + '/their images'
# codes = np.load('pggan_celebahq1024_latents.npy')
# codes = torch.from_numpy(codes).type(torch.FloatTensor).cuda()
# codes = codes.view(-1, 512)
# images = torch.FloatTensor().cuda()
# for i, each in enumerate(codes):
#     image = G(each.view(-1, 512))
#     images_row = postprocess(image)
#     images_row = images_row.reshape(1024,1024,3)
#     cv2.imwrite(path +'image_' + str(i) + '.jpg', images_row[:, :, ::-1])


In [17]:
codes = np.load('pggan_celebahq1024_latents.npy')
codes = torch.from_numpy(codes).type(torch.FloatTensor).cuda()

In [18]:
codes.shape

torch.Size([10, 512])

In [35]:
seed = random.randint(0, 100000000)
print(seed)
set_seed(seed)
# z = torch.randn(1, G.z_space_dim).cuda()
z = codes[2].view(-1,512)
direction = 70
cf_image_path = 'images_cf/'
ours_image_path = 'images_ours/'
save_images(z, shifts_r, shifts_count, direction, G, deformator_cf, cf_image_path)
save_images(z, shifts_r, shifts_count, direction, G, deformator_ours, ours_image_path)

54726729


# Load Attribute Predictors

In [None]:
attr = 'pose' ## eyeglasses, male, pose, smiling, young

In [None]:
def get_resnet():
    net = resnet18()
    modified_net = nn.Sequential(*list(net.children())[:-1])  # fetch all of the layers before the last fc.
    return modified_net

In [None]:
class ClassifyModel(nn.Module):
    def __init__(self, n_class=2):
        super(ClassifyModel, self).__init__()
        self.backbone = get_resnet()
        self.extra_layer = nn.Linear(512, n_class)

    def forward(self, x):
        out = self.backbone(x)
        out = torch.flatten(out, 1)
        out = self.extra_layer(out)
        return out

In [None]:
def get_classifier(pretrain_path, device):
    classifier = ClassifyModel().to(device)
    classifier.load_state_dict(torch.load(pretrain_path))    
    return classifier

In [None]:
pose_predictor = get_classifier(os.path.join(visualisation_data_path, "pretrain/classifier", 'pose', "weight.pkl"), 'cpu')
pose_predictor.eval()
gender_predictor = get_classifier(os.path.join(visualisation_data_path, "pretrain/classifier", 'male', "weight.pkl"), 'cpu')
gender_predictor.eval()
age_predictor = get_classifier(os.path.join(visualisation_data_path, "pretrain/classifier", 'young', "weight.pkl"), 'cpu')
age_predictor.eval()
glasses_predictor = get_classifier(os.path.join(visualisation_data_path, "pretrain/classifier", 'eyeglasses', "weight.pkl"), 'cpu')
glasses_predictor.eval()
smile_predictor = get_classifier(os.path.join(visualisation_data_path, "pretrain/classifier", 'smiling', "weight.pkl"), 'cpu')
smile_predictor.eval()

In [None]:
shift = 10 ## Set to 1 for a fair comparison in Closed Form Issues
direction = 499
total_images = 10
images_per_batch = 1
save_image_dir = os.path.join(visualisation_data_path, "images/images_new_direction_"+str(direction))
if not os.path.exists(save_image_dir):
    os.makedirs(save_image_dir)
for i in range(int(total_images/ images_per_batch)):
    print('Image_' + str(i))
#     z = torch.randn(images_per_batch, G.z_space_dim).cuda()
    z = codes[i].view(-1, 512)
    image = G(z)
    image = F.avg_pool2d(image, 4, 4)
    predictions_img = torch.softmax(gender_predictor(image.cpu()),dim=1)
    print('Image scores : ')
    print(predictions_img)
    latent_shift = deformator_cf(one_hot(deformator_cf.input_dim, shift,direction).cuda())
    image_shifted =  G(z+latent_shift.cuda())
    image_shifted = F.avg_pool2d(image_shifted, 4, 4)
    predictions_img_shift = torch.softmax(gender_predictor(image_shifted.cpu()),dim=1)
    print('Image shifted scores : ')
    print(predictions_img_shift)
    torch.save((image, image_shifted), save_image_dir + '/image_'+str(i)+'.pkl')
    del image
    del image_shifted
    

In [None]:
predictions_img_shift

# Plot images

In [None]:

num_images = 10
all_images = torch.FloatTensor().cuda()
for i in range(num_images):
    images = torch.load(visualisation_data_path + '/images/images_new_direction_'+str(direction) + '/image_'+str(i)+'.pkl')
    image = torch.cat((images[0], images[1])).detach()
    all_images = torch.cat((all_images, image))
    del image, images
    
import matplotlib.pyplot as plt
import torchvision
plt.figure(figsize= (60,60))
grid = torchvision.utils.make_grid(all_images.clamp(min=-1, max=1),nrow=2, scale_each=True, normalize=True)
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())

## Plot images

In [None]:
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

images = []
for img_path in glob.glob('images/*.png'):
    images.append(mpimg.imread(img_path))

plt.figure(figsize=(20,10))
columns = 10
for i, image in enumerate(images):
    plt.subplot(len(images) / columns + 1, columns, i + 1)
    plt.imshow(image)