In [None]:
#Imports and dependencies
import os
from os.path import join, abspath, dirname
import sys
sys.path.insert(0, abspath(join("..", dirname(os.getcwd()))))

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import pandas as pd
import random
import imageio
from skimage import img_as_ubyte
from scipy.ndimage.morphology import binary_dilation
from itertools import product
from typing import List
from tqdm import tqdm_notebook
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.transforms import Rotate, Translate
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.renderer import (
    SfMPerspectiveCameras, OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation,
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex, HardFlatShader
)
from pytorch3d.loss import (
    mesh_laplacian_smoothing, 
    mesh_normal_consistency
)
from torchvision import transforms

from dataclasses import dataclass, field, asdict, astuple
import numpy as np
#Plotting Libs
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
from datetime import datetime
import time
from copy import deepcopy

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

import json

from utils.visualization import plot_pointcloud
from utils.shapes import Sphere, SphericalSpiral
from utils.manager import RenderManager, ImageManager

In [None]:
#Matplotlib config nums
mpl.rcParams['savefig.dpi'] = 90
mpl.rcParams['figure.dpi'] = 90
#Set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device == "cuda:0": torch.cuda.set_device()

### Create a Renderer

In [None]:
img_size = (128, 128)

cameras = SfMPerspectiveCameras(device=device)

# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of 
# edges. Refer to blending.py for more details. 
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that 
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 
# the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size=img_size[0], 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=100, 
)

# Create a silhouette mesh renderer by composing a rasterizer and a shader. 
silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)


# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=img_size[0], 
    blur_radius=1e-5, 
    faces_per_pixel=1, 
)
# We can add a point light in front of the object. 
lights = PointLights(
    device=device, 
    location=[[3.0, 3.0, 0.0]], 
    diffuse_color=((1.0, 1.0, 1.0),),
    specular_color=((1.0, 1.0, 1.0),),
)
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=HardFlatShader(device=device, lights=lights, cameras=cameras)
)

### Event Renderer

In [None]:
ON = 254
OFF = 0
threshold = 254.5/255

def gray(img):
    return np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])

def event_renderer(img1, img2, render_type):
    
    if render_type == "phong":
        img1 = gray(img1)
        img2 = gray(img2)
        
    diff_frames = img2 - img1
    
    threshold_diff = (diff_frames < threshold) * diff_frames

    tanh_diff = np.tanh(threshold_diff)
    #tanh_diff = np.where(tanh_diff != 0, OFF, ON)
    
    return tanh_diff

### Create a trajectory and Render

### Diff Model

In [None]:
def neg_iou_loss(predict, target):
    dims = tuple(range(predict.ndimension())[1:])
    intersect = (predict * target).sum(dims)
    union = (predict + target - predict * target).sum(dims) + 1e-6

    return 1.0 - (intersect / union).sum() / intersect.nelement()


def intersection_and_union(pred, target, batch=None):
    r"""Computes intersection and union of predictions.

    Args:
        pred (LongTensor): The predictions.
        target (LongTensor): The targets.
        num_classes (int): The number of classes.
        batch (LongTensor): The assignment vector which maps each pred-target
            pair to an example.

    :rtype: (:class:`LongTensor`, :class:`LongTensor`)
    """
    pred, target = F.one_hot(pred.to(torch.int64), 2), F.one_hot(target.to(torch.int64), 2)
    if batch is None:
        i = (pred.detach().cpu() & target).sum(dim=0)
        u = (pred.detach().cpu() | target).sum(dim=0)
    else:
        i = scatter_add(pred & target, batch, dim=0)
        u = scatter_add(pred | target, batch, dim=0)

    return i, u



def mean_iou(pred, target, batch=None):
    r"""Computes the mean intersection over union score of predictions.

    Args:
        pred (LongTensor): The predictions.
        target (LongTensor): The targets.
        num_classes (int): The number of classes.
        batch (LongTensor): The assignment vector which maps each pred-target
            pair to an example.

    :rtype: :class:`Tensor`
    """
    i, u = intersection_and_union(pred, target, batch)
    iou = i.to(torch.float) / u.to(torch.float)
    iou[torch.isnan(iou)] = 1
    iou = iou.mean(dim=-1)
    return iou

In [None]:
class MeshDeformationModel(nn.Module):
    
    def __init__(self, device, template_mesh = None):
        super().__init__()
        
        self.device = device
        
        #Create a source mesh
        if not template_mesh:
            template_mesh = ico_sphere(2, device)
        
        verts, faces = template_mesh.get_mesh_verts_faces(0)
        #Initialize each vert to have no tetxture
        verts_rgb = torch.ones_like(verts)[None]
        textures = TexturesVertex(verts_rgb.to(self.device))
        self.template_mesh = Meshes(
            verts=[verts.to(self.device)],
            faces=[faces.to(self.device)],
            textures = textures
        )
        
        self.register_buffer('vertices', self.template_mesh.verts_padded() * 1.3)
        self.register_buffer('faces', self.template_mesh.faces_padded())
        self.register_buffer('textures', textures.verts_features_padded())
        
        deform_verts = torch.zeros_like(self.template_mesh.verts_packed(), device=device, requires_grad=True)
        #deform_verts = torch.full(self.template_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)
        #Create an optimizable parameter for the mesh
        self.register_parameter('deform_verts', nn.Parameter(deform_verts).to(self.device))
        
        laplacian_loss = mesh_laplacian_smoothing(template_mesh, method="uniform")
        flatten_loss = mesh_normal_consistency(template_mesh)
        
    def forward(self, batch_size):
        #Offset the mesh
        deformed_mesh_verts = self.template_mesh.offset_verts(self.deform_verts)
        texture = TexturesVertex(self.textures)
        deformed_mesh = Meshes(verts=deformed_mesh_verts.verts_padded(), faces=deformed_mesh_verts.faces_padded(), textures=texture)
        deformed_meshes = deformed_mesh.extend(batch_size)
    
        laplacian_loss = mesh_laplacian_smoothing(deformed_mesh, method="uniform")
        flatten_loss = mesh_normal_consistency(deformed_mesh)
        
        return deformed_meshes, laplacian_loss, flatten_loss
    

# IMAGE-BASED
### Initialize the model

In [None]:
def sample_tensor(t, batch_size, indices = None, strategy: str = "random", section: int = 6):
    """A fully parameterized function that samples tensors according to different strategies explained below
    """
    l = t.shape[0]
    if l < batch_size:
        return
    if not indices:
        #Selects #batch_size random indices across the entire tensor
        if strategy == "random":
            step = int(l / batch_size)
            start = random.randint(0, step - 1)
            indices = list(range(start, l, step))
        #Chooses random ind from random sextant of tensor
        elif strategy == "random-section":
            section_size = l / section
            section_delim = random.randint(0, section  - 1)
            quadrant = list(range(int(section_delim * section_size), int((section_delim + 1) * section_size)))
            indices = random.sample(quadrant, batch_size)
        #Chooses a random sequential section from random sextant of tensor
        elif strategy == "sequential-section":
            section_size = l / section
            section_delim = random.randint(0, section  - 1)
            quadrant = list(range(int(section_delim * section_size), int((section_delim + 1) * section_size)))
            start = random.randint(0, len(quadrant) - batch_size)
            print((quadrant[0], quadrant[-1]), start)
            indices = quadrant[start: start + batch_size]
    return t[indices], indices

## Generate Experiments indices

In [None]:
#Create path to tests folder
tests_path = "../data/tests/DiffRecon-SampleStrategy15-Dolphin"
fixed_sample_size = 15
ratios = [.25, .50, .75, .100]
#Fetch indices
"""
with open(join(tests_path, "indices.json"), 'r') as f:
    indices = json.load(f)
Generate all the good indices and create sets of experiments
experiments = {}
good_indices = set(range(360)) - set([elem for ind in indices.values() for elem in ind])
for name, ind in indices.items():
    if name == "outliers":
        continue
    for ratio in ratios:
        num_clean = int((1 - ratio) * batch_size)
        num_noisy = batch_size - num_clean
        
        ind_clean = sorted(random.choices(list(good_indices), k = num_clean))
        ind_noisy = sorted(random.choices(ind, k = num_noisy))
        
        experiment_indices = ind_clean + ind_noisy
        
        experiments[f"{name}_{int(100*ratio)}-{100-int(100*ratio)}"] = experiment_indices
"""
experiments = {}
number_of_runs = 5
"""
experiments["random_15"] = [sorted(random.choices(list(range(360)), k = 15)) for i in range(5)]
experiments["random_10"] = [sorted(random.choices(list(range(360)), k = 10)) for i in range(5)]
experiments["random_7"] = [sorted(random.choices(list(range(360)), k = 7)) for i in range(5)]
experiments["random_18"] = [sorted(random.choices(list(range(360)), k =18)) for i in range(5)]
experiments["random_4"] = [sorted(random.choices(list(range(360)), k = 4)) for i in range(5)]
experiments["random_3"] = [sorted(random.choices(list(range(360)), k = 3)) for i in range(5)]
experiments["random_2"] = [sorted(random.choices(list(range(360)), k = 2)) for i in range(5)]
"""
#experiments["random_15"] = [sorted(random.choices(list(range(fixed_sample_size), k = 15)) for i in range(5)]
#experiments["random_10"] = [sorted(random.choices(list(range(fixed_sample_size)), k = 10)) for i in range(5)]
#experiments["random_8"] = [sorted(random.choices(list(range(fixed_sample_size)), k = 8)) for i in range(5)]
#experiments["random_18"] = [sorted(random.choices(list(range(360)), k =18)) for i in range(5)]
#experiments["random_4"] = [sorted(random.choices(list(range(fixed_sample_size)), k = 4)) for i in range(5)]
#experiments["random_3"] = [sorted(random.choices(list(range(fixed_sample_size)), k = 3)) for i in range(5)]
#experiments["random_2"] = [sorted(random.choices(list(range(fixed_sample_size)), k = 2)) for i in range(5)]
#experiments["random_1"] = [sorted(random.choices(list(range(fixed_sample_size)), k = 1)) for i in range(5)]

#experiments["random_strategy_1"] = (["random"] * number_of_runs, 1)
#experiments["random_section_strategy_1"] = (["random-section"] * number_of_runs, 1)
#experiments["sequential_section_strategy_1"] = (["sequential-section"] * number_of_runs, 1)

experiments["random_strategy_5"] = (["random"] * number_of_runs, 5)
experiments["random_section_strategy_5"] = (["random-section"] * number_of_runs, 5)
experiments["sequential_section_strategy_5"] = (["sequential-section"] * number_of_runs, 5)

experiments["random_strategy_8"] = (["random"] * number_of_runs, 8)
experiments["random_section_strategy_8"] = (["random-section"] * number_of_runs, 8)
experiments["sequential_section_strategy_8"] = (["sequential-section"] * number_of_runs, 8)

## Run Optimization

In [None]:
#Create path to tests folder
for experiment_name, experiment_tuple in experiments.items():
    
    experiment_path = join(tests_path, experiment_name)
    os.makedirs(experiment_path, exist_ok=True)
         
    strategies, batch_size = experiment_tuple   
    
    for num, strategy in enumerate(strategies):
        
        print(strategy, batch_size)
        
        path = join(experiment_path, str(num))
        os.makedirs(path, exist_ok=True)
        
        weight_silhouette = 1
        weight_laplacian = .1
        weight_flatten = .001

        #Create a loss plotting object
        #loss_ax = plot_loss(num_losses = 3)

        # Initialize a model using the renderer, template mesh and reference image
        model = MeshDeformationModel(device).to(device)

        # Create an optimizer. Here we are using Adam and we pass in the parameters of the model
        optimizer = torch.optim.Adam(model.parameters(), lr=.001, betas=(0.5, 0.99)) #Hyperparameter tuning

        #render = RenderManager.from_directory(dir_num=60)
        render = RenderManager.from_path("../data/renders/test1_dolphin/001-dolphin_2020-10-12T11:14:51")
        R, T = render._trajectory
        R, indices = sample_tensor(R, fixed_sample_size, strategy = strategy)
        print(indices)
        T, _ = sample_tensor(T, fixed_sample_size, indices)
        images_gt = render._images(type_key="silhouette", img_size = img_size).to(device) / 255.
        images_gt, _ = sample_tensor(images_gt, fixed_sample_size, indices)
        
        results = {}
        results['indices'] = indices

        # We will save images periodically and compose them into a GIF.
        filename_output = join(path, "projection_loss.gif")
        writer = imageio.get_writer(filename_output, mode='I', duration=0.1)

        loop = tqdm_notebook(range(2000))
        laplacian_losses = []
        flatten_losses = []
        silhouette_losses = []

        start_time = time.time()
        for i in loop:
            
            indices = sorted(random.choices(range(len(images_gt)), k = batch_size))
            batch_R, batch_T = R[indices], T[indices]
            batch_cameras = SfMPerspectiveCameras(device=device, R=batch_R, T=batch_T)
            batch_images = images_gt[indices]

            mesh, laplacian_loss, flatten_loss = model(batch_size)

            images_pred = silhouette_renderer(mesh.clone(), device=device, cameras=batch_cameras)

            silhouette_loss = neg_iou_loss(batch_images, images_pred[...,-1])
            #silhouette_loss = mean_iou(images_pred[...,-1], images_gt)
            #print(silhouette_loss)
            #silhouette_loss = torch.sum((images_gt - images_pred[...,-1]) ** 2)

            loss = silhouette_loss * weight_silhouette + laplacian_loss * weight_laplacian + flatten_loss * weight_flatten

            loop.set_description('Optimizing (loss %.4f)' % loss.data)

            silhouette_losses.append(silhouette_loss * weight_silhouette)
            laplacian_losses.append(laplacian_loss * weight_laplacian)
            flatten_losses.append(flatten_loss * weight_flatten)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                #Write images
                image = images_pred.detach().cpu().numpy()[0][...,-1]

                writer.append_data((255*image).astype(np.uint8))
                #imageio.imsave(join(path, f"mesh_{i}.png"), (255*image).astype(np.uint8))

                f, (ax1, ax2) = plt.subplots(1, 2)

                image = img_as_ubyte(image)
                ax1.imshow(image)
                ax1.set_title("Deformed Mesh")

                ax2.plot(silhouette_losses, label="Silhouette Loss")
                ax2.plot(laplacian_losses, label="Laplacian Loss")
                ax2.plot(flatten_losses, label="Flatten Loss")
                ax2.legend(fontsize="16")
                ax2.set_xlabel("Iteration", fontsize="16")
                ax2.set_ylabel("Loss", fontsize="16")
                ax2.set_title("Loss vs iterations", fontsize="16")

                plt.show()

        #Save obj, gif, individual losses, mesh similarity metric
        verts, faces = mesh.get_mesh_verts_faces(0)
        save_obj(join(path, "mesh.obj"), verts, faces)

        def plot_and_save(elems: list, name: str):
            plt.plot(elems, label=f"{name} Loss")
            plt.legend(fontsize="16")
            plt.xlabel("Iteration")
            plt.ylabel("Loss")
            plt.title("Loss vs iterations")
            plt.savefig(join(path, f"{name}_loss.png"))
            plt.close()

        plot_and_save(silhouette_losses, "silhouette")
        plot_and_save(laplacian_losses, "laplacian")
        plot_and_save(flatten_losses, "flatten")

        results["silhouette_loss"] = [s.detach().cpu().numpy().tolist() for s in silhouette_losses]
        results["laplacian_loss"] = [l.detach().cpu().numpy().tolist() for l in laplacian_losses]
        results["flatten_loss"] = [f.detach().cpu().numpy().tolist() for f in flatten_losses]

        writer.close()
        
        gif = imageio.get_writer(join(path, "final_mesh.gif"), mode="I", duration=.1)
        
        all_R, all_T = render._trajectory
        all_images_gt = render._images(type_key="silhouette", img_size = img_size).to(device) / 255.
        all_images_pred = []
        for i in range(0, 360, len(mesh)):
            cameras = SfMPerspectiveCameras(device=device, R=all_R[i:i+len(mesh)], T=all_T[i:i+len(mesh)])
            images = silhouette_renderer(mesh.clone(), device=device, cameras=cameras)
            for img in images:
                gif.append_data(img[...,-1].detach().cpu().numpy())
            all_images_pred.append(images)
        all_images_pred = torch.cat(all_images_pred)
        gif.close()
        #imageio.mimsave(join(experiment_path, "final_mesh.gif"), all_images_pred.cpu().numpy().tolist())
    
        
        results["mean_iou"] = neg_iou_loss(all_images_gt, all_images_pred[...,-1]).detach().cpu().numpy().tolist()
        
        results["iterations_per_second"] = 2000 / (time.time() - start_time)

        with open(join(path, "results.json"), 'w') as f:
            json.dump(results, f)
    
    iou = []
    speed = []
    for i in range(len(strategies)):
        with open(join(experiment_path, f"{i}/results.json")) as f:
            res = json.load(f)
            iou.append(res["mean_iou"])
            speed.append(res["iterations_per_second"])
    out = dict(
        mean_iou=sum(iou) / len(iou),
        mean_iterations_per_second=sum(speed) / len(speed))
    
    with open(join(experiment_path, "results.json"), 'w') as f:
        json.dump(out, f)
