In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import cv2, json, torch, os, re

import util, constants, draw, cam, model, parse_alphapose

np.random.seed(0)
torch.manual_seed(0)

VID_ROOT = '/home/akarshkumar0101/Insync/akarshkumar0101@gmail.com/Google Drive/nba-3d-data/harden/'
DATA_ROOT = '/home/akarshkumar0101/Insync/akarshkumar0101@gmail.com/Google Drive/nba-3d-data/'

img_shape_yx = plt.imread(VID_ROOT+'/all_views/frame_00001.png').shape[:2];img_shape_xy = img_shape_yx[::-1]

uf_mat_int_default = cam.get_mat_intrinsic()
uf_mat_int = cam.get_intrinsic_mat_for_img_shape(img_shape_xy)

print(f'img_shape_xy: {img_shape_xy}')

img_shape_xy: (1280, 720)


# Loss function

In [None]:
"""
alpha= 0.0 will take the mean.
alpha= 0.2 will be very close to max.
alpha=-0.2 will be very close to min.
"""
def smooth_max(x, alpha, dim=-1):
    # unstable version:
    # return (x*(alpha*x).exp()).sum()/((alpha*x).exp()).sum()
    return ((alpha*x).softmax(dim=dim)*x).sum(dim=dim)

def pt2pt_dist(p1, p2):
    # (..., N1, D) and (..., N2, D)
    return (p1[..., None, :]-p2[..., None, :, :]).norm(dim=-1).min(dim=-1).values

def pt2pt_dist_soft(p1, p2, alpha=-.2):
    # (..., N1, D) and (..., N2, D)
    return smooth_max((p1[..., None, :]-p2[..., None, :, :]).norm(dim=-1), alpha=alpha, dim=-1)

def fitness(dofs_cam, X_i_true):
    X_i, vis_mask = cam.project_to_cam(X_w_down, dofs_cam)
    X_i = X_i[vis_mask]
    if(X_i.shape[0]==0):
        return np.inf
    p = X_i.shape[0] / X_w_down.shape[0]
    loss_vis = -np.log(p)
    loss_pt2pt = pt2pt_dist(X_i, X_i_true).mean().item()
    loss_small_f = (-dofs_cam[..., 6]).exp().mean().item()
#     print(loss_vis, loss_pt2pt, loss_small_f)
    return loss_vis+loss_pt2pt+loss_small_f

# Energy Minimization with Population Based Search

In [None]:
import cma

In [None]:
torch.manual_seed(0)
x = torch.distributions.MultivariateNormal(torch.zeros(8), torch.eye(8)).sample()
# x[6] = 1.
# x[7] = 0.

dofs_cam_true, _ = gaussian2cam(x)

In [None]:
draw.show_cam_view(X_w_down, dofs_cam_true)

In [None]:
X_i_true, vis_mask = cam.project_to_cam(X_w_down, dofs_cam_true)
X_i_true = X_i_true[vis_mask]

In [None]:
es = cma.CMAEvolutionStrategy(8*[0], 1.0)
i = 0
while not es.stop():
    solutions = es.ask(100)
    print(i)
#     for x in solutions:
#         x[6] = 1.
#         x[7] = 0.
    fitnesses = [fitness(gaussian2cam(torch.tensor(x).float())[0], X_i_true) for x in solutions]
    es.tell(solutions, fitnesses)
#     es.logger.add()
#     es.disp()
    i+=1
    if i%1==0:
        draw.show_cam_view(X_w, dofs_cam_true, label='ground truth')
        draw.show_cam_view(X_w, gaussian2cam(torch.tensor(es.best.x).float())[0], label='best so far')
        plt.legend(bbox_to_anchor=(1.4, 1.))
        plt.show()

In [None]:
x_true = torch.distributions.MultivariateNormal(torch.zeros(8), torch.eye(8)).sample()
x = torch.distributions.MultivariateNormal(torch.zeros(8), torch.eye(8)).sample((1000,))


dists = (x-x_true).norm(dim=-1)
plt.hist(dists.numpy())

In [None]:
# es.result_pretty()
# cma.plot()

# Energy Minimization with Supervised Rendering Loss


In [None]:
dofs_cam = cam.calc_dofs_cam(torch.tensor([.6,.6, 0.6]), 
                             torch.tensor([0., 0., 0.]), 
                             fxy=torch.tensor([0., 0.]))
dofs_cam_optim = dofs_cam.clone()
torch.manual_seed(10)
dofs_cam_optim = dofs_cam_optim + 0.1*torch.randn(dofs_cam_optim.shape)

draw.show_cam_view(X_w, dofs_cam, label='ground truth')
draw.show_cam_view(X_w, dofs_cam_optim, label='initial')
plt.legend(bbox_to_anchor=(1.4, 1.))

In [None]:
dofs_cam_optim_p = dofs_cam_optim[:6]
opt = torch.optim.Adam([{'params': dofs_cam_optim_p, 'lr': 0.01}])
dofs_cam_optim_p.requires_grad_()

X_true, vis_mask = util.project_to_cam(X_w_down, dofs_cam)
X_true = X_true[vis_mask]

for i in tqdm(range(40)):
    dofs_cam_optim = torch.cat([dofs_cam_optim_p, torch.tensor([0., 0.])])
    
    
    X_i, vis_mask = util.project_to_cam(X_w_down, dofs_cam_optim)
    X_i = X_i[vis_mask]
    
#     loss = pt2pt_dist(X_i, X_true)
    loss1 = pt2pt_dist(X_i, X_true)
    loss2 = pt2pt_dist_soft(X_true, X_i)
    loss = (loss1).mean()#+(loss2*1000).pow(2).mean()
    print('loss: ', loss.item())
    print('pt2pt: ', pt2pt_dist(X_i, X_true).mean().item())
    print(dofs_cam_optim[6:])
    print()
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if i%10==0:
        draw.show_cam_view(X_w, dofs_cam)
        draw.show_cam_view(X_w, dofs_cam_optim.detach())
        plt.legend(bbox_to_anchor=(1.4, 1.))
        plt.show()

In [None]:
dofs_cam_optim = dofs_cam_optim.detach()

In [None]:
draw.show_cam_view(X_w, dofs_cam, label='ground truth')
draw.show_cam_view(X_w, dofs_cam_optim, label='initial')
plt.legend(bbox_to_anchor=(1.4, 1.))