In [1]:
import torch
from torch.utils.data import DataLoader
from mymodel import *

In [2]:
class Raymarcher(nn.Module):
    def __init__(self,
             num_feature_channels,
             raymarch_steps):
        super().__init__()

        self.n_feature_channels = num_feature_channels
        self.steps = raymarch_steps

        hidden_size = 16
        self.lstm = nn.LSTMCell(input_size=self.n_feature_channels,
                                hidden_size=hidden_size)

        self.lstm.apply(init_recurrent_weights)
        lstm_forget_gate_init(self.lstm)

        self.out_layer = nn.Linear(hidden_size, 1)
#        self.counter = 0

    def forward(self,
            cam2world,
            phi,
            uv,
            intrinsics):
        batch_size, num_samples, _ = uv.shape
        #log = list()

        ray_dirs = get_ray_directions(
            uv,
            cam2world=cam2world,
            intrinsics=intrinsics)

        initial_depth = torch.zeros((batch_size, num_samples, 1))\
            .normal_(mean=0.05, std=5e-4)\
            .to(device)
        init_world_coords = \
            world_from_xy_depth(
            uv,
            initial_depth,
            intrinsics=intrinsics,
            cam2world=cam2world)

        world_coords = [init_world_coords]
        depths = [initial_depth]
        states = [None]

        for step in range(self.steps):
            v = phi(world_coords[-1])

            state = self.lstm(v.view(-1, self.n_feature_channels), states[-1])

            if state[0].requires_grad:
                state[0].register_hook(lambda x: x.clamp(min=-10, max=10))

            signed_distance = self.out_layer(state[0]).view(batch_size, num_samples, 1)
            new_world_coords = world_coords[-1] + ray_dirs * signed_distance

            states.append(state)
            world_coords.append(new_world_coords)

            depth = depth_from_world(world_coords[-1], cam2world)

            if self.training:
                print("Raymarch step %d: Min depth %0.6f, max depth %0.6f" %
                      (step, depths[-1].min().detach().cpu().numpy(), depths[-1].max().detach().cpu().numpy()))

            depths.append(depth)

        log = None
        return world_coords[-1], depths[-1], log