In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter
from Tars.distributions import Normal, Bernoulli
from Tars.distributions.divergences import KullbackLeibler
from Tars.models import VAE

from tqdm import tqdm

seed = 1234
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
# utility
class Conv2dLSTMCell(nn.Module):
    """
    2d convolutional long short-term memory (LSTM) cell.
    Functionally equivalent to nn.LSTMCell with the
    difference being that nn.Kinear layers are replaced
    by nn.Conv2D layers.

    :param in_channels: number of input channels
    :param out_channels: number of output channels
    :param kernel_size: size of image kernel
    :param stride: length of kernel stride
    :param padding: number of pixels to pad with
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Conv2dLSTMCell, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        kwargs = dict(kernel_size=kernel_size, stride=stride, padding=padding)

        self.forget = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.input  = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.output = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.state  = nn.Conv2d(in_channels, out_channels, **kwargs)

    def forward(self, input, states):
        """
        Send input through the cell.

        :param input: input to send through
        :param states: (hidden, cell) pair of internal state
        :return new (hidden, cell) pair
        """
        (hidden, cell) = states

        forget_gate = F.sigmoid(self.forget(input))
        input_gate  = F.sigmoid(self.input(input))
        output_gate = F.sigmoid(self.output(input))
        state_gate  = F.tanh(self.state(input))

        # Update internal cell state
        cell = forget_gate * cell + input_gate * state_gate
        hidden = output_gate * F.tanh(cell)

        return hidden, cell

In [3]:
# Using TowerRepresentation
class Representation(nn.Module):
    def __init__(self, n_channels, v_dim, r_dim=256, pool=True):
        """
        Network that generates a condensed representation
        vector from a joint input of image and viewpoint.

        Employs the tower/pool architecture described in the paper.

        :param n_channels: number of color channels in input image
        :param v_dim: dimensions of the viewpoint vector
        :param r_dim: dimensions of representation
        :param pool: whether to pool representation
        """
        super(Representation, self).__init__()
        # Final representation size
        self.r_dim = k = r_dim
        self.pool = pool

        self.conv1 = nn.Conv2d(n_channels, k, kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(k, k, kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(k, k//2, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(k//2, k, kernel_size=2, stride=2)

        self.conv5 = nn.Conv2d(k + v_dim, k, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(k + v_dim, k//2, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(k//2, k, kernel_size=3, stride=1, padding=1)
        self.conv8 = nn.Conv2d(k, k, kernel_size=1, stride=1)

        self.avgpool  = nn.AvgPool2d(k//16)

    def forward(self, x, v):
        """
        Send an (image, viewpoint) pair into the
        network to generate a representation
        :param x: image
        :param v: viewpoint (x, y, z, cos(yaw), sin(yaw), cos(pitch), sin(pitch))
        :return: representation
        """
        # Increase dimensions
        v = v.view(v.size(0), -1, 1, 1)
        v = v.repeat(1, 1, self.r_dim // 16, self.r_dim // 16)

        # First skip-connected conv block
        skip_in  = F.relu(self.conv1(x))
        skip_out = F.relu(self.conv2(skip_in))

        x = F.relu(self.conv3(skip_in))
        x = F.relu(self.conv4(x)) + skip_out

        # Second skip-connected conv block (merged)
        skip_in = torch.cat([x, v], dim=1)
        skip_out  = F.relu(self.conv5(skip_in))

        x = F.relu(self.conv6(skip_in))
        x = F.relu(self.conv7(x)) + skip_out

        r = F.relu(self.conv8(x))

        if self.pool:
            r = self.avgpool(r)

        return r

In [4]:
class GeneratorCore(nn.Module):
    def __init__(self, v_dim, r_dim, z_dim, h_dim, SCALE):
        super(GeneratorCore, self).__init__()
        self.core = Conv2dLSTMCell(v_dim + r_dim + z_dim, h_dim, kernel_size=5, stride=1, padding=2)
        self.upsample = nn.ConvTranspose2d(h_dim, h_dim, kernel_size=SCALE, stride=SCALE, padding=0)
        
    def forward(self, z, v, r, h_g, c_g, u):
        h_g, c_g =  self.core(torch.cat([z, v, r], dim=1), [h_g, c_g])
        u = self.upsample(h_g) + u
        return h_g, c_g, u


class InferenceCore(nn.Module):
    def __init__(self, x_dim, v_dim, r_dim, h_dim):
        super(InferenceCore, self).__init__()
        self.core = Conv2dLSTMCell(h_dim + x_dim + v_dim + r_dim, h_dim, kernel_size=5, stride=1, padding=2)
        
    def forward(self, x, v, r, h_g, h_e, c_e):
        h_e, c_e = self.core(torch.cat([h_g, x, v, r], dim=1), [h_e, c_e])
        return h_e, c_e

In [5]:
class Generator(Normal):
    def __init__(self, x_dim, h_dim):
        super(Generator, self).__init__(conv_var=["z","v_q","r"],var=["x_q"])
        self.eta_g = nn.Conv2d(h_dim, x_dim, kernel_size=1, stride=1, padding=0)
        
    # TODO; enable sigma annealing
    def forward(self, u, sigma=1):
        mu = F.sigmoid(self.observation_density(u))
        return {"loc":mu, "scale":sigma}

class Prior(Normal):
    def __init__(self, z_dim, h_dim):
        super(Prior, self).__init__(conv_var=["h_g"],var=["z"])
        self.eta_pi = nn.Conv2d(h_dim, 2*z_dim, kernel_size=5, stride=1, padding=2)
        
    def forward(self, h_g):
        mu, std = torch.split(self.eta_pi(h_g), zDim, dim=1)
        return {"loc":mu ,"scale":F.softplus(std)}
    
class Inference(Normal):
    def __init__(self, z_dim, h_dim):
        super(Inference, self).__init__(conv_var=["h_e"],var=["z"])
        self.eta_e = nn.Conv2d(h_dim, 2*z_dim, kernel_size=5, stride=1, padding=2)
        
    def forward(self, h_e):
        mu, std = torch.split(self.eta_e(h_e), zDim, dim=1)
        return {"loc":mu, "scale":std}

In [6]:
class GQN(nn.Module):
    def __init__(self, x_dim, v_dim, r_dim, h_dim, z_dim, L, SCALE):
        super(GQN, self).__init__()
        self.L = L
        self.SCALE = SCALE
        
        self.representation = Representation(z_dim, v_dim, r_dim)
        self.generator_core = GeneratorCore(v_dim, r_dim, z_dim, h_dim, self.SCALE)
        self.inference_core = InferenceCore(x_dim, v_dim, r_dim, h_dim)
        
        self.pi = Prior(z_dim, h_dim).to(device)
        self.q = Inference(z_dim, h_dim).to(device)
        self.g = Generator(x_dim, h_dim).to(device)

    
    
    def forward(self, images, viewpoints):
        # Number of context datapoints to use for representation
        batch_size, m, *_ = viewpoints.size()

        # Sample random number of views and generate representation
        n_views = random.randint(2, m-1)

        indices = torch.randperm(m)
        representation_idx, query_idx = indices[:n_views], indices[n_views]

        x, v = images[:, representation_idx], viewpoints[:, representation_idx]

        # Merge batch and view dimensions.
        _, _, *x_dims = x.size()
        _, _, *v_dims = v.size()

        x = x.view((-1, *x_dims))
        v = v.view((-1, *v_dims))

        # representation generated from input images
        # and corresponding viewpoints
        phi = self.representation(x, v)

        # Seperate batch and view dimensions
        _, *phi_dims = phi.size()
        phi = phi.view((batch_size, n_views, *phi_dims))

        # sum over view representations
        r = torch.sum(phi, dim=1)

        # Use random (image, viewpoint) pair in batch as query
        x_q, v_q = images[:, query_idx], viewpoints[:, query_idx]
        
        
        batch_size, _, h, w = x_q.size()
        kl = 0

        # Increase dimensions
        v_q = v_q.view(batch_size, -1, 1, 1).repeat(1, 1, h//self.SCALE, w//self.SCALE)
        if r.size(2) != h//self.SCALE:
            r = r.repeat(1, 1, h//self.SCALE, w//self.SCALE)
        
        # Reset hidden state
        hidden_g = x_q.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))
        hidden_i = x_q.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))

        # Reset cell state
        cell_g = x_q.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))
        cell_i = x_q.new_zeros((batch_size, self.h_dim, h//self.SCALE, w//self.SCALE))

        u = x_q.new_zeros((batch_size, self.h_dim, h, w))

        x_q = self.downsample(x_q)
        
        kls = 0
        for _ in range(self.L):    
            # kl
            z = self.q.sample({"h_e": hidden_e})
            kl = KullbackLeibler(self.q, self.pi)
            kls += kl.estimate(x_q)
            # update state
            hidden_e, cell_e = self.inference_core(x_q, v_q, r, hidden_g, hidden_e, cell_e)
            hidden_g, cell_g, u = self.generator_core(z, v_q, r, hidden_g, cell_g, u)
        
        x_sample = self.g.sample({"u": u})
        x_ll = self.g.log_likelihood(x_sample)
        loss = -x_ll + kls
        return loss

In [7]:
xDim=3
vDim=7
rDim=256
hDim=128
zDim=64
L=12
SCALE = 4 # Scale of image generation process

In [8]:
gqn=GQN(xDim,vDim,rDim,hDim,zDim, L, SCALE).to(device)

In [9]:
from shepardmetzler import ShepardMetzler, Scene, transform_viewpoint
# args
data_dir = '/root/data/GQN/shepard_metzler_5_parts-torch/train'
batch_size = 36
gradient_steps = 200 #default: 2*(10**6)

dataset = ShepardMetzler(root_dir=data_dir, target_transform=transform_viewpoint)

# Learning rate
mu_f, mu_i = 5*10**(-5), 5*10**(-4)
mu, sigma = mu_f, sigma_f

optimizer = torch.optim.Adam(model.parameters(), lr=mu)

loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, **kwargs)

# Number of gradient steps
s = 0
while True:
    if s >= 　gradient_steps:
        torch.save(model, "model-final.pt")
        break

    for x, v in tqdm(loader):
        x = x.to(device)
        v = v.to(device)
        
        loss = torch.mean(gqn(x, v).view(batch_size, -1), dim=0).sum()
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()

        s += 1

        # Keep a checkpoint every n steps
        if s % 100 == 0:
            torch.save(model, "model-{}.pt".format(s))