# Load image VAE-GAN

In [1]:
%cd "beta_vae_idgan"

C:\Users\zhang\idgan-master


## Utility_functions

In [2]:
import os    
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import seaborn as sns

%matplotlib inline

import cv2
import numpy as np

import time
from PIL import Image
import os.path
import glob
import pandas as pd
import random
from math import sqrt

import pandas as pd
from vrProjector_master import vrProjector
from moviepy.editor import ImageSequenceClip

from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid, save_image
from torch.utils.tensorboard import SummaryWriter

In [3]:
import argparse
import os
from os import path
import copy
from tqdm import tqdm
import torch
from torch import nn

from gan_training import utils
from gan_training.inputs import get_dataset
from gan_training.distributions import get_ydist, get_zdist
from gan_training.eval import Evaluator
from gan_training.config import (
    load_config, build_models
)

In [4]:
def get_one_hot(label, N):
    size = list(label.size())
    label = label.view(-1).cpu()   #reshape to a long vector
    ones = torch.sparse.torch.eye(N)
    ones = ones.index_select(0, label)   #turn to one hot
    size.append(N)  #reshape to h*w*channel classes
    return ones.view(*size).squeeze(1).permute(0,3,1,2)

def discretize_to_order_labels(t):
    tensor=t.cpu()
    bins = torch.tensor([-0.125,0.125, 0.375, 0.625, 0.875,1.125])
    inds = torch.bucketize(tensor, bins)
    tensor_discret = inds.add(-1)
    
    return tensor_discret

def generator_postprocess(x):
    n_channels = x.size(1)
    
    if n_channels ==5:
        x_im = x.add(1).div(2).clamp(0,1).round().argmax(dim=1).unsqueeze(1).cpu()/(n_channels-1)
        return x_im
        
    else:
        x_shift = x.add(1).div(2).cpu()
        x_discret = discretize_to_order_labels(x_shift)

        return x_discret.div(4)

def decoder_postprocess(x):
    n_channels = x.size(1)
    
    if n_channels ==5:
        x_recover = x.softmax(dim=1).argmax(dim=1).unsqueeze(1).cpu()/(n_channels-1)
        return x_recover
        
    else:
        return torch.sigmoid(x)

def show_im(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='bilinear')
    plt.axis('off')
    
def gray_label_to_torch_tensor(path):
    example =  cv2.imread(path)
    
    gt = np.round(np.mean(example,axis=2)/64)
    gt = torch.LongTensor(gt)

    def get_one_hot(label, N):
        size = list(label.size())
        label = label.view(-1)   #reshape to a long vector
        ones = torch.sparse.torch.eye(N)
        ones = ones.index_select(0, label)   #turn to one hot
        size.append(N)  #reshape to h*w*channel classes
        return ones.view(*size).cuda()


    gt_one_hot = get_one_hot(gt, 5)
    #print(gt_one_hot)
    #print(gt_one_hot.shape)
    #print(gt_one_hot.argmax(-1) == gt)  # check if one hot converting correct or not (1:correct)

    #gt_remove_edge = gt_one_hot[:,:,1:].permute(2,0,1)
    gt_reserve_edge = gt_one_hot.permute(2,0,1)
    img_tensor = gt_reserve_edge
    
    return img_tensor

def load_simple_im(path):
    example =  cv2.imread(path)
    
    gt = np.mean(example,axis=2)/256
    
    gt = (gt-0.5)*2

    img_tensor = torch.from_numpy(gt).unsqueeze(0).unsqueeze(0).float()
    
    return img_tensor

class solargan_im_trainset(Dataset):
    def __init__(self,images, loader):
        
        self.images = images #image path
        self.loader = loader

    def __getitem__(self, index):
        fn = self.images[index]
        tensor = self.loader(fn)
        return tensor

    def __len__(self):
        return len(self.images)
    
class CheckpointIO(object):
    def __init__(self, checkpoint_dir='./chkpts', **kwargs):
        self.module_dict = kwargs
        self.checkpoint_dir = checkpoint_dir

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

    def register_modules(self, **kwargs):
        self.module_dict.update(kwargs)

    def save(self, it, filename):
        filename = os.path.join(self.checkpoint_dir, filename)

        outdict = {'it': it}
        for k, v in self.module_dict.items():
            outdict[k] = v.state_dict()
        torch.save(outdict, filename)

    def load(self, filename):
        filename = os.path.join(self.checkpoint_dir, filename)

        if os.path.exists(filename):
            tqdm.write('=> Loading checkpoint...')
            out_dict = torch.load(filename,map_location=torch.device('cpu')) ##only use CPU here!!!
            it = out_dict['it']
            for k, v in self.module_dict.items():
                if k in out_dict:
                    v.load_state_dict(out_dict[k])
                else:
                    tqdm.write('Warning: Could not find %s in checkpoint!' % k)
        else:
            it = -1

        return it

In [5]:
def make_imsample_plot(tensor,title,col_space = 50,nrow=6, padding=2, pad_value=0):
    cmap = {
    '0'         : "#FFFFFF",      ## 	black edge: black
    '1'         : "#DAA520",      ## 	ground: Khaki    
    '2'         : "#CD5C5C",      ## 	opaque surfaces: FireBrick
    '3'         : "#87CEFA",      ## 	glazing: MediumTurquoise
    '4'         : "#F0FFFF",      ## 	sky：Azure 
}   

    tensor_grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value)
    np_img = tensor_grid.numpy()
                          
    p = figure(title=title,tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")],x_range = ([0,np_img.shape[1]+col_space]))


    mapper = CategoricalColorMapper(palette=list(cmap.values()), factors=list(cmap.keys()))
    
                              
    p.image(image = [np.flip((np_img[1,:,:]*4).astype(int).astype(str),axis=0)], x=0, y=0, dw=np_img.shape[1], dh=np_img.shape[2],color_mapper=mapper)
    

    
    p.axis.visible = False
    p.grid.visible = False
    p.outline_line_alpha=0
    p.outline_line_width = 3
    
    return p


In [6]:
def colorize_one_image(gray_image, n):
    
    def colormap(n):
        cmap=np.zeros([n, 3]).astype(np.uint8)
        cmap[0] = np.array([255,255,255])  ## 	black edge: blac
        cmap[1] = np.array([218,165,32])   ## 	ground: Goldenrod
        cmap[2] = np.array([205,92,92])    ## 	opaque surfaces: IndianRed
        cmap[3] = np.array([135,206,250])  ## 	glazing: LightSkyBlue
        cmap[4] = np.array([240,255,255	]) ## 	sky: Azure

        return cmap
    
    cmap = colormap(n)
    size = gray_image.size()  # 网络output的大小
    color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 

    for label in range(0, len(cmap)):# 依次遍历label的颜色表
        mask = gray_image[0] == label  
        #gray_image[0] 是将三维的图像，以【1, 10, 10】为例，变成二维【10,10】,这个参数是外部传入，这里确保是二维单通道就行了
        #gray_image[0] == label 意思是将 gray_image[0]中为label值的元素视为true或者1，其他的元素为False 或0，得到mask的布尔图

        color_image[0][mask] = cmap[label][0] #取取颜色表中为label列表(【a,b,c】)的a
        #color_image[0]是取三通道模板中的单通道 ，然后把mask放上去
        color_image[1][mask] = cmap[label][1]  # 取b
        color_image[2][mask] = cmap[label][2]#   取c

    return color_image

def gen_pure_white(w,h):

        b = g =r  = np.ones((w,h), dtype=np.uint8)*255


        white = cv2.merge([b, g, r])
        
        return white

def cubemap_back_to_fisheye(single_im_tensor,n_channels=5, outsize=256, padding=4):
    x_1 = colorize_one_image(single_im_tensor,5)

    npimg = np.transpose(x_1.numpy(), (1,2,0))
    npimg_front = npimg[32:96,32:96,:]
    npimg_top2rotate=np.vstack([gen_pure_white(32,64), npimg[0:32,32:96,:]])
    npimg_top = np.rot90(npimg_top2rotate,1)
    npimg_bottom2rotate=np.vstack([npimg[96:128,32:96,:],gen_pure_white(32,64)])
    npimg_bottom = np.rot90(npimg_bottom2rotate,-1)
    npimg_right=np.hstack([npimg[32:96,96:128,:],gen_pure_white(64,32)])
    npimg_left=np.hstack([gen_pure_white(64,32),npimg[32:96,0:32,:]])
    npimg_back=gen_pure_white(64,64)
    
    source = vrProjector.CubemapProjection()
    #source.loadImages("cubemap_reverse/left.png","cubemap_reverse/front.png","cubemap_reverse/right.png","cubemap_reverse/back.png","cubemap_reverse/top.png","cubemap_reverse/bottom.png")
    source.readImageArrays(npimg_left,npimg_front,npimg_right,npimg_back,npimg_top,npimg_bottom)
    #source.set_use_bilinear(True)


    out = vrProjector.SideBySideFisheyeProjection()
    out.initImage((outsize-padding*2)*2,outsize-padding*2)
    out.reprojectToThis(source)
    #out.saveImage("cubemap_reverse/fisheye.png")
    x_fish_dual = out.outputImage("fisheye")
    
    x_fish = x_fish_dual[:,outsize-padding*2:(outsize-padding*2)*2,:]
    blk_mask=x_fish==0
    x_fish[blk_mask] = 255
    x_fish = np.pad(x_fish, ((padding,padding), (padding,padding),(0, 0)), 'constant', constant_values=255)


    #plt.imshow(x_fish)
    #plt.axis("off")
    
    return x_fish



In [7]:
def make_fisheye_grid(tensor, mode,batch_size=36,n_rows = 3,im_size=256,padding=4):
    if mode == 'gt':
        factor = 1
    else:
        factor = 4
    
    np_allim = np.empty(shape=(0,im_size,im_size,4))
    for ix in range(batch_size):
        
        tensor_ix = tensor[ix]*factor
        npimg_ix = cubemap_back_to_fisheye(tensor_ix, n_channels = 5,outsize=im_size,padding=padding)/255
        np_allim = np.concatenate((np_allim,np.expand_dims(npimg_ix[:,:,:],0)),axis=0)
        
    tensor_allim = torch.from_numpy(np.transpose(np_allim,(0,3,1,2)))
    tensor_allim_grid = make_grid(tensor_allim, nrow=n_rows, padding=0, pad_value=0)
    np_allim_grid = np.transpose(tensor_allim_grid.numpy(),(1,2,0))
        
    #print(np.unique(tensor_allim_grid.numpy()))
    #plt.figure(figsize=(10,10))
    #plt.imshow(np_allim_grid)
    #plt.axis("off")
        
    return np_allim_grid*255

def make_grid_fisheye_with_plt(file_name,tensor1,tensor2,tensor3,nrows1=3,nrows2=3,nrows3=3,batch_size1=21,batch_size2=21,batch_size3=21):
    
    npimg1 = make_fisheye_grid(tensor1, 'gt', batch_size=batch_size1,n_rows=nrows1, im_size=256,padding=4)/255
    npimg2 = make_fisheye_grid(tensor2, 'res', batch_size=batch_size2,n_rows=nrows2,im_size=256,padding=4)/255
    npimg3 = make_fisheye_grid(tensor3, 'vae', batch_size=batch_size3,n_rows=nrows3,im_size=256,padding=4)/255
    
    fig = plt.figure(figsize=(10, 10),dpi=1200)
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                     nrows_ncols=(1, 3),  # creates 2x2 grid of axes
                     axes_pad=0.4,  # pad between axes in inch.
                     )

    for ax, im in zip(grid, [npimg1, npimg2, npimg3]):
        # Iterating over the grid returns the Axes.
        ax.imshow(im)
        ax.axis('off')
        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.axis("off")
    plt.savefig(file_name, format="svg")
    plt.show()

    

## Configs_input

In [8]:
#configres_path = 'graycube_imres_test.yaml'
configres_path = 'sbe_imtest_single.yaml'

#configs
configres = load_config(configres_path)

c_dim = configres['dvae']['c_dim']
out_res_name = configres['test']['out_name']
batch_size = configres['test']['batch_size']
checkpoint_res_dir = path.join(out_res_name, 'chkpts')


#c_dim_wwr = 30
#lim = 3
#lim_d = -3
#ncol = 7

c_meaningful = [12,16,29] ##12:z-direction 16: y-direction, 29:orbit_clockwise

## Load_trained_models

In [9]:
#load_batch_samples
test_dataset = get_dataset(
    name=configres['data']['type'],
    data_dir=configres['data']['test_dir'],
    size=configres['data']['img_size'],
)
test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        num_workers=0,
        shuffle=False, pin_memory=True, sampler=None, drop_last=True
)

ntest = batch_size
ntest = 10

In [10]:
# Logger
checkpoint_res_io = CheckpointIO(
    checkpoint_dir=checkpoint_res_dir
)

# Create models
dvae, generator_res, discriminator_res = build_models(configres)
dvae_ckpt_path = os.path.join('outputs', configres['dvae']['runname'], 'chkpts', configres['dvae']['ckptname'])
dvae_ckpt = torch.load(dvae_ckpt_path,map_location=torch.device('cpu'))['model_states']['net']
dvae.load_state_dict(dvae_ckpt)

# Put models on gpu if needed

is_cuda = torch.cuda.is_available()

#device = torch.device("cuda:0" if is_cuda else "cpu")

device = torch.device("cpu")

dvae = dvae.to(device)
generator_res = generator_res.to(device)
discriminator_res = discriminator_res.to(device)


# Use multiple GPUs if possible
generator_res = nn.DataParallel(generator_res)
discriminator_res = nn.DataParallel(discriminator_res)


# Register modules to checkpoint
checkpoint_res_io.register_modules(
    generator=generator_res,
    discriminator=discriminator_res,
)


# Test generator
generator_res_test = generator_res

# Distributions
zdist = get_zdist(configres['z_dist']['type'], configres['z_dist']['dim'],
                  device=device)

# Load checkpoint if existant
it_res = checkpoint_res_io.load('model_00499999.pt')

=> Loading checkpoint...


In [11]:
def wwr_traversing(x_real_test,c_dim_wwr = 30, lim = 1, lim_d=-2.6, ncol=4):

    x_real_test_batch_small = x_real_test[:,:,:]

    x_real_shift = x_real_test_batch_small.add(1).div(2)

    if x_real_shift.size(1) == 1:
        x_real_disc = discretize_to_order_labels(x_real_shift)
        x_real_onehot = get_one_hot(x_real_disc, 5)
        x_real_onehot = x_real_onehot.to(device)

        c, c_mu, c_logvar = cs = dvae(x_real_onehot, encode_only=True)

    else:
        x_real_shift = x_real_shift.to(device)
        c, c_mu, c_logvar = cs = dvae(x_real_shift, encode_only=True)

    interpolation = torch.linspace(lim_d, lim, ncol)

    idganres_samples_p = []
    dvae_samples_p = []

    z = zdist.sample((batch_size,))

    for i in range(x_real_shift.size(0)):

        c_ = c_mu[i:i+1]
        z_ = z[i:i+1]
        c_zero = torch.zeros_like(c_)

        for val in interpolation:
            c_p = c_
            c_p[:, c_dim_wwr] = val

            #c_zero[:, c_dim_wwr] = val
            #c_p = c_ + c_zero
            z_p_ = torch.cat([z_, c_p], 1)

            idganres_sample_p = generator_postprocess(generator_res(z_p_)).data.cpu()
            idganres_samples_p.append(idganres_sample_p)

            dvae_sample_p = decoder_postprocess(dvae(c=c_p, decode_only=True)).data.cpu()
            dvae_samples_p.append(dvae_sample_p)


    idganres_samples_p = torch.cat(idganres_samples_p, dim=0)

    dvae_samples_p = torch.cat(dvae_samples_p, dim=0)

    return x_real_disc/4, idganres_samples_p,dvae_samples_p

In [12]:
def wwr_traversing_gif(folder, x_real_test, c_dim_wwr = 30, lim = 1, lim_d=-2.6, ncol=12,fps=3):
    
    target_folder = folder+ '/img/'

    x_real_test_batch_small = x_real_test[:,:,:]

    x_real_shift = x_real_test_batch_small.add(1).div(2)

    if x_real_shift.size(1) == 1:
        x_real_disc = discretize_to_order_labels(x_real_shift)
        x_real_onehot = get_one_hot(x_real_disc, 5)
        x_real_onehot = x_real_onehot.to(device)

        c, c_mu, c_logvar = cs = dvae(x_real_onehot, encode_only=True)

    else:
        x_real_shift = x_real_shift.to(device)
        c, c_mu, c_logvar = cs = dvae(x_real_shift, encode_only=True)

    interpolation = torch.linspace(lim_d, lim, ncol)


    z = zdist.sample((batch_size,))
    ix = 0

    for i in range(x_real_shift.size(0)):

        c_ = c_mu[i:i+1]
        z_ = z[i:i+1]
        c_zero = torch.zeros_like(c_)

        for val in interpolation:
            c_p = c_
            c_p[:, c_dim_wwr] = val

            #c_zero[:, c_dim_wwr] = val
            #c_p = c_ + c_zero
            z_p_ = torch.cat([z_, c_p], 1)

            idganres_sample_p = generator_postprocess(generator_res(z_p_)).data.cpu().squeeze(0)
            #print(idganres_sample_p.size())
            x_gan = Image.fromarray(cubemap_back_to_fisheye(idganres_sample_p*4, n_channels = 5))
            gan_file = target_folder+str(ix)+'.PNG'
            x_gan.save(gan_file)
            
            ix+=1
            
    img_names_gan = [target_folder+str(i)+'.PNG' for i in range(ncol)]
    #img_names.reverse()
    clip_gan = ImageSequenceClip(img_names_gan,fps=fps)

    gan_gif_file = folder+'/'+str(folder)+'_img.gif'

    clip_gan.write_gif(gan_gif_file)

## WWR

In [13]:
#x_real_test_batch = utils.get_nsamples(test_loader, ntest)
x_real_test_batch=load_simple_im('zh_graycube/zh_graycube__91.PNG')
x_real_test_batch.shape

torch.Size([1, 1, 128, 128])

In [14]:
cd ..

C:\Users\zhang


In [15]:
cd att_feat_processing/gif

C:\Users\zhang\att_feat_processing\gif


In [16]:
##12:z-direction 16: y-direction, 29:orbit_clockwise
dim_list = [12,16,30]
dim_ix = 2
ncol=3
lim_d = 0.64
lim = -1.5
trav_dim = dim_list[dim_ix]
wwr_traversing_gif('wwr_30',x_real_test_batch, c_dim_wwr = trav_dim, lim_d = lim_d, lim=lim, ncol=12,fps=3)

#make_grid_fisheye_with_plt('sbe_im_91.svg',x_gt*4,x_gan,x_vae,nrows1=1,nrows2=ncol,nrows3=ncol,batch_size1=1,batch_size2=ncol*1,batch_size3=ncol*1)

MoviePy - Building file wwr_30/wwr_30_img.gif with imageio.


                                                   