In [None]:
import numpy as np
import torch
from torchvision import transforms as T
from models import bird_model, load_regressor
from utils.vis_bird import render_sample, render_sample_new
from datasets import Cowbird_Dataset
from keypoint_detection import load_detector, postprocess

%load_ext autoreload
%autoreload 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
bird = bird_model()
regressor = load_regressor().to(device)
predictor = load_detector().to(device)

In [None]:
normalize = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
        ])
unnormalize = T.Compose([
    T.Normalize(mean=[0, 0, 0], std=[1/0.225, 1/0.224, 1/0.229]),
    T.Normalize(mean=[-0.406, -0.456, -0.485], std=[1, 1, 1])
    ])
valid_set = Cowbird_Dataset('data/cowbird/images', 'data/cowbird/annotations/instance_test.json', transform=normalize)

In [None]:
poses = []
trans = []
bones = []
for i in range(len(valid_set)):
    imgs, target_kpts, target_masks, meta = valid_set[i]
    imgs = imgs[None]
    with torch.no_grad():
        # Prediction
        output = predictor(imgs.to(device))
        pred_kpts, pred_mask = postprocess(output)
        print(pred_kpts)
        print(pred_kpts.shape)
        # Regression
        kpts_in = pred_kpts.reshape(pred_kpts.shape[0], -1)
        print(kpts_in.shape)
        mask_in = pred_mask
        p_est, b_est = regressor(kpts_in, mask_in)
        print(p_est.shape, b_est.shape)
        pose, tran, bone = regressor.postprocess(p_est, b_est)
        poses.append(p_est.squeeze().cpu().numpy())
        # trans.append(tran.squeeze().cpu().numpy())
        bones.append(b_est.squeeze().cpu().numpy())
        # print(pose.shape, tran.shape, bone.shape)

In [None]:
print(poses)

In [None]:
# Gaussian to pose
# Noise to bone

"""
we fit a
multivariate Gaussian to the estimated pose parameters (pose, viewpoint, and
translation). We then sample 100 random points from this distribution for each
bird instance, project the corresponding model's visible keypoints onto the camera 
and render the silhouette, generating 14,000 synthetic instances for training.
We keep the bone lengths of the original 140 instances, but add in random noise
to the bone lengths for each sample.
"""

In [None]:
poses = np.asarray(poses)
bones = np.asarray(bones)

poses.shape, bones.shape

In [None]:
from sklearn import decomposition, mixture
from matplotlib import pyplot as plt


In [None]:
losses = []
i_s = []
mu = np.average(poses, axis=0)
for i in range(2, 114):
    pca = decomposition.PCA(n_components=i)
    poses_pca = pca.fit(poses)
    nComp = i
    Xhat = np.dot(pca.transform(poses)[:,:nComp], pca.components_[:nComp,:])
    Xhat += mu
    # print(Xhat.shape, poses.shape)
    loss = np.sum(np.abs(poses - Xhat) ** 2)
    losses.append(loss)
    i_s.append(i)

In [None]:
plt.plot(i_s, losses)
print(losses[19], losses[39], losses[79])

In [None]:
# fit
mu, cov = np.mean(poses, axis=0), np.cov(poses, rowvar=0)
print(mu.shape, cov.shape)

In [None]:
# sample
sample = np.random.multivariate_normal(mu, cov, size=100000)
print(sample.shape)

In [None]:
mu_bone = np.mean(bones, axis=0)
print(mu_bone.shape)

In [None]:
import torch
from models import bird_model, load_regressor
from utils.vis_bird import render_sample, render_sample_new
from optimization import OptimizeSV, base_renderer
sample = np.random.multivariate_normal(mu, cov, size=1)
p_est = sample[0][None,:]
b_est = (bones[0] + np.random.normal(loc=0, size=(24), scale=0.2))[None,:]
print(p_est.shape, b_est.shape)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
bird = bird_model()

p_est = torch.from_numpy(p_est)
b_est = torch.from_numpy(b_est)

pose, tran, bone = regressor.postprocess(p_est, b_est)
print(pose.shape, tran.shape, bone.shape)

optimizer = OptimizeSV(num_iters=0, prior_weight=1, mask_weight=1, 
                               use_mask=False, device=device)

global_t = tran.clone()
bone_length = bone.clone()

init_pose = optimizer.transform_p(pose)
global_orient = init_pose.clone()[:, :3]
body_pose = init_pose.clone()[:, 3:]

bird_output = bird(global_orient, body_pose, bone_length)
global_txyz = optimizer.transform_t(global_t)

model_mesh = bird_output['vertices'] + global_txyz.unsqueeze(1).to(torch.float)

# print(bird_output.shape)

img_opt, mask = render_sample_new(bird, model_mesh[0])


In [None]:
plt.imshow(img_opt)

In [None]:
plt.imshow(mask, cmap='gray')

In [None]:
print(global_txyz.unsqueeze(1).to(torch.float))

In [None]:
print(bird_output.keys())

In [None]:
print(bird_output['keypoints'] + global_txyz.unsqueeze(1).to(torch.float))

In [None]:
from utils.renderer_p3d import RendererP3D

kps = bird_output['keypoints'] + global_txyz.unsqueeze(1).to(torch.float)

render = RendererP3D(faces=bird.dd['F'])
cameras = render.cameras

from utils.geometry import perspective_projection

print(cameras.get_projection_transform().transform_points(kps))

proj_kps = perspective_projection(kps, None, None, 2167, 128, None)
print(proj_kps)

In [None]:
plt.imshow(img_opt)
plt.scatter(proj_kps.squeeze()[:,0], proj_kps.squeeze()[:,1])

In [None]:
import torchvision.transforms.functional as F
from utils.img_utils import draw_kpts
# for i in range(len(valid_set)):
# 8, 20, 24, 29!
i = 0
imgs, target_kpts, target_masks, meta = valid_set[i]
imgs = unnormalize(imgs)
imgs = F.to_pil_image(imgs)
plt.imshow(imgs)
plt.scatter(target_kpts[:, 0], target_kpts[:, 1], c='red')

In [None]:
from datasets.syn_dataset import synDataset
train_set = Cowbird_Dataset('data/cowbird/images', 'data/cowbird/annotations/instance_train.json', transform=normalize)
syn_dataset = synDataset(train_set)

syn_dataset[0]

In [None]:
test_int = syn_dataset[0]
# plt.scatter(test_int[2][:, 0], test_int[2][:, 1], c='red')
# plt.imshow(test_int[0])
# print(syn_dataset[0][0].shape)


In [None]:
test_int = syn_dataset[0]
# plt.scatter(test_int[2][:, 0], test_int[2][:, 1], c='red')
# plt.imshow(test_int[0])
# print(syn_dataset[0][0].shape)


In [None]:
# define model loss and training process.
from models.mesh_regressor import mesh_regressor

model = mesh_regressor()
loss_fn = torch.nn.MSELoss()
dataloader = torch.utils.data.DataLoader(syn_dataset, 16, shuffle=True, num_workers=4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model = model.cuda()
model.train()

loss_total = 0

for epoch in range(20):
    for i, data in enumerate(dataloader):
        _, mask, kps, p_gt, b_gt = data
        
        mask = mask.cuda()
        kps = kps.cuda()
        p_gt = p_gt.cuda()
        b_gt = b_gt.cuda()

        # print(mask.shape, kps.shape, p_gt.shape, b_gt.shape)

        pose_tran, bone = model(kps, mask)
        # print(pose_tran.shape, bone.shape)
        loss = loss_fn(pose_tran, p_gt) + loss_fn(bone, b_gt)
        loss_total += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    
    batch_loss = loss_total / len(dataloader)
    print(f"Epoch {epoch+1}: {batch_loss}")

    torch.save(model.state_dict(), f"./{epoch+1}.pth")

