# Geometry-Free View Synthesis

This is a colab demo for [Geometry-Free View Synthesis](https://github.com/CompVis/geometry-free-view-synthesis). Compared to [the pygame demo](https://github.com/CompVis/geometry-free-view-synthesis#demo) the controls of this one are a bit clumsy. But you can dive right in by selecting `Runtime->Run all`.

Install

In [None]:
!pip install git+https://github.com/CompVis/geometry-free-view-synthesis#egg=geometry-free-view-synthesis

Image loading function.

In [None]:
import numpy as np
from PIL import Image

def load_im_as_example(im):
    size = [208, 368]
    w,h = im.size
    if np.abs(w/h - size[1]/size[0]) > 0.1:
        print(f"Center cropping image to AR {size[1]/size[0]}")
        if w/h < size[1]/size[0]:
            # crop h
            left = 0
            right = w
            top = h/2 - size[0]/size[1]*w/2
            bottom = h/2 + size[0]/size[1]*w/2
        else:
            # crop w
            top = 0
            bottom = h
            left = w/2 - size[1]/size[0]*h
            right = w/2 + size[1]/size[0]*h
        im = im.crop(box=(left, top, right, bottom))

    im = im.resize((size[1],size[0]),
                   resample=Image.LANCZOS)
    im = np.array(im)/127.5-1.0
    im = im.astype(np.float32)

    example = dict()
    example["src_img"] = im
    example["K"] = np.array([[184.0, 0.0, 184.0],
                             [0.0, 184.0, 104.0],
                             [0.0, 0.0, 1.0]], dtype=np.float32)
    example["K_inv"] = np.linalg.inv(example["K"])

    ## dummy data not used during inference
    example["dst_img"] = np.zeros_like(example["src_img"])
    example["src_points"] = np.zeros((1,3), dtype=np.float32)

    return example

def load_as_example(path):
    im = Image.open(path)
    return load_im_as_example(im)

Define some functions related to the camera control.

In [None]:
def normalize(x):
    return x/np.linalg.norm(x)

def cosd(x):
    return np.cos(np.deg2rad(x))

def sind(x):
    return np.sin(np.deg2rad(x))

def look_to(camera_pos, camera_dir, camera_up):
  camera_right = normalize(np.cross(camera_up, camera_dir))
  R = np.zeros((4, 4))
  R[0,0:3] = normalize(camera_right)
  R[1,0:3] = normalize(np.cross(camera_dir, camera_right))
  R[2,0:3] = normalize(camera_dir)
  R[3,3] = 1
  trans_matrix = np.array([[1.0, 0.0, 0.0, -camera_pos[0]],
                           [0.0, 1.0, 0.0, -camera_pos[1]],
                           [0.0, 0.0, 1.0, -camera_pos[2]],
                           [0.0, 0.0, 0.0,            1.0]])
  tmp = R@trans_matrix
  return tmp[:3,:3], tmp[:3,3]

def rotate_around_axis(angle, axis):
    axis = normalize(axis)
    rotation = np.array([[cosd(angle)+axis[0]**2*(1-cosd(angle)),
                          axis[0]*axis[1]*(1-cosd(angle))-axis[2]*sind(angle),
                          axis[0]*axis[2]*(1-cosd(angle))+axis[1]*sind(angle)],
                         [axis[1]*axis[0]*(1-cosd(angle))+axis[2]*sind(angle),
                          cosd(angle)+axis[1]**2*(1-cosd(angle)),
                          axis[1]*axis[2]*(1-cosd(angle))-axis[0]*sind(angle)],
                         [axis[2]*axis[0]*(1-cosd(angle))-axis[1]*sind(angle),
                          axis[2]*axis[1]*(1-cosd(angle))+axis[0]*sind(angle),
                          cosd(angle)+axis[2]**2*(1-cosd(angle))]])
    return rotation

Forward splatting of an image given its depth.

In [None]:
import torch
from splatting import splatting_function

def render_forward(src_ims, src_dms,
                   Rcam, tcam,
                   K_src,
                   K_dst):
    Rcam = Rcam.to(device=src_ims.device)[None]
    tcam = tcam.to(device=src_ims.device)[None]

    R = Rcam
    t = tcam[...,None]
    K_src_inv = K_src.inverse()

    assert len(src_ims.shape) == 4
    assert len(src_dms.shape) == 3
    assert src_ims.shape[1:3] == src_dms.shape[1:3], (src_ims.shape,
                                                      src_dms.shape)

    x = np.arange(src_ims[0].shape[1])
    y = np.arange(src_ims[0].shape[0])
    coord = np.stack(np.meshgrid(x,y), -1)
    coord = np.concatenate((coord, np.ones_like(coord)[:,:,[0]]), -1) # z=1
    coord = coord.astype(np.float32)
    coord = torch.as_tensor(coord, dtype=K_src.dtype, device=K_src.device)
    coord = coord[None] # bs, h, w, 3

    D = src_dms[:,:,:,None,None]

    points = K_dst[None,None,None,...]@(R[:,None,None,...]@(D*K_src_inv[None,None,None,...]@coord[:,:,:,:,None])+t[:,None,None,:,:])
    points = points.squeeze(-1)

    new_z = points[:,:,:,[2]].clone().permute(0,3,1,2) # b,1,h,w
    points = points/torch.clamp(points[:,:,:,[2]], 1e-8, None)

    src_ims = src_ims.permute(0,3,1,2)
    flow = points - coord
    flow = flow.permute(0,3,1,2)[:,:2,...]

    alpha = 0.5
    importance = alpha/new_z
    importance_min = importance.amin((1,2,3),keepdim=True)
    importance_max = importance.amax((1,2,3),keepdim=True)
    importance=(importance-importance_min)/(importance_max-importance_min+1e-6)*10-10
    importance = importance.exp()

    input_data = torch.cat([importance*src_ims, importance], 1)
    output_data = splatting_function("summation", input_data, flow)

    num = torch.sum(output_data[:,:-1,:,:], dim=0, keepdim=True)
    nom = torch.sum(output_data[:,-1:,:,:], dim=0, keepdim=True)

    rendered = num/(nom+1e-7)
    rendered = rendered.permute(0,2,3,1)[0,...]
    return rendered

Helper class to render with GeoGPT.

In [None]:
from geofree import pretrained_models
from torch.utils.data.dataloader import default_collate

class Renderer(object):
    def __init__(self, model, device):
        self.model = pretrained_models(model=model)
        self.model = self.model.to(device=device)
        self._active = False

    def init(self,
             start_im,
             example,
             show_R,
             show_t):
        self._active = True
        self.step = 0

        batch = self.batch = default_collate([example])
        batch["R_rel"] = show_R[None,...]
        batch["t_rel"] = show_t[None,...]

        _, cdict, edict = self.model.get_xce(batch)
        for k in cdict:
            cdict[k] = cdict[k].to(device=self.model.device)
        for k in edict:
            edict[k] = edict[k].to(device=self.model.device)

        quant_d, quant_c, dc_indices, embeddings = self.model.get_normalized_c(cdict,edict,fixed_scale=True)

        start_im = start_im[None,...].to(self.model.device).permute(0,3,1,2)
        quant_c, c_indices = self.model.encode_to_c(c=start_im)
        cond_rec = self.model.cond_stage_model.decode(quant_c)

        self.current_im = cond_rec.permute(0,2,3,1)[0]
        self.current_sample = c_indices

        self.quant_c = quant_c # to know shape
        # for sampling
        self.dc_indices = dc_indices
        self.embeddings = embeddings

    def __call__(self):
        if self.step < self.current_sample.shape[1]:
            z_start_indices = self.current_sample[:, :self.step]
            temperature=None
            top_k=250
            callback=None
            index_sample = self.model.sample(z_start_indices, self.dc_indices,
                                             steps=1,
                                             temperature=temperature if temperature is not None else 1.0,
                                             sample=True,
                                             top_k=top_k if top_k is not None else 100,
                                             callback=callback if callback is not None else lambda k: None,
                                             embeddings=self.embeddings)
            self.current_sample = torch.cat((index_sample,
                                             self.current_sample[:,self.step+1:]),
                                            dim=1)

            sample_dec = self.model.decode_to_img(self.current_sample,
                                                  self.quant_c.shape)
            self.current_im = sample_dec.permute(0,2,3,1)[0]
            self.step += 1

        if self.step >= self.current_sample.shape[1]:
            self._active = False

        return self.current_im

    def active(self):
        return self._active

    def reconstruct(self, x):
        x = x.to(self.model.device).permute(0,3,1,2)
        quant_c, c_indices = self.model.encode_to_c(c=x)
        x_rec = self.model.cond_stage_model.decode(quant_c)
        return x_rec.permute(0,2,3,1)

Everything is defined. Load included example.

In [None]:
import importlib.resources as pkg_resources
with pkg_resources.path("geofree.examples", "artist.jpg") as path:
  example = load_as_example(path)

Initialize models.

In [None]:
from geofree.modules.warp.midas import Midas

model = "re_impl_depth"
torch.set_grad_enabled(False)

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    print("Warning: Running on CPU---sampling might take a while...")
    device = torch.device("cpu")
midas = Midas().eval().to(device)
renderer = Renderer(model=model, device=device)

Backend for interface.

In [None]:
class Looper(object):
    def __init__(self, midas, renderer, example):
        self.midas = midas
        self.renderer = renderer
        self.init_example(example)
        self.RENDERING = False
  
    def init_example(self, example):
        self.example = example

        ims = example["src_img"][None,...]
        K = example["K"]

        # compute depth for preview
        dms = [None]
        for i in range(ims.shape[0]):
            midas_in = torch.tensor(ims[i])[None,...].permute(0,3,1,2).to(device)
            scaled_idepth = self.midas.fixed_scale_depth(midas_in, return_inverse_depth=True)
            dms[i] = 1.0/scaled_idepth[0].cpu().numpy()

        # now switch to pytorch
        src_ims = torch.tensor(ims, dtype=torch.float32)
        src_dms = torch.tensor(dms, dtype=torch.float32)
        K = torch.tensor(K, dtype=torch.float32)

        self.src_ims = src_ims.to(device=device)
        self.src_dms = src_dms.to(device=device)
        self.K = K.to(device=device)

        self.init_cam()

    def init_cam(self):
        self.camera_pos = np.array([0.0, 0.0, 0.0])
        self.camera_dir = np.array([0.0, 0.0, 1.0])
        self.camera_up = np.array([0.0, 1.0, 0.0])
        self.CAM_SPEED = 0.25
        self.MOUSE_SENSITIVITY = 10.0

    def update_camera(self, keys):
        ######### Camera
        if keys["a"]:
            self.camera_pos += self.CAM_SPEED*normalize(np.cross(self.camera_dir, self.camera_up))
        if keys["d"]:
            self.camera_pos -= self.CAM_SPEED*normalize(np.cross(self.camera_dir, self.camera_up))
        if keys["w"]:
            self.camera_pos += self.CAM_SPEED*normalize(self.camera_dir)
        if keys["s"]:
            self.camera_pos -= self.CAM_SPEED*normalize(self.camera_dir)
        if keys["q"]:
            self.camera_pos -= self.CAM_SPEED*normalize(self.camera_up)
        if keys["e"]:
            self.camera_pos += self.CAM_SPEED*normalize(self.camera_up)

        camera_yaw = 0
        camera_pitch = 0
        if "look" in keys:
            dx, dy = keys["look"]
            if not self.RENDERING:
                camera_yaw -= self.MOUSE_SENSITIVITY*dx
                camera_pitch += self.MOUSE_SENSITIVITY*dy

        # adjust for yaw and pitch
        rotation = np.array([[cosd(-camera_yaw), 0.0, sind(-camera_yaw)],
                             [0.0, 1.0, 0.0],
                             [-sind(-camera_yaw), 0.0, cosd(-camera_yaw)]])
        self.camera_dir = rotation@self.camera_dir

        rotation = rotate_around_axis(camera_pitch, np.cross(self.camera_dir,
                                                             self.camera_up))
        self.camera_dir = rotation@self.camera_dir

        show_R, show_t = look_to(self.camera_pos, self.camera_dir, self.camera_up) # look from pos in direction dir
        show_R = torch.as_tensor(show_R, dtype=torch.float32)
        show_t = torch.as_tensor(show_t, dtype=torch.float32)

        self.show_R = show_R
        self.show_t = show_t

    def update(self, keys):
        self.update_camera(keys)
        if not self.RENDERING:
            wrp_im = render_forward(self.src_ims, self.src_dms,
                                    self.show_R, self.show_t,
                                    K_src=self.K,
                                    K_dst=self.K)
            
        if keys["render"]:
            self.RENDERING = True
            self.renderer.init(wrp_im, self.example, self.show_R, self.show_t)

        if self.RENDERING:
            wrp_im = self.renderer()
        
        if not self.renderer._active or keys["stop"]:
          self.RENDERING = False

        return wrp_im, self.RENDERING

Frontend inspired by [Infinite Nature](https://infinite-nature.github.io/).

In [None]:
import IPython
from google.colab import output, files
import base64
from io import BytesIO

looper = Looper(midas, renderer, example)

def as_png(x):
  if hasattr(x, "detach"):
      x = x.detach().cpu().numpy()
  #x = x.transpose(1,0,2)
  x = (x+1.0)*127.5
  x = x.clip(0, 255).astype(np.uint8)
  data = io.BytesIO()
  Image.fromarray(x).save(data, format="png")
  data.seek(0)
  data = data.read()
  return base64.b64encode(data).decode()

def pyloop(data):
  if data.get("upload", False):
    data = files.upload()
    fname = sorted(data.keys())[0]
    I = Image.open(BytesIO(data[fname]))
    looper.init_example(load_im_as_example(I))

  keys = dict()
  if "look" in data:
    keys["look"] = np.array(data["look"])*2.0-1.0
  move = data.get("direction", None)
  keys["w"] = move=="forward"
  keys["a"] = move=="left"
  keys["s"] = move=="backward"
  keys["d"] = move=="right"
  keys["q"] = move=="up"
  keys["e"] = move=="down"
  keys["render"] = move=="render"
  keys["stop"] = data.get("stop", False)
  output, rendering = looper.update(keys)

  ret = dict()
  ret["image"] = as_png(output)
  ret["loop"] = rendering
  ret = IPython.display.JSON(ret)

  return ret

output.register_callback('pyloop', pyloop)

# The front-end for our interactive demo.

html='''
<style>
#view {
  width: 368px;
  height: 208px;
  background-color: #aaa;
  background-size: 100% 100%;
  border: 1px solid #000;
  margin: 20px;
  position: relative;
}
.buttons {
  margin: 20px;
}
.buttons div {
  display: inline-block;
  cursor: pointer;
  padding: 20px;
  background: #eee;
  border: 2px solid #aaa;
  border-radius: 3px;
  margin-right: 10px;
  font-weight: bold;
  text-transform: uppercase;
  letter-spacing: 1px;
  color: #444;
  width: 100px;
  text-align: center;
}
.buttons div:active {
  background: #444;
  color: #fff;
}
#rgb {
  height: 100%;
}
h3 {
  margin-left: 20px;
}
</style>
<h3>Braindance Colab Demo</h3>
<div id=view><img id=rgb></div>
<div class=buttons>
<div id=up>Up</div><div id=forward>Forward</div><div id=down>Down</div><br>
<div id=left>Left</div><div id=backward>Backward</div><div id=right>Right</div><br>
<div id=render>Render</div><div id=stop>Stop</div><div id=upload>Upload</div>
<script>
stop_rendering = false;

async function loop(...parms) {
  result = await google.colab.kernel.invokeFunction('pyloop', parms, {});
  result = result.data['application/json'];
  image = result['image'];
  // console.log(image);
  const url = `data:image/png;base64,${image}`;
  document.querySelector('#rgb').src = url;
  if(stop_rendering) {
    result['loop'] = false;
    await google.colab.kernel.invokeFunction('pyloop', [{"stop": true}], {});
  }
  stop_rendering = false;
  if(result['loop']) {
    loop({});
  }
}

function cursor(e) {
  x = e.offsetX / e.target.clientWidth;
  y = e.offsetY / e.target.clientHeight;
  loop({"look": [x,y]})
}

function move(direction) {
  loop({"direction": direction})
}

loop({});
document.querySelector('#view').addEventListener('click', cursor);
document.querySelector('#up').addEventListener('click', () => move("up"));
document.querySelector('#forward').addEventListener('click', () => move("forward"));
document.querySelector('#down').addEventListener('click', () => move("down"));
document.querySelector('#left').addEventListener('click', () => move("left"));
document.querySelector('#backward').addEventListener('click', () => move("backward"));
document.querySelector('#right').addEventListener('click', () => move("right"));
document.querySelector('#render').addEventListener('click', () => move("render"));
document.querySelector('#stop').addEventListener('click', () => {stop_rendering=true;});
document.querySelector('#upload').addEventListener('click', () => loop({"upload": true}));



</script>
'''

display(IPython.display.HTML(html))

Click on the image to look around. Use the first six buttons to move around. Render to start rendering the novel view with GeoGPT. Stop to abort rendering and regain control. Upload to upload your own images. Have fun!