In [6]:
%load_ext autoreload
%autoreload 2
from util.shape_implicit import ShapeImplicit
from deepsdf.model import DeepSDFDecoder
import trimesh
import torch
import pickle
from util.tet_grid import Grid
import numpy as np
import clip;
import kaolin as kal
import os;
from kaolin.ops.conversions import marching_tetrahedra
from torch.nn import functional as F
from matplotlib import pyplot as plt
from torch import nn
import cv2
tet_res = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_size = 256
use_color = False

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
g = Grid(-1,1, tet_res, tet_res, tet_res, device)
v = g.v.clone()
T = g.T.clone()
mean = torch.as_tensor((0.48145466, 0.4578275, 0.40821073), dtype=torch.float, device=device)
std = torch.as_tensor((0.26862954, 0.26130258, 0.27577711), dtype=torch.float, device=device)
model = DeepSDFDecoder(latent_size)
model.load_state_dict(torch.load("checkpoints/sdf_best.ckpt", map_location='cpu'))
model.eval()
model.to(device)
model_clip, preprocess = clip.load("ViT-B/32", device=device)
model_clip.eval()

os.makedirs("images", exist_ok=True)
cam_proj = kal.render.camera.generate_perspective_projection(1, ratio=1.0, dtype=torch.float32).cuda()
p = np.array([4,-3, -2, -0.1, 0.1, 2 ,3, 4])
coords = [torch.from_numpy(np.random.permutation(p)[:3]).float() for i in range(100)]
camera_positions = torch.stack([torch.FloatTensor([0 , 0, -2])]).cuda()
look_at = torch.FloatTensor([0, 0, 0]).cuda().unsqueeze(0).expand_as(camera_positions)
camera_up_direction = torch.FloatTensor([0, 1, 0]).cuda().unsqueeze(0).expand_as(camera_positions)
cam_transform =  kal.render.camera.generate_transformation_matrix(camera_positions, look_at, camera_up_direction).cuda().float()
v.requires_grad = False
T.requires_grad = False
criterion = torch.nn.CosineSimilarity().cuda()

B = 4 * torch.randn(size=(3,32)).cuda()

def laplace_regularizer_const(mesh_verts, mesh_faces):
    term = torch.zeros_like(mesh_verts)
    norm = torch.zeros_like(mesh_verts[..., 0:1])

    v0 = mesh_verts[mesh_faces[:, 0], :]
    v1 = mesh_verts[mesh_faces[:, 1], :]
    v2 = mesh_verts[mesh_faces[:, 2], :]

    term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
    term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
    term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

    two = torch.ones_like(v0) * 2.0
    norm.scatter_add_(0, mesh_faces[:, 0:1], two)
    norm.scatter_add_(0, mesh_faces[:, 1:2], two)
    norm.scatter_add_(0, mesh_faces[:, 2:3], two)

    term = term / torch.clamp(norm, min=1.0)

    return torch.mean(term**2)
timelapse = kal.visualize.Timelapse("logs")


In [8]:
# Z = torch.ones(1, 256).normal_(mean=0, std=0.01).to(device)
# Z.requires_grad = True
# init_Z = Z.clone().detach().cpu().numpy()
init_Z = np.load("init_z.npy")
best_sim = 0
v = g.v.clone()
T = g.T.clone()
all_texts = ["short sofa"]
lights = torch.tensor([1.0, 1.0, 1.0, 0, 0, 0, 0, 0, 0]).cuda().unsqueeze(0)
t = torch.FloatTensor([0.5, 0.5, 0.5]).cuda()

for text_orig in all_texts:    
    text = clip.tokenize([text_orig]).to(device)
    text_features = model_clip.encode_text(text)
    nof_iterations = 200

    for jp in range(8):
        if use_color:
            
            color_net = nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128,128), nn.ReLU(), nn.Linear(128, 3)).to(device)
            color_opt = torch.optim.Adam(color_net.parameters(), lr=1e-3)
        images = []

        if not os.path.exists("videos" + "/" + text_orig + 'output' + str(jp)):
                os.makedirs("videos" + "/" + text_orig + 'output' + str(jp))
        try:
            
            Z = torch.from_numpy(init_Z).float().cuda()
            Z.requires_grad = True

            best_sim = 0
            v = g.v.clone()
            v.requires_grad = False
            T = g.T.clone()
            optimizer = torch.optim.Adam([Z] , lr=1e-2)

            for i in range(nof_iterations):
                sdfs_tensors = []
                split_size = v.shape[1] // 32
                q = torch.split(v,  split_size, dim=1)

                for pts in q:
                    

                    pts = pts.squeeze(0).float()

                    sdf = model(torch.cat([Z.expand(pts.shape[0], -1), pts.float()], dim=-1))
                    sdfs_tensors.append(sdf)

                sdfs = torch.stack(sdfs_tensors).unsqueeze(0).view(1,-1)
                del sdfs_tensors
                vertices, faces, tet_idxs = marching_tetrahedra(v, T, sdfs, return_tet_idx=True)            
                if i == 0:
                    in_mesh = trimesh.Trimesh(vertices[0].detach().cpu().numpy(), torch.fliplr(faces[0]).detach().cpu().numpy())
                    in_mesh.export("in.obj")
                vertices = vertices[0].float()
                faces = faces[0].long()
                faces = torch.fliplr(faces)
                # t = torch.rand( size=(1, faces.shape[0], faces.shape[1], 3)).cuda().expand(camera_positions.shape[0], faces.shape[0], faces.shape[1], 3)
                # t = torch.randint(high=254, size=(camera_positions.shape[0], faces.shape[0], faces.shape[1], 3)).cuda().float().expand(camera_positions.shape[0], faces.shape[0], faces.shape[1], 3)
                t_expand = t.unsqueeze(0).unsqueeze(0).unsqueeze(0).float().expand(camera_positions.shape[0], faces.shape[0], faces.shape[1], 3)
                
                # noise = torch.ones_like(t_expand).normal_(mean=0, std=0.1)
                # t_expand = t_expand + noise
                if use_color:
                    
                    with torch.no_grad():
                        v_ffn =  torch.cat([torch.sin(vertices @ B), torch.cos(vertices @ B)], dim=-1)
                    pred_colors = color_net(v_ffn)
                    colors = pred_colors[faces].unsqueeze(0)
                    colors = torch.repeat_interleave(colors, cam_transform.shape[0], dim=0)
                    t_expand = colors
                face_vertices_camera, face_vertices_image, face_normals =  kal.render.mesh.prepare_vertices(vertices, faces, cam_proj, camera_transform=cam_transform)
                im = kal.render.mesh.rasterize(224, 224, face_vertices_camera[:,:,:,-1].float(),
                            face_vertices_image.float(), t_expand)
                face_idx = im[1]
                face_normals_1 = [face_normals[i, face_idx[i]].squeeze(0) for i in range(cam_transform.shape[0])]
                im_normals = torch.stack(face_normals_1)

                lighting = kal.render.mesh.utils.spherical_harmonic_lighting(
                im_normals, lights.expand(im_normals.shape[0], 9))


                img = im[0]
                img = img
                img = torch.clamp(img*lighting.unsqueeze(-1), 0, 0.99)
                img = torch.clamp(img, 0, 0.99)
                im_to_save = (img[0:1]*255).long().detach().cpu().numpy().squeeze(0)
                img = ((img - mean) / std)
                img = img.permute(0, 3, 1, 2)
                image_features = model_clip.encode_image(img).float()
                cos_sim = F.cosine_similarity(image_features, text_features.float()).mean()
                if cos_sim.item() > best_sim:
                    best_sim = cos_sim.item()
                    best_vertices = vertices.clone()
                    best_faces = faces.clone()
                    best_im = im_to_save
                sims = criterion(image_features, text_features.float())
                loss =  -1*sims.sum() 
                optimizer.zero_grad()
                if use_color:
                    
                    color_opt.zero_grad()
                
                loss.backward(retain_graph=True)
                print(cos_sim.item())
                if Z.grad.pow(2).mean() == 0:
                    break

                optimizer.step()
                if use_color:
                    
                    color_opt.step()
                timelapse.add_mesh_batch(category=text_orig + 'output' + str(jp),
                                         iteration=i,
                                         faces_list=[faces],
                                         vertices_list=[vertices])
                images.append(im_to_save)
            print("done")

            plt.imsave( "videos" + "/" + text_orig + 'output' + str(jp) + "/last.jpg", images[-1] / 255.01)

            plt.imsave( "videos" + "/" + text_orig + 'output' + str(jp) + "/best.jpg", best_im / 255.01)

            video = cv2.VideoWriter("videos" + "/" + text_orig + 'output' + str(jp) + "/video.avi", 0, 5, (224,224))
            for image in images:
                im_rgb = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2BGR)
                video.write(im_rgb.astype(np.uint8))
                
        
        except Exception as e:
            print(e)
            if not os.path.exists("videos" + "/" + text_orig + 'output' + str(jp)):
                os.makedirs("videos" + "/" + text_orig + 'output' + str(jp))
            video = cv2.VideoWriter("videos" + "/" + text_orig + 'output' + str(jp) + "/video.avi", 0, 5, (224,224))
            for image in images:
                
                video.write(image.astype(np.uint8))            
            continue

0.2695707082748413
0.24309584498405457
0.2395281195640564
0.24614816904067993
0.2522371709346771
0.25186702609062195
0.2637485861778259
0.25948357582092285
0.2540360391139984
0.27484065294265747
0.2717013657093048
0.287092387676239
0.29017552733421326
0.2794739305973053
0.2872731685638428
0.29672500491142273
0.27831512689590454
0.28779226541519165
0.2851526737213135
0.30451536178588867
0.2783195972442627
0.268205463886261
0.27267760038375854


KeyboardInterrupt: 

In [240]:
best_mesh = trimesh.Trimesh(best_vertices.detach().cpu().numpy(), best_faces.detach().cpu().numpy())
last_mesh = trimesh.Trimesh(vertices.detach().cpu().numpy(), faces.detach().cpu().numpy())
last_mesh.show()



In [221]:
best_mesh.show()

In [110]:
in_mesh.show()

In [34]:
criterion(image_features.expand_as(text_features), text_features.float())

tensor([0.2141], device='cuda:0', grad_fn=<SumBackward1>)

In [101]:
 _, point_feats = model(torch.cat([Z.expand(vertices[0].shape[0], -1), vertices[0].float().squeeze(0)], dim=-1))


In [100]:
Z.expand(vertices[0].shape[0], -1).shape

torch.Size([3824, 256])

In [53]:
import cv2
image_folder = 'images'
video_name = 'video.avi'

images = [img for img in os.listdir(image_folder) if img.endswith(".jpg")]
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape



video = cv2.VideoWriter(video_name, 0, 5, (width,height))

for image in images:
    video.write(cv2.imread(os.path.join(image_folder, image)))

cv2.destroyAllWindows()
video.release()

In [23]:
from IPython.display import Video

f = Video("video.mp4", embed=True)
f

In [151]:
np.save("init_z.npy", init_Z)

In [152]:
np.allclose(np.load("init_z.npy"), init_Z)

True