In [1]:
# needed libraries
import re
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose
from tensorflow.keras.layers import Flatten, Reshape, Dropout, BatchNormalization, Activation, LeakyReLU

# utilities
import os
from glob import glob
import matplotlib.pyplot as plt
import pathlib
import time
import datetime

from IPython import display

gpu_available = tf.config.list_physical_devices('GPU')
print(gpu_available)

[]


In [None]:
!pip3 install pickle5
import pickle5 as pickle

# store processed data in pkl files
def save_pkl_data(data, filename):
    with open(filename, 'wb') as file:
        pickle.dump(data, file, pickle.HIGHEST_PROTOCOL)
        print("data stored succesfully to: ", filename)


# read processed data in pkl files
def load_pkl_data(filename):
    with open(filename, 'rb') as file:
        data = pickle.load(file)
    return data

Collecting pickle5
  Downloading pickle5-0.0.12-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (256 kB)
[?25l[K     |█▎                              | 10 kB 24.3 MB/s eta 0:00:01[K     |██▋                             | 20 kB 10.0 MB/s eta 0:00:01[K     |███▉                            | 30 kB 7.0 MB/s eta 0:00:01[K     |█████▏                          | 40 kB 6.3 MB/s eta 0:00:01[K     |██████▍                         | 51 kB 5.5 MB/s eta 0:00:01[K     |███████▊                        | 61 kB 5.5 MB/s eta 0:00:01[K     |█████████                       | 71 kB 5.3 MB/s eta 0:00:01[K     |██████████▎                     | 81 kB 6.0 MB/s eta 0:00:01[K     |███████████▌                    | 92 kB 4.8 MB/s eta 0:00:01[K     |████████████▉                   | 102 kB 5.2 MB/s eta 0:00:01[K     |██████████████                  | 112 kB 5.2 MB/s eta 0:00:01[K     |███████████████▍                | 122 kB 5.2 MB/s eta 0:00:01[K     |████████████████▋           

In [None]:
cubes = load_pkl_data('nusc_inps.pkl') 

# Masking

In [2]:
def get_look_ahead_mask(input):
  input_shape = list(input.shape)[:-1]
  input_shape.insert(-1, input_shape[-1])
  input_shape.insert(1, 1)
  mask = 1 - tf.linalg.band_part(tf.ones(input_shape), -1, 0)
  return mask

In [None]:
def adapt_spatial_mask(mask):
  return mask[np.newaxis, : , np.newaxis, : ]         #(1, seq, 1, neighbors)

# Positional Encoding

In [None]:
def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

In [None]:
def positional_encoding(max_position, d_model):
  angle_rads = get_angles(np.arange(max_position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)

# Attention

In [None]:
def ScaledDotProduct(Q, K, V, mask=None):
    dk = tf.cast(tf.shape(K)[-1], tf.float32)

    # compute attention 
    KT = tf.transpose(K, [0, 1, 2, 4, 3])                 
    attention = tf.matmul(Q, KT)/tf.sqrt(dk)

    # mask if necessary
    if mask is not None:
      #print(attention.shape)
      attention += (mask * -1e9)

    # compute values and weighted sum of their attention
    weights = tf.nn.softmax(attention, axis=-1)
    output = tf.matmul(weights, V)

    return output, weights 

In [None]:
class MultiHeadAttention(keras.layers.Layer):
  def __init__(self, dk=256, num_heads=8):
    super(MultiHeadAttention, self).__init__()
    
    # params
    self.num_heads = num_heads
    self.dk = dk
    self.dk_by_head = dk//num_heads

    # layers
    self.WQ = keras.layers.Dense(dk)
    self.WK = keras.layers.Dense(dk)
    self.WV = keras.layers.Dense(dk)
    self.dense = keras.layers.Dense(dk)
    
  def splitheads(self, x):
    batch_size, seq_length = x.shape[0:2]

    # spliting the heads done by reshaping last dimension
    x = tf.reshape(x, (batch_size, seq_length, -1, self.num_heads, self.dk_by_head))      #(batch, seq, neighbors, head, features_by_head)
    return tf.transpose(x, (0, 3, 1, 2, 4))                                               #(batch, head, seq, neighbors, features_by_head)

  def call(self, q, k, v, mask=None):
    batch_size, seq_length = q.shape[0:2]

    # projections
    q = self.WQ(q)
    k = self.WK(v)
    v = self.WV(k)

    # split heads
    q = self.splitheads(q)
    k = self.splitheads(k)
    v = self.splitheads(v)

    # compute attention and merge heads
    attn_output, attention = ScaledDotProduct(q, k, v, mask)                              #(batch, head, seq, neighbors, features_by_head)
    attn_output = tf.transpose(attn_output,  (0, 2, 3, 1, 4))                             #(batch, seq, neighbors, head, features_by_head)
    concat_output = tf.reshape(attn_output, (batch_size, seq_length, -1, self.dk))        #(batch, seq, neighbors, features)
    output = self.dense(concat_output)

    return output, attention


# Transformer Architecture

In [None]:
def get_ffn(d_model, hidden_size, act_func='relu'):
  return keras.models.Sequential([
                                  keras.layers.Dense(hidden_size, activation=act_func),
                                  keras.layers.Dense(d_model)
  ], name='SEQ')

In [None]:
class EncoderLayer(keras.layers.Layer):
  def __init__(self, dk=256, num_heads=8, hidden_layer_size=256, use_dropout=True, drop_rate=0.1):
    super(EncoderLayer, self).__init__()
    # params
    self.use_dropout = use_dropout

    # layers
    self.MH = MultiHeadAttention(dk, num_heads)
    self.ffn = get_ffn(dk, dk, 'relu')
    self.normLayer1 = keras.layers.LayerNormalization(epsilon=1e-6)
    self.normLayer2 = keras.layers.LayerNormalization(epsilon=1e-6)
    self.dropout1 = keras.layers.Dropout(drop_rate)
    self.dropout2 = keras.layers.Dropout(drop_rate)

  def call(self, x, training, mask):
    # multihead attention
    attn_output, _ = self.MH(x, x, x, mask)

    # dropout layer
    if self.use_dropout and training:
      attn_output = self.dropout1(attn_output)
    
    # normalization and feed forward layers
    z = self.normLayer1(x + attn_output)
    output = self.ffn(z)

    # dropout layer
    if self.use_dropout and training:
      output = self.dropout2(output)
    
    # normalization layer
    output = self.normLayer2(z + output)

    return output 

In [None]:
sample_encoder_layer = EncoderLayer()

In [None]:
samp_inp = tf.random.uniform((3, 20, 6, 256))
out = sample_encoder_layer(samp_inp, True, None)
out.shape

TensorShape([3, 20, 6, 256])

In [None]:
class DecoderLayer(keras.layers.Layer):
  def __init__(self, dk=256, num_heads=8, hidden_layer=256, use_dropout=True, drop_rate=0.1):
    super(DecoderLayer, self).__init__()

    #params
    self.use_dropout = use_dropout

    # layers
    self.SAMH = MultiHeadAttention(dk, num_heads)
    self.EDMH = MultiHeadAttention(dk, num_heads)
    self.ffn = get_ffn(dk, hidden_layer)

    self.normLayer1 = keras.layers.LayerNormalization(epsilon=1e-6)
    self.normLayer2 = keras.layers.LayerNormalization(epsilon=1e-6)
    self.normLayer3 = keras.layers.LayerNormalization(epsilon=1e-6)\

    self.dropout1 = keras.layers.Dropout(drop_rate)
    self.dropout2 = keras.layers.Dropout(drop_rate)
    self.dropout3 = keras.layers.Dropout(drop_rate)
  
  def call(self, x, enc_output, training, look_ahead_mask, padding_mask):

    # self attention computation
    self_attn_out, self_attn = self.SAMH(x, x, x, look_ahead_mask)

    if self.use_dropout and training:
      self_attn_out = self.dropout1(self_attn_out)
    
    z = self.normLayer1(x + self_attn_out) 

    # encoder decoder computation
    enc_dec_out, enc_dec_attn = self.EDMH(z, enc_output, enc_output, padding_mask)

    if self.use_dropout and training:
      enc_dec_out = self.dropout2(enc_dec_out)
    
    z = self.normLayer2(z + enc_dec_out)

    # feed forward computation
    output = self.ffn(z)

    if self.use_dropout and training:
      output = self.dropout3(output)
    
    output = self.normLayer3(z + output)

    return output, self_attn, enc_dec_attn


In [None]:
sample_decoder_layer = DecoderLayer()

In [None]:
dec_inp = tf.random.uniform((3, 20, 6, 256))
out2 = sample_decoder_layer(dec_inp, out, True, None, None)
out2[0].shape

TensorShape([3, 20, 6, 256])

In [None]:
class Encoder(keras.layers.Layer):
  def __init__(self, features_size, max_size, dk_model=256, num_heads=8, num_encoders=6, 
               enc_hidden_size=256, use_pos_emb=True, use_dropout=True, drop_rate=0.1):
    super(Encoder, self).__init__()

    # params
    self.dk_model = dk_model
    self.max_size = max_size
    self.use_dropout = use_dropout
    self.use_pos_emb = use_pos_emb
    self.enc_hidden_size = enc_hidden_size
    self.num_encoders = num_encoders

    # layers
    #self.embedding = keras.layers.Embedding(features_size, dk_model)
    self.embedding = keras.layers.Dense(dk_model)
    self.encoders_stack = [EncoderLayer(dk_model, num_heads, enc_hidden_size, use_dropout, drop_rate) for _ in range(num_encoders)]
    self.dropout = tf.keras.layers.Dropout(drop_rate)
  
  def call(self, x, padding_mask, training):
    x = self.embedding(x)
    x *= tf.math.sqrt(tf.cast(self.dk_model, tf.float32))

    if self.use_pos_emb:
      x += positional_encoding(self.max_size, self.dk_model)
    
    if self.use_dropout and training:
      x = self.dropout(x)
    
    for encoder_layer in self.encoders_stack:
      x = encoder_layer(x, training, padding_mask)
    
    return x

In [None]:
samp_inp = tf.random.uniform((3, 6, 20, 256))
encoder = Encoder(256, 20, 256)
out = encoder(samp_inp, None, True)
out.shape

HI


TensorShape([3, 6, 20, 256])

In [None]:
class Decoder(keras.layers.Layer):
  def __init__(self, features_size, max_size, dk_model=256, num_heads=8, num_decoders=6, 
               dec_hidden_size=256, use_pos_emb=True, use_dropout=True, drop_rate=0.1):
    
    super(Decoder, self).__init__()

    # params
    self.dk_model = dk_model
    self.max_size = max_size
    self.use_dropout = use_dropout
    self.use_pos_emb = use_pos_emb
    self.dec_hidden_size = dec_hidden_size
    self.num_decoders = num_decoders

    # layers
    self.embedding = keras.layers.Dense(dk_model)
    self.decoders_stack = [DecoderLayer(dk_model, num_heads, dec_hidden_size, use_dropout, drop_rate) for _ in range(num_decoders)]
    self.dropout = tf.keras.layers.Dropout(drop_rate)

  def call(self, x, enc_output, look_ahead_mask, padding_mask, training):
    #print(x)
    x = self.embedding(x)
    #print(x)
    x *= tf.math.sqrt(tf.cast(self.dk_model, tf.float32))
    if self.use_pos_emb:
      x += positional_encoding(self.max_size, self.dk_model)
    
    #print(x)
    if self.use_dropout and training:
      x = self.dropout(x)
    
    for decoder_layer in self.decoders_stack:
      x, attn1, attn2, = decoder_layer(x, enc_output, training, look_ahead_mask, padding_mask)
    
    return x


In [None]:
samp_inp = tf.random.uniform((3, 6, 20, 256))
decoder = Decoder(256, 20, 256)
out2 = decoder(samp_inp, out, None, None, True)
out2.shape

HI
xs:  (3, 6, 20, 256)
pe:  (1, 20, 256)
xs:  (3, 6, 20, 256)


TensorShape([3, 6, 20, 256])

In [None]:
class STTransformer(keras.Model):
  def __init__(self, features_size, max_seq_size, max_neighbors_size, 
               sp_dk=256, sp_enc_heads=8, sp_dec_heads=8, sp_num_encoders=6, sp_num_decoders=6, 
               tm_dk=256, tm_enc_heads=8, tm_dec_heads=8, tm_num_encoders=6, tm_num_decoders=6, 
               dec_hidden_size=256, use_dropout=True, drop_rate=0.1):
    
    super(STTransformer, self).__init__()

    # layers
    self.sp_encoder = Encoder(features_size, max_neighbors_size, sp_dk, use_pos_emb=False)
    self.sp_decoder = Decoder(features_size, max_neighbors_size, sp_dk, use_pos_emb=False)
    self.tm_encoder = Encoder(features_size, max_seq_size, tm_dk)
    self.tm_decoder = Decoder(features_size, max_seq_size, tm_dk)
    self.linear = tf.keras.layers.Dense(3, name='Linear_Trans')

    
  def call(self, inputs, masks, training):
    inp, targets = inputs
    inp_mask, tar_mask = masks
    

    sp_enc_out = self.sp_encoder(inp,  inp_mask, training)                             #(batch, seq, neighbors, <spatial attn features>)
    out = tf.transpose(sp_enc_out, [0, 2, 1, 3])                                       #(batch, neighbors, seq, <spatial attn features>)
    tm_enc_out = self.tm_encoder(out, None, training)                                  #(batch, neighbots, seq, <time attn features>)
    
    # decode time
    targets = tf.transpose(targets, [0, 2, 1, 3])                                      #(batch, neighbors, seq, features)
    look_mask = get_look_ahead_mask(targets)
    tm_dec_out = self.tm_decoder(targets, tm_enc_out, look_mask, None, training)
    out2 = tf.transpose(tm_dec_out, [0, 2, 1, 3])                                      #(batch, seq, neighbors, features)
    sp_dec_out = self.sp_decoder(out2, sp_enc_out, None, tar_mask, training)
    
    # linear projection
    output = self.linear(sp_dec_out)
    return output

In [None]:
model = STTransformer(100, 20, 6)

In [None]:
input = tf.random.uniform((8, 8, 10, 5))

In [None]:
target = tf.random.uniform((8, 8, 10, 5))

In [None]:
inputs = (input, target)

In [None]:
o = model(inputs, (None, None))

In [None]:
o.shape

TensorShape([8, 10, 6, 3])

In [None]:
print(o)

In [None]:
def buildDataSet(input, batch_size):
  input_dataset = tf.data.Dataset.from_tensor_slices([x[0].astype(np.float32) for x in input])
  mask_dataset = tf.data.Dataset.from_tensor_slices([ adapt_spatial_mask(x[1].astype(np.float32)) for x in input])
  #mask_dataset = adapt_spatial_mask(mask_dataset)
  dataset = tf.data.Dataset.zip((input_dataset, mask_dataset))
  dataset = dataset.shuffle(40)
  dataset = dataset.batch(batch_size)
  return dataset

In [None]:
dataset = buildDataSet(cubes, 16)

In [None]:
loss_function = tf.keras.losses.MeanSquaredError()

In [None]:
optimizer = optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)

In [None]:
model = STTransformer(16, 8, 10)

In [None]:
#@tf.function
def train_step(zipped_input, l, losses):
  inputs = zipped_input[0]
  masks = zipped_input[1]

  # divide input as the trajectory input, and target (basically past and future to predict) 
  inp, tar = inputs[:, :l, :, :], inputs[:, l-1:, :, :]                   
  mask_inp, mask_tar = masks[:, :, :l, :, :], masks[:, :, l-1:, :, :]
  
  # get only x, y, and rotation
  targets = tar[:, :, :, :3]                                            

  with tf.GradientTape() as tape:
    predictions = model((inp, tar), (mask_inp, mask_tar), training=True)
    loss = loss_function(targets, predictions)

  #print('targets: ', targets.numpy())
  #print(predictions)
  print('loss: ', loss)
  losses.append(loss)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  return losses

In [None]:
model.summary()

In [None]:
for epoch in range(20):
  print('epoch: ', epoch)
  losses = []
  for batch in dataset:
    losses = train_step(batch, 8, losses)

  print("avg loss", tf.reduce_mean(losses)) 

epoch:  0
loss:  tf.Tensor(987.5495, shape=(), dtype=float32)
loss:  tf.Tensor(1939.2532, shape=(), dtype=float32)
loss:  tf.Tensor(1026.8756, shape=(), dtype=float32)
loss:  tf.Tensor(866.13904, shape=(), dtype=float32)
loss:  tf.Tensor(1261.075, shape=(), dtype=float32)
loss:  tf.Tensor(1148.3922, shape=(), dtype=float32)
loss:  tf.Tensor(1421.7554, shape=(), dtype=float32)
loss:  tf.Tensor(1280.4187, shape=(), dtype=float32)
loss:  tf.Tensor(820.8073, shape=(), dtype=float32)
loss:  tf.Tensor(1481.7888, shape=(), dtype=float32)
loss:  tf.Tensor(1886.6484, shape=(), dtype=float32)
loss:  tf.Tensor(1946.199, shape=(), dtype=float32)
loss:  tf.Tensor(2714.373, shape=(), dtype=float32)
loss:  tf.Tensor(1900.1019, shape=(), dtype=float32)
loss:  tf.Tensor(2521.305, shape=(), dtype=float32)
loss:  tf.Tensor(2401.936, shape=(), dtype=float32)
loss:  tf.Tensor(2739.5632, shape=(), dtype=float32)
loss:  tf.Tensor(1523.3484, shape=(), dtype=float32)
loss:  tf.Tensor(1330.3943, shape=(), dtype

In [None]:
ld = list(dataset)

In [None]:
inputs = ld[0][0]
masks = ld[0][1]

# divide input as the trajectory input, and target (basically past and future to predict) 
inp, tar = inputs[:, :8, :, :], inputs[:, 8-2:, :, :]                   
mask_inp, mask_tar = masks[:, :, :8, :, :], masks[:, :, 8-2:, :, :]
out = model((inp, tar), (mask_inp, mask_tar))

In [None]:
tar

<tf.Tensor: shape=(16, 8, 10, 5), dtype=float32, numpy=
array([[[[-2.50668964e+01,  1.99907112e+01,  5.01123480e-02,
           4.68301438e-02,  1.28885021e-03],
         [ 1.52601032e+01, -1.27532883e+01,  2.04460174e-01,
           7.03465509e+00, -1.41141748e+00],
         [ 2.75801029e+01, -2.00642891e+01,  3.26544609e+01,
           3.67394471e+00,  2.06004500e-01],
         ...,
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
           0.00000000e+00,  0.00000000e+00],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
           0.00000000e+00,  0.00000000e+00],
         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
           0.00000000e+00,  0.00000000e+00]],

        [[-2.50738964e+01,  1.99687119e+01,  1.37068868e-01,
           4.61850390e-02, -1.29053218e-03],
         [ 1.27241030e+01, -1.06922884e+01,  2.04460174e-01,
           6.53737402e+00, -9.94808674e-01],
         [ 2.60571041e+01, -1.95842876e+01,  2.49544601e+01,
           3.194491

In [None]:
out

<tf.Tensor: shape=(16, 8, 10, 3), dtype=float32, numpy=
array([[[[-3.32868665e-01,  1.79106975e+00,  2.74370174e+01],
         [-3.32868069e-01,  1.79106975e+00,  2.74370174e+01],
         [-3.32868487e-01,  1.79106975e+00,  2.74370193e+01],
         ...,
         [-3.32868785e-01,  1.79106891e+00,  2.74370174e+01],
         [-3.32868785e-01,  1.79106891e+00,  2.74370174e+01],
         [-3.32868785e-01,  1.79106891e+00,  2.74370174e+01]],

        [[-2.80308574e-01,  1.56844592e+00,  2.74878788e+01],
         [-2.80308098e-01,  1.56844592e+00,  2.74878807e+01],
         [-2.80307978e-01,  1.56844544e+00,  2.74878807e+01],
         ...,
         [-2.80308634e-01,  1.56844592e+00,  2.74878769e+01],
         [-2.80308634e-01,  1.56844592e+00,  2.74878769e+01],
         [-2.80308634e-01,  1.56844592e+00,  2.74878769e+01]],

        [[-3.15515734e-02,  1.34530592e+00,  2.75404224e+01],
         [-3.15514542e-02,  1.34530556e+00,  2.75404205e+01],
         [-3.15516330e-02,  1.34530628e+00, 

# Testing functions

In [None]:
for i in range(len(cubes)):
  for j in range(len(cubes[i][0])):
    for k in range(len(cubes[i][0][j])):
      for l in range(len(cubes[i][0][j][k])):
        if np.isnan(cubes[i][0][j][k][l]):
          cubes[i][0][j][k][l] = 0.0      


In [None]:
all_inps = [x[0] for x in cubes]
for inp in all_inps:
  for face in inp:
    for row in face:
      for el in row:
        if np.isnan(el):
          print('WHAAAAT')

In [None]:
np.arange(10)[:, np.newaxis]

array([[0],
       [1],
       [2],
       [3],
       [4],
       [5],
       [6],
       [7],
       [8],
       [9]])

In [None]:
t = tf.constant(np.arange(3 * 4 * 3 * 5 * 5)) + 1    
t = tf.reshape(t, (3, 4, 3, 5, 5))             #(batch, head, seq, N, N)
t = tf.cast(t, tf.float32)
t2 = np.random.choice([0, 1], (3, 3, 5)) * 0.5

In [None]:
t2 = tf.reshape(t2, (3, 1, 3, 1, 5))          # (batch, 1, seq, 1, N)
t2 = tf.cast(t2, tf.float32)

In [None]:
t2

<tf.Tensor: shape=(3, 1, 3, 1, 5), dtype=float32, numpy=
array([[[[[0. , 0. , 0. , 0.5, 0. ]],

         [[0. , 0.5, 0.5, 0.5, 0.5]],

         [[0. , 0.5, 0. , 0.5, 0. ]]]],



       [[[[0.5, 0. , 0.5, 0. , 0. ]],

         [[0.5, 0. , 0. , 0. , 0.5]],

         [[0. , 0. , 0.5, 0.5, 0.5]]]],



       [[[[0. , 0. , 0.5, 0. , 0. ]],

         [[0. , 0. , 0. , 0.5, 0. ]],

         [[0.5, 0.5, 0.5, 0. , 0. ]]]]], dtype=float32)>

In [None]:
t + t2

<tf.Tensor: shape=(3, 4, 3, 5, 5), dtype=float32, numpy=
array([[[[[  1. ,   2. ,   3. ,   4.5,   5. ],
          [  6. ,   7. ,   8. ,   9.5,  10. ],
          [ 11. ,  12. ,  13. ,  14.5,  15. ],
          [ 16. ,  17. ,  18. ,  19.5,  20. ],
          [ 21. ,  22. ,  23. ,  24.5,  25. ]],

         [[ 26. ,  27.5,  28.5,  29.5,  30.5],
          [ 31. ,  32.5,  33.5,  34.5,  35.5],
          [ 36. ,  37.5,  38.5,  39.5,  40.5],
          [ 41. ,  42.5,  43.5,  44.5,  45.5],
          [ 46. ,  47.5,  48.5,  49.5,  50.5]],

         [[ 51. ,  52.5,  53. ,  54.5,  55. ],
          [ 56. ,  57.5,  58. ,  59.5,  60. ],
          [ 61. ,  62.5,  63. ,  64.5,  65. ],
          [ 66. ,  67.5,  68. ,  69.5,  70. ],
          [ 71. ,  72.5,  73. ,  74.5,  75. ]]],


        [[[ 76. ,  77. ,  78. ,  79.5,  80. ],
          [ 81. ,  82. ,  83. ,  84.5,  85. ],
          [ 86. ,  87. ,  88. ,  89.5,  90. ],
          [ 91. ,  92. ,  93. ,  94.5,  95. ],
          [ 96. ,  97. ,  98. ,  99.5, 100

In [None]:
mask = np.random.choice([0, 1], size=(3, 5))

In [None]:
mask

array([[0, 1, 0, 0, 1],
       [1, 1, 0, 0, 1],
       [0, 1, 0, 1, 0]])

In [None]:
adapt_spatial_mask(mask).shape

(1, 3, 1, 5)