In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import gridspec
import argparse, os, pickle, glob, h5py
import numpy as np
from hangul.ml import LeaveOneFontOutCV
from torch.utils.data import DataLoader, TensorDataset
import torch
from pathlib import Path
from scripts.functions import *
from hangul.label_mapping import imf2idx, idx2imf
from hangul import style
from scripts.reimp import ReImp
%load_ext autoreload
%autoreload 2
device = 'cuda:1'

In [None]:
def traverse(model, h_dim, kl_indexes, pixels, trav_path, number, si, low, high, trav_steps, mean_sample, seed, fold, pix, special=False, spec_imf=''):
    with torch.no_grad():
        pixels = torch.tensor([pixels] * h_dim).unsqueeze(1).float()
        x, m_indices, sizes = model.encode(pixels)
        mu, logvar = x[:, :h_dim], x[:, h_dim:]
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        actual, _, _ = model.forward(pixels)
        actual = actual.view(-1, *si)
        pixels = pixels[kl_indexes]
        actual = actual[kl_indexes]
        indices = torch.tensor(np.linspace(low, high, trav_steps))
        image = torch.tensor([])
        if mean_sample == 'mean':
            for i in range(len(indices)):
                mu_copy = mu.clone()
                step = indices[i]
                for j in range(h_dim):
                    mu_copy[j, j] = step[j]

                sample = mu_copy + std * eps
                sample = model.decode(sample, m_indices.copy(), sizes.copy())
                sample = sample.view(-1, *si)
                sample = sample[kl_indexes]
                image = torch.cat((image, sample))
        elif mean_sample == 'sample':
            for i in range(len(indices)):
                step = indices[i]
                sample = mu + std * eps
                for j in range(h_dim):
                    sample[j, j] = step[j]
                sample = model.decode(sample, m_indices.copy(), sizes.copy())
                
                sample = sample.view(-1, *si)
                sample = sample[kl_indexes]
                image = torch.cat((image, sample))
        else:
            raise ValueError
            
        both = torch.cat((pixels.view(-1, *si),
                          actual.view(-1, *si),
                          image.view(-1, *si)))
        if special:
            save_image(both, os.path.join(trav_path,
                                          f"special_{spec_imf}_{seed}_{number}_{fold}_{mean_sample}_{high[0]}_{pix}_traversal.png"),
                       nrow=h_dim)
        else:
            save_image(both, os.path.join(trav_path,
                                          f"{seed}_{number}_{fold}_{mean_sample}_{high[0]}_{pix}_traversal.png"),
                       nrow=h_dim)

In [None]:
def col_plotter(ex, seed, fold, img_num, col, pix, special, trav_range, spec_imf=''):
    fig, ax = plt.subplots(figsize=(20, 25))
    final = []
    im_dir = f'trav_short/{ex}_{seed}'
    for i in range(len(pix)):
        if special:
            im_path = f"{im_dir}/special_{spec_imf}_{seed}_{img_num}_{fold}_mean_{trav_range}_{pix[i]}_traversal.png"
        elif pix[i] == None:
            im_path = f"{im_dir}/{seed}_{img_num}_{fold}_mean_traversal.png"
        else:
            im_path = f"{im_dir}/{seed}_{img_num}_{fold}_mean_{trav_range}_{pix[i]}_traversal.png"
        img = mpimg.imread(im_path)
        final.append(img[:, (col-1)*31:(col*31)])
    final = np.concatenate(final, axis=1)
    ax.imshow(final)
    plt.xticks(np.arange(len(pix))*31, labels=np.arange(len(pix)))
    plt.yticks(np.arange(22)*31, labels=np.arange(22))
    return final

In [None]:
ex = 0
seed = 0
fold = 0
img_num = 0
col = 0
imf = 'initial'
im_dir = f'trav_short/{ex}_{seed}'
os.makedirs(im_dir, exist_ok=True)

In [None]:
locs = os.walk('') # dataset location
h5_files = []
for base, _, fnames in locs:
    for f in fnames:
        if '{}.h5'.format(24) in f:
            h5_files.append(os.path.join(base, f))
ds = LeaveOneFontOutCV(h5_files, 0, mean_center=False, imf='i')
X_train, y = ds.training_set()
X_valid, yv = ds.validation_set()
X_test, yt = ds.test_set()

ds_train = DataLoader(TensorDataset(*[torch.tensor(t) for t in ds.training_set()]),
                      batch_size=batch_size)
ds_valid = DataLoader(TensorDataset(*[torch.tensor(t) for t in ds.validation_set()]),
                      batch_size=batch_size)
ds_test = DataLoader(TensorDataset(*[torch.tensor(t) for t in ds.test_set()]),
                     batch_size=batch_size)
si = torch.tensor(X_train[0]).unsqueeze(0).size()

In [None]:
Xt = np.concatenate((X_train, X_valid, X_test), axis=0)
ds = torch.cat((ds_train.dataset.tensors[0], ds_valid.dataset.tensors[0], ds_test.dataset.tensors[0]))
print(ds_train.batch_size)

In [None]:
model_id = f'{ex}_{fold}_{seed}'
root = f'' # model location
with open(f'{root}/model_params.pkl', 'rb') as f:
    data = pickle.load(f)
params = data[1]
print(params)
h_dim = params['h_dim']
batch_size = params['batch_size']
nfolds=7
num_fonts = 5
ps = list(Path(root).rglob('c*.pt'))
mod = make_vae_model(data[0], params)
mod.load_state_dict(torch.load(ps[0], map_location=device))
print('loaded')

In [None]:
mod.eval()
base_trav = im_dir
if os.path.isfile(f"{base_trav}/{seed}_{fold}_samples.npz"):
    samples = np.load(f"{base_trav}/{seed}_{fold}_samples.npz")
    indexes = samples['kl_indices']
else:
    low_mean, high_mean, low_sample, high_sample, indexes = find_mean(
        ps[0], base_trav, params, 7, [ds_train, ds_valid, ds_test],
        [X_train, X_valid, X_test], device=device, seed=seed, fold=fold)
indexes

In [None]:
trav_range = 2
trav(ps[0], params, indexes, base_trav, si=si, low=np.array(
    [-trav_range] * h_dim), high=np.array([trav_range] * h_dim), mean_sample='mean',
     trav_steps=20, seed=seed, fold=fold, device=device)

In [None]:
special = True
if special:
    saved_pix = glob.glob(os.path.join(os.getcwd(), f'test_sets_{model_id}_*.pkl'))
else:
    saved_pix = list(glob.glob(os.path.join(os.getcwd(), 'test_sets_font_*.pkl')))
print(saved_pix)

In [None]:
def tra(saved_pix, trav_range, special, spec_imf=''):
    for pixel_root in saved_pix:
        with open(pixel_root, 'rb') as f:
            pixe = pickle.load(f)
            pixels = pixe[0]
        pix = pixel_root.split("_")[-1][:-4]
        if special:
            for i in range(len(pixels)):
                traverse(mod, h_dim, indexes, pixels[i], im_dir,
                        img_num, si=(1, 29, 29), low=[-trav_range]*h_dim, high=[trav_range]*h_dim,
                        trav_steps=20, mean_sample='mean', seed=seed, fold=fold, pix=i, special=True, spec_imf=spec_imf)
        else:
            traverse(mod, h_dim, indexes, pixels[img_num], im_dir,
                        img_num, si=(1, 29, 29), low=[-trav_range]*h_dim, high=[trav_range]*h_dim,
                        trav_steps=20, mean_sample='mean', seed=seed, fold=fold, pix=pix)

In [None]:
zos = np.zeros((29, 29))
ons = np.ones((29, 29))
weird = np.zeros(29*29)
weird[0::2] = 1
weird = weird.reshape((29, 29))
traverse(mod, h_dim, indexes, weird, im_dir,
                        img_num, si=(1, 29, 29), low=[-trav_range]*h_dim, high=[trav_range]*h_dim,
                        trav_steps=20, mean_sample='mean', seed=seed, fold=fold, pix='zeros_ones_inter', special=True)

In [None]:
trav_range = 4
spec_imf = 'f'

if type(saved_pix) == list:
    tra(saved_pix, trav_range, special, spec_imf)
else:
    tra([saved_pix], trav_range, special, spec_imf)

In [None]:
if special:
    blob = col_plotter(ex, seed, fold, img_num, col, np.arange(24), special, trav_range, spec_imf)
else:
    blob = col_plotter(ex, seed, fold, img_num, col, np.arange(len(saved_pix)), special, trav_range)

In [None]:
# For recreating figure from pkl files
# if special:
#     with open(f'good_{imf}_{model_id}_{img_num}_{col}_{trav_range}_special.pkl', 'rb') as f:
#         d = pickle.load(f)
#         sta = d['start']
#         end = d['end']
#         wanted_col = d['wanted_col']
#         img_data = d['img_data']
# else:
# #     with open(f'good_{imf}_{model_id}_{img_num}_{col}_{trav_range}.pkl', 'rb') as f:
# #     with open(f'good_final_5_4_3868_16_4_10.pkl', 'rb') as f:
#         d = pickle.load(f)
#         sta = d['start']
#         end = d['end']
#         wanted_col = d['wanted_col']
#         img_data = d['img_data']
# print(wanted_col)
wanted_col = [0, 1, 11, 23, 8]

In [None]:
pruned = []
for c in wanted_col:
    pruned.append(blob[:, (c)*31:(c+1)*31])
pruned = np.concatenate(pruned, axis=1)
fig, ax = plt.subplots(figsize=(20, 25))
plt.xticks(np.arange(len(wanted_col))*31, labels=np.arange(len(wanted_col)))
plt.yticks(np.arange(22)*31, labels=np.arange(22))
ax.imshow(pruned)

In [None]:
sta = np.array([8, 8, 8, 5, 7])
end = sta + 5
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
            hspace = 0, wspace = 0)
plt.margins(0,0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())

tops = pruned[:31]
sel_prune = np.concatenate([pruned[s*31:e*31, i*31:(i+1)*31] for i, (s, e) in enumerate(zip(sta, end))], axis=1)
print(sel_prune.shape, tops.shape)
line = np.zeros((7, len(sel_prune[0]), 3))
line[3, :] = 1
final_prune = np.concatenate((tops, line, sel_prune), axis=0)
print(final_prune.shape)
plt.imshow(1-final_prune)
if special:
    plt.savefig(f'good_{spec_imf}_{model_id}_{img_num}_{col}_{trav_range}_special.pdf', bbox_inches = 'tight',
    pad_inches = 0, dpi=300)
    with open(f'good_{spec_imf}_{model_id}_{img_num}_{col}_{trav_range}_special.pkl', 'wb') as f:
        pickle.dump({'img_data': final_prune, 'wanted_col': wanted_col, 'start': sta, 'end': end}, f)
else:
    plt.savefig(f'good_{imf}_{model_id}_{img_num}_{col}_{trav_range}.pdf', bbox_inches = 'tight',
    pad_inches = 0, dpi=300)
    with open(f'good_{imf}_{model_id}_{img_num}_{col}_{trav_range}.pkl', 'wb') as f:
        pickle.dump({'img_data': final_prune, 'wanted_col': wanted_col, 'start': sta, 'end': end}, f)

In [None]:
plt.imshow(pruned) 

In [None]:
surround_i = np.concatenate([np.array([(0, -1, 0), (0, 0, -1), (0, -1, -1), 
            (0, 1, 0), (0, 0, 1), (0, 1, 1),  
            (0, 1, -1), (0, -1, 1)]), np.array([(0, -1, 0), (0, 0, -1), (0, -1, -1), 
            (0, 1, 0), (0, 0, 1), (0, 1, 1),  
            (0, 1, -1), (0, -1, 1)])*2, np.array([(0, -1, 0), (0, 0, -1), (0, -1, -1), 
            (0, 1, 0), (0, 0, 1), (0, 1, 1),  
            (0, 1, -1), (0, -1, 1)])*3], axis=0)
surround_f = np.concatenate([np.array([(0, -1, 0), (-1, 0, 0), (-1, -1, 0), 
            (0, 1, 0), (1, 0, 0), (1, 1, 0),  
            (1, -1, 0), (-1, 1, 0)]), np.array([(0, -1, 0), (-1, 0, 0), (-1, -1, 0), 
            (0, 1, 0), (1, 0, 0), (1, 1, 0),  
            (1, -1, 0), (-1, 1, 0)])*2, 
                           np.array([(0, -1, 0), (-1, 0, 0), (-1, -1, 0), 
            (0, 1, 0), (1, 0, 0), (1, 1, 0),  
            (1, -1, 0), (-1, 1, 0)])*3], axis=0)
surround = {'i': surround_i, 'f': surround_f}

In [None]:
data = [imf2idx(sur[1][img_num][0] + x[0], sur[1][img_num][1] + x[1], sur[1][img_num][2] + x[2]) for x in surround[spec_imf]]
if spec_imf == 'i':
    data.append(imf2idx(sur[1][img_num][0], sur[1][img_num][1], 0))
pics = [Xt[i + 11172 * 5] for i in data]
indices = [idx2imf(i) for i in data]
new_data = [pics, indices]
if special:
    with open(os.path.abspath(os.path.join(os.getcwd(), f'test_sets_{model_id}_{img_num}_{spec_imf}.pkl')), 'wb') as f:
        pickle.dump(new_data, f)
else:
    with open(os.path.abspath(os.path.join(os.getcwd(), f'test_sets_{model_id}_{img_num}.pkl')), 'wb') as f:
        pickle.dump(new_data, f)

In [None]:
pdf = convert_from_path(f'good_{imf}_{model_id}_{img_num}_{col}.pdf')
plt.imshow(pdf[0])

In [None]:
image = np.array(pdf[0])
plt.imshow(image)

In [None]:
sel_prune = np.concatenate([pruned[s*31:e*31, i*31:(i+1)*31] for i, (s, e) in enumerate(zip(start, end))], axis=0)

In [None]:
good_img = final_prune

In [None]:
def invert(img, col=np.array([])):
    counter = 38
    nrow = len(img-counter)//31
    top = img[:counter]
    reverse = []
    rest = []
    print(nrow)
    if len(col) != 0:
        for i in range(nrow):
            reverse.append(img[counter:counter+31, col[0]*31:(col[0]+1)*31])
            rest.append(img[counter:counter+31, :col[0]*31])
            counter += 31
        reverse.reverse()
        reverse = np.array(reverse)
        rest = np.concatenate(rest, axis=0)
        reverse = np.concatenate(reverse, axis=0)
        print(reverse.shape, rest.shape, top.shape)
        rev = np.concatenate((rest, reverse), axis=1)
        rev = np.concatenate((top, rev), axis=0)
        print(rev.shape)
    else:
        for i in range(nrow):
            curr = img[counter:counter+31]
            reverse.append(curr)
            counter += 31
        reverse.append(top)
        reverse.reverse()
        rev = np.concatenate(reverse, axis=0)

        
    plt.gca().set_axis_off()
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
                hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.imshow(1-rev)
    plt.savefig(f'good_{imf}_{model_id}_{img_num}_{col}_{trav_range}_rev_special.pdf', bbox_inches = 'tight',
    pad_inches = 0)
    return rev

In [None]:
inv_col = np.array([])
reverse = invert(good_img, col=inv_col)

In [None]:
def select_col(img, cols, rows):
    im = []
    for i in cols:
        im.append(img[:, i*31:(i+1)*31])
    im = np.concatenate(im, axis=1)
    imgs = []
    counter = 38
    top = im[:counter]
    imgs.append(top)
    for i in rows:
        imgs.append(im[counter:counter+31])
        counter += 31
    im = np.concatenate(imgs, axis=0)
    return im

In [None]:
good_img = select_col(good_img, [1, 3, 4, 8], [0, 1, 2, 3, 4, 5])
plt.imshow(good_img)

In [None]:
def add_trav_to(img, additional):
    final = []
    for i in additional:
        i = np.pad(i, pad_width=1, mode='constant', constant_values=0)
        final.append(i)
    final = np.concatenate(final, axis=1)
    final = np.repeat(final[:, :, np.newaxis], 3, axis=2)
    line = np.zeros((7, len(final[1]), 3))
    print(img.shape)
    print(line.shape)
    print(final.shape)
    line[3, :] = 1
    im = np.concatenate((img, line, final), axis=0)
    plt.gca().set_axis_off()
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
                hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.imshow(1-im)
    if special:
        plt.savefig(f'good_{spec_imf}_{model_id}_{img_num}_{col}_{trav_range}_rev_special_special.pdf', bbox_inches = 'tight',
        pad_inches = 0, dpi=300)
        with open(f'good_{spec_imf}_{model_id}_{img_num}_{col}_{trav_range}_rev_special_special.pkl', 'wb') as f:
            pickle.dump({'img_data': 1-im, 'traverse_to': additional}, f)
    else:
        plt.savefig(f'good_{imf}_{model_id}_{img_num}_{col}_{trav_range}_rev_special.pdf', bbox_inches = 'tight',
        pad_inches = 0, dpi=300)
        with open(f'good_{imf}_{model_id}_{img_num}_{col}_{trav_range}_rev_special.pkl', 'wb') as f:
            pickle.dump({'img_data': 1-im, 'traverse_to': additional}, f)

In [None]:
with open(f'good_{imf}_{model_id}_{img_num}_{inv_col}_{trav_range}_rev_special.pkl', 'rb') as f:
    good_info = pickle.load(f)
    trav_to = good_info['traverse_to']
print(len(trav_to))

In [None]:
good_col = wanted_col

In [None]:
one = [Xt[imf2idx(16, 17, 1)+11172*good_col[0]], Xt[imf2idx(15, 18, 1)+11172*good_col[1]],
       Xt[imf2idx(16, 20, 1)+11172*good_col[2]], Xt[imf2idx(12, 0, 1)+11172*good_col[3]],
       Xt[imf2idx(16, 19, 1)+11172*good_col[4]]]

add_trav_to(reverse, one)

In [None]:
def transpose(im):
    result = []
    counter = 0
    nrow = len(im)//31 + 2
    ncol = len(im[0])//31
    for i in range(nrow):
        temp = []
        if i == 1 or i == nrow-2:
            curr = im[counter:counter+7]
            counter += 7
            temp = curr.T
        else:
            curr = im[counter:counter+31]
            counter += 31
            for j in range(ncol):
                temp.append(curr[:, j*31:(j+1)*31])
            temp = np.concatenate(temp, axis=0)
        result.append(temp)
    result = np.concatenate(result, axis=1)
    return result

In [None]:
# Traversal figure generation
# Input traversal special pkl files
data = []
with open('', 'rb') as f:
    good_info = pickle.load(f)
    data.append(np.mean(good_info['img_data'], axis=2))

with open('', 'rb') as f:
    good_info = pickle.load(f)
    data.append(np.mean(good_info['img_data'], axis=2))
    
with open('', 'rb') as f:
    good_info = pickle.load(f)
    data.append(np.mean(good_info['img_data'], axis=2))

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(6, 2))
names = ['Initial Across Fonts', 'Initial Across Blocks', 'Final Across Fonts']
gs = gridspec.GridSpec(1, 3,
         wspace=0.0, hspace=0.0, 
         top=1.-0.5/(1+1), bottom=0.5/(1+1), 
         left=0.5/(3+1), right=1-0.5/(3+1)) 
for i in range(3):
    ax[i].imshow(transpose(data[i]), cmap='gray')
    ax[i].axis('off')
fig.text(0.125, 0.75, 'A', **style.panel_letter_fontstyle)
fig.text(0.4,0.75, 'B', **style.panel_letter_fontstyle)
fig.text(0.675, 0.75, 'C', **style.panel_letter_fontstyle)
fig.savefig(f'bvae_go_traversals.pdf', bbox_inches = 'tight',
        pad_inches = 0, dpi=300)