In [9]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from einops.layers.torch import Rearrange

In [10]:
n_times_two = (4, 2, 1)
n_plus_one = (4, 1, 1)
n_minus_one = (4, 1, 2)
n_plus_three = (4, 1, 0)
n_plus_one_times_two = (4, 2, 0)


class UpsampleBlock(nn.Module):
    def __init__(self, inp_channels, out_channels, upsample_args=n_times_two, kernel_size=5):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(1, out_channels),
            nn.Mish(),
            nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(1, out_channels),
            nn.Mish(),
            nn.ConvTranspose1d(out_channels, out_channels, *upsample_args)
        )

    def forward(self, x):
        return self.block(x)

class LearnedTrajectoryPrior(nn.Module):
    def __init__(self, action_dim, observation_dim, horizon):
        super().__init__()
        self.register_buffer('transition_dim', torch.tensor(observation_dim + action_dim))
        self.register_buffer('horizon', torch.tensor(horizon))
        self.register_buffer('action_dim', torch.tensor(action_dim))
        self.blocks = nn.ModuleList([])

        upsample_args, residual_horizon = self.get_starting_blocks(horizon)
        self.blocks.extend([UpsampleBlock(self.transition_dim, self.transition_dim, arg) for arg in upsample_args])

        number_of_2X_blocks = math.log2(residual_horizon)
        assert number_of_2X_blocks == number_of_2X_blocks // 1

        self.blocks.extend([UpsampleBlock(self.transition_dim, self.transition_dim, n_times_two) for i in
                            range(int(number_of_2X_blocks))])

    def get_starting_blocks(self, horizon):
        if horizon % 15 == 0:
            start_blocks, residual_horizon = [n_plus_three, n_plus_one_times_two, n_plus_three], horizon / 15
        elif horizon % 5 == 0:
            start_blocks, residual_horizon = [n_plus_three], horizon / 5
        elif horizon % 3 == 0:
            start_blocks, residual_horizon = [n_plus_one], horizon / 3
        else:
            start_blocks, residual_horizon = [], horizon / 2
        return start_blocks, residual_horizon

    def forward(self, cond):
        device = cond[0].device
        actions = torch.zeros(cond[0].shape[0], self.action_dim, device=device)
        trajectory_start = torch.cat((actions, cond[0]), 1)
        trajectory_end = torch.cat((actions, cond[self.horizon.item() - 1]), 1)

        trajectory_start = einops.rearrange(trajectory_start, 'b t -> b 1 t')
        trajectory_end = einops.rearrange(trajectory_end, 'b t -> b 1 t')

        x = torch.cat((trajectory_start, trajectory_end), 1)

        x = einops.rearrange(x, 'b h t -> b t h')

        for block in self.blocks:
            x = block(x)

        x = einops.rearrange(x, 'b t h -> b h t')
        return x


In [15]:
action_dim = 2
obs_dim = 2
t_dim = action_dim + obs_dim

batch_size = 1
horizon = 12

cond = {}
cond[0] = torch.zeros((batch_size, obs_dim))
cond[horizon-1] = torch.zeros((batch_size, obs_dim))
print(cond)

{0: tensor([[0., 0.]]), 11: tensor([[0., 0.]])}


In [27]:
prior = LearnedTrajectoryPrior(action_dim, obs_dim, horizon)
prior(cond)

tensor([[[-0.4386,  0.2232, -0.4104,  0.1377],
         [-0.6210, -0.0595, -0.5796, -0.1607],
         [-0.1972, -0.1105, -0.4292,  0.0963],
         [-0.2311,  0.1960,  0.4103,  0.0153],
         [-0.3052,  0.2257, -0.5012,  0.0528],
         [-0.5672, -0.1388, -0.1228, -0.2821],
         [-0.7540, -0.2090, -0.6686,  0.1961],
         [-1.1134, -0.2460, -0.1282, -0.1065],
         [-0.4984,  0.2687, -0.3601, -0.5382],
         [-0.5785, -0.0809,  0.1140, -0.2211],
         [-0.6047,  0.0213, -0.4181,  0.0433],
         [-0.6867, -0.1780, -0.0868,  0.0670]]],
       grad_fn=<ReshapeAliasBackward0>)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

count_parameters(prior)

In [None]:
torch.cat((cond[0], cond[self.horizon.item() - 1]), 1)

In [None]:
actions = torch.zeros(cond[0].shape[0],2)
trajectory_start = torch.cat((actions,cond[0]), 1)[:, :, None]
trajectory_end = torch.cat((actions,cond[self.horizon.item() - 1]), 1)[:, :, None]

trajectory_start = einops.rearrange(trajectory_start, 'b t -> b 1 t')
trajectory_end = einops.rearrange(trajectory_end, 'b t -> b 1 t')

x = torch.cat((trajectory_start,trajectory_end), 1)

x = einops.rearrange(x, 'b h t -> b t h')


In [None]:
batch_size = 10
a = torch.rand(batch_size,4)
b = torch.rand(batch_size,4)
z = torch.zeros(batch_size,2)
print(a.shape, b.shape,  z.shape)
a

In [None]:
o1 = torch.cat((z,a), 1)
o2 = torch.cat((z,b), 1)
o1.shape, o2.shape

In [None]:
n1 = einops.rearrange(o1, 'b t -> b 1 t')
n2 = einops.rearrange(o2, 'b t -> b 1 t')
n1.shape, n2.shape

In [None]:
n1[0][0]

In [None]:
x = torch.cat((n1,n2), 1)
x.shape

In [None]:
ans = einops.rearrange(x, 'b h t -> b t h')

In [None]:
ans.shape

In [None]:
x[0][1]

In [None]:
ans[0,:,1]