In [2]:
import numpy as np
import open3d as o3d
import random
import torch
import sys
from src.param import parse_args
import src.models as models
from src.utils.data import normalize_pc
from src.utils.misc import load_config
from huggingface_hub import hf_hub_download
from collections import OrderedDict
import open_clip
import re
from PIL import Image
import torch.nn.functional as F

In [4]:
def load_pcd(pcd, num_points=10000, y_up=True):
    #pcd = o3d.io.read_point_cloud(file_name)
    xyz = np.asarray(pcd.points)
    rgb = np.asarray(pcd.colors)
    n = xyz.shape[0]
    if n != num_points:
        idx = random.sample(range(n), num_points)
        xyz = xyz[idx]
        rgb = rgb[idx]
    if y_up:
        # swap y and z axis
        xyz[:, [1, 2]] = xyz[:, [2, 1]]
    xyz = normalize_pc(xyz)
    if rgb is None:
        rgb = np.ones_like(rgb) * 0.4
    features = np.concatenate([xyz, rgb], axis=1)
    xyz = torch.from_numpy(xyz).type(torch.float32)
    print(xyz)
    features = torch.from_numpy(features).type(torch.float32)
    return torch.stack([xyz]).float(), torch.stack([features]).cuda()

In [None]:
def load_model(config, model_name="OpenShape/openshape-pointbert-vitg14-rgb"):
    model = models.make(config).cuda()

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    checkpoint = torch.load(hf_hub_download(repo_id=model_name, filename="model.pt"))
    model_dict = OrderedDict()
    pattern = re.compile('module.')
    for k,v in checkpoint['state_dict'].items():
        if re.search("module", k):
            model_dict[re.sub(pattern, '', k)] = v
    model.load_state_dict(model_dict)
    return model

In [None]:
@torch.no_grad()
def extract_text_feat(texts, clip_model,):
    text_tokens = open_clip.tokenizer.tokenize(texts).cuda()
    return clip_model.encode_text(text_tokens)

In [None]:
def class_probs(pcd):
    print("loading OpenShape model...")
    cli_args, extras = parse_args(sys.argv[1:])
    config = load_config("src/configs/train.yaml", cli_args = vars(cli_args), extra_args = extras)
    model = load_model(config)
    model.eval()

    open_clip_model, _, open_clip_preprocess = open_clip.create_model_and_transforms('ViT-bigG-14', pretrained='laion2b_s39b_b160k', cache_dir="/kaiming-fast-vol/workspace/open_clip_model/")
    open_clip_model.cuda().eval()

    xyz, feat = load_pcd(pcd)
    shape_feat = model(xyz, feat, device='cuda', quantization_size=config.model.voxel_size) 

    texts = ["chair", "shelf", "door", "sink", "sofa", "bed", "toilet", "desk", "display", "table"]
    text_feat = extract_text_feat(texts, open_clip_model)
    scores = F.normalize(shape_feat, dim=1) @ F.normalize(text_feat, dim=1).T
    return F.softmax(scores)