# Convolutional LSTM for coordinate prediction

### Imports

In [1]:
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchinfo import summary

from tqdm.auto import tqdm

# own
import common.action as action
import common.world as world
import common.plot as plot
import common.preprocess as preprocess
import common.nets as nets
import common.train as train
import common.tools as tools

  if LooseVersion(mpl.__version__) >= "3.0":
  other = LooseVersion(other)
  if not hasattr(tensorboard, '__version__') or LooseVersion(tensorboard.__version__) < LooseVersion('1.15'):


### Load datasets

In [2]:
with open("datasets/oracle_data.pickle", "rb") as handle:
    oracle_data = pickle.load(handle)

with open("datasets/oracle_reversed_data.pickle", "rb") as handle:
    oracle_reversed_data = pickle.load(handle)

with open("datasets/oracle_random_data.pickle", "rb") as handle:
    oracle_random_data = pickle.load(handle)

### Preprocess data

In [3]:
length_trajectory = 10
batch_size = 128

# split and shuffle data
train_data, test_data = preprocess.split_data_for_trajectories(
    oracle_reversed_data, 0.8, length_trajectory
)
train_imgs, train_pos = preprocess.process_trajectory(train_data)
test_imgs, test_pos = preprocess.process_trajectory(test_data)

# stage data for the DataLoader
train_data = preprocess.ObtainDataset_notransform(train_imgs, train_pos)
test_data = preprocess.ObtainDataset_notransform(test_imgs, test_pos)

# DataLoader
dataset_loader_train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)
dataset_loader_test_data = DataLoader(test_data, batch_size=batch_size, shuffle=True)

### Initialize models

In [4]:
# initialize network
net_cnn = nets.CNN_coords()
net_lstm = nets.LSTM_coords(length_trajectory)

# checking values
h0 = torch.randn(2, 10, 100)
c0 = torch.randn(2, 10, 100)
x = torch.rand((64, 10, 3, 32, 32))

# check network
features = net_cnn(x)
out0, out1, hidden, c = net_lstm(features, h0, c0)

# shape statistics
tools.shapes(x, features, hidden, out0)

# network summary
print("SUMMARY CNN \n", summary(net_cnn, (64, 10, 3, 32, 32)), "\n")
print("SUMMARY LSTM \n", summary(net_lstm, ((64, 10, 480), (2, 10, 100), (2, 10, 100))))

input cnn: torch.Size([64, 10, 3, 32, 32]) - Batch size, Channel out, Height out, Width out
output cnn: torch.Size([64, 10, 480])  - Batch size, sequence length, input size
input lstm: torch.Size([64, 10, 480])  - Batch size, sequence length, input size
hidden lstm: torch.Size([10, 100])
output lstm: torch.Size([64, 10]) 

SUMMARY CNN 
Layer (type:depth-idx)                   Output Shape              Param #
CNN_coords                               [64, 10, 480]             --
├─Conv2d: 1-1                            [640, 10, 28, 28]         760
├─Conv2d: 1-2                            [640, 20, 24, 24]         5,020
├─MaxPool2d: 1-3                         [640, 20, 12, 12]         --
├─Conv2d: 1-4                            [640, 30, 8, 8]           15,030
├─MaxPool2d: 1-5                         [640, 30, 4, 4]           --
Total params: 20,810
Trainable params: 20,810
Non-trainable params: 0
Total mult-adds (G): 2.85
Input size (MB): 7.86
Forward/backward pass size (MB): 108.95
P

In [18]:
inputs, label = next(iter(train_data))
print(len(inputs))
inputs = torch.stack(inputs)
inputs = torch.swapaxes(inputs, 0, 1)
print(len(x))

net_cnn = nets.CNN_coords()

encoded = net_cnn(x)
print(encoded)

10
64
tensor([[[0.0423, 0.0803, 0.0520,  ..., 0.0000, 0.0175, 0.0000],
         [0.0430, 0.0599, 0.0557,  ..., 0.0000, 0.0060, 0.0000],
         [0.0309, 0.0503, 0.0440,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0544, 0.0335, 0.0492,  ..., 0.0016, 0.0000, 0.0000],
         [0.0580, 0.0572, 0.0446,  ..., 0.0038, 0.0000, 0.0108],
         [0.0329, 0.0749, 0.0641,  ..., 0.0000, 0.0054, 0.0000]],

        [[0.0654, 0.0641, 0.0670,  ..., 0.0116, 0.0147, 0.0000],
         [0.0432, 0.0507, 0.0521,  ..., 0.0000, 0.0000, 0.0000],
         [0.0416, 0.0320, 0.0469,  ..., 0.0000, 0.0000, 0.0076],
         ...,
         [0.0704, 0.0466, 0.0405,  ..., 0.0000, 0.0000, 0.0000],
         [0.0482, 0.0651, 0.0448,  ..., 0.0000, 0.0000, 0.0000],
         [0.0736, 0.0638, 0.0781,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0439, 0.0484, 0.0601,  ..., 0.0000, 0.0000, 0.0000],
         [0.0471, 0.0314, 0.0439,  ..., 0.0000, 0.0000, 0.0000],
         [0.0641, 0.0547, 0.0485,  ..., 0.0000, 0.00

### Train model

In [12]:
criterion = nn.MSELoss()
params = list(net_cnn.parameters()) + list(net_lstm.parameters())
optimizer = optim.Adam(params, lr=0.01)
episodes = 20

(
    train_loss,
    test_loss,
    train_dis,
    test_dis,
    train_dis_item,
    test_dis_item,
) = train.train_ConvLSTM(
    dataset_loader_train_data,
    dataset_loader_test_data,
    net_cnn,
    net_lstm,
    criterion,
    optimizer,
    episodes,
    length_trajectory,
)

Progress:   0%|          | 0/20 [00:00<?, ? Episode/s]

10
tensor([[[1.5697, 1.5551, 1.5400,  ..., 1.5130, 1.4870, 1.3630],
         [1.5957, 1.5742, 1.5552,  ..., 1.5542, 1.5226, 1.3879],
         [1.6204, 1.5923, 1.5697,  ..., 1.5977, 1.5613, 1.4106],
         ...,
         [1.7748, 1.7829, 1.7916,  ..., 1.5870, 1.6859, 1.7227],
         [1.7749, 1.7760, 1.7751,  ..., 1.5074, 1.5344, 1.6079],
         [1.7283, 1.7600, 1.7751,  ..., 1.5045, 1.5076, 1.4996]],

        [[1.6703, 1.8223, 1.8926,  ..., 1.9933, 2.1062, 2.0789],
         [1.6692, 1.8241, 1.8958,  ..., 1.9957, 2.1098, 2.0813],
         [1.6684, 1.8265, 1.8981,  ..., 2.0016, 2.1163, 2.0769],
         ...,
         [1.6680, 1.8376, 1.9103,  ..., 2.0058, 2.1246, 2.0954],
         [1.6656, 1.8392, 1.9120,  ..., 2.0101, 2.1289, 2.0938],
         [1.6629, 1.8420, 1.9145,  ..., 2.0137, 2.1303, 2.0879]],

        [[1.8730, 1.8595, 1.8543,  ..., 1.8815, 1.8596, 1.8128],
         [1.8646, 1.8499, 1.8470,  ..., 1.8491, 1.8322, 1.7972],
         [1.8558, 1.8399, 1.8396,  ..., 1.8173, 1.8080,

tensor([[[34.2132, 33.4114, 34.1832,  ..., 31.3851, 31.3938, 32.2015],
         [35.0114, 34.5142, 33.6115,  ..., 31.9733, 31.5378, 31.0836],
         [34.7467, 34.9210, 34.7561,  ..., 32.1449, 32.0737, 31.7401],
         ...,
         [32.9215, 33.5696, 33.6394,  ..., 35.1446, 35.2652, 35.3906],
         [32.6678, 33.2343, 33.2926,  ..., 34.5261, 34.6356, 34.7681],
         [32.3903, 32.8752, 32.9227,  ..., 33.8620, 33.9548, 34.0905]],

        [[36.1758, 36.9904, 36.9790,  ..., 41.0718, 41.5574, 39.7709],
         [36.8686, 36.6977, 35.2654,  ..., 41.2275, 38.5931, 33.2658],
         [36.0169, 34.2650, 31.2770,  ..., 36.7356, 31.3152, 26.7786],
         ...,
         [32.9107, 37.0812, 38.8560,  ..., 41.5225, 44.2206, 43.4144],
         [32.9144, 37.0756, 38.8351,  ..., 41.5195, 44.1986, 43.4002],
         [32.8992, 37.0518, 38.8029,  ..., 41.4784, 44.1413, 43.3550]],

        [[35.3078, 35.3676, 35.2574,  ..., 33.4017, 33.3626, 34.3921],
         [35.1290, 35.1441, 34.9990,  ..., 32

tensor([[[46.7831, 47.9577, 48.7422,  ..., 55.0407, 56.0281, 56.0154],
         [46.8282, 48.1761, 48.8708,  ..., 55.1265, 56.2496, 56.2802],
         [48.7189, 48.9472, 50.0098,  ..., 56.2862, 57.6103, 66.6613],
         ...,
         [59.7850, 57.2334, 57.8268,  ..., 71.1866, 71.2469, 72.5816],
         [58.9000, 56.4060, 57.0265,  ..., 69.1852, 69.4714, 71.2986],
         [57.9525, 55.5370, 56.1565,  ..., 67.1105, 67.5506, 69.8942]],

        [[46.8848, 46.5944, 48.1288,  ..., 53.9113, 54.8343, 56.3362],
         [47.0812, 46.6638, 48.0569,  ..., 54.1244, 54.9098, 56.2515],
         [47.2417, 46.7384, 47.9714,  ..., 54.3403, 54.9861, 56.0958],
         ...,
         [47.5758, 47.0063, 47.5959,  ..., 54.6690, 54.9915, 58.3228],
         [47.9710, 47.4666, 47.9987,  ..., 54.7585, 55.2125, 59.0575],
         [48.5226, 48.0855, 48.7233,  ..., 54.7741, 55.4064, 59.7251]],

        [[50.1751, 49.9415, 49.6685,  ..., 56.2118, 55.9988, 55.9908],
         [51.0654, 50.8801, 50.6587,  ..., 57

tensor([[[ 73.3394,  73.2325,  72.9583,  ...,  78.7382,  78.6495,  78.5725],
         [ 74.4588,  74.3911,  74.1764,  ...,  81.1908,  81.1490,  81.0440],
         [ 75.6700,  75.6095,  75.4261,  ...,  83.7393,  83.6956,  83.4910],
         ...,
         [ 80.5191,  80.5328,  80.4283,  ...,  95.0201,  95.0085,  93.8554],
         [ 81.3874,  81.4117,  81.3296,  ...,  96.9323,  96.9232,  95.5069],
         [ 84.9211,  83.5667,  82.0590,  ...,  99.8681,  98.5124,  95.6827]],

        [[ 69.5658,  69.2369,  69.3694,  ...,  72.9186,  72.7903,  72.5696],
         [ 70.8697,  70.6013,  70.5645,  ...,  74.7209,  74.5572,  73.9173],
         [ 72.1158,  71.8734,  71.7968,  ...,  76.6920,  76.5152,  75.3851],
         ...,
         [ 77.8553,  77.7205,  77.2969,  ...,  88.5999,  87.8258,  83.4022],
         [ 78.8372,  78.7116,  78.2077,  ...,  90.8001,  89.8491,  84.6701],
         [ 79.7272,  79.6129,  79.0215,  ...,  93.0320,  91.9144,  86.0139]],

        [[ 85.8436,  86.0089,  85.7517,  ...

tensor([[[ 84.6551,  84.6703,  87.1276,  ...,  82.4776,  84.1612,  85.7611],
         [ 84.9711,  84.7009,  86.9102,  ...,  82.8117,  84.1747,  85.4010],
         [ 85.2374,  84.6826,  86.6270,  ...,  83.0458,  84.1096,  85.0195],
         ...,
         [ 86.9773,  86.1530,  86.9398,  ...,  83.3826,  83.6155,  89.6926],
         [ 88.1398,  87.4343,  88.3236,  ...,  83.5681,  83.9769,  92.4365],
         [ 89.5389,  88.9454,  90.0030,  ...,  84.5704,  85.4232,  95.6093]],

        [[ 88.8308,  88.1546,  89.6455,  ...,  84.0388,  85.1646,  92.0008],
         [ 90.2752,  89.7024,  91.4574,  ...,  85.1486,  86.7127,  93.8009],
         [ 91.9824,  91.5233,  93.5637,  ...,  86.8208,  88.8470,  95.8598],
         ...,
         [100.0396,  99.9901, 102.6030,  ..., 100.1996, 102.8794, 106.2173],
         [101.4762, 101.5132, 104.0672,  ..., 103.0121, 105.4999, 107.8132],
         [102.8558, 102.9589, 105.4382,  ..., 105.7043, 108.0316, 109.4155]],

        [[124.3122, 126.1810, 123.8102,  ...

tensor([[[101.6008, 102.4730, 103.6483,  ..., 102.0119, 101.7584, 104.5359],
         [ 98.6872, 100.4865, 102.2520,  ..., 100.9994, 101.9327, 101.9007],
         [101.3429,  99.2221,  99.8501,  ..., 101.0020, 100.8124, 102.0959],
         ...,
         [103.5257, 102.8141, 102.6198,  ..., 101.3077, 100.9477, 101.1124],
         [105.0495, 104.5295, 104.3885,  ..., 102.4615, 102.1250, 102.3871],
         [106.7477, 106.5060, 106.4222,  ..., 104.1003, 103.8488, 104.2282]],

        [[ 99.0132,  98.4752, 101.7770,  ...,  99.2508, 101.0133, 103.8121],
         [ 99.4298,  98.6139, 101.6096,  ...,  99.6446, 101.1644, 103.6544],
         [ 99.7731,  98.7615, 101.4169,  ..., 100.0417, 101.3019, 103.3676],
         ...,
         [100.5065,  99.2776, 100.5855,  ..., 100.6795, 101.3056, 107.4453],
         [101.3856, 100.2915, 101.4710,  ..., 100.8477, 101.7093, 108.8061],
         [102.6041, 101.6601, 103.0575,  ..., 100.8791, 102.0642, 110.0397]],

        [[130.8886, 138.9743, 139.9027,  ...

tensor([[[111.4055, 111.4382, 114.6774,  ..., 113.9420, 116.3271, 118.5280],
         [111.8205, 111.4739, 114.3874,  ..., 114.4039, 116.3450, 118.0295],
         [112.1703, 111.4468, 114.0118,  ..., 114.7311, 116.2539, 117.5034],
         ...,
         [114.4794, 113.3873, 114.4378,  ..., 115.2146, 115.5679, 124.0087],
         [116.0237, 115.0899, 116.2778,  ..., 115.4464, 116.0379, 127.7735],
         [117.8816, 117.0970, 118.5093,  ..., 116.7914, 117.9983, 132.1248]],

        [[160.9051, 165.2927, 164.4604,  ..., 203.8156, 202.9563, 183.4322],
         [160.8900, 165.2197, 164.4298,  ..., 203.5293, 202.7171, 183.5841],
         [160.8877, 165.1665, 164.4159,  ..., 203.2460, 202.5100, 183.8054],
         ...,
         [160.9028, 164.5890, 164.0829,  ..., 201.2265, 200.7419, 185.1860],
         [160.9698, 164.4193, 163.9940,  ..., 200.5957, 200.2225, 185.6079],
         [161.0175, 164.2030, 163.8716,  ..., 199.9214, 199.6876, 186.1537]],

        [[151.9392, 153.7937, 156.0675,  ...

tensor([[[152.5021, 168.8781, 174.6460,  ..., 208.6730, 218.2933, 214.8294],
         [152.3991, 169.0187, 174.9996,  ..., 208.8329, 218.8693, 215.3630],
         [152.3576, 169.1680, 175.2725,  ..., 209.1619, 219.4736, 215.8602],
         ...,
         [151.9607, 169.8205, 176.4524,  ..., 210.3135, 221.9639, 218.0145],
         [151.6903, 169.8976, 176.6189,  ..., 210.6455, 222.4483, 218.4464],
         [151.4270, 170.0596, 176.8652,  ..., 210.7966, 222.7088, 218.6685]],

        [[143.5494, 142.2265, 139.9911,  ..., 150.6540, 147.6406, 140.0875],
         [146.0988, 144.3686, 141.6848,  ..., 154.8029, 151.1601, 142.1142],
         [148.4629, 146.3980, 143.3145,  ..., 158.9323, 154.5344, 143.8732],
         ...,
         [159.4411, 161.4630, 162.4067,  ..., 170.1013, 176.5971, 176.8954],
         [159.0584, 159.0196, 160.0085,  ..., 164.0376, 165.5992, 170.6207],
         [154.9676, 158.0414, 159.0151,  ..., 163.7028, 163.9522, 163.6243]],

        [[151.2754, 151.1875, 151.2218,  ...

tensor([[[165.3370, 163.7587, 163.5295,  ..., 188.6863, 188.6181, 187.7085],
         [162.7725, 161.8542, 161.7034,  ..., 184.8600, 184.8467, 184.2262],
         [160.1910, 160.0786, 159.8487,  ..., 180.5851, 180.6057, 180.1996],
         ...,
         [149.4315, 149.4856, 149.0884,  ..., 158.3021, 158.2573, 158.1869],
         [147.2010, 147.1846, 146.6524,  ..., 153.6750, 153.6238, 153.4519],
         [144.7433, 144.7254, 144.2038,  ..., 149.5857, 149.5794, 149.3541]],

        [[171.1391, 171.0520, 169.3249,  ..., 190.3424, 193.6346, 190.7636],
         [171.9232, 171.3791, 170.7468,  ..., 181.5928, 185.8899, 189.2607],
         [167.6250, 171.1873, 171.4890,  ..., 179.9429, 180.9289, 180.4292],
         ...,
         [167.7180, 163.4556, 164.7408,  ..., 175.9822, 174.4275, 178.3012],
         [167.7422, 163.7664, 164.8801,  ..., 176.0910, 175.0553, 178.2701],
         [167.6884, 164.0136, 164.8737,  ..., 176.1218, 175.4656, 178.0012]],

        [[166.7944, 164.6920, 170.8276,  ...

tensor([[[172.8753, 171.2368, 170.9982,  ..., 199.6067, 199.5341, 198.5693],
         [170.1987, 169.2463, 169.0883,  ..., 195.5595, 195.5454, 194.8870],
         [167.5039, 167.3893, 167.1482,  ..., 191.0403, 191.0612, 190.6302],
         ...,
         [156.2464, 156.3027, 155.8872,  ..., 167.4972, 167.4490, 167.3725],
         [153.9113, 153.8938, 153.3373,  ..., 162.6143, 162.5594, 162.3744],
         [151.3379, 151.3189, 150.7735,  ..., 158.2986, 158.2905, 158.0495]],

        [[168.5092, 168.5017, 171.3289,  ..., 194.2130, 196.9759, 198.2045],
         [170.4158, 170.4913, 173.2272,  ..., 198.4417, 200.8939, 201.5108],
         [172.2401, 172.4002, 175.0823,  ..., 202.3627, 204.6158, 204.7044],
         ...,
         [179.1003, 179.7392, 181.8103,  ..., 219.0557, 219.7354, 217.0515],
         [179.9507, 180.9827, 182.9326,  ..., 222.0980, 222.6373, 219.0173],
         [180.6328, 182.0520, 183.8815,  ..., 224.8321, 225.2452, 220.9207]],

        [[170.3773, 163.1777, 164.9496,  ...

tensor([[[140.9173, 139.3808, 143.0938,  ..., 151.3327, 153.1840, 156.3114],
         [140.4405, 139.1563, 143.3677,  ..., 150.7424, 152.9844, 156.7570],
         [139.8801, 138.9560, 143.5899,  ..., 150.1125, 152.7334, 157.0061],
         ...,
         [157.3523, 178.0531, 193.8311,  ..., 212.0475, 245.7952, 254.0455],
         [184.8979, 197.2232, 198.8508,  ..., 253.6534, 256.8822, 249.1307],
         [184.9965, 196.9095, 198.3530,  ..., 252.7211, 255.7290, 248.5715]],

        [[180.9045, 179.3309, 181.2194,  ..., 193.7661, 194.5552, 205.5668],
         [181.8044, 180.5764, 182.4273,  ..., 194.2320, 195.5254, 209.6707],
         [182.7121, 182.0007, 183.9548,  ..., 195.7139, 198.0297, 214.3518],
         ...,
         [188.6166, 189.3567, 192.0507,  ..., 211.2789, 218.3216, 237.1196],
         [189.4952, 190.7027, 193.4866,  ..., 215.0079, 222.6020, 240.9502],
         [190.2727, 192.0045, 194.8320,  ..., 218.7430, 226.7386, 244.4190]],

        [[193.6786, 195.9647, 197.4272,  ...

tensor([[[143.7014, 147.4034, 149.7274,  ..., 158.0943, 160.9670, 160.8710],
         [143.8806, 148.0928, 150.1322,  ..., 158.3595, 161.5924, 161.6005],
         [149.7186, 150.3384, 153.9067,  ..., 161.6456, 165.4156, 190.9856],
         ...,
         [184.8964, 177.0226, 178.8296,  ..., 201.6156, 202.0021, 206.3380],
         [182.1218, 174.4225, 176.3175,  ..., 195.9892, 197.0098, 202.7622],
         [179.1435, 171.6829, 173.5807,  ..., 190.2097, 191.6235, 198.8085]],

        [[191.2194, 191.1478, 189.2341,  ..., 216.1878, 219.7512, 216.4610],
         [192.0881, 191.4824, 190.7965,  ..., 206.3939, 211.1889, 214.8210],
         [187.2976, 191.2692, 191.6033,  ..., 204.5356, 205.6477, 205.0455],
         ...,
         [187.3818, 182.6184, 184.0687,  ..., 200.0096, 198.2412, 202.6681],
         [187.4088, 182.9656, 184.2218,  ..., 200.1356, 198.9554, 202.6321],
         [187.3483, 183.2424, 184.2131,  ..., 200.1733, 199.4235, 202.3276]],

        [[182.1218, 174.4225, 176.3175,  ...

tensor([[[183.9973, 207.5577, 216.9248,  ..., 266.6104, 283.4888, 278.0063],
         [184.0189, 207.6799, 217.0970,  ..., 266.8494, 283.8486, 278.3354],
         [184.0977, 207.8019, 217.2664,  ..., 267.0922, 284.2326, 278.7010],
         ...,
         [206.2892, 204.1145, 192.3836,  ..., 258.9955, 234.4068, 196.4767],
         [204.7184, 207.6697, 206.5533,  ..., 265.7315, 264.4875, 245.0603],
         [196.1558, 202.7996, 207.0764,  ..., 243.2401, 262.0546, 264.2183]],

        [[204.9524, 210.0826, 209.0609,  ..., 271.0926, 265.9277, 241.5192],
         [204.5579, 209.5434, 208.5773,  ..., 270.2264, 265.2355, 241.3901],
         [204.3496, 209.1743, 208.2215,  ..., 269.2859, 264.4600, 241.0795],
         ...,
         [202.3601, 206.3589, 205.7241,  ..., 262.9072, 259.3088, 239.5181],
         [201.8213, 205.6713, 205.1607,  ..., 261.5602, 258.3493, 239.3752],
         [201.2525, 204.9219, 204.5301,  ..., 259.8580, 256.9363, 238.8575]],

        [[161.1871, 177.0588, 197.0456,  ...

### Plot distance and loss over episodes

In [None]:
plot.plot_euclidean_distance(train_dis, test_dis)
plot.plot_losses(train_loss[10:], test_loss[10:])

### Histogram of the distribution shift (for test and training distances)

In [None]:
print("Training set \n")
plot.histo_distribution_shift(train_dis_item)
print("Validation set \n")
plot.histo_distribution_shift(test_dis_item)

### Histograms showing the training and validation distance distribution (for test and training distances)

In [None]:
plot.histo_train_val(test_dis_item, train_dis_item)

### Save and load models