In [None]:
from torchvision.transforms import Resize
from utils import *
from models.latentdiscovery.utils import load_generator
from models.latentdiscovery.latent_deformator import LatentDeformator
from torchvision.utils import save_image
import numpy as np
import random
import torch.nn.functional as F
from torchvision.models import resnet18
import torch.nn as nn
import matplotlib.pylab as plt
import torchvision
import cv2
from IPython import display
%matplotlib inline

In [None]:
def set_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
def one_hot(dims, value, indx):
    vec = torch.zeros(dims)
    vec[indx] = value
    return vec

## Configurations

In [None]:
random_seed = 1234
set_seed(random_seed)
load_codes = True
algo = 'ortho'
root_directory= '/home/adarsh/PycharmProjects/disentagled_latent_dirs'
result_path = os.path.join(root_directory,  'results/celeba_hq/latent_discovery_ours/qualitative_analysis')

## Model Selection

In [None]:
generator = load_generator(None, model_name='pggan_celebahq1024', gan_type='ProgGAN')

In [None]:

deformator_path = os.path.join(root_directory, 'pretrained_models/deformators/LatentDiscovery/pggan_celebahq1024/deformator_0.pt')
directions = torch.load(deformator_path, map_location='cpu')
root_deformator = LatentDeformator(shift_dim=generator.dim_z,
                      input_dim=512,  # dimension of one-hot encoded vector
                      out_dim=generator.dim_z[0],
                      type='projection',
                      random_init=True).cuda()
root_deformator.load_state_dict(directions)
root_deformator.cuda()


deformator_path = os.path.join(root_directory, 'results/celeba_hq/latent_discovery_ours/models/22000_model.pkl')
if algo == 'ortho':
    dse_deformator = torch.load(deformator_path)['deformator']['ortho_mat']
    q, r = torch.qr(dse_deformator)
    unflip = torch.diag(r).sign().add(0.5).sign()
    q *= unflip[..., None, :]
    dse_deformator = q.T
    dse_deformator.cuda()

elif algo == 'linear':
    deformator = torch.load(os.path.join(deformator_path))['deformator']
    dse_deformator = deformator.T
        


# if load_codes:
#     codes = np.load(os.path.join(root_dir, 'pretrained_models/latent_codes/pggan_celebahq1024_latents.npy'))
#     codes = torch.from_numpy(codes).type(torch.FloatTensor).cuda()
#     codes = torch.load(os.path.join(root_dir, 'results/celeba_hq/closed_form_ours/quantitative_analysis/z_analysis.pkl'))
# else:
num_samples = 1000
codes = torch.randn(num_samples, 512, 1, 1).cuda()

In [None]:
def postprocess_images(images):
        """Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
        images = images.detach().cpu().numpy()
        images = (images + 1) * 255 / 2
        images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
        images = images.transpose(0, 2, 3, 1)
        return images


def save_images(codes, shifts_r, shifts_count, root_dir, dse_dir, generator, root_deformator, dse_deformator):
        plt.figure(figsize=(30,30))
        for idx, z in enumerate(codes):
            print('Figure : ' + str(idx))
            z_shift_root = []
            z_shift_dse = []
            for i, shift in enumerate(np.linspace(-shifts_r,shifts_r,shifts_count)):
                
                latent_shift = root_deformator(one_hot(512, shift, root_dir).cuda())
                z_shifted = z + latent_shift
                z_shift_root.append(z_shifted)
                
                latent_shift = dse_deformator[dse_dir: dse_dir + 1] * shift
                z_shifted = z + latent_shift.view(1,512,1,1)
                z_shift_dse.append(z_shifted)
                
            z_shift_root = torch.stack(z_shift_root).squeeze(dim=1)
            z_shift_dse = torch.stack(z_shift_dse).squeeze(dim=1)
            
            with torch.no_grad():
                root_images= generator(z_shift_root)
            torch.save(root_images, os.path.join(result_path, 'temp', 'cf.pkl'))
            del root_images
            with torch.no_grad():
                dse_images= generator(z_shift_dse)
            torch.save(dse_images, os.path.join(result_path, 'temp', 'dse.pkl'))
            del dse_images
            root_images = torch.load(os.path.join(result_path, 'temp', 'cf.pkl'))
            dse_images = torch.load(os.path.join(result_path, 'temp', 'dse.pkl'))
            all_images = torch.cat((root_images, dse_images), dim=0)
            grid = torchvision.utils.make_grid(all_images.clamp(min=-1, max=1),nrow=3, scale_each=True, normalize=True)
            display.display(plt.gcf())
            plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
            del all_images
            del root_images
            del dse_images
            del grid

    
z_min_index = 40
z_max_index = 50
root_dir = 94
dse_dir = 196
shift_r = 10
shift_count = 3
all_images = save_images(codes[z_min_index:z_max_index], shift_r, shift_count, root_dir, dse_dir, generator, root_deformator, dse_deformator)                    

In [None]:
idxes = [12,196,16,137,87,119,63,118,94,167]
selected_dirs = torch.zeros((512,len(idxes)))

In [None]:
selected_dirs = [directions['linear.weight'][:,idx] for idx in idxes]

In [None]:
for i,idx in enumerate(idxes):
    selected_dirs[:,i] = directions['linear.weight'][:,idx]

In [None]:
torch.save(selected_dirs,'selected_dirs.pt')

12,196,16,137,87,119,63,118,94,167