In [1]:
import torch
import functools
import numpy as np
import mcubes as libmcubes
from sklearn.cluster import MiniBatchKMeans
from train_gan import Generator, Discriminator, engineer_feature_vec, chairInfo, pca_out, get_part_clusters_for_chairs_iterable, p_percentile_part_count_on_chair, DATA_DIR, centroid

Checkpoint loaded from chkt_dir/partae/latest.pth


In [48]:
from dataload.data_utils import loadH5Full
import io
from PIL import Image
from matplotlib import pyplot as plt
import trimesh

colors = [[0, 0, 255, 255],  # blue
          [0, 255, 0, 255],  # green
          [255, 0, 0, 255],  # red
          [255, 255, 0, 255],  # yellow
          [0, 255, 255, 255],  # cyan
          [255, 0, 255, 255],  # Magenta
          [160, 32, 240, 255],  # purple
          [255, 255, 240, 255]]  # ivory
vox_rez = 64
latent_dim=128
quadratic_feat=True
n_clusters = 8
kmfit_ = MiniBatchKMeans(n_clusters=n_clusters,
                         random_state=0).fit(pca_out)
vec_clusters = kmfit_.labels_

maxCountPerCluster = p_percentile_part_count_on_chair(kmfit_, get_part_clusters_for_chairs_iterable(kmfit_), 95).astype(int)
clusterEndIndices = maxCountPerCluster.cumsum()
clusterStartIndices = clusterEndIndices - maxCountPerCluster
count_dim = clusterEndIndices[-1]
affine_dim = count_dim*4

part_filenames = functools.reduce(lambda l1, l2: l1+l2, map(lambda x:x['filenames'], chairInfo))
part_fileorder = np.concatenate(list(map(lambda x:np.arange(0, len(x['filenames'])), chairInfo)))

total_num_parts = pca_out.shape[0]
pca_dims = pca_out.shape[1]
cluster_vec_bools = list(map(lambda cluster:vec_clusters==cluster, range(n_clusters)))
cluster_vec_indices = list(map(lambda bools:np.where(bools)[0], cluster_vec_bools))
centroids = np.array(list(map(lambda cluster: np.mean(pca_out[cluster_vec_bools[cluster]], axis=0), range(n_clusters))))
distances = np.linalg.norm(pca_out - np.repeat(centroids.reshape(-1, n_clusters, pca_dims), total_num_parts, axis=0)[np.arange(total_num_parts), vec_clusters, :], axis=1)
distances_by_cluster = list(map(lambda cluster:distances[cluster_vec_bools[cluster]], range(n_clusters)))

def sample_from_cluster(cluster_id, strict):
    if strict: # choose the closest part index to the cluster
        indices = cluster_vec_indices[cluster_id]
        return indices[np.argmin(distances_by_cluster[cluster_id])]
    return None

def sample_dict_from_g(generator, n):
    with torch.no_grad():
        z = torch.cuda.FloatTensor(np.random.normal(0, 1, (n, latent_dim)))
        fake_batch = generator(z)
    return fake_batch

def score_sample_with_d(discriminator, sample_dict):
    with torch.no_grad():
        feature_vec = engineer_feature_vec(sample_dict['count'], sample_dict['affine'], quadratic_feat)
    return discriminator(feature_vec)

def visualize_sample(generator, discriminator):
    n = 4
    #fig, ax = plt.subplots(1,n, figsize=(16, 4))
    sample_dict = sample_dict_from_g(generator, n)
    sample_scores = score_sample_with_d(discriminator, sample_dict)
    
    for (chair_i, (count, affine, score)) in enumerate(zip(sample_dict['count'], sample_dict['affine'], sample_scores)):
        count = count.detach().cpu().numpy()
        affine = affine.detach().cpu().numpy()
        count = count > 0.5
        part_indices = []
        scales = []
        translations = []
        categories = []
        for (part_i, pred) in enumerate(count):
            if pred:
                category = np.argmax(clusterEndIndices > part_i)
                categories.append(category)
                part_indices.append(sample_from_cluster(category, strict=True))
                scales.append(affine[part_i*4])
                translations.append(affine[part_i*4+1:part_i*4+4] + centroid)
        shape_meshes = []
        scene = trimesh.Scene()
        for (i, scale, tran, latent_category) in zip(part_indices, scales, translations, categories):
            part_file = part_filenames[i]
            path = os.path.join(DATA_DIR, part_file)
            _, voxels, _, _, _, _, _ = loadH5Full(path, resolution=vox_rez)
            voxels = voxels[part_fileorder[i]]
            vertices, triangles = libmcubes.marching_cubes(voxels, 0)
            inCol = colors[latent_category % len(colors)]
            mesh = trimesh.Trimesh(vertices, triangles, face_colors=inCol)
            mesh.apply_translation((-32, -32, -32))
            mesh.apply_scale(scale)
            mesh.apply_translation(tran)
            scene.add_geometry(mesh)
            shape_meshes.append(mesh)
        print(count)
        print(part_indices)
        print(shape_meshes)
        shape_meshes = trimesh.util.concatenate(shape_meshes)
        shape_meshes.export(os.path.join('data', str(chair_i)+".obj"), file_type='obj')
        #png = scene.save_image(resolution=[600, 400],visible=True)
        #mat = np.array(Image.open(io.BytesIO(data)))
        #ax.imshow(mat)
        #ax[chair_i].title.set_text(str(score))

In [50]:
import os
g = Generator(latent_dim, count_dim, affine_dim).cuda()
d = Discriminator(count_dim, affine_dim, quadratic_feat).cuda()
directory = 'km8_q_6000'
g.load_state_dict(torch.load(os.path.join('data', 'weights', directory, 'generator','latest.pth')))
d.load_state_dict(torch.load(os.path.join('data', 'weights', directory, 'discriminator','latest.pth')))

<All keys matched successfully>

In [51]:
visualize_sample(g,d)

[False False False False False False False False False False False False
 False False False False False False]
[]
[]


IndexError: index 0 is out of bounds for axis 0 with size 0