In [3]:
import torch
from torch import nn
import time

In [15]:
num_actor_obs = 49 * 16
actor_hidden_dims = [128, 128, 128]
num_actions = 12
activation = nn.ELU()
kernel_sizes = [3, 5, 5, 5, 5]
strides = [3, 2, 2, 2, 2, 2]
filters = [32, 32, 32, 16, 8]
paddings = [1, 2, 2, 2, 2] 
dilations = [1, 1, 1, 1, 1]

out_channels = filters[:]
in_channels = [1] + filters[:-1]

In [16]:
obs = torch.rand(128, num_actor_obs)
obs.shape

torch.Size([128, 784])

In [17]:
actor_layers = []
mlp_input_dim_a = num_actor_obs
for in_ch, out_ch, kernel_size, stride, padding, dilation in zip(in_channels, out_channels, kernel_sizes, strides, paddings, dilations):
    actor_layers.append(nn.Conv1d(
        in_channels=in_ch,
        out_channels=out_ch,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation
    ))
    actor_layers.append(activation)
    print((mlp_input_dim_a + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1, ((mlp_input_dim_a + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1) * out_ch)
    print(((mlp_input_dim_a + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1) * stride >= mlp_input_dim_a) 
    mlp_input_dim_a = (mlp_input_dim_a + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

actor_layers.append(nn.Flatten())
mlp_input_dim_a = mlp_input_dim_a * out_channels[-1]
print(mlp_input_dim_a)

actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for layer_index in range(len(actor_hidden_dims)):
    if layer_index == len(actor_hidden_dims) - 1:
        actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions))
    else:
        actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
        actor_layers.append(activation)
actor = nn.Sequential(*actor_layers)

262 8384
True
131 4192
True
66 2112
True
33 528
True
17 136
True
136


In [11]:
sum([p.numel() for p in actor.parameters()])

63716

In [None]:
start = time.time()

num_actor_obs = 45 * 100
batch_size = 8192
device = torch.device('cuda')

actor = actor.to(device)

for _ in range(100):
    obs = torch.rand(batch_size, num_actor_obs).to(device)
    with torch.no_grad():
        out = actor(obs.unsqueeze(1))


print(time.time() - start)

16.193702697753906


In [55]:
num_actor_obs_h = 100
num_actor_obs_w = 45

print(num_actor_obs_h * num_actor_obs_w)

actor_hidden_dims = [128, 128, 128]
num_actions = 12
activation = nn.ELU()
# kernel_sizes = [(4, 1), (4, 1), (4, 1), (1, 3), (1, 3)]
# strides = [(2, 1), (2, 1), (2, 1), (1, 3), (1, 3)]
# filters = [32, 32, 32, 32, 32]
# paddings = [(1, 0), (1, 0), (0, 0), (0, 0), (0, 0)]
# dilations = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
# kernel_sizes = [(1, 3), (1, 3), (4, 1), (4, 1), (4, 1)]
# strides = [(1, 3), (1, 3), (2, 1), (2, 1), (2, 1)]
# filters = [64, 64, 64, 64, 32]
# paddings = [(0, 0), (0, 0), (1, 0), (1, 0), (0, 0)]
# dilations = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
kernel_sizes = [(5, 3), (5, 5), (5, 5), (5, 1)]
strides = [(2, 3), (2, 2), (2, 2), (2, 1)]
filters = [128, 64, 32, 16]
paddings = [(2, 0), (2, 2), (2, 2), (2, 0)]
dilations = [(1, 1), (1, 1), (1, 1), (1, 1)]

out_channels = filters[:]
in_channels = [1] + filters[:-1]

4500


In [56]:
obs = torch.rand(10, num_actor_obs_h, num_actor_obs_w)

In [57]:
actor_layers = []

out_h = num_actor_obs_h
out_w = num_actor_obs_w

for in_ch, out_ch, kernel_size, stride, padding, dilation in zip(in_channels, out_channels, kernel_sizes, strides, paddings, dilations):
    actor_layers.append(nn.Conv2d(
        in_channels=in_ch,
        out_channels=out_ch,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        dilation=dilation
    ))
    actor_layers.append(activation)
    last_h = out_h
    last_w = out_w
    out_h = (out_h + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
    out_w = (out_w + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1

    print(f'out_h: {out_h} out_w: {out_w} total: {out_h * out_w * out_ch}')
    print(f'pad_h: {out_h * stride[0], out_h * stride[0] >= last_h} pad_w: {out_w * stride[1], out_w * stride[1] >= last_w}')

actor_layers.append(nn.Flatten())
mlp_input_dim_a = out_h * out_w * out_channels[-1]

print(mlp_input_dim_a)

actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for layer_index in range(len(actor_hidden_dims)):
    if layer_index == len(actor_hidden_dims) - 1:
        actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions))
    else:
        actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
        actor_layers.append(activation)
actor = nn.Sequential(*actor_layers)

out_h: 50 out_w: 15 total: 96000
pad_h: (100, True) pad_w: (45, True)
out_h: 25 out_w: 8 total: 12800
pad_h: (50, True) pad_w: (16, True)
out_h: 13 out_w: 4 total: 1664
pad_h: (26, True) pad_w: (8, True)
out_h: 7 out_w: 4 total: 448
pad_h: (14, True) pad_w: (4, True)
448


In [340]:
sum([p.numel() for p in actor.parameters()])

108908

In [14]:
start = time.time()

num_actor_obs_h = 20
num_actor_obs_w = 45
batch_size = 8192
device = torch.device('cuda')

actor = actor.to(device)

for _ in range(100):
    obs = torch.rand(batch_size, num_actor_obs_h, num_actor_obs_w).to(device)
    with torch.no_grad():
        out = actor(obs.unsqueeze(1))

print(time.time() - start)

5.8920578956604


In [136]:
a = torch.rand(100, 10, 45)

In [137]:
a.reshape(a.shape[0], -1).shape

torch.Size([100, 450])