In [None]:
from torchvision.transforms import Resize
from utils import *
from models.latentdiscovery.utils import load_generator
from models.latentdiscovery.latent_deformator import LatentDeformator
from torchvision.utils import save_image
import numpy as np
import random
import torch.nn.functional as F
from torchvision.models import resnet18
import torch.nn as nn
import matplotlib.pylab as plt
from torchvision.utils import make_grid
from torch_tools.visualization import to_image
import torchvision
import cv2
from IPython import display
%matplotlib inline

In [None]:
def set_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
def one_hot(dims, value, indx):
    vec = torch.zeros(dims)
    vec[indx] = value
    return vec

## Configurations

In [None]:
# random_seed = 215
# set_seed(random_seed)
load_codes = True
algo = 'ortho'
root_directory= '/home/adarsh/PycharmProjects/disentagled_latent_dirs'
result_path = os.path.join(root_directory,  'results/celeba_hq/latent_discovery_ours/qualitative_analysis')

## Model Selection

In [None]:
generator = load_generator(None, model_name='pggan_celebahq1024', gan_type='ProgGAN')

In [None]:

deformator_path = os.path.join(root_directory, 'pretrained_models/deformators/LatentDiscovery/pggan_celebahq1024/deformator_0.pt')
directions = torch.load(deformator_path, map_location='cpu')
root_deformator = LatentDeformator(shift_dim=generator.dim_z,
                      input_dim=512,  # dimension of one-hot encoded vector
                      out_dim=generator.dim_z[0],
                      type='projection',
                      random_init=True).cuda()
root_deformator.load_state_dict(directions)
root_deformator.cuda()


deformator_path = os.path.join(root_directory, 'results/celeba_hq/latent_discovery_ours/models/22000_model.pkl')
if algo == 'ortho':
    dse_deformator = torch.load(deformator_path)['deformator']['ortho_mat']
    q, r = torch.qr(dse_deformator)
    unflip = torch.diag(r).sign().add(0.5).sign()
    q *= unflip[..., None, :]
    dse_deformator = q.T
    dse_deformator.cuda()

elif algo == 'linear':
    deformator = torch.load(os.path.join(deformator_path))['deformator']
    dse_deformator = deformator.T
        


# if load_codes:
#     codes = np.load(os.path.join(root_dir, 'pretrained_models/latent_codes/pggan_celebahq1024_latents.npy'))
#     codes = torch.from_numpy(codes).type(torch.FloatTensor).cuda()
#     codes = torch.load(os.path.join(root_dir, 'results/celeba_hq/closed_form_ours/quantitative_analysis/z_analysis.pkl'))
# else:
num_samples = 1000
codes = torch.randn(num_samples, 512, 1, 1).cuda()

In [None]:
# def add_border(tensor):
#     border = 3
#     for ch in range(tensor.shape[0]):
#         color = 1.0 if ch == 0 else -1
#         tensor[ch, :border, :] = color
#         tensor[ch, -border:,] = color
#         tensor[ch, :, :border] = color
#         tensor[ch, :, -border:] = color
#     return tensor

# @torch.no_grad()
# def interpolate(G, z, shifts_r, shifts_count, dim, deformator=None, with_central_border=False):
#     shifted_images = []
#     for shift in np.linspace(-shifts_r,shifts_r,shifts_count):
# #         direction = deformator[dim: dim + 1]
# #         direction = direction.unsqueeze(2)
# #         direction = direction.unsqueeze(3)
#         direction = root_deformator(one_hot(512, shift, dim).cuda())
#         shifted_image = G(z +  direction* shift).cpu()[0]
#         if shift == 0.0 and with_central_border:
#             shifted_image = add_border(shifted_image)
#         shifted_images.append(shifted_image)
#     return shifted_images

# @torch.no_grad()
# def make_interpolation_chart(G, deformator=None, z=None,
#                              shifts_r=10.0, shifts_count=5,
#                              dims=None, dims_count=10, texts=None, **kwargs):


#     original_img = G(z).cpu()
#     imgs = []
#     if dims is None:
#         dims = range(dims_count)
#     for i in dims:
#         imgs.append(interpolate(G, z, shifts_r, shifts_count, i, deformator))

#     rows_count = len(imgs) + 1
#     fig, axs = plt.subplots(rows_count, **kwargs)

#     axs[0].axis('off')
#     axs[0].imshow(to_image(original_img, True))

#     if texts is None:
#         texts = dims
#     for ax, shifts_imgs, text in zip(axs[1:], imgs, texts):
#         ax.axis('off')
#         plt.subplots_adjust(left=0.5)
#         ax.imshow(to_image(make_grid(shifts_imgs, nrow=(2 * shifts_count + 1), padding=1), True))
#         ax.text(-20, 21, str(text), fontsize=10)


#     return fig


# @torch.no_grad()
# def inspect_all_directions(G, deformator, out_dir, zs=None, num_z=3, shifts_r=8.0):
#     os.makedirs(out_dir, exist_ok=True)

#     step = 5
#     max_dim = G.dim_shift[0]
#     zs = zs if zs is not None else make_noise(num_z, G.dim_z).cuda()
#     shifts_count = zs.shape[0]

#     for start in range(10, max_dim - 1, step):
#         imgs = []
#         dims = range(start, min(start + step, max_dim))
#         for z in zs:
#             z = z.unsqueeze(0)
#             fig = make_interpolation_chart(
#                 G, deformator=deformator, z=z,
#                 shifts_count=5, dims=dims, shifts_r=shifts_r,
#                 dpi=250, figsize=(int(shifts_count * 4.0), int(0.5 * step) + 2))
#             fig.canvas.draw()
#             plt.close(fig)
#             img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
#             img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))

#             # crop borders
#             nonzero_columns = np.count_nonzero(img != 255, axis=0)[:, 0] > 0
#             img = img.transpose(1, 0, 2)[nonzero_columns].transpose(1, 0, 2)
#             imgs.append(img)

#         out_file = os.path.join(out_dir, '{}_{}.jpg'.format(dims[0], dims[-1]))
#         print('saving chart to {}'.format(out_file))
#         Image.fromarray(np.hstack(imgs)).save(out_file)
        
# # z = torch.load('codes.pkl').cuda()
# out_dir = '/home/adarsh/PycharmProjects/disentagled_latent_dirs/results/celeba_hq/latent_discovery_ours/inspect_all_dirs'
# inspect_all_directions(generator, root_deformator,out_dir,zs=codes, shifts_r=6)

In [None]:
def postprocess_images(images):
        """Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
        images = images.detach().cpu().numpy()
        images = (images + 1) * 255 / 2
        images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
        images = images.transpose(0, 2, 3, 1)
        return images


def save_images(codes, shifts_r, shifts_count, root_dir, dse_dir, generator, root_deformator, dse_deformator):
        temp_path = os.path.join(result_path, 'temp')
#         temp_path = result_path
        os.makedirs(temp_path, exist_ok=True)
        plt.figure(figsize=(30,30))
        for idx, z in enumerate(codes):
            print('Figure : ' + str(idx))
            z_shift_root = []
            z_shift_dse = []
            for i, shift in enumerate(np.linspace(-shifts_r,shifts_r,shifts_count)):
                
                latent_shift = root_deformator(one_hot(512, shift, root_dir).cuda())
                z_shifted = z + latent_shift
                z_shift_root.append(z_shifted)
                
                latent_shift = dse_deformator[dse_dir: dse_dir + 1] * shift
                z_shifted = z + latent_shift.view(1,512,1,1)
                z_shift_dse.append(z_shifted)
                
            z_shift_root = torch.stack(z_shift_root).squeeze(dim=1)
            z_shift_dse = torch.stack(z_shift_dse).squeeze(dim=1)
            
            with torch.no_grad():
                root_images= generator(z_shift_root)
            torch.save(root_images, os.path.join(temp_path, 'cf.pkl'))
            del root_images
            with torch.no_grad():
                dse_images= generator(z_shift_dse)
            torch.save(dse_images, os.path.join(temp_path, 'dse.pkl'))
            del dse_images
            root_images = torch.load(os.path.join(temp_path, 'cf.pkl'))
            dse_images = torch.load(os.path.join(temp_path, 'dse.pkl'))
            all_images = torch.cat((root_images, dse_images), dim=0)
            grid = torchvision.utils.make_grid(all_images.clamp(min=-1, max=1),nrow=3, scale_each=True, normalize=True)
            display.display(plt.gcf())
            plt.imsave(os.path.join(temp_path, str(idx) + '.png'), grid.permute(1, 2, 0).cpu().numpy())
#             plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
            del all_images
            del root_images
            del dse_images
            del grid

    
z_min_index = 0
z_max_index = 1000
# codes = torch.load(os.path.join(result_path, 'temp/codes_196_76_eyeglasses.pkl'))
# indices = [39, 42, 55, 85, 86]
# codes = codes[indices]
root_dir = 196
dse_dir = 76
shift_r = 10
shift_count = 3
# result_path = os.path.join(result_path, 'Eyeglasses')
# torch.save(codes, os.path.join(result_path, 'attr_codes.pkl'))
all_images = save_images(codes[z_min_index:z_max_index], shift_r, shift_count, root_dir, dse_dir, generator, root_deformator, dse_deformator)                    

# Plot results

In [None]:
def get_manipulated_images(z, shift_r, shift_count, root_dir, dse_dir, generator, root_deformator, dse_deformator):
    temp_path =  os.path.join(result_path, 'temp')
    os.makedirs(temp_path, exist_ok=True)
    z_shift_root = []
    z_shift_dse = []
    for i, shift in enumerate(np.linspace(-shifts_r,shifts_r,shifts_count)):
        latent_shift = root_deformator(one_hot(512, shift, root_dir).cuda())
        z_shifted = z + latent_shift
        z_shift_root.append(z_shifted)

        latent_shift = dse_deformator[dse_dir: dse_dir + 1] * shift
        z_shifted = z + latent_shift.view(1,512,1,1)
        z_shift_dse.append(z_shifted)
                
    z_shift_root = torch.stack(z_shift_root).squeeze(dim=1)
    z_shift_dse = torch.stack(z_shift_dse).squeeze(dim=1)
    with torch.no_grad():
        root_images= generator(z_shift_root)
    torch.save(root_images, os.path.join(temp_path, 'cf.pkl'))
    del root_images
    with torch.no_grad():
        dse_images= generator(z_shift_dse)
    torch.save(dse_images, os.path.join(temp_path, 'dse.pkl'))
    del dse_images
    root_images = torch.load(os.path.join(temp_path, 'cf.pkl'))
    dse_images = torch.load(os.path.join(temp_path, 'dse.pkl'))
    return root_images, dse_images

In [None]:
root_dir= '/home/adarsh/PycharmProjects/disentagled_latent_dirs'
result_path = os.path.join(root_dir,  'results/celeba_hq/latent_discovery_ours/qualitative_analysis')
attr_list = ['Gender', 'Smiling', 'Eyeglasses']
z = []
for each_attr in attr_list:
    z.append(torch.load(os.path.join(result_path, each_attr + '/attr_codes.pkl')))

In [None]:

shifts_r = 10
shifts_count = 3
root_dir = 16
dse_dir = 108
desired_idx  = 3

root_gender, dse_gender = get_manipulated_images(z[0][desired_idx], shifts_r, shifts_count, root_dir, dse_dir, generator, root_deformator, dse_deformator)


shifts_r = 10
shifts_count = 3
root_dir = 189
dse_dir = 10
desired_idx  = 0

root_smiling, dse_smiling = get_manipulated_images(z[1][desired_idx], shifts_r, shifts_count, root_dir, dse_dir, generator, root_deformator, dse_deformator)


shifts_r = 10
shifts_count = 3
root_dir = 196
dse_dir = 76
desired_idx  = 2

root_glass, dse_glass = get_manipulated_images(z[2][desired_idx], shifts_r, shifts_count, root_dir, dse_dir, generator, root_deformator, dse_deformator)


In [None]:
root = torch.stack((root_gender, root_smiling, root_glass),dim=0)
dse = torch.stack((dse_gender, dse_smiling, dse_glass),dim=0)
all_images = [root, dse]

In [None]:
algo = ['LD', 'LD + SRE']

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

attr_list = ['Gender', 'Smiling', 'Glasses']
SMALL_SIZE = 8
plt.rc('axes', titlesize=22, labelsize=20)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams["figure.facecolor"] = 'w'
plt.rcParams["font.family"] = "Times New Roman"

fig = plt.figure(figsize=(20, 5))
gs = gridspec.GridSpec(2, 3, wspace=0.1, hspace=0.01)
ax = np.zeros(6, dtype=object)
count = 0
for i in range(2):
    for j in range(3):
        ax[count] = fig.add_subplot(gs[i, j])
        grid = torchvision.utils.make_grid(all_images[i][j].clamp(min=-1, max=1),nrow=3, scale_each=True, normalize=True)
        ax[count].imshow(grid.permute(1, 2, 0).cpu().numpy())
        ax[count].grid(False)
        ax[count].set_xticks([])
        ax[count].set_yticks([])
        ax[count].spines["top"].set_visible(False)
        ax[count].spines["right"].set_visible(False)
        ax[count].spines["left"].set_visible(False)
        ax[count].spines["bottom"].set_visible(False)
        count = count + 1
        ax[j].title.set_text(attr_list[j])
        
ax[0].set_ylabel(algo[0], rotation=90)
ax[3].set_ylabel(algo[1], rotation=90)



gs.tight_layout(fig)
plt.savefig(os.path.join(result_path,'latent_traversal.pdf'), bbox_inches = 'tight')

# Appendix plot 

In [None]:
root_dir= '/home/adarsh/PycharmProjects/disentagled_latent_dirs'
result_path = os.path.join(root_dir,  'results/celeba_hq/latent_discovery_ours/qualitative_analysis')
# attr_list = ['Gender', 'Smiling', 'Glasses', 'Pose', 'Young']
attr_list = ['Hair']
z = []
for each_attr in attr_list:
    z.append(torch.load(os.path.join(result_path, each_attr, 'attr_codes.pkl')))

In [None]:
algo = ['LD', 'LD + SRE']

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

SMALL_SIZE = 8
plt.rc('axes', titlesize=22, labelsize=20)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams["figure.facecolor"] = 'w'
plt.rcParams["font.family"] = "Times New Roman"


shifts_r = 10
shifts_count = 5
root_dir = 100
dse_dir =27
desired_idx  = [0,1,2]
attr = 0

desired_z = z[attr][desired_idx]


fig = plt.figure(figsize=(50, 50))
gs = gridspec.GridSpec(len(desired_z), 1, wspace=0.1, hspace=0.01)
ax = np.zeros(len(z[attr]), dtype=object)
count = 0
for i in range(len(desired_z)):
    for j in range(1):
        ax[count] = fig.add_subplot(gs[i, j])
        root_images, dse_images = get_manipulated_images(desired_z[i], shifts_r, shifts_count, root_dir, dse_dir, generator, root_deformator, dse_deformator)
        all_images = torch.cat((root_images, dse_images), dim=0)
        grid = torchvision.utils.make_grid(all_images.clamp(min=-1, max=1),nrow=5, scale_each=True, normalize=True)
        ax[count].imshow(grid.permute(1, 2, 0).cpu().numpy())
        ax[count].grid(False)
        ax[count].set_xticks([])
        ax[count].set_yticks([])
        count = count + 1
        del all_images
        del root_images
        del dse_images
        del grid
plt.savefig(result_path + '/appendix_images/' + attr_list[0] + '.pdf', bbox_inches = 'tight')
#         ax[j].title.set_text(attr_list[j])
# ax[0].set_ylabel(algo[0], rotation=90)
# ax[3].set_ylabel(algo[1], rotation=90)

