# Neural scene representation and rendering
### https://deepmind.com/blog/neural-scene-representation-and-rendering

Datasets: https://github.com/deepmind/gqn-datasets

Datasets Translater: https://github.com/l3robot/gqn_datasets_translator

In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image

import os
import datetime
import random
import math
from tensorboardX import SummaryWriter

from pixyz.distributions import Normal
from pixyz.losses import NLL, KullbackLeibler

from tqdm import tqdm

from gqn_dataset import GQNDataset, Scene, transform_viewpoint
from conv_lstm import Conv2dLSTMCell

In [2]:
class Pyramid(nn.Module):
    def __init__(self):
        super(Pyramid, self).__init__()
        self.conv1 = nn.Conv2d(7+3, 32, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=8, stride=8)

    def forward(self, x, v):
        # Broadcast
        v = v.view(-1, 7, 1, 1).repeat(1, 1, 64, 64)
        
        h = torch.cat((v, x))
        h = self.conv1(h)
        h = self.conv2(h)
        h = self.conv3(h)
        h = self.conv4(h)
        
        r = r.repeat(1, 1, 16, 16)

        return r

In [3]:
class Tower(nn.Module):
    def __init__(self):
        super(Tower, self).__init__()
        self.conv1 = nn.Conv2d(3, 256, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(256+7, 256, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(256+7, 128, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(256, 256, kernel_size=1, stride=1)

    def forward(self, x, v):
        # Resisual connection
        skip_in  = F.relu(self.conv1(x))
        skip_out = F.relu(self.conv2(skip_in))

        r = F.relu(self.conv3(skip_in))
        r = F.relu(self.conv4(r)) + skip_out

        # Broadcast
        v = v.view(v.size(0), 7, 1, 1).repeat(1, 1, 16, 16)
        
        # Resisual connection
        # Concatenate
        skip_in = torch.cat((r, v), dim=1)
        skip_out  = F.relu(self.conv5(skip_in))

        r = F.relu(self.conv6(skip_in))
        r = F.relu(self.conv7(r)) + skip_out
        r = F.relu(self.conv8(r))

        return r

In [4]:
class Pool(nn.Module):
    def __init__(self):
        super(Pool, self).__init__()
        self.conv1 = nn.Conv2d(3, 256, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(256+7, 256, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(256+7, 128, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(256, 256, kernel_size=1, stride=1)

        self.pool  = nn.AvgPool2d(16)

    def forward(self, x, v):
        # Resisual connection
        skip_in  = F.relu(self.conv1(x))
        skip_out = F.relu(self.conv2(skip_in))

        r = F.relu(self.conv3(skip_in))
        r = F.relu(self.conv4(r)) + skip_out

        # Broadcast
        v = v.view(v.size(0), 7, 1, 1).repeat(1, 1, 16, 16)
        
        # Resisual connection
        # Concatenate
        skip_in = torch.cat((r, v), dim=1)
        skip_out  = F.relu(self.conv5(skip_in))

        r = F.relu(self.conv6(skip_in))
        r = F.relu(self.conv7(r)) + skip_out
        r = F.relu(self.conv8(r))
        
        # Pool
        r = self.pool(r)
        r = r.repeat(1, 1, 16, 16)

        return r

In [5]:
class InferenceCore(nn.Module):
    def __init__(self):
        super(InferenceCore, self).__init__()
        self.downsample_x = nn.Conv2d(3, 3, kernel_size=4, stride=4, padding=0)
        self.downsample_u = nn.Conv2d(128, 128, kernel_size=4, stride=4, padding=0)
        self.core = Conv2dLSTMCell(3+7+256+2*128, 128, kernel_size=5, stride=1, padding=2)
        
    def forward(self, x, v, r, c_e, h_e, h_g, u):
        v = v.view(-1, 7, 1, 1).repeat(1, 1, 16, 16)
        x = self.downsample_x(x)
        u = self.downsample_u(u)
        c_e, h_e = self.core(torch.cat((x, v, r, h_g, u), dim=1), (c_e, h_e))
        return c_e, h_e
    
class GeneratorCore(nn.Module):
    def __init__(self):
        super(GeneratorCore, self).__init__()
        self.core = Conv2dLSTMCell(7+256+3, 128, kernel_size=5, stride=1, padding=2)
        self.upsample = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=4, padding=0)
        
    def forward(self, v, r, c_g, h_g, u, z):
        v = v.view(-1, 7, 1, 1).repeat(1, 1, 16, 16)
        c_g, h_g =  self.core(torch.cat((v, r, z), dim=1), (c_g, h_g))
        u = self.upsample(h_g) + u
        return c_g, h_g, u

In [6]:
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["h_e"],var=["z"])
        self.eta_e = nn.Conv2d(128, 2*3, kernel_size=5, stride=1, padding=2)
        
    def forward(self, h_e):
        mu, logvar = torch.split(self.eta_e(h_e), 3, dim=1)
        std = torch.exp(0.5*logvar)
        return {"loc": mu, "scale": std}
    
class Prior(Normal):
    def __init__(self):
        super(Prior, self).__init__(cond_var=["h_g"],var=["z"])
        self.eta_pi = nn.Conv2d(128, 2*3, kernel_size=5, stride=1, padding=2)

    def forward(self, h_g):
        mu, logvar = torch.split(self.eta_pi(h_g), 3, dim=1)
        std = torch.exp(0.5*logvar)
        return {"loc": mu ,"scale": std}
    
class Generator(Normal):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["u", "sigma"],var=["x_q"])
        self.eta_g = nn.Conv2d(128, 3, kernel_size=1, stride=1, padding=0)
        
    def forward(self, u, sigma):
        mu = self.eta_g(u)
        return {"loc": mu, "scale": sigma}

In [7]:
class GQN(nn.Module):
    def __init__(self, representation="pool", L=12, shared_core=False):
        super(GQN, self).__init__()
        
        # Number of generative layers
        self.L = L
        
        self.shared_core = shared_core
        
        # Representation network
        if representation=="pyramid":
            self.phi = Pyramid()
        elif representation=="tower":
            self.phi = Tower()
        elif representation=="pool":
            self.phi = Pool()
            
        # Generation network
        if shared_core:
            self.inference_core = InferenceCore()
            self.generator_core = GeneratorCore()
        else:
            self.inference_core = nn.ModuleList([InferenceCore() for _ in range(L)])
            self.generator_core = nn.ModuleList([GeneratorCore() for _ in range(L)])
        
        # Distribution
        self.pi = Prior()
        self.q = Inference()
        self.g = Generator()

    # EstimateELBO
    def forward(self, x, v, v_q, x_q, sigma):
        B, M, *_ = x.size()
        
        # Scene encoder
        r = x.new_zeros((B, 256, 16, 16))
        for k in range(M):
            r_k = self.phi(x[:, k], v[:, k])
            r += r_k
            
        # Generator initial state
        c_g = x.new_zeros((B, 128, 16, 16))
        h_g = x.new_zeros((B, 128, 16, 16))
        u = x.new_zeros((B, 128, 64, 64))

        # Inference initial state
        c_e = x.new_zeros((B, 128, 16, 16))
        h_e = x.new_zeros((B, 128, 16, 16))
                
        elbo = 0
        for l in range(self.L):
            # Inference state update
            if self.shared_core:
                c_e, h_e = self.inference_core(x_q, v_q, r, c_e, h_e, h_g, u)
            else:
                c_e, h_e = self.inference_core[l](x_q, v_q, r, c_e, h_e, h_g, u)
            
            # Posterior sample
            z = self.q.sample({"h_e": h_e}, reparam=True)["z"]
            
            # ELBO KL contribution update
            elbo -= KullbackLeibler(self.q, self.pi).estimate({"h_e": h_e, "h_g": h_g})
            
            # Generator state update
            if self.shared_core:
                c_g, h_g, u = self.generator_core(v_q, r, c_g, h_g, u, z)
            else:
                c_g, h_g, u = self.generator_core[l](v_q, r, c_g, h_g, u, z)
                
        # ELBO likelihood contribution update
        elbo -= NLL(self.g).estimate({"u":u, "sigma":sigma, "x_q": x_q})

        return elbo
    
    def generate(self, x, v, v_q):
        B, M, *_ = x.size()
        
        # Scene encoder
        r = x.new_zeros((B, 256, 16, 16))
        for k in range(M):
            r_k = self.phi(x[:, k], v[:, k])
            r += r_k

        # Initial state
        c_g = x.new_zeros((B, 128, 16, 16))
        h_g = x.new_zeros((B, 128, 16, 16))
        u = x.new_zeros((B, 128, 64, 64))
        
        for l in range(self.L):
            # Prior sample
            z = self.pi.sample({"h_g": h_g})["z"]
            
            # State update
            if self.shared_core:
                c_g, h_g, u = self.generator_core(v_q, r, c_g, h_g, u, z)
            else:
                c_g, h_g, u = self.generator_core[l](v_q, r, c_g, h_g, u, z)
            
        x_q_hat = self.g.sample_mean({"u": u, "sigma": 0})

        return torch.clamp(x_q_hat, 0, 1)
    
    def kl_divergence(self, x, v, v_q, x_q):
        B, M, *_ = x.size()

        # Scene encoder
        r = x.new_zeros((B, 256, 16, 16))
        for k in range(M):
            r_k = self.phi(x[:, k], v[:, k])
            r += r_k
            
        # Generator initial state
        c_g = x.new_zeros((B, 128, 16, 16))
        h_g = x.new_zeros((B, 128, 16, 16))
        u = x.new_zeros((B, 128, 64, 64))

        # Inference initial state
        c_e = x.new_zeros((B, 128, 16, 16))
        h_e = x.new_zeros((B, 128, 16, 16))
                
        kl = 0
        for l in range(self.L):
            # Inference state update
            if self.shared_core:
                c_e, h_e = self.inference_core(x_q, v_q, r, c_e, h_e, h_g, u)
            else:
                c_e, h_e = self.inference_core[l](x_q, v_q, r, c_e, h_e, h_g, u)
            
            # Posterior sample
            z = self.q.sample({"h_e": h_e}, reparam=True)["z"]
            
            # KL divergence
            kl += KullbackLeibler(self.q, self.pi).estimate({"h_e": h_e, "h_g": h_g})
            
            # Generator state update
            if self.shared_core:
                c_g, h_g, u = self.generator_core(v_q, r, c_g, h_g, u, z)
            else:
                c_g, h_g, u = self.generator_core[l](v_q, r, c_g, h_g, u, z)

        return kl
    
    def reconstruct(self, x, v, v_q, x_q):
        B, M, *_ = x.size()

        # Scene encoder
        r = x.new_zeros((B, 256, 16, 16))
        for k in range(M):
            r_k = self.phi(x[:, k], v[:, k])
            r += r_k
            
        # Generator initial state
        c_g = x.new_zeros((B, 128, 16, 16))
        h_g = x.new_zeros((B, 128, 16, 16))
        u = x.new_zeros((B, 128, 64, 64))

        # Inference initial state
        c_e = x.new_zeros((B, 128, 16, 16))
        h_e = x.new_zeros((B, 128, 16, 16))
                
        for l in range(self.L):
            # Inference state update
            if self.shared_core:
                c_e, h_e = self.inference_core(x_q, v_q, r, c_e, h_e, h_g, u)
            else:
                c_e, h_e = self.inference_core[l](x_q, v_q, r, c_e, h_e, h_g, u)
            
            # Posterior sample
            z = self.q.sample({"h_e": h_e}, reparam=True)["z"]
            
            # Generator state update
            if self.shared_core:
                c_g, h_g, u = self.generator_core(v_q, r, c_g, h_g, u, z)
            else:
                c_g, h_g, u = self.generator_core[l](v_q, r, c_g, h_g, u, z)
                
        x_q_rec = self.g.sample_mean({"u": u, "sigma": 0})

        return torch.clamp(x_q_rec, 0, 1)

In [8]:
def sample_batch(x_data, v_data, D, M=None, seed=None):
    random.seed(seed)
    
    if D == "Room":
        K = 5
    elif D == "Jaco":
        K = 7
    elif D == "Labyrinth":
        K = 20
    elif D == "Shepard-Metzler":
        K = 15

    # Sample number of views
    if not M:
        M = random.randint(1, K)

    context_idx = random.sample(range(x_data.size(1)), M)
    query_idx = random.randint(0, x_data.size(1)-1)

    # Sample view
    x, v = x_data[:, context_idx], v_data[:, context_idx]
    # Sample query view
    x_q, v_q = x_data[:, query_idx], v_data[:, query_idx]
    
    return x, v, x_q, v_q

In [9]:
# Learning rate at training step s with annealing 
class AnnealingStepLR(_LRScheduler):
    def __init__(self, optimizer, mu_i=5e-4, mu_f=5e-5, n=1.6e6):
        self.mu_i = mu_i
        self.mu_f = mu_f
        self.n = n
        super(AnnealingStepLR, self).__init__(optimizer)

    def get_lr(self):
        return [max(self.mu_f + (self.mu_i - self.mu_f) * (1.0 - self.last_epoch / self.n), self.mu_f) for base_lr in self.base_lrs]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Dataset directory
train_data_dir = '/workspace/dataset/mazes-torch/train'
test_data_dir = '/workspace/dataset/mazes-torch/test'

# Number of workers to load data
num_workers = 8

# Log
log_interval_num = 100
save_interval_num = 10000
dir_name = "gqn_mazes"
log_dir = '/workspace/logs/'+ dir_name
os.mkdir(log_dir)
os.mkdir(log_dir+'/models')
os.mkdir(log_dir+'/runs')

# TensorBoardX
writer = SummaryWriter(log_dir=log_dir+'/runs')

# Dataset
train_dataset = GQNDataset(root_dir=train_data_dir, target_transform=transform_viewpoint)
test_dataset = GQNDataset(root_dir=test_data_dir, target_transform=transform_viewpoint)
D = "Labyrinth"

# Pixel standard-deviation
sigma_i, sigma_f = 2.0, 0.7
sigma = sigma_i

# Number of scenes over which each weight update is computed
B = 36

# Maximum number of training steps
S_max = 2*10**6

# Define model
model = GQN().to(device)
model = nn.DataParallel(model, device_ids=[0, 1])

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-08)
scheduler = AnnealingStepLR(optimizer, mu_i=5e-4, mu_f=5e-5, n=1.6e6)

kwargs = {'num_workers':num_workers, 'pin_memory': True} if torch.cuda.is_available() else {}

train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True, **kwargs)
test_loader = DataLoader(test_dataset, batch_size=B, shuffle=True, **kwargs)
    
train_iter = iter(train_loader)
x_data_test, v_data_test = next(iter(test_loader))

# Training Iterations
for t in tqdm(range(S_max)):
    try:
        x_data, v_data = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        x_data, v_data = next(train_iter)
        
    x_data = x_data.to(device)
    v_data = v_data.to(device)
    x, v, x_q, v_q = sample_batch(x_data, v_data, D)
    elbo = model(x, v, v_q, x_q, sigma)
        
    # Compute empirical ELBO gradients
    (-elbo.mean()).backward()
        
    # Update parameters
    optimizer.step()
    optimizer.zero_grad()
        
    # Update optimizer state
    scheduler.step()
            
    # Pixel-variance annealing
    sigma = max(sigma_f + (sigma_i - sigma_f)*(1 - t/(2e5)), sigma_f)
        
    # logs
    writer.add_scalar('train_loss', -elbo.mean(), t)
                
    with torch.no_grad():
        # write logs to tensorboard
        if t % log_interval_num == 0:
            x_data_test = x_data_test.to(device)
            v_data_test = v_data_test.to(device)
                
            x_test, v_test, x_q_test, v_q_test = sample_batch(x_data_test, v_data_test, D, M=3, seed=0)
            elbo_test = model(x_test, v_test, v_q_test, x_q_test, sigma)
            kl_test = model.module.kl_divergence(x_test, v_test, v_q_test, x_q_test)
            x_q_rec_test = model.module.reconstruct(x_test, v_test, v_q_test, x_q_test)
            x_q_hat_test = model.module.generate(x_test, v_test, v_q_test)
                        
            writer.add_scalar('test_loss', -elbo_test.mean(), t)
            writer.add_scalar('test_kl', kl_test.mean(), t)
            writer.add_image('test_ground_truth', make_grid(x_q_test, 6, pad_value=1), t)
            writer.add_image('test_reconstruction', make_grid(x_q_rec_test, 6, pad_value=1), t)
            writer.add_image('test_generation', make_grid(x_q_hat_test, 6, pad_value=1), t)
            
        if t % save_interval_num == 0:
            torch.save(model.state_dict(), log_dir + "/models/model-{}.pt".format(t))

torch.save(model.state_dict(), log_dir + "/models/model-final.pt")  
writer.close()

  0%|          | 2047/2000000 [56:29<925:31:36,  1.67s/it]