In [60]:
import sys
sys.path.insert(0, '../')

import numpy as np
from utils.Config import Config 
from registrationNN.models import NNModel, model_visualizer
from utils.ObjectUtil import ObjectUtil
from sklearn.model_selection import train_test_split
import os
from munch import Munch
import time
import random
import json 
import sys
import matplotlib.pyplot as plt
from utils.RegistrationUtils import RegistrationUtils

In [61]:
%load_ext autoreload

%autoreload 2
%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [62]:
exp_id = 'after-decomposition-penalty'

model_config = Config.default_model_config(exp_id)
model_config.learning_rate = 5e-4
model_config.re_sampling = 120
model_config.n_files = 15
model_config.k_select = 5
model_config.epochs = 2
model_config.comment = 'no penalty on the movements'
model_config.select_only_matched = True
model_config.obj_accepted_labels = ['Triangle', 'Circle', 'Star']
model_config.redirect_out = False

model_config.load = True
model_config.load_ckpt = True
model_config.save = True
model_config.save_ckpt = True
model_config.vis_transformation = False
model_config.num_vis_samples = 5

print(f"[RegisterationMLP.py] {time.ctime()}: Expermint {model_config.exp_id} started")

org_objs, tar_objs = [], []
objs, labels = ObjectUtil.extract_objects_from_directory(model_config.dataset_path,
                                                        n_files=model_config.n_files,
                                                        acceptable_labels=model_config.obj_accepted_labels,
                                                        re_sampling = model_config.re_sampling)
labels, objs = np.asarray(labels), np.asarray(objs)

# validate that objects are distincts 
tot = 0
for obj1 in objs:
    for obj2 in objs:
        if obj1 == obj2:
            tot += 1

print(len(objs), tot)

random.seed(model_config.seed)
for obj, lbl in zip(objs, labels):
    if model_config.select_only_matched:
        matched_objs = objs[labels == lbl] # TODO test with non-matched objects
    else:
        matched_objs = objs

    # choose k random matched objects
    matched_objs = random.choices(matched_objs, k=model_config.k_select)

    for obj2 in matched_objs:
        org_objs.append(obj)
        tar_objs.append(obj2)


# split train test
train_org_sketches, val_org_sketches, train_tar_sketches, val_tar_sketches = train_test_split(org_objs, tar_objs, random_state=model_config.seed, test_size=0.2)



[RegisterationMLP.py] Fri Aug  6 20:04:49 2021: Expermint after-decomposition-penalty started


KeyboardInterrupt: 

In [63]:
obj = train_org_sketches[10]

In [64]:
fig2, ax2 = plt.subplots()
obj.visualize(show=False, ax=ax2)

<IPython.core.display.Javascript object>

In [9]:
len(obj)

120

In [65]:
# scale obj1 
tmp = ObjectUtil.poly_to_accumulative_stroke3([obj], red_rdp=False)
tmp = RegistrationUtils.pad_sketches(tmp, maxlen=128)
tmp = np.expand_dims(tmp, axis=-1)
obj2 = ObjectUtil.accumalitive_stroke3_to_poly(tmp)[0]
obj.visualize(show=False, ax=ax2)

In [66]:
# prepare model
model = NNModel(model_config)

[model.py] loading saved model of experiment after-decomposition-penalty


In [67]:
c = 10
org_obj, tar_obj = train_org_sketches[c].get_copy(), train_tar_sketches[c].get_copy()

In [68]:
fig2, ax2 = plt.subplots()
org_obj.visualize(show=False, ax=ax2)
tar_obj.visualize(show=False, ax=ax2)

<IPython.core.display.Javascript object>

In [69]:
len(org_obj), len(tar_obj)

(120, 120)

In [70]:
params, losses = model.predict([org_obj], [tar_obj])

In [71]:
params, losses

(array([[ 0.81585073, -1.9841704 ,  1.5226567 , -0.16904697, -0.00757917,
         -3.6649334 ,  1.9507761 ]], dtype=float32),
 array([1.99087973]))

In [72]:
RegistrationUtils.obtain_transformation_matrix(params[0])

array([ 0.04543591,  1.98187184, -3.66493344,  0.81460804, -0.09548037,
        1.9507761 ])

In [116]:
# scale obj1 
tmp = ObjectUtil.poly_to_accumulative_stroke3([org_obj])
tmp = RegistrationUtils.pad_sketches(tmp, maxlen=128)
tmp = np.expand_dims(tmp, axis=-1)
obj1_2 = ObjectUtil.accumalitive_stroke3_to_poly(tmp)[0]

# scale obj2
tmp = ObjectUtil.poly_to_accumulative_stroke3([tar_obj])
tmp = RegistrationUtils.pad_sketches(tmp, maxlen=128)
tmp = np.expand_dims(tmp, axis=-1)
obj2_2 = ObjectUtil.accumalitive_stroke3_to_poly(tmp)[0]

fig2, ax2 = plt.subplots()
obj1_2.visualize(show=False, ax=ax2)
obj2_2.visualize(show=False, ax=ax2)

p = params[0]
t = RegistrationUtils.obtain_transformation_matrix(p)
obj1_2.transform(t, object_min_origin=False)
obj1_2.visualize(show=False, ax=ax2)

<IPython.core.display.Javascript object>

In [117]:
t = RegistrationUtils.obtain_transformation_matrix(params[0])
t_denormalized = ObjectUtil.denormalized_transformation(org_obj, tar_obj, t)

fig2, ax2 = plt.subplots()
org_obj.visualize(show=False, ax=ax2)
tar_obj.visualize(show=False, ax=ax2)
# obj2.transform(t_denormalized)
# obj2.visualize(show=False, ax=ax2)

0.6813139072710953


<IPython.core.display.Javascript object>

In [119]:
tmp_obj = org_obj.get_copy()
tmp_obj.transform(t_denormalized, object_min_origin=True)
tmp_obj.visualize(show=False, ax=ax2)

In [120]:
# test sequential 

In [140]:
p_decomposed = RegistrationUtils.decompose_tranformation_matrix(t)

obj1_seq = org_obj.get_copy()

p_denormalized = RegistrationUtils.decompose_tranformation_matrix(t_denormalized)

seq_params = RegistrationUtils.get_seq_translation_matrices(p_denormalized)

fig2, ax2 = plt.subplots()
obj1_seq.visualize(show=False, ax=ax2)
tar_obj.visualize(show=False, ax=ax2)
# obj2.transform(t_denormalized)
# obj2.visualize(show=False, ax=ax2)




<IPython.core.display.Javascript object>

In [141]:
i = 0
t1 = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]

In [146]:
tp = np.zeros(6)
t = seq_params[i]
tp = np.zeros(6)
print(t)
tp[0] = t[0] * t1[0] + t[1] * t1[3]
tp[1] = t[0] * t1[1] + t[1] * t1[4]
tp[2] = t[2] + t1[2]
tp[3] = t[3] * t1[0] + t[4] * t1[3]
tp[4] = t[3] * t1[1] + t[4] * t1[4]
tp[5] = t[5] + t1[5]
t1 = tp
if i == 4:
    obj1_seq.transform(t, object_min_origin=True, retain_origin=False)
else:
    obj1_seq.transform(t, object_min_origin=True, retain_origin=True)
obj1_seq.visualize(show=False, ax=ax2)
i += 1

[1.0, 0.0, 1759.5345522360444, 0.0, 1.0, 621.669447253921]


In [147]:
t1, t_denormalized, RegistrationUtils.obtain_transformation_matrix(p_denormalized) #should all be equal

(array([ 3.09561143e-02,  1.35027685e+00,  1.75953455e+03,  5.55003785e-01,
        -6.50521074e-02,  6.21669447e+02]),
 array([ 3.09561143e-02,  1.35027685e+00,  1.75953455e+03,  5.55003785e-01,
        -6.50521074e-02,  6.21669447e+02]),
 array([ 3.09561143e-02,  1.35027685e+00,  1.75953455e+03,  5.55003785e-01,
        -6.50521074e-02,  6.21669447e+02]))

# animate transformation

In [None]:
obj1_seq.reset()
animation = SketchAnimation([obj1_seq], [tar_obj]) 
animation.seq_animate_all([params[0]], 
                         denormalize_trans=True,
                         save=model_config.save_transformation_vis, 
                         file=os.path.join(vis_dir, f'example_{i}.mp4')) 