In [1]:
from partglot.utils.predict import get_loaded_model
import numpy as np
from src.helper.visualization import visualize_pointclouds_parts_partglot

import pymeshlab as pm
import torch
from partglot.utils.neural_utils import tokenizing
from sklearn.cluster import KMeans

np.random.seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

part_names = ["back", "seat", "leg", "arm"]

part_semantic_groups = {
    "back": ["back"],
    "seat": ["seat"],
    "leg": ["leg", "wheel", "base"],
    "arm": ["arm"],
}

model_dir = "/home/bellatini/DL3D-Practical/models/pn_agnostic.ckpt"
data_dir = "/home/bellatini/DL3D-Practical/data/partglot"
data_dir = "/home/bellatini/DL3D-Practical/Baselines/PartGlot/data"

partglot, partglot_dm = get_loaded_model(data_dir=data_dir, model_path=model_dir)
sup_segs2label, pc2label = partglot.get_attn_maps()[1]
segs, masks = partglot_dm.h5_data['data'][1], partglot_dm.h5_data['mask'][1]

partglot.to(device)

"setup done"



write state dict


: 

: 

In [None]:
def sort_arrays(arrays):
    ref_array = arrays[0]
    sorted_indices = ref_array.argsort()
    out = []
    for a in arrays:
        out.append(a[sorted_indices])
    return out

def random_sample_array(arr: np.array, size: int = 1, with_replacement:bool=True) -> np.array:
    if with_replacement:
        while len(arr) < size:
            arr = np.concatenate([arr, arr])
    return arr[np.random.choice(len(arr), size=size, replace=False)]

def get_attn_mask_objects(partglot_pointcloud, pc2label):
    """
    Returns ordered point cloud and mask indices in our format.
    """
    stacked_pc = np.vstack(partglot_pointcloud)
    
    arg_sort = pc2label.argsort()
    
    out_pc2label, out_pg_pc = pc2label[arg_sort], np.vstack(stacked_pc)[arg_sort]

    mask = {}
    for i, pn in enumerate(part_names):
        tmp = np.where(out_pc2label == i)[0]
        mask[pn] = [tmp.min(), tmp.max()]
    
    return {"mask_vertices": mask}, out_pg_pc

def cluster_supsegs(sorted_labels, sorted_pc, sup_seg_size=512):
    sup_segs, labels = [], []
    for lbl in np.unique(sorted_labels):
        indices = np.where(sorted_labels==lbl)[0]
        tmp_pc = sorted_pc[indices]
        tmp_lbl = sorted_labels[indices]
        sup_segs.append(random_sample_array(tmp_pc, sup_seg_size))
        labels.append(random_sample_array(tmp_lbl, sup_seg_size))
    return np.array(sup_segs), np.array(labels)

In [None]:
batch_data = torch.from_numpy(partglot_dm.h5_data['data'][1:2]).unsqueeze(dim=1).float().to(device)
mask_data = torch.from_numpy(partglot_dm.h5_data['mask'][1:2]).unsqueeze(dim=1).float().to(device)

pc = np.vstack((np.vstack(np.vstack(batch_data.cpu().numpy()))))

kmeans = KMeans(n_clusters=25, random_state=1).fit(pc)
pc2sup_segs = kmeans.labels_
np.unique(pc2sup_segs, return_counts=True)

sorted_labels, sorted_pc = sort_arrays((kmeans.labels_, pc))

super_segs, _ = cluster_supsegs(sorted_labels, sorted_pc)

custom_ssegs_batch = torch.from_numpy(np.array([[super_segs]])).float().to(device)
custom_mask_batch = torch.from_numpy(np.array([[np.ones(custom_ssegs_batch.shape[2])]])).float().to(device)
custom_ssegs_batch.shape, custom_mask_batch.shape

attn_maps = []
for pn in part_names:
    text_embeddings = tokenizing(partglot_dm.word2int, f"chair with a {pn}").to(device)[None].expand(
        1, -1
    )
    tmp = partglot.forward(
        batch_data, # custom_ssegs_batch / batch_data
        mask_data, # custom_mask_batch / mask_data
        text_embeddings, True)
    attn_maps.append(tmp)
    
sup_segs2label = np.squeeze(torch.cat(attn_maps).max(0)[1].cpu().numpy())
sup_segs2label

# custom_ssegs_batch
super_segs.shape
# custom_ssegs_batch.shape
K, n_points, coord = super_segs.shape

pc2sup_segs=[]
for ki in range(K):
    tmp = np.ones(n_points) * ki
    pc2sup_segs.append(tmp)
    
pc2sup_segs = np.concatenate(pc2sup_segs).astype(int)

assign_ft = lambda x: sup_segs2label[x]

pc2label = assign_ft(pc2sup_segs)

final_mask, final_pc = get_attn_mask_objects(super_segs, pc2label)



In [None]:
out = []
for s,f in final_mask['mask_vertices'].values():
    tmp = final_pc[s:f]
    out.append(tmp)

In [None]:
# pc2label_ref = np.load("/home/bellatini/DL3D-Practical/Baselines/PartGlot/logs/pre_trained/pn_agnostic/12-12_14-37-13/pred_label/final/0_pc_label.npy")

In [None]:
# torch.cat(attn_maps).shape

torch.Size([4, 1, 1, 50])

In [None]:
visualize_pointclouds_parts_partglot(out)

Output()

In [35]:
visualize_pointclouds_parts_partglot(segs)

Output()

In [37]:
pc2label

array([3, 3, 3, ..., 0, 0, 0])