## LICENSE: GPL 3.0
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.\
Eigendecomposition-based style transfer tool\
Copyright (c)

Honda Research Institute Europe GmbH\
Authors: Timo Friedrich\
Contact: timo.friedrich@honda-ri.de\

# Eigendecomposition-based Style Transfer
This notebook provides the tool and source code to reproduce the results from the following publication:

- Friedrich, T., Schmitt, S., & Menzel, S. (2020). Rapid Creation of Vehicle Line-Ups By Eigenspace Projections for Style Transfer. Proceedings of the Design Society: DESIGN Conference, 1, 867–876. https://doi.org/10.1017/dsd.2020.162


---
## Data
First, download the dataset by Umetani and unzip the content to the folder `./EigendecompositionBasedStyleTransfer/data.`\
Link: https://cgenglab.github.io/labpage/en/publication/sigga17tb_mlcarshape/

---
## Environment
Create a compatible Conda Environment:

```
conda create --name EigenEnv --file requirements.txt
conda activate EigenEnv
```

---
## Info
**Please execute all cells in the notebook to run the style transfer tool.**\
The GUI will be created in the next Code cell.\
This can take some time if it is executed for the first time.

The notebook works best im Chromium or Chrome. We observed issues regarding the ipyvolume figures with Firefox.

# Style Transfer Tool GUI

In [None]:
import ipywidgets as widgets
from IPython.display import display
toolarea = widgets.VBox()
display(toolarea)

# Imports and Configuration

In [None]:
import sys
import os
import numpy as np 
import networkx as nx
from bqplot import pyplot as plt
import bqplot
import scipy.interpolate
import ipywidgets as widgets
from IPython.display import display

import trimesh
import ipyvolume as ipv
import pymesh
import copy
import pickle

def defview():
    ipv.squarelim()
    ipv.xyzlim(-0.5, 0.5)
    ipv.ylim(0,1)
    ipv.style.box_off()
    ipv.style.axes_off()
    ipv.view(azimuth=145, elevation=15, distance=5)

In [None]:
# Data configuration
model_file = []
model_file.append('./data/cubeheightobj/fe78ad3863e25cb3253c03b7df20edd5.obj')
model_file.append('./data/cubeheightobj/1c53bc6a3992b0843677ee89898ae463.obj')
model_file.append('./data/cubeheightobj/1d4066f8ae88a1ebec8ca19d7516cb42.obj')
model_file.append('./data/cubeheightobj/ba7a583467ff8aee8cfac9da0ff28f62.obj') 
model_file.append('./data/cubeheightobj/1e54527efa629c37a047cd0a07d473f1.obj') 
model_file.append('./data/cubeheightobj/1aef0af3cdafb118c6a40bdf315062da.obj')                    

fnames = [(f'Car{n}', n) for n,m in enumerate(model_file)]

#temp folder for Eigenvector storage 
tempfolder = './temp/'

# Index of initial models to be shown
i0 = 0
i1 = 1

# Calculations

In [None]:
# This function calculates the Eigenvectors of the 3D shapes.
# This process is time consuming, thus the intermediate results get stored in the ./temp folder
# for repeated usage.


def doCalc(pathmesh, log=True):
    '''
    Returns several geometric features for the given 3D mesh file.

            Parameters:
                    pathmesh (str): Path to the file containing the 3D mesh. E.g. *.off, *.obj or *.stl
                    log (bool): Switch for additional calculation information

            Returns:
                    M           (object):   mesh object
                    coords      (np.array): vertex coordinates
                    LL          (np.array): Laplacian matrix
                    Evals       (np.array): Eigenvalues
                    Evecs       (np.array): Eigenvectors
                    coeffs      (np.array): Coefficients
                    sorti_coeff (np.array): sort index array based on coefficients
                    sorti_eval  (np.array): sort index array based on Eigenvalues
    '''
    M = pymesh.load_mesh(pathmesh)
    if M.faces.shape[1] == 4:
        M = pymesh.quad_to_tri(M)

    coords = np.array(M.vertices)
    if log: print(f'coords.shape: {coords.shape}')
    
    filename = os.path.basename(pathmesh)  
    tempfile = os.path.join(tempfolder, f"{filename}.pkl")
    
    if os.path.exists(tempfile):
        try:
            if log: print(f"{tempfile} found. Loading LL and EVs ...")
            [LL, Evals, Evecs] = pickle.load( open( tempfile, "rb" ) )
        except:
            assert(f"Loading failed. Please delete the temp file and run all cells again.\n {tempfile}")    
    else:
        print("LL, EVs not computed for this file yet. Compute now ...")
        # time consuming EV calculation
        assembler = pymesh.Assembler(M)
        LL = assembler.assemble("laplacian").toarray()
        Evals ,Evecs = np.linalg.eigh(LL) 
        if log: print(f'LL.shape: {LL.shape}')
        if log: print(f'Evecs.shape: {Evecs.shape}')
        
        pickle.dump( [LL, Evals, Evecs], open( tempfile, "wb" ) )
        print(f"Stored in {tempfile}")
        
    coeffs = np.matmul(np.transpose(Evecs),coords)
    if log: print(f'coeffs.shape: {coeffs.shape}')

    # by norm coeff
    sorti_coeff = np.argsort(np.linalg.norm(coeffs, axis=1))[::-1]

    # by eigenvalue
    sorti_eval = np.argsort(Evals)
    
    return M, coords, LL, Evals, Evecs, coeffs, sorti_coeff, sorti_eval


M0, coords0, LL0, Evals0, Evecs0, coeffs0, sorti0_coeff, sorti0_eval = doCalc(model_file[i0])
M1, coords1, LL1, Evals1, Evecs1, coeffs1, sorti1_coeff, sorti1_eval = doCalc(model_file[i1])

# GUI
GUI elements code and callback functions.

In [None]:
# This cell creates the GUI and provide the necessary callback functions.

newCoords = coords0 # workaround in order to have proper scaled plot

#initial plot
n = len(Evals0)
range0 = [0,60]
range1 = [60,2000]

finter = scipy.interpolate.interp1d

cpx0 = [0,10, 15, 100, 1000, 4000, n-1]
cpy0 = [1, 1, 0.5, 0 ,0, 0, 0]
finter0 = finter(cpx0, cpy0)
px0 = range(n)
py0 = finter0(px0)

cpx1 = [0,10, 15, 100, 1000, 4000, n-1]
cpy1 = [0, 0, 0.5, 1 ,1, 0, 0]
finter1 = finter(cpx1, cpy1)
px1 = range(n)
py1 = finter1(px0)


# update function
def update_plot_mix(change):
    global newCoords, h_s_mix, h_tri_mix 
    newCoords = np.zeros_like(coords0)
    
    if rb_sort.value == 'Eigenvalues':
        sorti0 = sorti0_eval
        sorti1 = sorti1_eval
    elif rb_sort.value == 'norm(Coeff)':
        sorti0 = sorti0_coeff
        sorti1 = sorti1_coeff
    
    # In the following, the two models provide their content and style features 
    # resulting in new vertex coordinates newCoords.
    
    # model0 components
    if cb0.value:  
        
        if rb0.value == 'line':
            for p in px0:
                if py0[p]:
                    newCoords += np.matmul( Evecs0[:, sorti0[p]:sorti0[p]+1 ] ,
                                        coeffs0[sorti0[p]:sorti0[p]+1 , :]) * py0[p]
        else:
            newCoords += np.matmul( Evecs0[:, sorti0[slider0.value[0]:slider0.value[1]] ],
                               coeffs0[sorti0[slider0.value[0]:slider0.value[1]], :] )        
        
    # model1 components   
    if cb1.value:
        if rb1.value == 'line':
            for p in px1:
                if py1[p]:
                    newCoords += np.matmul( Evecs1[:, sorti1[p]:sorti1[p]+1 ] ,
                                        coeffs1[sorti1[p]:sorti1[p]+1 , :]) * py1[p]
        else:
            newCoords += np.matmul( Evecs1[:, sorti1[slider1.value[0]:slider1.value[1]] ],
                           coeffs1[sorti1[slider1.value[0]:slider1.value[1]], :] )
        
    # Updates vertex coordinates of the generated mesh
    x, y, z = np.array(newCoords).T
    h_tri_mix.x, h_tri_mix.y, h_tri_mix.z = x, y, z

# CALLBACK functions    
    
def cb_dragcp0(self, target):
    cb_dragcp(self,target)
    update_a0line()
    update_plot_mix(None)
    
def cb_dragcp1(self, target):
    cb_dragcp(self,target)
    update_a1line()
    update_plot_mix(None)
    
def cb_dragcp(self, target):
    # check invalidity
    tempx = copy.copy(self.x)
    tempy = copy.copy(self.y)
    
    if target['point']['x'] <= 0 or target['point']['x'] >= n-1:
        if not (target['index'] == 0 or target['index'] == len(self.x)- 1):
            print('delete', target['index'])
            tempx = np.delete(tempx, target['index'])
            tempy = np.delete(tempy, target['index'])

    
    tempx[0], tempx[-1] = 0, n-1
    self.x = tempx
    self.y = tempy

def update_a0line():
    global py0, finter0, a0line, a0cp
    finter0 = finter(a0cp.x, a0cp.y)
    py0 = finter0(px0)
    a0line.x = px0
    a0line.y = py0 
    
def update_a1line():
    global py1, finter1, a1line, a1cp
    finter1 = finter(a1cp.x, a1cp.y)
    py1 = finter1(px1)
    a1line.x = px1
    a1line.y = py1
    
def cb_click_a0line(self, target):
    global a0cp
    tempx = np.append(a0cp.x, target['data']['x'])
    tempy = np.append(a0cp.y, target['data']['y'])
    sort = np.argsort(tempx)
    tempx = tempx[sort]
    tempy = tempy[sort]
    update_a0line()
    update_plot_mix(None)
    
def cb_click_a1line(self, target):
    global a1cp
    tempx = np.append(a1cp.x, target['data']['x'])
    tempy = np.append(a1cp.y, target['data']['y'])
    sort = np.argsort(tempx)
    tempx = tempx[sort]
    tempy = tempy[sort]
    update_a1line()
    update_plot_mix(None)
    
def update_model0(self):
    global tri0, M0, coords0, LL0, Evals0, Evecs0, coeffs0, sorti0_coeff, sorti0_eval
    M0, coords0, LL0, Evals0, Evecs0, coeffs0, sorti0_coeff, sorti0_eval = doCalc(model_file[select0.value], log=False)
    update_plot_mix(None)
    x, y, z = np.array(coords0).T
    tri0.x, tri0.y, tri0.z = x, y, z
    
def update_model1(self):
    global tri1, M1, coords1, LL1, Evals1, Evecs1, coeffs1, sorti1_coeff, sorti1_eval
    M1, coords1, LL1, Evals1, Evecs1, coeffs1, sorti1_coeff, sorti1_eval = doCalc(model_file[select1.value], log=False)
    update_plot_mix(None)
    x, y, z = np.array(coords1).T
    tri1.x, tri1.y, tri1.z = x, y, z

def resetView(self):
    ipv.figure(f_mix)
    defview()
    ipv.figure(f0)
    defview()
    ipv.figure(f1)
    defview()
     
# GUI elements   
    
# model0
f0 = ipv.figure(width=400, height=300, lighting=True)
tri0 = ipv.plot_trisurf(*M0.vertices.T, triangles=M0.faces)

# model1
f1 = ipv.figure(width=400, height=300, lighting=False)
tri1 = ipv.plot_trisurf(*M1.vertices.T, triangles=M1.faces)

# setup mixing figure
x, y, z = np.array(newCoords).T
f_mix = ipv.figure(width=800, height=400)
h_tri_mix = ipv.plot_trisurf(x, y, z, triangles=M0.faces)


slider0 = widgets.IntRangeSlider(
    max=len(Evals0), 
    min=0, 
    layout=widgets.Layout(width='98%'),
    value=range0, 
    description='Range: ',
    continuous_update=False
)
slider1 = widgets.IntRangeSlider(
    max=len(Evals1), 
    min=0, 
    layout=widgets.Layout(width='98%'),
    value=range1, 
    description='Range: ',
    continuous_update=False
)
cb0 = widgets.Checkbox(
    value=True,
    description='Active',
    disabled=False
)
rb0 = widgets.RadioButtons(
    options=['line', 'range'],
    value='range',
    description='Selection',
    disabled=False
)
cb1 = widgets.Checkbox(
    value=True,
    description='Active',
    disabled=False
)
rb1 = widgets.RadioButtons(
    options=['line', 'range'],
    value='range',
    description='Selection',
    disabled=False
)
rb_sort = widgets.RadioButtons(
    options=[ 'Eigenvalues', 'norm(Coeff)'],
    value='Eigenvalues',
    description='Sorting by:',
    disabled=False,
)
l0 = widgets.Label(value="Model0 (Content)")
l1 = widgets.Label(value="Model1 (Style)")


# selector figures
fsel0 = plt.figure()
fsel0.layout.height = '100px'
fsel0.layout.width = '95%'
fsel0.fig_margin = dict(top=0, bottom=20, left=20, right=0)
x_sc = bqplot.LinearScale(min=0, max=n)
y_sc = bqplot.LinearScale(min=0, max=2)

a0line = plt.plot(px0, py0, '.r', scales={'x': x_sc, 'y': y_sc})
a0line.on_element_click(cb_click_a0line)
a0cp = plt.scatter(cpx0, cpy0, enable_move=True, scales={'x': x_sc, 'y': y_sc})
a0cp.on_drag_end(cb_dragcp0)

tb0 = plt.Toolbar(figure=fsel0)
bsel0 = widgets.VBox([fsel0, tb0])
bsel0.layout.width = "100%"


# selector figures
fsel1 = plt.figure()
fsel1.layout.height = '100px'
fsel1.layout.width = '95%'
fsel1.fig_margin = dict(top=0, bottom=20, left=20, right=0)
x_sc = bqplot.LinearScale(min=0, max=n)
y_sc = bqplot.LinearScale(min=0, max=2)

a1line = plt.plot(px1, py1, '.r', scales={'x': x_sc, 'y': y_sc})
a1line.on_element_click(cb_click_a1line)
a1cp = plt.scatter(cpx1, cpy1, enable_move=True, scales={'x': x_sc, 'y': y_sc})
a1cp.on_drag_end(cb_dragcp1)

tb1 = plt.Toolbar(figure=fsel1)
bsel1 = widgets.VBox([fsel1, tb1])
bsel1.layout.width = "100%"

# selector model files
select0 = widgets.Dropdown(
    options=fnames,
    value=i0,
    # rows=10,
    description='Model0 \n(Content)',
    disabled=False,
    layout=widgets.Layout(width='50%')
)
select1 = widgets.Dropdown(
    options=fnames,
    value=i1,
    # rows=10,
    description='Model1 \n(Style)',
    disabled=False,
    layout=widgets.Layout(width='49%')
)
selectbox = widgets.HBox([select0, select1], layout=widgets.Layout(width='98%', height='100%'))

# reset button view
buttonview = widgets.Button(
    description='Reset View',
    disabled=False,
)
buttonview.on_click(resetView)

# final composition
originalbox = widgets.HBox([f0, f1], layout=widgets.Layout(width='100%', height='100%'))
vbox00 = widgets.VBox([l0, cb0, rb0])
vbox01 = widgets.VBox([bsel0, slider0], layout=widgets.Layout(width='100%', height='100%'))
box0 = widgets.HBox([vbox00, vbox01], layout=widgets.Layout(width='100%', height='100%'))

vbox10 = widgets.VBox([l1, cb1, rb1])
vbox11 = widgets.VBox([bsel1, slider1], layout=widgets.Layout(width='100%', height='100%'))
box1 = widgets.HBox([vbox10, vbox11], layout=widgets.Layout(width='100%', height='100%'))

toolarea.children = ( rb_sort, box0,  box1, buttonview, f_mix, originalbox, selectbox)

# register callbacks
slider0.observe(update_plot_mix, names='value')
slider1.observe(update_plot_mix, names='value')
cb0.observe(update_plot_mix, names='value')
cb1.observe(update_plot_mix, names='value')
rb0.observe(update_plot_mix, names='value')
rb1.observe(update_plot_mix, names='value')
rb_sort.observe(update_plot_mix, names='value')
select0.observe(update_model0, names='value')
select1.observe(update_model1, names='value')

#sync 3D plots
widgets.jslink((f0.camera, "position"),(f1.camera, "position"))
widgets.jslink((f0.camera, "rotation"),(f1.camera, "rotation"))
widgets.jslink((f0.camera, "position"),(f_mix.camera, "position"))
widgets.jslink((f0.camera, "rotation"),(f_mix.camera, "rotation"))
update_plot_mix(None)
defview()
ipv.figure(f0)
defview()
ipv.figure(f1)
defview()

# Extra
In the following we provide some helpful functions.

## Plot all Models

In [None]:
# for p in model_file:
#     M = pymesh.load_mesh(p)
#     if M.faces.shape[1] == 4:
#         M = pymesh.quad_to_tri(M)
#     xp, yp, zp = M.vertices.T
#     f = ipv.figure(width=800, height=600, )
#     ipv.plot_trisurf(xp, yp, zp, triangles=M.faces)
#     ipv.show()
#     defview()

## Do a simple average mixing

In [None]:
# k0 = pymesh.load_mesh(model_file[0])
# k1 = pymesh.load_mesh(model_file[1])

# k0 = pymesh.quad_to_tri(k0)
# k1 = pymesh.quad_to_tri(k1)

# xp, yp, zp = (k0.vertices.T + k1.vertices.T) / 2

# f = ipv.figure(width=800, height=600, )
# ipv.plot_trisurf(xp, yp, zp, triangles=k0.faces)
# ipv.show()
# defview()

## Save Models as *.stl files

In [None]:
# pymesh.save_mesh('model0.stl', M0)

In [None]:
# pymesh.save_mesh('model1.stl', M1)

In [None]:
# Mresult = pymesh.form_mesh(newCoords, M0.faces)
# pymesh.save_mesh('result.stl', Mresult)

In [None]:
# Maverage = pymesh.form_mesh(np.stack([xp, yp, zp], axis=1), M0.faces)
# pymesh.save_mesh('average.stl', Maverage)