In [None]:
#Imports and dependencies
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

from os.path import join, abspath, dirname
import sys
sys.path.insert(0, abspath(join("..", dirname(os.getcwd()))))
          
import re
import random
import torch
import imageio
from PIL import Image, ImageOps
from tqdm import tqdm_notebook
from skimage import img_as_ubyte
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.io import load_obj, save_obj
from pytorch3d.datasets import (
    ShapeNetCore,
    collate_batched_meshes
)
from pytorch3d.loss import (
    mesh_laplacian_smoothing, 
    mesh_normal_consistency
)
from pytorch3d.renderer import (
    PerspectiveCameras, RasterizationSettings, MeshRenderer, MeshRasterizer, 
    BlendParams, SoftSilhouetteShader, SoftPhongShader, PointLights, TexturesVertex, 
    TexturesAtlas, HardPhongShader, HardFlatShader, look_at_view_transform
)
from pytorch3d.io import load_objs_as_meshes

from dataclasses import dataclass, field, asdict, astuple

import numpy as np
#Plotting Libs
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl


from synth_dataset.trajectory import cam_trajectory
from synth_dataset.mesh import (
    load_meshes, mesh_random_translation, rotate_mesh_around_axis, 
    translate_mesh_on_axis, scale_mesh
)
from synth_dataset.event_renderer import generate_event_frames

from utils.visualization import plot_trajectory_cameras
from utils.manager import RenderManager, ImageManager

In [None]:

mpl.rcParams['savefig.dpi'] = 150
mpl.rcParams['figure.dpi'] = 150
#Set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

In [None]:
shapenet_path = "../data/ShapeNetCorev2"
car_dir = '02958343'
chair_dir = '02691156'
plane_dir = '03001627'
synsets = [car_dir, chair_dir, plane_dir]
shapenet_dataset = ShapeNetCore(shapenet_path, version=2, synsets=synsets)

### PARAMS

In [None]:
@dataclass
class RenderParams:
    
    img_size: int = (280, 280)
    sigma_hand: float = .15
    
    #Size of the dataset
    batch_size: int = 360
    data_batch_size: int = int(360 / 8) #what we actually save from the mesh batch
    mesh_iter: int = 5
        
    show_frame: bool = False

### Create a Renderer

In [None]:
cameras = PerspectiveCameras(device=device)

# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of 
# edges. Refer to blending.py for more details. 
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that 
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 
# the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size= RenderParams.img_size[0], 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=100, 
)

# Create a silhouette mesh renderer by composing a rasterizer and a shader. 
silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)


# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=RenderParams.img_size[0], 
    blur_radius=0, 
    faces_per_pixel=1,
    max_faces_per_bin=500000
)
# We can add a point light in front of the object. 
#lights = PointLights(device=device, location=((2., 2.0, 2.0),))
lights = PointLights(
    device=device, 
    location=[[3.0, 3.0, 0.0]], 
    diffuse_color=((0.3, 0.3, 0.3),),
    specular_color=((0.3, 0.3, 0.3),),
)

phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=HardFlatShader(device=device, lights=lights, cameras=cameras)
)

In [None]:
# Load obj file
path = "../data/meshes/dolphin/dolphin.obj"
mesh = load_objs_as_meshes([path], create_texture_atlas=False,load_textures=True , device=device)
#mesh.textures = TexturesVertex(
#                verts_features=torch.ones_like(mesh.verts_padded(), device=device)
#            )
#verts, faces, _ = load_obj(path)
#verts_rgb = torch.zeros_like(verts)
#textures = TexturesVertex(verts_features=[verts_rgb.to(device)])
#mesh = Meshes([verts], [faces.verts_idx], textures)
#mesh = mesh.to(device)

verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
scale = max((verts - center).abs().max(0)[0])
mesh.offset_verts_(-center.expand(N, 3))
mesh.scale_verts_((1.0 / float(scale)));


mesh = rotate_mesh_around_axis(mesh, [110,90,180], phong_renderer,dist=2, device=device)
#mesh = translate_mesh_on_axis(mesh, [0,-20,-50], phong_renderer, dist=5)
#verts, faces = mesh.get_mesh_verts_faces(0)
#save_obj("../data/meshes/plane_WWII/plane_WWII.obj", verts, faces)


In [None]:
def merge_with_background(image, background = None, show: bool = False):

    def background_generator():
            files = []
            background_folder = "../data/sun360"
            files = os.listdir(background_folder)
            rand_file = random.randint(0, len(files) - 1)
            rand_file_path = abspath(join(background_folder, files[rand_file]))
            rand_file_path = abspath(join(background_folder, 'pano_530c02959a7fdf8fdda4bac494ba3724'))
            files = os.listdir(rand_file_path)
            for img_num in range(len(files)):
                path_img = join(rand_file_path, f"{img_num}.jpg")
                img = Image.open(path_img).resize(RenderParams.img_size)
                yield np.array(img).astype(np.uint8)
    
    if background is None:
        background = background_generator()

    #Image.fromarray((np.array(image) * 255).astype(np.uint8)).save("test.png")
    #image = Image.open("test.png")
    
    image = img_as_ubyte(np.clip(image,0, 1))[...,:3]
    
    image = np.array(image).astype(np.uint8)
    
    try:
        image_bg = next(background)
    except StopIteration:
        background = background_generator()
        image_bg = next(background)

    #image_thresh = (image > 1) * 255

    image_white = np.all(image==[255,255,255], axis=-1)
    image[image_white] = image_bg[image_white]
    #img_add = np.amin((image_bg + image), 255)
    #img_add[img_add > 255] = 255
    
    if show:
        plt.imshow(image)
        plt.show()
        
    return image, background

In [None]:

small_car_list = [
    "a1d85821a0666a4d8dc995728b1ad443",
    "2861ac374a2ea7f997692eea6221681c",
    "1b1a7af332f8f154487edd538b3d83f6",
    "7db6c18d976e52e451553ea674d2701f",
    "ff564f7ec327ed83391a2a133df993ee",
    "303bbfd0c5862496ec8ca19d7516cb42",
    "b2b2f4952e4068d955fe55d6e406ecd4",
    "82ede85c805bd5a85af609a73d2c2947",
    "4cabd6d81c0a9e8c6436916a86a90ed7",
    "a4fc879c642e8fc4a5a4c80d90b70728",
    "ef8e257ca685d594473f10e6caaeca56",
    "c1ac2aee4851937c8e30bdcd3135786b",
    "29e9a4beeaeea1becf71e2e014ff6f",
    "5ec7fa8170eee943713e820becfd99b",
    "813bedf2a45f5681ca92a4cdad802b45",
    "c0b2a4cec94f10436f0bd9fb2f72f93d",
    "145e18e4ec54ed5792abb3e9ac4cd40c",
    "99f49d11dad8ee25e517b5f5894c76d9"
]

car_exclude = [
    ""
]
car_exclude_pytorch3d = ['97d0903cf8912c3ee9d790a68c844819', '3ff887eaebf0bc7e9d2b99af43da16b3', 'e5d6df012b219fa6a7c4dca3ad2d19bf', '6c39e401487c95e9ee897525d11b0599', '397d2e2b3e0988a2d3901a534bc610d8', '4fd5c18c1536d65be129fc90649e41d3', 'f60779c934ee51eddd1e15301c83686f', '350be6825c19fb14e0675251723e1e08', '633dd9755319ce369dfd5136ef0f2af', '97d0903cf8912c3ee9d790a68c844819', '19f52dd4592c3fb5531e940de4b7770d', '9a92ea1009f6b5127b5d9dbd93af5e1a', 'eadebe4328e2c7d7c10520be41d00de2']


### ShapeNet

In [None]:
model_path = "models/model_normalized.obj"
car_list = os.listdir(join(shapenet_path, car_dir))
print(len(car_list))
#car_list = random.sample(small_car_list, k=5)

car_paths = [join(shapenet_path, car_dir, c) for c in car_list]
car_model_paths = [join(c, model_path) for c in car_paths]
meshes = []
for num, car in enumerate(car_model_paths):
    try:
        verts, faces, aux = load_obj(car, load_textures=True, create_texture_atlas=True)
        mesh = Meshes(
            verts=[verts],
            faces=[faces.verts_idx],
            textures=TexturesAtlas(atlas=[aux.texture_atlas]),
        ).to(device)
        meshes.append(mesh)
        print(f"Added mesh: {car_list[num]}")
    except Exception as e:
        car_exclude_pytorch3d.append(car_list[num])
        print(e, car_list[num])
        continue
'''
mesh = []
for car in car_model_paths:
    verts, faces, aux = load_obj(car, load_textures=True, create_texture_atlas=True, texture_atlas_size=4)
    textures = aux.texture_atlas
    if textures is None:
        textures = verts.new_ones(
            faces.verts_idx.shape[0],
            self.texture_resolution,
            self.texture_resolution,
            3,
        )
    textures = TexturesAtlas(atlas=textures)
    mesh.append(Meshes(verts, faces, textures))
'''
#Render the models
#R, T = look_at_view_transform(dist=5, azim=10, elev=80, device=device)
#cameras = PerspectiveCameras(device=device, R=R, T=T)
images = []
for count, m in enumerate(meshes):
    try:
        rotate_mesh_around_axis(m, [0,90,0], phong_renderer,dist=.7, device=device)
        print(small_car_list[count])
    except Exception as e:
        car_exclude_pytorch3d.append(car_list[count])
        print(e, car_list[count])
        continue
"""
car_mesh_dicts = [shapenet_dataset[car] for car in range(len(cars))]
car_mesh_dicts = collate_batched_meshes(car_mesh_dicts)
for mesh in car_mesh_dicts['mesh']:
    mesh = mesh.to(device)
    image = phong_renderer(meshes_world=mesh.to(device), R=R, T=T, device=device)
    plt.imshow(image.squeeze().cpu().numpy()[...,:3])
    plt.show()
"""






### Data Creation Loop

In [None]:
renders = {
    "phong": None,
    "silhouette": None,
    "events": None
}


#meshes = load_meshes()
# Set paths
#DATA_DIR = "../data/meshes"
#obj_filename = os.path.join(DATA_DIR, "dog/dog.obj")

# Load obj file
#mesh = load_objs_as_meshes([obj_filename], device=device)

name = "dolphin"
mesh_name = "dolphin"
#Iterate over each mesh
#for name, mesh in meshes.items():
    
"""Augmentation scenarios (all trajectories complete full 360 w/ simulated handshake)
    -normal trajectory
    -varying distance
    -
"""
model_path = "models/model_normalized.obj"
car_dir = '02958343'
chair_dir = '02691156'
plane_dir = '03001627'
car_list = os.listdir(join(shapenet_path, car_dir))
#car_list = random.sample(small_car_list, k=5)

car_paths = [join(shapenet_path, car_dir, c) for c in car_list]
car_model_paths = [join(c, model_path) for c in car_paths][578:]
for num, car_path in enumerate(car_model_paths):
    '''
    try:
        verts, faces, aux = load_obj(car_path, load_textures=True, create_texture_atlas=True)
        print(verts.shape)
        print(faces.verts_idx.shape)
        break
        mesh = Meshes(
            verts=[verts],
            faces=[faces.verts_idx],
            textures=TexturesAtlas(atlas=[aux.texture_atlas]),
        ).to(device)
        print(f"Adding mesh num {num}: {car_list[num]} ")
    except Exception as e:
        car_exclude_pytorch3d.append(car_list[num])
        print(e, car_list[num])
        continue
    
    
    '''
    #Create a random trajectory
    cam_poses = cam_trajectory(
        variation,
        pepper,
        random_start,
        RenderParams.batch_size
    )
    
    mesh, translation = mesh_random_translation(mesh, .1, device=device)
    mesh = mesh.to(device)
    background = None
    
    #Batch indices to actually save
    data_indices = sorted(random.sample(range(RenderParams.batch_size), k=RenderParams.data_batch_size))
    print(data_indices)
    
    render_manager = RenderManager(
        types=list(renders.keys()),
        mesh_name = mesh_name,
        new_folder = f"bg_ablation_{name}",
        metadata = {
            "augmentation_params": {
                "variation": variation,
                "pepper": pepper,
                "random_start": random_start
            },
            "mesh_transformation": {
                "translation": translation.get_matrix().cpu().numpy().tolist()
            },
            "mesh_info": {
                "mesh_id":"",
                "synset_id":"",
                "category_name": "car"
            }
        }
    )
    render_manager.init()
    # Render the teapot providing the values of R and T.
    R, T = cam_poses
    data_img_num = 0
    for idx in range(1, len(R) + 1):
        img_dict = {}
                
        if "phong" in renders.keys():
            #try:
            image_ref = phong_renderer(meshes_world=mesh, R=R[idx-1:idx:], T=T[idx-1:idx:], device=device)
            #except Exception as e:
            #    car_exclude_pytorch3d.append(car_list[num])
            #    print(e, car_list[num])
            #    continue
            image_ref = image_ref.cpu().numpy()
            img_dict["phong"] = image_ref.squeeze()
            

        if "silhouette" in renders.keys():
            """
            silhouette = silhouette_renderer(meshes_world=mesh, R=R[num-1:num:], T=T[num-1:num:])
            silhouette = silhouette.cpu().numpy()
            img_dict["silhouette"] = silhouette.squeeze()[...,3]
            """
            #Creating a mask from the image instead of using the silhouette renderer
            silhouette = np.clip(((img_dict["phong"][...,:3]).astype(np.uint8)) * 255, 0 , 255)
            silhouette = (silhouette < 1) * 255
            img_dict["silhouette"] = silhouette

        #Merge with background images
        #img, background = merge_with_background(img_dict["phong"], background, show=False)
        #img_dict["phong"] = img
        if RenderParams.show_frame:
            for plot_num, img in enumerate(img_dict.values()):
                plot_num += 1
                ax = plt.subplot(1, len(img_dict.values()), plot_num)
                ax.imshow(img)
            plt.show()
        
        #Only add the image dict if the image was randomly selected
        if idx - 1 in data_indices:
            #print(num - 1)
            render_manager.add_images( 
                data_img_num,
                img_dict,
                R[idx-1:idx:], T[idx-1:idx:])
            data_img_num+=1
        
        extra_args = {"compress_level": 3}
        imageio.imwrite(f'tmp/{idx}.png', img_dict['phong'], format='PNG', **extra_args)
        
    
    image_path_list = []
    t = [s for s in os.listdir('tmp') if s.endswith('.png')]
    for f in sorted(t, key=lambda s : int(re.sub(r"\D","", s))):
        image_path_list.append(join('tmp', f))
    
    #print("image_path_list: ", image_path_list)
    event_frames = generate_event_frames(image_path_list, RenderParams.img_size, RenderParams.batch_size)
    #print("event frames count: ", len(event_frames))
    event_count = 0
    for ev_count, frame in enumerate(event_frames):
        frame = frame * 4
        all_white = np.zeros((frame.shape), dtype=np.uint8)
        all_white.fill(255)
        frame_black = np.all(frame==[0,0,0], axis=-1)
        frame[frame_black] = all_white[frame_black]
        if ev_count in data_indices:
            #print(num)
            render_manager.add_event_frame(event_count, frame)
            event_count+=1

    render_manager.close()
    for f in sorted(os.listdir('tmp')):
        if f.endswith('.png'):
            os.remove(join('tmp', f))


In [None]:
x = T[:, 0].cpu()
y = T[:, 1].cpu()
z = T[:, 2].cpu()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter3D(x, y, z, marker='o', color='blue')

camera = PerspectiveCameras(device=device, R=R, T=T)
cam_wvt = cameras.get_world_to_view_transform()
cam_center = cameras.get_camera_center()
print(cam_center[0])
x = cam_center[:, 0].cpu()
y = cam_center[:, 1].cpu()
z = cam_center[:, 2].cpu()
ax.scatter3D(x, y, z, marker='o', color='red')

P = cam_wvt.inverse().get_matrix()[0]
print(P)
t = P[3, :3]
r = P[3:, ]

### Plotting Event Volume

In [None]:
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import

import matplotlib.pyplot as plt
import numpy as np

import esim_py

import plotly.express as px
import plotly.graph_objects as go

import pandas as pd

@dataclass
class EsimParams:

    Cp: float = 0.5
    Cn: float = 0.5
    sigma_cp: float = 0.03
    sigma_cn: float = 0.03
    refractory_period: float = 1e-4
    log_eps: float = 1e-3
    use_log: bool = True

    show_frame: bool = False

render_manager = RenderManager.from_directory(1, 'data/renders/test_car_shapenet')
        
image_path_list = [img['image_path'] for img in render_manager.images['phong']]

timestamp_list = range(len(image_path_list))
esim = esim_py.EventSimulator(
    EsimParams.Cp,
    EsimParams.Cn,
    EsimParams.refractory_period,
    EsimParams.log_eps,
    EsimParams.use_log,
    EsimParams.sigma_cp,
    EsimParams.sigma_cn,
)

events = esim.generateFromStampedImageSequence(
    image_path_list, timestamp_list
)

batch_events_plot = int(len(events) / int(360 / 25))
print(batch_events_plot)

event_batch_size = 0
event_frames = []

curr_batch_events = events[
        event_batch_size : event_batch_size + batch_events_plot
    ]

fig = plt.figure(figsize=(25,12))
ax = fig.add_subplot(111, projection='3d')
y = curr_batch_events[:, 0]
z = curr_batch_events[:, 1]
x = curr_batch_events[:, 2] * 2 #timestamp 
m = 'o'
c = ['red' if p == 1 else 'blue' for p in curr_batch_events[:, 3]]

ax.scatter3D(x, y, z, c=c, marker=m, s=.2)
ax.set_xlabel('Time [s]')
ax.set_ylabel('x [pix]')
ax.set_zlabel('y [pix]')
#plt.axis('off')
#plt.grid(b=None)
#ax.invert_yaxis()
plt.show()
fig.savefig('ev_volume.png', dpi=fig.dpi)
'''
pd_dict = pd.DataFrame(dict(x = x, y = y, z = z, c = c))
fig = go.Figure()
fig = px.scatter_3d(pd_dict, x="x", y="y", z="z")
fig.update_layout(title_text="Event Volume",)
fig.show()
'''
img_size = (280, 280)
event_batch_size = 0
batch_events_images = int(len(events) / 2)

while event_batch_size <= batch_events_plot :

    curr_batch_events = events[
        event_batch_size : event_batch_size + batch_events_images
    ]

    pos_events = curr_batch_events[curr_batch_events[:, -1] == 1]
    neg_events = curr_batch_events[curr_batch_events[:, -1] == -1]

    image_pos = np.zeros(img_size[0] * img_size[1], dtype="uint8")
    image_neg = np.zeros(img_size[0] * img_size[1], dtype="uint8")

    np.add.at(
        image_pos,
        (pos_events[:, 0] + pos_events[:, 1] * 280).astype("int32"),
        pos_events[:, -1] ** 2,
    )
    np.add.at(
        image_neg,
        (neg_events[:, 0] + neg_events[:, 1] * 280).astype("int32"),
        neg_events[:, -1] ** 2,
    )

    image_rgb = (
        np.stack(
            [
                image_pos.reshape(img_size),
                np.zeros(img_size, dtype="uint8"),
                image_neg.reshape(img_size),
            ],
            -1,
        )
        * 50
    )

    # img_black = np.all(image_rgb == [0,0,0], axis=-1)
    # image_rgb[img_black] = [255, 255, 255]
    #all_white = np.zeros((frame.shape), dtype=np.uint8)
    #all_white.fill(255)
    #frame_black = np.all(image_rgb==[0,0,0], axis=-1)
    #image_rgb[frame_black] = all_white[frame_black]
   
    #event_frames.append(image_rgb)
    

    event_batch_size += len(curr_batch_events)


z, y = np.ogrid[0:event_frames[0].shape[0], 0:event_frames[1].shape[1]]
fig = plt.figure()
ax = fig.gca(projection='3d')
x = np.zeros_like(y)
for i, ev in enumerate(event_frames):
    ev = np.pad(ev, pad_width=((1,1), (1,1), (0,0)), constant_values=0, mode='constant') / 255
    print(ev.shape)
    ax.plot_surface(x + i, y, z, rstride=3, cstride=3, facecolors=np.rot90(ev, 2, (0,1)), shade=False, antialiased=True)
plt.axis('off')
plt.grid(b=None)
plt.show()
fig.savefig('stack.png', dpi=fig.dpi)



In [None]:
'''
r, g, b = event_frames[0][:, :, 0], event_frames[0][:, :, 1], event_frames[0][:, :, 2]
r = r.flatten()
g = g.flatten()
b = b.flatten()
'''

from matplotlib.image import imread

ev = imread('103_phong.png')
ev = event_frames[0] / 255
print(ev.shape)
plt.imshow(ev)
plt.show()
x, y = np.ogrid[0:ev.shape[0], 0:ev.shape[1]]
ax = plt.gca(projection='3d')
z = np.zeros_like(y)
#ax.plot_surface(x, y, np.ones(ev.shape[:2])*4, rstride=1, cstride=1, facecolors=ev)
#ax.plot_surface(x, y, np.ones(ev.shape[:2])*, rstride=1, cstride=1, facecolors=ev)
ax.plot_surface(x, y, z, rstride=2, cstride=2, facecolors=ev, shade=False, antialiased=True)
plt.show()

'''
ev = event_frames[1]
nr, nc = ev.shape[:2]
x,y = np.mgrid[:nr, :nc]
z = np.ones((nr, nc))

vv.functions.surf(x, y, z, ev, aa=3)
'''


In [None]:
path = "../data/renders/test2_dolphin/003-dolphin_2020-10-29T00:30:43"
man = RenderManager.from_path(path)
for idx in range(len(man)):
    ev_frame = np.array(man.get_event_frame(idx))
    all_white = np.zeros((ev_frame.shape), dtype=np.uint8)
    all_white.fill(255)
    frame_black = np.all(ev_frame==[0,0,0], axis=-1)
    ev_frame[frame_black] = all_white[frame_black]
    man.add_event_frame(idx, ev_frame)