In [3]:
import torch
import numpy as np
from pathlib import Path
import plotly.graph_objs as go
import quaternion


from initialization import TargetImageDataset

In [4]:
target_image_dataset = TargetImageDataset(
    Path("../data/calculator/raw_images"),
    Path("../data/image_pose.json")
)

fg_V_sdf = torch.load("../data/initial_v_sdf_fg.pt", weights_only=True)
bg_V_sdf = torch.load("../data/initial_v_sdf_bg.pt", weights_only=True)
fg_V_feat = fg_V_sdf.repeat(1, 3, 1, 1, 1)  # R,G,B
bg_V_feat = bg_V_sdf.repeat(1, 3, 1, 1, 1)  # R,G,B
fg_V_feat[:] = 0.5

print(fg_V_feat)
print(len(target_image_dataset))

tensor([[[[[0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           ...,
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000]],

          [[0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           ...,
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000]],

          [[0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
           [0.5000, 0.5000

In [100]:
data, pose = target_image_dataset[0]

rot_mat = quaternion.as_rotation_matrix(pose.qvec)   # 3x3 rot mat

In [176]:
# https://github.com/colmap/colmap/blob/main/src/colmap/sensor/models.h#L715
# https://calib.io/blogs/knowledge-base/camera-models?srsltid=AfmBOoosTUXUe3QZqSrWoJXC9Yr04axC6Mvx7ru4xjo-yHMRf4H_erhx

GRAPH_RESOLUTION = 64.0

px, py = np.mgrid[-pose.width/2:pose.width/2:pose.width/GRAPH_RESOLUTION,
                      -pose.height/2:pose.height/2:pose.height/GRAPH_RESOLUTION]
pz = np.full(px.shape, 1.0)

# ox, oy, oz = np.full(px.shape, 0.0), np.full(px.shape, 0.0), np.full(px.shape, 1.0)

#### FOCAL LENGTH (f)
img_points = np.dstack([(px.flatten()+0.5)/pose.f, (py.flatten()+0.5)/pose.f, pz.flatten()])[0]  # Nx2 points to be rot'd at origin
# img_offset = np.array([0.0, 0.0, 1.0])  # normalized points offset w=1 (z=1)

##### DISTORTION (k)
distortion = pose.k*(img_points[:,0]**2+img_points[:,1]**2)

dx = img_points[:,0]*(1+distortion)
dy = img_points[:,1]*(1+distortion)
dz = img_points[:,2]
distorted_img_points = np.dstack([dx, dy, dz])[0]

##### ROTATION (qvec)
rotated_img_points = np.einsum("ij,nj->ni", rot_mat, distorted_img_points)

rpx, rpy, rpz = np.split(rotated_img_points, 3, axis=1)
rpx, rpy, rpz = rpx.flatten(), rpy.flatten(), rpz.flatten()

In [177]:
STEP_COUNT = 20
STEP_SIZE = 0.2

t = np.expand_dims(1.0+STEP_SIZE + np.arange(0, STEP_COUNT)*STEP_SIZE, axis=1)
expand = lambda p: np.repeat(np.expand_dims(p, axis=0), STEP_COUNT, axis=0)

spx = t*expand(rpx)
spy = t*expand(rpy)
spz = t*expand(rpz)

img_points_c = np.dstack([px.flatten(), py.flatten()])[0]

# path of x_i points for one viewing direction
np.dstack([spx[:, 0], spy[:, 0], spz[:, 0]])

sp = np.array([spx, spy, spz])  # 3xTxN
st = np.transpose(sp, (2, 1, 0))  # NxTx3

In [184]:
stx, sty, stz = np.split(st[50], 3, axis=1)
# img_points_c[50]+[pose.width/2, pose.height/2]
# rpx[50], rpy[50], rpz[50]
# print(st[500000])
img_points_c1 = ((img_points_c+np.array([pose.width/2, pose.height/2])).astype(int))
img_points_c1

array([[   0,    0],
       [   0,   60],
       [   0,  120],
       ...,
       [2126, 3660],
       [2126, 3720],
       [2126, 3780]])

In [199]:
color2 = data.data.reshape(-1, 3)
((px.flatten()+pose.width/2).astype(int))
((py.flatten()+pose.height/2).astype(int))

array([   0,   60,  120, ..., 3660, 3720, 3780])

In [209]:
color3 = data.data[((py.flatten()+pose.height/2).astype(int)), ((px.flatten()+pose.width/2).astype(int))]
color3 = color3.astype(float)/255

In [210]:
print(st[500])
print(color3[500])
print(len(st))
print(len(color3))

[[-0.15818907  0.18253968  1.20027103]
 [-0.18455391  0.21296297  1.4003162 ]
 [-0.21091876  0.24338625  1.60036137]
 [-0.2372836   0.27380953  1.80040654]
 [-0.26364845  0.30423281  2.00045171]
 [-0.29001329  0.33465609  2.20049689]
 [-0.31637814  0.36507937  2.40054206]
 [-0.34274298  0.39550265  2.60058723]
 [-0.36910782  0.42592593  2.8006324 ]
 [-0.39547267  0.45634921  3.00067757]
 [-0.42183751  0.48677249  3.20072274]
 [-0.44820236  0.51719577  3.40076791]
 [-0.4745672   0.54761905  3.60081308]
 [-0.50093205  0.57804233  3.80085826]
 [-0.52729689  0.60846562  4.00090343]
 [-0.55366174  0.6388889   4.2009486 ]
 [-0.58002658  0.66931218  4.40099377]
 [-0.60639143  0.69973546  4.60103894]
 [-0.63275627  0.73015874  4.80108411]
 [-0.65912112  0.76058202  5.00112928]]
[0.51372549 0.64705882 0.69803922]
4096
4096


In [181]:
color = []
for x, y in img_points_c1:
    # print(x, y)
    color.append(data.data[y][x])

color = np.array(color)/255
color

array([[0.77647059, 0.68627451, 0.67843137],
       [0.62745098, 0.56078431, 0.57647059],
       [0.85098039, 0.82745098, 0.82352941],
       ...,
       [0.33333333, 0.4627451 , 0.54117647],
       [0.32156863, 0.43921569, 0.51372549],
       [0.29803922, 0.38039216, 0.45098039]])

In [182]:
# data.data[810][1350]
# print(pose.width, pose.height)
# print(img_points_c+np.array([pose.width/2, pose.height/2]))
# px+pose.width/2
# py+pose.height/2

In [208]:
go.Figure(data=[
    go.Scatter3d(
        x=[0],
        y=[0],
        z=[0],
        mode='markers',
        marker={
            'size': 10,
            'opacity': 1.0,
        }
    ),
    go.Scatter3d(
        x=rpx,
        y=rpy,
        z=rpz,
        mode='markers',
        marker={
            'size': 5,
            'opacity': 1.0,
            "color": color3,
        }
    ),
    go.Scatter3d(
        x=spx.flatten(),
        y=spy.flatten(),
        z=spz.flatten(),
        mode='markers',
        marker={
            'size': 1,
            'opacity': 1.0,
        }
    ),
    go.Scatter3d(
        x=stx.flatten(),
        y=sty.flatten(),
        z=stz.flatten(),
        mode='markers',
        marker={
            'size': 1,
            'opacity': 1.0,
        }
    ),
], layout=go.Layout(
    margin={'l': 0, 'r': 0, 'b': 0, 't': 0},
    scene=dict(
        aspectmode="cube",
        xaxis=dict(range=[-5, 5]),  # Set x-axis range
        yaxis=dict(range=[-5, 5]),  # Set y-axis range
        zaxis=dict(range=[-5, 5])   # Set z-axis range
    )
)) # 1.0529453 ,  0.47962357, -0.76794111