In [20]:
import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
import os.path as osp
from PIL import Image

from pytorch3d.renderer import FoVPerspectiveCameras
from pytorch3d.vis.plotly_vis import plot_scene

from dataset import CustomDataset
from eval import evaluate_coordinate_ascent, evaluate_mst
from models import get_model
from utils import (
    unnormalize_image,
    view_color_coded_images_from_tensor,
    view_color_coded_images_from_path,
)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [21]:
IMAGE_DIR = "../examples/balm"
WEIGHTS_DIR = "../weights"

In [22]:
bboxes = []
tightness = [0.8, 0.8, 0.8, 0.5]
for t, image_path in zip(tightness, sorted(os.listdir(IMAGE_DIR))):
    im = Image.open(osp.join(IMAGE_DIR, image_path))
    x, y = im.width, im.height
    cropx, cropy = im.width * t, im.height * t
    startx = x // 2 - (cropx // 2)
    starty = y // 2 - (cropy // 2)
    bboxes.append([startx, starty, startx + cropx, starty + cropy])

print(bboxes)

[[303.0, 404.0, 2722.2000000000003, 3629.6000000000004], [303.0, 404.0, 2722.2000000000003, 3629.6000000000004], [303.0, 404.0, 2722.2000000000003, 3629.6000000000004], [1008.0, 756.0, 3024.0, 2268.0]]


In [23]:
# Load in the wild images
dataset = CustomDataset(image_dir=IMAGE_DIR, bboxes=bboxes)
num_frames = dataset[0]["n"]
images = dataset[0]["image"].to("cuda")
crop_params = dataset[0]["crop_params"].to("cuda")

IndexError: list index out of range

In [None]:
from PIL import Image

picture_location = "/home/amyxlase/relpose-plus-plus/examples/balm/IMG_2253.jpg"
compressed_picture_location = (
    "/home/amyxlase/relpose-plus-plus/examples/balm/IMG_2253-compressed.jpg"
)
im = Image.open(picture_location)
quality = 1
im.save(compressed_picture_location, quality=quality)

In [25]:
# Load pretrained weights
model = get_model(model_dir=WEIGHTS_DIR, num_images=num_frames, device="cuda")

Loading checkpoint ckpt_000400000.pth
Missing keys: ['feature_extractor.feature_positional_encoding.pos_table_1']
Unexpected keys: []


In [26]:
# Initialize a quick, reasonable solution using MST reasoning
batched_images, batched_crop_params = images.unsqueeze(0), crop_params.unsqueeze(0)

print(dataset.bboxes)

_, hypothesis = evaluate_mst(
    model=model,
    images=batched_images,
    use_all_features=True,
    crop_params=batched_crop_params,
)
R_pred = np.stack(hypothesis)


# Regress to optimal translation
with torch.no_grad():
    _, _, T_pred = model(images=batched_images, crop_params=batched_crop_params)

print(R_pred.shape)
print(T_pred.shape)

[[303.0, 404.0, 2722.2000000000003, 3629.6000000000004], [303.0, 404.0, 2722.2000000000003, 3629.6000000000004], [404.0, 303.0, 3629.6000000000004, 2722.2000000000003], [1008.0, 756.0, 3024.0, 2268.0]]
(4, 3, 3)
torch.Size([4, 3])


In [27]:
# Construct cameras and visualize scene for quick solution
cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred)
scenes = {"Predicted Cameras": {}}

for i in range(num_frames):
    scenes["Predicted Cameras"][i] = FoVPerspectiveCameras(
        R=R_pred[i, None], T=T_pred[i, None]
    )

fig = plot_scene(
    scenes,
    camera_scale=0.03,
    ncols=2,
)
fig.update_scenes(aspectmode="data")

for i in range(num_frames):
    fig.data[i].line.color = matplotlib.colors.to_hex(cmap(i / (num_frames)))

fig

In [28]:
# Search for optimal rotation via coordinate ascent
_, hypothesis = evaluate_coordinate_ascent(
    model=model,
    images=batched_images,
    use_all_features=True,
    crop_params=batched_crop_params,
)
R_pred = np.stack(hypothesis)

  0%|          | 0/50 [00:00<?, ?it/s]

In [29]:
# Construct cameras and visualize scene for best solution
cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred)
scenes = {"Predicted Cameras": {}}

for i in range(num_frames):
    scenes["Predicted Cameras"][i] = FoVPerspectiveCameras(
        R=R_pred[i, None], T=T_pred[i, None]
    )

fig = plot_scene(
    scenes,
    camera_scale=0.03,
    ncols=2,
)
fig.update_scenes(aspectmode="data")

for i in range(num_frames):
    fig.data[i].line.color = matplotlib.colors.to_hex(cmap(i / (num_frames)))

fig