In [44]:
%load_ext autoreload
%autoreload 2
print('runing')


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
runing


In [45]:
from partglot.utils.predict import get_loaded_model
import numpy as np
from src.helper.visualization import visualize_pointclouds_parts_partglot
from src.helper.visualization import get_rnd_color

import pymeshlab as pm
import torch
from partglot.utils.neural_utils import tokenizing
from sklearn.cluster import KMeans

part_names = ["back", "seat", "leg", "arm"]
sseg_cmap = [get_rnd_color() for i in range(1000)]

part_semantic_groups = {
    "back": ["back"],
    "seat": ["seat"],
    "leg": ["leg", "wheel", "base"],
    "arm": ["arm"],
}

def sort_arrays(arrays):
    ref_array = arrays[0]
    sorted_indices = ref_array.argsort()
    out = []
    for a in arrays:
        out.append(a[sorted_indices])
    return out

def random_sample_array(arr: np.array, size: int = 1, with_replacement:bool=True) -> np.array:
    if with_replacement:
        while len(arr) < size:
            arr = np.concatenate([arr, arr])
    return arr[np.random.choice(len(arr), size=size, replace=False)]

def cluster_supsegs(sorted_labels, sorted_pc, sup_seg_size=512):
    sup_segs, labels = [], []
    for lbl in np.unique(sorted_labels):
        indices = np.where(sorted_labels==lbl)[0]
        tmp_pc = sorted_pc[indices]
        tmp_lbl = sorted_labels[indices]
        sup_segs.append(random_sample_array(tmp_pc, sup_seg_size))
        # labels.append(random_sample_array(tmp_lbl, sup_seg_size))
        labels.append(lbl)
    return np.array(sup_segs), np.array(labels)


def get_attn_mask_objects(pc, pc2label):
    """
    Returns ordered point cloud and mask indices in our format.
    """
    stacked_pc = np.vstack(np.vstack(pc))
    
    arg_sort = pc2label.argsort()
    
    out_pc2label, out_pg_pc = pc2label[arg_sort], np.vstack(stacked_pc)[arg_sort]

    mask = {}
    for i, pn in enumerate(part_names):
        tmp = np.where(out_pc2label == i)[0]
        print(tmp.shape)
        if tmp.shape[0] == 0:
            continue
        mask[pn] = [tmp.min(), tmp.max()]
    
    return {"mask_vertices": mask}, out_pg_pc

def vstack2dim(data, dim=2):
    if len(data.shape) <= dim:
        return data
    else:
        data = np.vstack(data)
        return vstack2dim(data=data, dim=dim)


In [122]:
sample_idx = 1
use_bsp_ssegs_gt = False
n_ssseg_custom = 25
opacity = 0.25

np.random.seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
label_cmap = [0xff0000, 0x00ff00, 0x0000ff, 0xff00ff, 0xffff00, 0x00ffff]


model_dir = "/home/bellatini/DL3D-Practical/models/pn_agnostic.ckpt"
# data_dir = "/home/bellatini/DL3D-Practical/data/partglot"
data_dir = "/home/bellatini/DL3D-Practical/data/partglot_100"

# LOAD MODEL AND REFERENCE DATASET
partglot, partglot_dm = get_loaded_model(data_dir=data_dir, model_path=model_dir)
batch_data = torch.from_numpy(partglot_dm.h5_data['data'][sample_idx:sample_idx+1]).unsqueeze(dim=1).float().to(device)
mask_data = torch.from_numpy(partglot_dm.h5_data['mask'][sample_idx:sample_idx+1]).unsqueeze(dim=1).float().to(device)

# sup_segs2label, pc2label = partglot.get_attn_maps()[sample_idx]
# segs, masks = partglot_dm.h5_data['data'][sample_idx], partglot_dm.h5_data['mask'][sample_idx]

partglot.to(device)

# CLUSTER POINT CLOUD INTO SSEGS (KMEANS)
pc = vstack2dim(batch_data.cpu().numpy())
kmeans = KMeans(n_clusters=n_ssseg_custom, random_state=1).fit(pc)
pc2sup_segs_kmeans = kmeans.labels_
# sorted_labels, sorted_pc = sort_arrays((pc2sup_segs_kmeans, pc))
# sorted_ssegs, pc2sup_segs = cluster_supsegs(sorted_labels, sorted_pc)
sup_segs, pc2sup_segs = cluster_supsegs(pc2sup_segs_kmeans, pc)

# SET VARIABLES FOR PREDICTION (REF POINT CLOUDS OR CUSTOM ONES)
if use_bsp_ssegs_gt:
    final_ssegs_batch = batch_data
    final_mask_batch = mask_data 
    sup_segs = batch_data[0][0].cpu().numpy() # gt_super_segs / super_segs
else:
    final_ssegs_batch = torch.from_numpy(np.array([[sup_segs]])).float().to(device) 
    final_mask_batch = torch.from_numpy(np.array([[np.ones(final_ssegs_batch.shape[2])]])).float().to(device) 
    
# GET ATTN MAPS PER SSEG (sseg2label)
attn_maps = []
for pn in part_names:
    text_embeddings = tokenizing(partglot_dm.word2int, f"chair with a {pn}").to(device)[None].expand(
        1, -1
    )
    tmp = partglot.forward(
        final_ssegs_batch, # custom_ssegs_batch / batch_data
        final_mask_batch, # custom_mask_batch / mask_data
        text_embeddings, True)
    attn_maps.append(tmp)
    
attn_maps_concat = torch.cat(attn_maps).max(0)[1].cpu().numpy()

sup_segs2label = np.squeeze(attn_maps_concat)
sup_segs2label

# EXPAND ATTN MAPS TO POINT-LEVEL GRANULARITY (pc2label)
pc2label=[] # pc2sup_segs: is actually pc2label
for lbl in sup_segs2label:
    tmp = np.ones(512) * lbl
    pc2label.append(tmp)
    
pc2label = np.concatenate(pc2label).astype(int)

assign_ft = lambda x: sup_segs2label[x]

# pc2label_sorted = assign_ft(pc2sup_segs.astype(int))
# import numpy as np
# pc2label_prefinal = []
# for lbl in sup_segs2label:
#     tmp = np.ones(512)*lbl
#     pc2label_prefinal.append(tmp)
# pc2label_prefinal = np.concatenate(pc2label_prefinal)

# VISUALIZE PART SEGMENTATION LABELS
pc_final = vstack2dim(final_ssegs_batch.cpu().numpy())
# pc_final = vstack2dim(vstack2dim(sorted_ssegs))

# final_mask, final_pc = get_attn_mask_objects(super_segs, pc2label)
final_mask, final_pc = get_attn_mask_objects(pc_final, pc2label) # pc2sup_segs: is actually pc2label

mask = (pc_final != np.array([0,0,0])).max(axis=1)
non_zero_pc = pc_final[mask]

unique_point_perc = np.unique(pc_final, axis=0).shape[0] / vstack2dim(pc_final).shape[0]
unique_point_perc_non_zero = np.unique(non_zero_pc, axis=0).shape[0] / vstack2dim(non_zero_pc).shape[0]
print("use_bsp_ssegs_gt:", use_bsp_ssegs_gt)
print(f"unique point percentage: {unique_point_perc:.1%}")
print(f"unique point percentage (non-zeros): {unique_point_perc_non_zero:.1%}")

out = []
for s,f in final_mask['mask_vertices'].values():
    tmp = final_pc[s:f].astype(float)
    out.append(tmp)
    
visualize_pointclouds_parts_partglot(np.array(out), names=list(final_mask['mask_vertices'].keys()), part_colors=label_cmap, opacity=opacity)


write state dict
(2560,)
(4608,)
(4096,)
(1536,)
use_bsp_ssegs_gt: False
unique point percentage: 17.6%
unique point percentage (non-zeros): 18.4%


  visualize_pointclouds_parts_partglot(np.array(out), names=list(final_mask['mask_vertices'].keys()), part_colors=label_cmap, opacity=opacity)


Output()

In [123]:
visualize_pointclouds_parts_partglot(sup_segs, opacity=opacity, part_colors=sseg_cmap)

Output()

In [58]:
visualize_pointclouds_parts_partglot(sup_segs, opacity=opacity, part_colors=sseg_cmap)


Output()

In [84]:
pc_final[:, pc_final != [0,0,0]]

IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed

0.00014960766

In [34]:
pc2label=[] # pc2sup_segs: is actually pc2label
for lbl in sup_segs2label:
    tmp = np.ones(n_ssseg_custom) * lbl
    pc2label.append(tmp)

In [37]:
pc2label = np.concatenate(pc2label).astype(int)


In [39]:
pc2label.shape

(1250,)

In [40]:
n_ssseg_custom

25

In [13]:
attn_maps_concat

array([[[0, 0, 2, 3, 1, 0, 2, 1, 3, 3, 2, 1, 2, 1, 3, 2, 0, 3, 0, 0, 2,
         2, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]])

In [124]:
np.unique(final_pc, axis=0).shape

(1729, 3)

In [91]:
np.vstack(sup_segs).shape

(12800, 3)

In [92]:
pc2label.shape[0] /2

6400.0

In [110]:
final_pc.shape

(12800, 3)

In [18]:
out = []
for s,f in final_mask['mask_vertices'].values():
    tmp = final_pc[s:f]
    out.append(tmp)

In [20]:
pc2label_ref = np.load("/home/bellatini/DL3D-Practical/Baselines/PartGlot/logs/pre_trained/pn_agnostic/12-12_14-37-13/pred_label/final/0_pc_label.npy")

In [22]:
pc2label_ref.shape

(2691,)

In [108]:
pc2label_ref.shape, np.unique(pc2label_ref)

((2691,), array([0, 1, 2, 3]))

In [95]:
# torch.cat(attn_maps).shape

In [19]:
visualize_pointclouds_parts_partglot(out)

Output()

In [36]:
visualize_pointclouds_parts_partglot(segs)

Output()

In [37]:
pc2label

array([3, 3, 3, ..., 0, 0, 0])