In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [2]:
from datetime import datetime
import os

import numpy as np
import torch
from IPython.display import display
import nvdiffrast.torch as dr
from PIL import Image

In [3]:
image_fn = 'assets/monalisa.png'
h, w = 256, 256
n_triangles = 1000
base_lr = 0.03
report_iterations = 10
n_iterations = 10000

In [4]:
def to_pil_img(img):
    return Image.fromarray((img[0].clip(0,1).detach().cpu().numpy()*255).astype('uint8'))

def show(img):
    display(to_pil_img(img))

In [5]:
target = Image.open(image_fn).resize((h, w)).convert('RGB')
target = np.asarray(target)
target = torch.from_numpy(target.astype('f') / 255).cuda()[None, ...]

In [6]:
glctx = dr.RasterizeGLContext()

In [7]:
verts = (torch.rand(n_triangles, 3, 2) * 2 - 1).cuda()
verts.requires_grad = True
color = torch.rand(n_triangles, 4).float().cuda()
color.requires_grad = True

In [8]:
faces = torch.arange(n_triangles * 3).reshape(-1, 3).int().cuda()
def combine_layers(out, bg_noise=0.05):
    canvas = torch.ones_like(out[0:1,:,:,0:3])
    canvas = canvas + torch.randn_like(out[0:1,:,:,0:3]) * bg_noise
    
    canvas_box  = torch.ones_like(out[0:1,:,:,0:3])
    for i in range(0, n_triangles):
        alpha = out[i:i+1, ..., 3:4]
        #if i > 150:
        alpha = alpha * 0.1
        draw_color = out[i:i+1, ..., 0:3]
        canvas = canvas * (1 - alpha) + alpha * draw_color
    
    return canvas

In [9]:
def render(v, c, combine_layers_kwargs={}):
    verts_norm = v 
    color_norm = c 
    m_ones = -torch.ones(n_triangles, 3, 1).cuda().float()
    ones = torch.ones(n_triangles, 3, 1).cuda().float()
    verts_in = torch.cat([verts_norm, m_ones, ones], dim =2) # [..., None, :]
    colors = color_norm[:, None, :].repeat(1, 3, 1)
    rast, _ = dr.rasterize(glctx, verts_in, faces, resolution=[h, w], grad_db=True)
    out_inter, _ = dr.interpolate(colors, rast, faces)
    out_layers = dr.antialias(out_inter, rast, verts_in, faces, pos_gradient_boost=1)
    out = combine_layers(out_layers, **combine_layers_kwargs)
    return out

In [10]:
class OptimRateScheduler(object):
    def __init__(self, total_iterations):
        self.total_iterations = total_iterations
        self.ramp_up_to = 1 / 20
        self.ramp_down_from = 3 / 4
    def get_lr_scale(self, c_iter):
        t = c_iter / self.total_iterations
        lr_ramp = min(1.0, (1.0 - t) / 0.25)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / 0.05)
        return lr_ramp
    def get_noise_scale(self, c_iter):
        t = c_iter / self.total_iterations
        if t > self.ramp_down_from:
            return 0
        else:
            return ((self.ramp_down_from - t) / self.ramp_down_from)**2

In [11]:
optim = torch.optim.Adam([verts, color], lr = base_lr)
ors = OptimRateScheduler(total_iterations=n_iterations)

In [12]:
def save_as_gif(fn, imgs, fps=12):
    img, *imgs = imgs
    with open(fn, 'wb') as fp_out:
        img.save(fp=fp_out, format='GIF', append_images=imgs,
             save_all=True, duration=int(1000./fps), loop=0)
        
def save_as_frames(fn, imgs, overwrite=True):
    # save to folder `fn` with sequenced filenames
    os.makedirs(fn, exist_ok=True)
    for i, img in enumerate(imgs):
        this_fn = os.path.join(fn, f'{i:08}.png')
        if overwrite or not os.path.exists(this_fn):
            save_as_png(this_fn, img)

def save_as_png(fn, img):
    if not fn.endswith('.png'):
        fn = f'{fn}.png'
    img.save(fn)
            
def save_info_list(fn, info_list):
    with open(fn, 'w') as fout:
        list(map(lambda r: print(r, file=fout), info_list))

In [None]:
save_dir = 'nvdiffrast_1000t_10000e'
os.makedirs(save_dir, exist_ok=True)
pil_img_list = []
info_list = []

for i in range(n_iterations):
    lr = ors.get_lr_scale(i) * base_lr
    optim.param_groups[0]['lr'] = lr
    noise_scale = (0.05 * ors.get_noise_scale(i))
    verts_noise = torch.randn_like(verts) * noise_scale
    color_noise = torch.randn_like(color) * noise_scale * 0.5
    verts_in = verts + verts_noise
    color_in = color + color_noise
    out = render(verts, color_in, {'bg_noise': 0.05})
    loss = ((out-target)**2).mean()
    loss_view = ((out-target)**2).mean()
    optim.zero_grad()
    loss.backward()
    if (i + 1) % report_iterations == 0:
        
        with torch.no_grad():
            verts_in = verts
            color_in = color
            out = render(verts, color_in, {'bg_noise': 0.0})
            wonoise_loss = ((out-target)**2).mean()

        info = f"[{datetime.now()}]   Iteration {i + 1}, lr {optim.param_groups[0]['lr']}, loss {loss.item()} loss (without noise) {wonoise_loss.item()}, loss_view {loss_view}"
        print(info)
        info_list.append(info)
        save_info_list(f'{save_dir}/log.txt', info_list)
            
        show(out)
        pil_img_list.append(to_pil_img(out))
        save_as_gif(f'{save_dir}/animate.gif', pil_img_list, fps=12)
        save_as_frames(f'{save_dir}/animate.frames', pil_img_list)
        
    optim.step()