### Retrieval

This notebook is dedicated to testing trained pipelines on part retrieval and evaluate their performance.
The ultimate goal is to create a pipeline that finds similar parts, that fit the query object.

In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys
sys.path.append("/home/beast/Desktop/vlassis/retrieval2/experiments")

import torch
from torch.utils.data import DataLoader, random_split
import numpy as np
from tqdm import tqdm
from diffusers.optimization import get_cosine_schedule_with_warmup

from scripts.dataset import Items3Dataset, Warehouse4Dataset, warehouse4_collate_fn
from scripts.visualization import quick_vis, quick_vis_with_parts, quick_vis_many
from scripts.visualization import quick_vis_pretty, quick_vis_with_parts_pretty
from scripts.visualization import plot_histogram, visualize_distribution_over_time
from scripts.visualization import coord_frame
from scripts.model import *
from scripts.logger import LivePlot
from scripts.metrics import AccuracyMultiClass
from scripts.utils import map_labels, generate_label_map, normalize_parts, normalize_parts_1, get_truly_random_seed_through_os
from scripts.utils import normalize_and_split, pc_to_bounding_box

TEST = False


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
seed = get_truly_random_seed_through_os()
print(seed)
np.random.seed(seed)
torch.manual_seed(seed)
TEST = True

663100628


##### Create data loaders and data related parameters

In [3]:
batch_size = 128
categories = [0, 4, 15] # [0, 1, 3] #

# warehouse_path = "/home/beast/Desktop/vlassis/retrieval2/experiments/data/vectors_warehouse_partnet2"
warehouse_path = "/home/beast/Desktop/vlassis/retrieval2/experiments/data/vectors_warehouse_shapenet"
warehouse = Warehouse4Dataset(cat = None, path = warehouse_path, encodings=True)
dataloader = DataLoader(warehouse, batch_size = batch_size, shuffle = False, collate_fn = warehouse4_collate_fn)

# items_path = "/home/beast/Desktop/vlassis/retrieval2/experiments/data/vectors2_items2_partnet.h5"
items_path = "/home/beast/Desktop/vlassis/retrieval2/experiments/data/vectors_items_shapenet.h5"
itemsdataset = Items3Dataset(cat=categories, path = items_path)

#Retaining a label map to use throughout the retrieval process. Each label is mapped to a specific model output neuron
label_map = generate_label_map(torch.Tensor(categories).to(torch.int32))
print(label_map)

Initializing warehouse dataset


100%|███████████████████████████████████| 62120/62120 [00:18<00:00, 3307.93it/s]


Warehouse dataset initialization complete (t = 18.92544960975647)
(12137, 1)
Items2 dataset initialization complete (t = 1.0237460136413574)
{'0': 0, '4': 1, '15': 2}


##### Load the model we want to test each time

In [4]:
#choose the model we want and load the appropriate checkpoint
model = PartFinderPipeline2(in_channels = 3, out_channels = 384,
                           num_classes = len(categories),
                           num_attention_blocks = 3,
                           pos_emb_dim = 3,
                           pool_method = "cls_token_pool"
                          ).cuda()

# model.load_state_dict(torch.load("/home/beast/Desktop/vlassis/retrieval2/checkpoints/CLS_allcats_T7_partnet_pointnetnew_batchless.pt"))
model.load_state_dict(torch.load("/home/beast/Desktop/vlassis/retrieval2/checkpoints/CLS_allcats_T7_shapenet.pt"))
model = model.eval()

##### Grab an item from the dataset, choose a random part and discard it.

In [17]:
#acquiring a sample for testing
sample, label, part_labels, pid, vectors = itemsdataset[5000]
label = label.long()
print("LABEL: ", label)
quick_vis_with_parts_pretty(sample, pid, title=f"Query_sample, label: {label.item()}")
quick_vis_pretty(sample)
print(f"sample: {sample.shape}, label: {label.shape}, part_label: {part_labels.shape}\npid: {pid.shape}, vectors: {vectors.shape}")

#discarding a random part from the query sample
discard_part_id = 0#np.random.choice(np.unique(pid))
keep_indices = np.where(pid != discard_part_id)[0]
discard_indices = np.where(pid == discard_part_id)[0]

#finding the class of the discarded part
idx = np.where(pid == discard_part_id)[0][0]
discarded_part_class = part_labels[idx].item()
print("discarded part: ", discarded_part_class)

#keeping both the discarded and the kept shape
discarded_part = sample[discard_indices]
query_sample, query_pid, query_part_label = sample[keep_indices], pid[keep_indices], part_labels[keep_indices]
query_sample_copy = query_sample.clone()
query_vectors = torch.cat((vectors[:discard_part_id], vectors[discard_part_id + 1:]), dim=0)
                                                                                                                                                                                                                                                                                                                                                    
#rearranging pids
query_pid[query_pid > discard_part_id] = query_pid[query_pid > discard_part_id] -1
discarded_bb = pc_to_bounding_box(discarded_part)
# quick_vis_with_parts_pretty(query_sample, query_pid, extra_geometries = [discarded_bb._o3d],
#                      title="Query sample, after discarding a part")
quick_vis_pretty(query_sample, extra_geometries = [discarded_bb._o3d],
                     title=f"Query sample, after discarding part {discard_part_id}")

#sanity check - these are the variables that will be used through the rest of the notebook
print(query_sample.shape, query_pid.shape, query_part_label.shape, query_vectors.shape)

LABEL:  tensor([0])
sample: torch.Size([2048, 3]), label: torch.Size([1]), part_label: torch.Size([2048])
pid: torch.Size([2048]), vectors: torch.Size([70, 3])
discarded part:  0
torch.Size([1141, 3]) torch.Size([1141]) torch.Size([1141]) torch.Size([69, 3])


In [None]:
# {
# 	"class_name" : "ViewTrajectory",
# 	"interval" : 29,
# 	"is_loop" : false,
# 	"trajectory" : 
# 	[
# 		{
# 			"boundingbox_max" : [ 0.86086872887860688, 0.32886562875646308, 0.46846096634864809 ],
# 			"boundingbox_min" : [ -0.86086872887860688, -0.32886562875646308, -0.46846096634864809 ],
# 			"field_of_view" : 60.0,
# 			"front" : [ -0.62950627893910571, 0.32522670284422861, 0.70565532417272447 ],
# 			"lookat" : [ 0.0, 0.0, 0.0 ],
# 			"up" : [ 0.18282276118249258, 0.94468488339549284, -0.27229819881455958 ],
# 			"zoom" : 0.96000000000000019
# 		}
# 	],
# 	"version_major" : 1,
# 	"version_minor" : 0
# }

Encoding the query shape

In [8]:
#creating a batch dim and transferring to gpu
query_sample, query_pid, query_part_label, query_vectors = query_sample.unsqueeze(0).cuda(), query_pid.unsqueeze(0).cuda(), query_part_label.unsqueeze(0).cuda(), query_vectors.unsqueeze(0).cuda()
print(query_sample.shape)
#normalizing the parts -> parts: N x 3, pid: N, part_label: M, centroids: M x 3
parts, query_pid, part_label, query_centroids = normalize_and_split(query_sample, query_pid, query_part_label, query_vectors, include_centroids = True)

#encoding the shape -> M x F
query_shape_feats = model.forward_encoder(parts, query_pid)

print("part feats: ", query_shape_feats.shape)
print("pid: ", query_pid.shape)
print("part lb: ", part_label.shape)
print("centroids: ", query_centroids.shape)

torch.Size([1, 1320, 3])
part feats:  torch.Size([4, 384])
pid:  torch.Size([1320])
part lb:  torch.Size([4])
centroids:  torch.Size([4, 3])


##### Evaluating the entire warehouse against the query sample. The higher the score, the better the similarity

In [9]:
#EVALUATING THE SCORE OF EACH WAREHOUSE SPARE PART
print(label.shape, label_map)
label = map_labels(label, label_map)
warehouse_scores = []

with torch.no_grad():
    #iterating through every single part in the warehouse
    for samples, labels, part_labels, vectors, warehouse_features in tqdm(dataloader):
        
        #computing the centroids of each part - M x 3
        warehouse_centroids = torch.stack([s.mean(dim=0) for s in samples])
        
        #encoded samples have been normalized beforehand. Transferring to gpu
        warehouse_features, warehouse_centroids = warehouse_features.cuda(), warehouse_centroids.cuda()

        #M x K class scores
        scores = model.forward_retrieval(query_shape_feats, query_centroids, warehouse_features, warehouse_centroids, normalize=True)
        
        #selecting the output of only the relevant neuron, the one corresponding to the query shape
        scores = scores[:, label]
        
        #keeping track of all scores
        warehouse_scores.append(scores.cpu().squeeze())

torch.Size([1]) {'0': 0, '4': 1, '15': 2}


100%|█████████████████████████████████████████| 486/486 [00:13<00:00, 35.49it/s]


##### Evaluating the score of the actual discarded part, see how it compares to the rest

##### Display the top K matches and worst K matches

In [10]:
#Concatenating warehouse scores (batches of size B) into a single tensor
warehouse_scores = torch.cat(warehouse_scores, dim=0)

In [11]:
print(warehouse_scores.shape)
best_indices = torch.argsort(warehouse_scores, descending = True)
worst_indices = torch.argsort(warehouse_scores, descending = False)

k = 20

top_k_scores = [warehouse_scores[i].item() for i in best_indices[:k]]
bot_k_scores = [warehouse_scores[i].item() for i in worst_indices[:k]]

print(top_k_scores)
print(bot_k_scores)

top_k_parts = [warehouse[i.item()][0] for i in best_indices[:k]]
bot_k_parts = [warehouse[i.item()][0] for i in worst_indices[:k]]

torch.Size([62120])
[4.475837230682373, 4.475737571716309, 4.475613594055176, 4.475452423095703, 4.4753899574279785, 4.475274085998535, 4.4750566482543945, 4.474409580230713, 4.473893165588379, 4.4736433029174805, 4.473320007324219, 4.473201274871826, 4.473171710968018, 4.473169803619385, 4.473116397857666, 4.47283935546875, 4.472552299499512, 4.472504615783691, 4.47241735458374, 4.4722089767456055]
[-2.827885627746582, -2.7242982387542725, -2.684530735015869, -2.555819511413574, -2.553339958190918, -2.5467076301574707, -2.5046098232269287, -2.4568099975585938, -2.4381332397460938, -2.420393943786621, -2.4175522327423096, -2.416565418243408, -2.3816745281219482, -2.3607778549194336, -2.3547019958496094, -2.352668285369873, -2.321974992752075, -2.3195242881774902, -2.3076348304748535, -2.301144599914551]


Visualizing the samples corresponding to the highest and lowest scores

In [12]:
top_k_parts = [warehouse[i.item()][0] for i in best_indices[:k]]
bot_k_parts = [warehouse[i.item()][0] for i in worst_indices[:k]]

top_k_labels = [warehouse[i.item()][2] for i in best_indices[:k]]
bot_k_labels = [warehouse[i.item()][2] for i in worst_indices[:k]]

In [13]:
quick_vis_pretty(query_sample_copy.squeeze().cpu(), extra_geometries = [discarded_bb._o3d])
quick_vis_pretty(discarded_part.squeeze().cpu())

In [14]:
for part, lb in zip(top_k_parts, top_k_labels):
    lb = lb.squeeze().item()
    quick_vis_pretty(part, title = f"BEST - discarded id: {discarded_part_class}, current id: {lb}")

KeyboardInterrupt: 

In [15]:
for part, lb in zip(bot_k_parts, bot_k_labels):
    quick_vis_pretty(part, title = f"WORST - discarded id: {discarded_part_class}, current id: {lb}")

KeyboardInterrupt: 