In [1]:
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


In [2]:
import argparse
import json
import os.path as osp

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from pytorch3d.vis.plotly_vis import plot_scene
from pytorch3d.renderer import FoVPerspectiveCameras

from dataset import CustomDataset
from eval import evaluate_coordinate_ascent, evaluate_mst
from models import get_model
from utils import (
    unnormalize_image,
    view_color_coded_images_from_path,
    view_color_coded_images_from_tensor,
)
import plotly

In [3]:
model_dir = "/home/jason/relpose-plus-plus-dev/weights/relposepp"
model, args = get_model(model_dir, num_images=8)

Loading checkpoint ckpt_000800000.pth
Missing keys: ['feature_extractor.feature_positional_encoding.pos_table_1']
Unexpected keys: []


In [4]:
image_dir = "../examples/robot/images"
mask_dir = "../examples/robot/masks"
bboxes = None

In [5]:
dataset = CustomDataset(
    image_dir=image_dir,
    mask_dir=mask_dir,
    bboxes=bboxes,
    mask_images=args.get("mask_images", False),
)

In [6]:
device = torch.device("cuda:0")
num_frames = dataset.n
batch = dataset.get_data(ids=np.arange(num_frames))
images = batch["image"].to(device)
crop_params = batch["crop_params"].to(device)

batched_images, batched_crop_params = images.unsqueeze(0), crop_params.unsqueeze(0)

In [7]:
_, hypothesis = evaluate_mst(
    model=model,
    images=batched_images,
    use_all_features=True,
    crop_params=batched_crop_params,
)
R_pred = np.stack(hypothesis)

In [8]:
with torch.no_grad():
    _, _, T_pred = model(images=batched_images, crop_params=batched_crop_params)

In [9]:
with torch.no_grad():
    _, hypothesis = evaluate_coordinate_ascent(
        model=model,
        images=batched_images,
        use_all_features=True,
        crop_params=batched_crop_params,
    )
R_final = np.stack(hypothesis)

  0%|          | 0/50 [00:00<?, ?it/s]

In [10]:
def plotly_scene_visualization(R_pred_mst, R_pred_coord_asc, T_pred):
    num_frames = len(R_pred_mst)

    scenes = {
        "Initial Predicted Cameras": {},
        "Final Optimized Cameras": {},
    }

    for i in range(num_frames):
        scenes["Initial Predicted Cameras"][i] = FoVPerspectiveCameras(
            R=R_pred_mst[i, None], T=T_pred[i, None]
        )
        scenes["Final Optimized Cameras"][i] = FoVPerspectiveCameras(
            R=R_pred_coord_asc[i, None], T=T_pred[i, None]
        )

    fig = plot_scene(
        scenes,
        camera_scale=0.03,
        ncols=2,
    )
    fig.update_scenes(aspectmode="data")

    cmap = plt.get_cmap("hsv")
    for i in range(num_frames):
        fig.data[i].line.color = matplotlib.colors.to_hex(cmap(i / (num_frames)))
        fig.data[i + num_frames].line.color = matplotlib.colors.to_hex(
            cmap(i / (num_frames))
        )

    return fig

In [11]:
fig = plotly_scene_visualization(R_pred, R_final, T_pred)

In [12]:
html_plot = plotly.io.to_html(fig, full_html=False, include_plotlyjs="cdn")

# with open("template.html", "w") as f:
#     f.write(html_plot)

In [13]:
import io
import base64

#     s = io.BytesIO()
#     plt.plot(list(range(100)))
#     plt.savefig(s, format='png', bbox_inches="tight")
#     plt.close()
#     s = base64.b64encode(s.getvalue()).decode("utf-8").replace("\n", "")
#     return '<img align="left" src="data:image/png;base64,%s">' % s

s = io.BytesIO()
view_color_coded_images_from_tensor(images)
plt.savefig(s, format="png", bbox_inches="tight")
plt.close()
image_encoded = base64.b64encode(s.getvalue()).decode("utf-8").replace("\n", "")

In [14]:
HTML_TEMPLATE = """<html><head><meta charset="utf-8"/></head>
<body><img src="data:image/png;charset=utf-8;base64,{image_encoded}"/>
{plotly_html}</body></html>"""

In [15]:
with open("test.html", "w") as f:
    s = HTML_TEMPLATE.format(
        image_encoded=image_encoded,
        plotly_html=html_plot,
    )
    f.write(s)

In [43]:
dataset.bboxes

[[138, 96, 431, 747],
 [86, 108, 459, 729],
 [117, 127, 447, 747],
 [204, 142, 438, 748],
 [118, 90, 427, 748],
 [199, 83, 499, 760],
 [150, 100, 510, 713],
 [236, 154, 514, 739]]