# Demo of 3D Generative Model Latent Disentanglement via Local Eigenprojection

To run the demo download the `demo_files` folder and copy it in the project directory.

---

### Network initialisation


In [29]:
%matplotlib notebook

import os
import trimesh
import torch
import numpy as np

import utils
from model_manager import ModelManager

# demo_directory = "demo_files"
demo_directory = "./outputs/experiments/302/"
norm_dict_path = "./precomputed/norm_babies_faces.pt"
# print(os.path.join(demo_directory, "config.yaml"))
configurations = utils.get_config(os.path.join(demo_directory, "config.yaml"))

if configurations['model']['age_disentanglement']:

    if configurations['model']['age_per_feature'] == True:
        no_remainder = configurations['model']['latent_size'] % configurations['model']['age_latent_size'] == 0
        correct_value = configurations['model']['latent_size'] // configurations['model']['age_latent_size'] == 5
        assert no_remainder and correct_value
        
    configurations['model']['latent_size'] += configurations['model']['age_latent_size']


if not torch.cuda.is_available():
    device = torch.device("cpu")
    print("GPU not available, running on CPU")
else:
    device = torch.device("cuda")

manager = ModelManager(
    configurations=configurations, device=device,
    precomputed_storage_path=configurations['data']['precomputed_path'])
manager.resume(os.path.join(demo_directory, "checkpoints"))

normalization_dict = torch.load(norm_dict_path)

Resume from epoch 600


### Randomly generate a shape and create a scene

Every time you run this cell, a new shape is generated. Run it untli you are satisfied with the generated shape.

In [30]:
from pythreejs import *
from IPython.display import display
import ipywidgets

view_width = 800
view_height = 800

# z = torch.randn([1, manager.model_latent_size])

age_latent_size = configurations['model']['age_latent_size']
feature_latent_size = configurations['model']['latent_size'] - age_latent_size

z_features = torch.randn([1, feature_latent_size])

single_random_value = torch.randn(1)
z_age = torch.full([1, age_latent_size], single_random_value.item())

z = torch.cat((z_features, z_age), dim=1)

# print(normalization_dict['std'].to(device).size())
# print(normalization_dict['mean'].to(device).size())

gen_verts = manager.generate(z.to(device))[0, :, :] 
gen_verts = gen_verts * normalization_dict['std'].to(device) + \
    normalization_dict['mean'].to(device)
initially_gen_verts = gen_verts.clone()
faces = manager.template.face

def compute_vertex_colours(current_verts):
    source_dist = manager.compute_vertex_errors(current_verts, initially_gen_verts)
    colours = utils.errors_to_colors(source_dist.unsqueeze(dim=0), min_value=0,
                                     max_value=10, cmap='plasma') / 255
    return colours.squeeze().cpu().detach().numpy()


buffer_verts = BufferAttribute(gen_verts.detach().cpu().numpy().tolist(), normalized=False)
buffer_faces = BufferAttribute(np.uint32(faces.t().numpy().tolist()).ravel(), normalized=False)
buffer_colours = BufferAttribute(compute_vertex_colours(gen_verts).tolist(), normalized=False)

geometry = BufferGeometry(
    attributes={
        'position': buffer_verts,
        'index': buffer_faces,
        'color': buffer_colours
    })
geometry.exec_three_obj_method('computeVertexNormals')

material = MeshPhongMaterial(color="#34eb46", specular="#222222", shininess=15)
material_colours = MeshPhongMaterial(specular="#222222", shininess=15, vertexColors='VertexColors')

mesh = Mesh(geometry, material=material)
mesh_colours = Mesh(geometry, material=material_colours)

camera = PerspectiveCamera(position=[2, 0, 3], aspect=view_width/view_height)
ambient_light = AmbientLight(intensity=0.2)
ambient_light_dispmap = AmbientLight(intensity=1)
key_light = SpotLight(position=[0, 10, 10], angle = 0.3, penumbra = 0.1)

key_light.target = mesh
mesh.castShadow = True
mesh.receiveShadow = True
mesh_colours.castShadow = True
mesh_colours.receiveShadow = True

scene = Scene(children=[mesh, camera, key_light, ambient_light])
scene_colours = Scene(children=[mesh_colours, camera, ambient_light_dispmap])

controller = OrbitControls(controlling=camera)
renderer = Renderer(camera=camera, scene=scene, controls=[controller],
                    width=view_width, height=view_height, antialias=True)
renderer_colours = Renderer(camera=camera, scene=scene_colours, controls=[controller],
                            width=view_width/3, height=view_height/3, antialias=True)
renderers_pair = ipywidgets.HBox([renderer_colours, renderer])
renderers_pair.layout.align_items = "center"
display(renderers_pair)

HBox(children=(Renderer(camera=PerspectiveCamera(position=(2.0, 0.0, 3.0), projectionMatrix=(1.0, 0.0, 0.0, 0.…

### Create GUI

In [None]:
def update_vertices(z_i_value, z_i_index):
    
    age_latent_index = configurations['model']['latent_size'] - configurations['model']['age_latent_size']
    data_type = configurations['data']['dataset_type'].split("_", 1)[1]

    import pickle
    storage_path = os.path.join('precomputed', f'normalise_age_{data_type}.pkl')
    with open(storage_path, 'rb') as file:
        age_train_mean, age_train_std = \
            pickle.load(file)

    if z_i_index >= age_latent_index:
        z_i_value = (z_i_value.new - age_train_mean) / age_train_std
    else:
        z_i_value = z_i_value.new

    z[0, z_i_index] = z_i_value
    verts = manager.generate(z.to(device))[0, :, :] 
    verts = verts * normalization_dict['std'].to(device) + \
        normalization_dict['mean'].to(device)
    colours = compute_vertex_colours(verts)
    verts = verts.detach().cpu().numpy()
    v = verts.astype("float32", copy=False)
    geometry.attributes["position"].array = v
    geometry.attributes["position"].needsUpdate = True
    mesh.geometry.exec_three_obj_method('computeVertexNormals')
    geometry.attributes["color"].array = colours.astype("float32", copy=False)
    geometry.attributes["color"].needsUpdate = True
    mesh.geometry.verticesNeedUpdate = True
    mesh.geometry.elementsNeedUpdate = True
    mesh.geometry.colorsNeedUpdate = True
    mesh.exec_three_obj_method('update')
    mesh_colours.geometry.verticesNeedUpdate = True
    mesh_colours.geometry.elementsNeedUpdate = True
    mesh_colours.geometry.colorsNeedUpdate = True
    mesh_colours.exec_three_obj_method('update')
    controller.exec_three_obj_method('update')
    camera.exec_three_obj_method('updateProjectionMatrix')
    scene.exec_three_obj_method('update')
    scene_colours.exec_three_obj_method('update')
    
color2name_dict = {'[160 241 251 255]': "temporal", '[ 57 130 135 255]': "eyes",  '[251 155   0 255]': "cheekbones", 
                   '[251 231 144 255]': "cheeks", '[168  95 251 255]': "jaw", '[135 251 185 255]': "forehead", 
                   '[251 174 204 255]': "chin",  '[251  57 158 255]': "mouth", '[251 130  96 255]': "nose"}

region_sliders_names = []
region_sliders = []
grouped_region_indices = []
for r_name, r_range in manager.latent_regions.items():
    region_indices = list(range(r_range[0], r_range[1]))
    grouped_region_indices.append(region_indices)
    sl = []
    for i in region_indices:
        current_s = ipywidgets.FloatSlider(min=-3, max=3, step=0.05, value=z[0, i], description=f"z_{str(i)}")
        current_s.observe(lambda x, y=i: update_vertices(x, y), names='value')
        sl.append(current_s)
    region_sliders.append(ipywidgets.VBox(sl))
    region_sliders_names.append(color2name_dict[r_name])

{'name': 'value', 'old': 0.26040446758270264, 'new': 0.25, 'owner': FloatSlider(value=0.25, description='z_0', max=3.0, min=-3.0, step=0.05), 'type': 'change'}
0
0.25
{'name': 'value', 'old': 1.602525234222412, 'new': 1.5999999999999996, 'owner': FloatSlider(value=1.5999999999999996, description='z_1', max=3.0, min=-3.0, step=0.05), 'type': 'change'}
1
1.5999999999999996
{'name': 'value', 'old': -1.2663394212722778, 'new': -1.25, 'owner': FloatSlider(value=-1.25, description='z_2', max=3.0, min=-3.0, step=0.05), 'type': 'change'}
2
-1.25
{'name': 'value', 'old': -0.7691777944564819, 'new': -0.75, 'owner': FloatSlider(value=-0.75, description='z_3', max=3.0, min=-3.0, step=0.05), 'type': 'change'}
3
-0.75
{'name': 'value', 'old': 1.2535516023635864, 'new': 1.2500000000000009, 'owner': FloatSlider(value=1.2500000000000009, description='z_4', max=3.0, min=-3.0, step=0.05), 'type': 'change'}
4
1.2500000000000009
{'name': 'value', 'old': -0.5808898210525513, 'new': -0.6000000000000001, 'own

This code create a seperate section of the sliders on the GUI for all of the age latents that correspond to each feature. It uses the min and max of the normalised ages. 

In [None]:
total_latents = z.shape[1]
extra_latents_start = r_range[1]
extra_latents_indices = list(range(extra_latents_start, total_latents))

features = ["temporal", "eyes", "cheekbones", "cheeks", "jaw", "forehead", "chin", "mouth", "nose"]

extra_sliders = []

import pickle
# storage_path = os.path.join(demo_directory, 'z_stats.pkl')
# with open(storage_path, 'rb') as file:
#     z_stats = pickle.load(file)

# age_latent_size = configurations['model']['age_latent_size']

# z_age_mins = z_stats['mins'][-age_latent_size:]
# z_age_maxs = z_stats['maxs'][-age_latent_size:]


data_type = configurations['data']['dataset_type'].split("_", 1)[1]

storage_path = os.path.join('precomputed', f'normalise_age_{data_type}.pkl')
with open(storage_path, 'rb') as file:
    age_train_mean, age_train_std = \
        pickle.load(file)
    
unnormalized_value = (z[0, extra_latents_indices] * age_train_std) + age_train_mean
                    
for i in range(len(extra_latents_indices)):

    current_s = ipywidgets.FloatSlider(min=0, max=17, step=1, value=unnormalized_value[i], description=f"{features[i]}")
    # current_s = ipywidgets.FloatSlider(min=z_age_mins[i], max=z_age_maxs[i], step=0.05, value=z[0, extra_latents_indices[i]], description=f"{features[i]}")
    current_s.observe(lambda x, y=extra_latents_indices[i]: update_vertices(x, y), names='value')
    extra_sliders.append(current_s)

region_sliders.append(ipywidgets.VBox(extra_sliders))
region_sliders_names.append("Age Features")  

{'name': 'value', 'old': 4.599389553070068, 'new': 5.000000000000001, 'owner': FloatSlider(value=5.000000000000001, description='temporal', max=17.0, step=1.0), 'type': 'change'}
45
-0.3649519843253325
{'name': 'value', 'old': 4.599389553070068, 'new': 5.000000000000001, 'owner': FloatSlider(value=5.000000000000001, description='eyes', max=17.0, step=1.0), 'type': 'change'}
46
-0.3649519843253325
{'name': 'value', 'old': 4.599389553070068, 'new': 5.000000000000001, 'owner': FloatSlider(value=5.000000000000001, description='cheekbones', max=17.0, step=1.0), 'type': 'change'}
47
-0.3649519843253325
{'name': 'value', 'old': 4.599389553070068, 'new': 5.000000000000001, 'owner': FloatSlider(value=5.000000000000001, description='cheeks', max=17.0, step=1.0), 'type': 'change'}
48
-0.3649519843253325
{'name': 'value', 'old': 4.599389553070068, 'new': 5.000000000000001, 'owner': FloatSlider(value=5.000000000000001, description='jaw', max=17.0, step=1.0), 'type': 'change'}
49
-0.3649519843253325

In [33]:
out_mesh_dir = os.path.join(demo_directory, "out_meshes")
if not os.path.isdir(out_mesh_dir):
    os.mkdir(out_mesh_dir)

fname_widget = ipywidgets.Text(placeholder="mesh_name.ply", description="Filename:", disabled=False)
out_message_widget = ipywidgets.Output()


def compute_vertex_colours(current_verts):
    source_dist = manager.compute_vertex_errors(current_verts, initially_gen_verts)
    colours = utils.errors_to_colors(source_dist.unsqueeze(dim=0), min_value=0,
                                     max_value=10, cmap='plasma') / 255
    return colours.squeeze().cpu().detach().numpy()


def save_current_mesh(b):
    verts = manager.generate(z.to(device))[0, ::] 
    verts = verts * normalization_dict['std'].to(device) + \
        normalization_dict['mean'].to(device)
    v_col = compute_vertex_colours(verts)
    mesh = trimesh.Trimesh(
        verts.cpu().detach().numpy(),
        manager.template.face.t().cpu().numpy(),
        vertex_colors=v_col)
    
    fname = fname_widget.value
    if fname.endswith(".ply") or fname.endswith(".ply"):
        mesh.export(os.path.join(out_mesh_dir, fname))
        with out_message_widget:
            print("Mesh saved!")
    else:
        with out_message_widget:
            print(f"'{fname}' is not a valid meshfile name. Make sure it finishes in '.ply' or '.obj'")

    
save_button_widget = ipywidgets.Button(description="Save", disabled=False, button_style='info')
save_button_widget.on_click(save_current_mesh)
saving_widgets = ipywidgets.HBox([fname_widget, save_button_widget, out_message_widget])

add 'random' buttom to randomly generate a new face mesh 

In [None]:
def reset_mesh_and_sliders(b=None):
    # Generate a new random z vector
    # global z
    z = torch.randn([1, manager.model_latent_size]) #.to(device)
    
    # # Update sliders for each region
    # for i, slider_group in enumerate(region_sliders):
    #     for j, slider in enumerate(slider_group.children):
    #         z_index = grouped_region_indices[i][j]
    #         slider.value = z[0, z_index].item()
    
    # # Update extra sliders for age features
    # for i, slider in enumerate(extra_sliders):
    #     z_index = extra_latents_indices[i]
    #     slider.value = z[0, z_index].item()
    
    # Regenerate the mesh with the new z vector
    # update_vertices(z[0, 0], 0)  # Trigger mesh update using the first slider as a proxy
    for i in range(z.shape[1]):
        update_vertices(z[0, i], i)

# Create the "Random" button
random_button = ipywidgets.Button(description="Random mesh", button_style='success')
random_button.on_click(reset_mesh_and_sliders)


# Add the "Random" button to the UI layout
random_widget = ipywidgets.HBox([random_button])

### Run the GUI 

Each slider corresponds to a latent variable. When sliders are changed, the VAE generates a new shapes that is displayed in real time. Sliders are grouped according to the anatomical region that they are influencing. On the left side a displacement map shows the regions that were altered while manipulating the latent variables with the sliders. Displacements are computed from the random mesh that was initially generated. When you are satisfied with the final result, you can also save the mesh.

In [34]:
accordion = ipywidgets.Accordion(children=region_sliders, titles=region_sliders_names)
ipywidgets.VBox([ipywidgets.HBox([renderers_pair, accordion]), saving_widgets, random_widget])

VBox(children=(HBox(children=(HBox(children=(Renderer(camera=PerspectiveCamera(position=(2.0, 0.0, 3.0), proje…