In [1]:
import sys
sys.path.append('/viscam/u/iamisaac/mast3r/dust3r/croco')
import mast3r.utils.path_to_dust3r
from mast3r.model import AsymmetricMASt3R
from mast3r.fast_nn import fast_reciprocal_NNs
import mast3r.demo as demo

from dust3r.inference import inference
from dust3r.utils.image import load_images
# visualize a few matches
import numpy as np
import torch
import torchvision.transforms.functional
from matplotlib import pyplot as pl
import cv2
import numpy as np
import numpy as np
import requests
from PIL import Image
import base64
import io
import pandas as pd
import ast
import copy
import mast3r.demo
import dust3r.utils.geometry as geometry
print(geometry.__file__)
print(dir(geometry))
print(sys.modules['dust3r.utils.geometry'])

def init_mast3r(device='cuda'):
    model_name = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
    model = AsymmetricMASt3R.from_pretrained(model_name).to(device)
    return model

def run_mast3r(model1, img1_path, img2_path, device='cuda'):
    images = load_images([img1_path, img2_path], size=512)
    output = inference([tuple(images)], model, device, batch_size=1, verbose=False)
    
    # at this stage, you have the raw dust3r predictions
    view1, pred1 = output['view1'], output['pred1']
    view2, pred2 = output['view2'], output['pred2']
    
    desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
    
    # find 2D-2D matches between the two images
    matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
                                                   device=device, dist='dot', block_size=2**13)
    
    # ignore small border around the edge
    H0, W0 = view1['true_shape'][0]
    valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
        matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
    
    H1, W1 = view2['true_shape'][0]
    valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (
        matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
    
    valid_matches = valid_matches_im0 & valid_matches_im1
    matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
    return (output, matches_im0, matches_im1)

def viz_matches(mast3r_output, matches_im0, matches_im1, n_viz = 20):
    view1, view2 = output['view1'], output['view2']
    num_matches = matches_im0.shape[0]
    match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
    viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
    
    image_mean = torch.as_tensor([0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
    image_std = torch.as_tensor([0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
    
    viz_imgs = []
    for i, view in enumerate([view1, view2]):
        rgb_tensor = view['img'] * image_std + image_mean
        viz_imgs.append(rgb_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy())
    
    H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2]
    img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
    img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
    img = np.concatenate((img0, img1), axis=1)
    pl.figure()
    pl.imshow(img)
    cmap = pl.get_cmap('jet')
    for i in range(n_viz):
        (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
        pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
    pl.show(block=True)

import numpy as np


def rotation_error(R_out, R_gt):
    n = R_out.shape[0]
    error = 0

    for i in range(n):
        trace = np.trace(np.dot(R_out[i].T, R_gt[i]))
        frame_error = np.arccos((trace - 1) / 2)
        error += frame_error

    return error


def translation_error(T_out, T_gt):
    n = T_out.shape[0]
    error = 0

    distance_sum = 0
    for i in range(n):
        for j in range(i + 1, n):
            distance += np.linalg.norm(T_gt[i] - T_gt[j])

    scene_scale = distance_sum/n

    for i in range(n):
        frame_error = np.linalg.norm(T_out[i] - T_gt[i])
        normalized_frame_error = frame_error / scene_scale
        error += normalized_frame_error

    return error

def send_query(parquet_file_path, display_images=False, num_responses = 5):
    metadata = pd.read_parquet(parquet_file_path)
    embedding = metadata["embedding"][50].tolist()
    # print(metadata["image_path"][50])
    url = "http://localhost:1234/knn-service"
    data = {
        "embedding_input": embedding,
        # "text": "a photo of an apple",
        "modality": "image", #// image or text index to use
        "num_images": num_responses, #// number of output images
        "num_result_ids": num_responses, #// optional, if s
        "indice_name": "co3d",
    }
    response = requests.post(url, json=data)
    if response.status_code != 200:
        raise ValueError("Request failed!!")
    results = parse_response(response)

    output = []
    images = []
    # print(f'input path:{metadata["image_path"][50]}')
    # print(output)
    if display_images:
        display(Image.open(metadata["image_path"][50]))
        for result in results:
            display(Image.fromarray(result["image"]))
    return results

def parse_response(response):
    results = response.json()
    parsed_results = []
    for result in results:
        image = Image.open(io.BytesIO(base64.b64decode(result["image"])))
        R = ast.literal_eval(result["R"])
        R = np.array(R).reshape(3, 3)
        T = ast.literal_eval(result["T"])
        T = np.array(T)
        focal_length = np.array(ast.literal_eval(result["focal_length"]))
        principal_point = np.array(ast.literal_eval(result["principal_point"]))
        
        parsed_result = copy.deepcopy(result)
        parsed_result["R"] = R
        parsed_result["T"] = T
        parsed_result["focal_length"] = focal_length
        parsed_result["principal_point"] = principal_point
        parsed_result["image"] = np.array(image)
        parsed_results.append(parsed_result)
    return parsed_results

def run_reconstruction(model, filelist, output_path="/viscam/projects/sfs/mast3r_outputs/test1"):
    gradio_delete_cache = False
    current_scene_state = None # gradio scene state
    image_size = 512 #224
    optim_level = "refine"
    silent = False
    device = 'cuda'
    niter1 = 500
    lr1 = 0.07
    niter2 = 200
    lr2 = 0.014
    min_conf_thr = 1.5
    matching_conf_thr = 2.0
    as_pointcloud = True
    mask_sky = True
    clean_depth = True
    transparent_cams = True
    cam_size = 0.2
    scenegraph_type = "complete"
    winsize = 1
    win_cyclic = False
    TSDF_thresh = 0
    refid = 0 # Scene ID
    shared_intrinsics = True
    
    args = {
    "outdir":output_path,
    "model":model,
    "device":device,
    "filelist":filelist,
    "niter1":niter1,
    "niter2":niter2,
    "lr1":lr1,
    "lr2":lr2,
    "as_pointcloud":as_pointcloud,
    "cam_size":cam_size,
    "TSDF_thresh":TSDF_thresh,
    "gradio_delete_cache":gradio_delete_cache,
    "current_scene_state":current_scene_state,
    "image_size":image_size,
    "optim_level":optim_level,
    "silent":silent,
    "mask_sky":mask_sky,
    "clean_depth":clean_depth,
    "transparent_cams":transparent_cams,
    "scenegraph_type":scenegraph_type,
    "winsize":winsize,
    "win_cyclic":win_cyclic,
    "refid":refid,
    "shared_intrinsics":shared_intrinsics,
    "min_conf_thr":min_conf_thr,
    "matching_conf_thr":matching_conf_thr,
    }
    sparse_ga_state, outfile = demo.get_reconstructed_scene(**args)
    return sparse_ga_state.sparse_ga, outfile

def run_mast3r_from_clip_retrieval(model, output_path, query_file_path, num_responses=10):
    responses = send_query(query_file_path, display_images=True, num_responses=num_responses)
    img_paths = [response["image_path"] for response in responses]
    rotations = [response["R"] for response in response]
    translations = [response["T"] for response in response]
    sparse_ga, outfile = run_reconstruction(model, img_paths, output_path)
    return sparse_ga, outfile
    # img_poses = sparse_ga.get_im_poses()
    # pred_relative_transform = torch.matmul(torch.inverse(img_poses[0]), img_poses[1])
    # gt_relative_transform = 
    

ModuleNotFoundError: No module named 'mast3r'

In [None]:
model = init_mast3r()

In [None]:
run_mast3r_from_clip_retrieval(model, "/viscam/u/iamisaac/sfs/co3d_embeddings/final4/metadata/336_34859_64111.parquet")

In [6]:
sparse_ga = sparse_ga_state.sparse_ga
img_poses = sparse_ga.get_im_poses()
relative_transform = torch.matmul(torch.inverse(img_poses[0]), img_poses[1])
print(relative_transform)

tensor([[ 1.4224e-01, -2.4439e-01,  9.5919e-01, -2.5692e+00],
        [ 2.0351e-01,  9.5556e-01,  2.1328e-01, -6.1434e-01],
        [-9.6868e-01,  1.6487e-01,  1.8566e-01,  2.7310e+00],
        [ 1.1868e-08, -3.9900e-09, -1.3822e-08,  1.0000e+00]], device='cuda:0')
