In [None]:

from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive/My Drive/InImNet/

In [None]:
import os
import tensorflow as tf
#tf.get_logger().setLevel('ERROR')

import matplotlib as mpl
import matplotlib.pyplot as plt
import logging

mpl.rcParams['figure.figsize'] = (8, 6)


In [None]:
class InputAutoencoder (tf.keras.Model):
    def __init__(self, autoenc_lr: float = 0.001, represent_dim_in=50, represent_dim_out=50):
        super(InputAutoencoder, self).__init__()
        
        self.encoder_input = []
        self.decoder_input = []


        self.encoder_input.append(tf.keras.layers.Conv2D(16, (5, 5), strides=2, padding='same'))
        self.encoder_input.append(tf.keras.layers.BatchNormalization())
        self.encoder_input.append(tf.keras.layers.ReLU())
        self.encoder_input.append(tf.keras.layers.Conv2D(32, (5, 5), strides=2, padding='same'))
        self.encoder_input.append(tf.keras.layers.BatchNormalization())
        self.encoder_input.append(tf.keras.layers.ReLU())
        self.encoder_input.append(tf.keras.layers.Conv2D(64, (5, 5), strides=2, padding='same'))
        self.encoder_input.append(tf.keras.layers.ReLU())
        #self.encoder_input.append(tf.keras.layers.BatchNormalization())
        #self.encoder_input.append(tf.keras.layers.Conv2D(4, (5, 5), strides=1, activation='relu', padding='same'))
        #self.encoder_input.append(tf.keras.layers.Flatten())


        self.dim_reducer = tf.keras.layers.Dense (represent_dim_in)
        self.dim_reconstructor = tf.keras.layers.Dense (represent_dim_out)
        
        #self.max_pool_output_shape = [6, 6, 8]
        #self.flattened_max_pool_output_shape = 288

        #self.decoder_input.append(tf.keras.layers.Dense(self.flattened_max_pool_output_shape, activation='relu')) 
        #self.decoder_input.append(tf.keras.layers.Reshape(self.max_pool_output_shape))
        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=2, padding='same'))
        self.decoder_input.append(tf.keras.layers.BatchNormalization())
        self.decoder_input.append(tf.keras.layers.ReLU())
        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=2, padding='same'))
        self.decoder_input.append(tf.keras.layers.BatchNormalization())
        self.decoder_input.append(tf.keras.layers.ReLU())
        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=2, padding='same'))
        self.decoder_input.append(tf.keras.layers.BatchNormalization())
        self.decoder_input.append(tf.keras.layers.ReLU())
        self.decoder_input.append(tf.keras.layers.Conv2DTranspose(4, (5, 5), padding='same'))
        lr_decayed_fn = autoenc_lr#tf.keras.optimizers.schedules.CosineDecay(autoenc_lr, 32)
        self.optimiser_autoencoder = tf.keras.optimizers.Adam(learning_rate=lr_decayed_fn)  


    def call_encoder(self, x, training):
        z = x               
        for i in range (len(self.encoder_input)):       
            z = self.encoder_input[i](z, training=training)   
            print ('encoder['+str(i)+'].shape: ', z.shape)
        return z

    def call_dim_reducer(self, x, training):
        return self.dim_reducer (x, training=training) 

    def call_dim_reconstructor(self, x, training):
        return self.dim_reconstructor (x, training=training) 

    def call_decoder (self, x, training):
        z = x               
        for i in range (len(self.decoder_input)):       
            z = self.decoder_input[i](z, training=training)   
            print ('decoder['+str(i)+'].shape: ', z.shape)
        return z
    
    def optimise_autoencoder (self, loss, g): 
        self.optimiser_autoencoder.minimize (loss, [w.trainable_weights for w in self.encoder_input+self.decoder_input+[self.dim_reducer]+[self.dim_reconstructor]], tape=g)
        return 0

In [None]:
class AutoDiffInImNet(tf.keras.Model):
    """Choose output method for model."""

    def __init__(self, dim: int, represent_dim_out: int,
                 num_resnet_layers: int,
                 activation,
                 num_internal_layers: int,
                 bias_on: bool = True,
                 mult: int = 1,               
                 use_batch_norm: bool = False,
                 dropout=0.1,
                 lr_inim = 0.001,
                 lr_autoencoder = 0.001,
                 approx_jacobian = True,
                 weight_regularisation_alpha = 0.0):
        super(AutoDiffInImNet, self).__init__()
        self.num_internal_layers = num_internal_layers
        self.num_resnet_layers = num_resnet_layers
        self.activation = activation
        self.dim = dim
        self.lr_inim = lr_inim#tf.keras.optimizers.schedules.PiecewiseConstantDecay (boundaries=[128], values=[lr_inim/50, lr_inim])#tf.keras.optimizers.schedules.ExponentialDecay(lr_inim, decay_steps=15625, decay_rate=0.5, staircase=True)
        self.represent_dim_out = represent_dim_out
        self.weight_regularisation_alpha = weight_regularisation_alpha
 
        self.optimiser = []
        for i in range (self.num_resnet_layers):
            self.optimiser.append(tf.keras.optimizers.Adam(learning_rate=self.lr_inim))
        self.approx_jacobian = approx_jacobian
        self.autoencoder =  []
        for i in range (self.num_resnet_layers):
            self.autoencoder.append(InputAutoencoder (lr_autoencoder, represent_dim_in=dim, represent_dim_out=represent_dim_out))

        
    def call_phi (self, ell, x, training):
        z = x          
        #for i in range (len(self.fc_network[ell])):  
        encoded_images = self.autoencoder[ell].call_encoder (x, training=training)  
        z = self.autoencoder[ell].call_decoder (encoded_images, training=training)          
            #z = self.fc_network[ell][i](z, training=training)    
        return z

    #@tf.function
    def optimise (self, x, y_in, t, loss): 
        self.output_dim = y_in.shape[2]
        #print ('optimise')

        with tf.GradientTape(persistent=True) as g:
            z = self (x, t, training=True)
            y_in_shape = self._infer_shape (y_in)
            y = tf.reshape (y_in, [y_in_shape[0] * y_in_shape[1], y_in_shape[2], y_in_shape[3]])
            losses = []
            total_loss = 0
            for i in range (self.num_resnet_layers):
                #print (y[:, i, :])
                curr_loss = tf.math.reduce_mean(loss(z[i+1], y))
                losses.append(curr_loss)
                total_loss = total_loss + curr_loss
            print ('losses: ', losses)
            final_loss = losses[-1]

        w = []
        for i in range (self.num_resnet_layers):
            w.extend ([fc_ij.trainable_weights for fc_ij in self.autoencoder[i].encoder_input+self.autoencoder[i].decoder_input])
        self.optimiser[0].minimize (final_loss, w, tape=g)

        #for i in range (len(losses)):
        #   self.autoencoder.optimise_autoencoder (losses[i], g)
        
        
        return final_loss
                  
    def jacobian (self, z, x, g):
        return g.batch_jacobian(z, x) 

    def _infer_shape(self, x):
        x = tf.convert_to_tensor(x)

        # If unknown rank, return dynamic shape
        if x.shape.dims is None:
            return tf.shape(x)

        static_shape = x.shape.as_list()
        dynamic_shape = tf.shape(x)

        ret = []
        for i in range(len(static_shape)):
            dim = static_shape[i]
            if dim is None:
                dim = dynamic_shape[i]
            ret.append(dim)

        return ret

    #@tf.function
    def call(self, x, t, training=False):
        x_old_shape =  self._infer_shape(x)
        print ('x_old_shape: ', x_old_shape)
        #inferred_shape = x_old_shape
        #x = tf.reshape (x, [x_old_shape[0] * x.shape[1], x.shape[2], x.shape[3]])
        #x = tf.expand_dims (x, -1)
        x = tf.transpose (x, [0, 2, 3, 1])

        x_reshaped = tf.repeat (x, t.shape[0], axis=0)
        t_reshaped = tf.tile(tf.reshape(t, [t.shape[0], 1, 1, 1]), [x_old_shape[0], x.shape[1], x.shape[2], 1])
        x = tf.concat((x_reshaped[:, :, :, :], t_reshaped), axis=3)
        #encoded_images = self.autoencoder.call_encoder (x, training=training)
    

        #shape_encoded = self._infer_shape(encoded_images)
        #dim = tf.reduce_prod(shape_encoded[1:])
        #encoded_images = tf.reshape(encoded_images, [-1, dim])
        #encoded_images_shape = self._infer_shape(encoded_images)
        #encoded_images = tf.reshape (encoded_images, [x_old_shape[0], x_old_shape[1], encoded_images_shape[1]])
        #encoded_images_shape = self._infer_shape(encoded_images)
        #encoded_images = tf.reshape (encoded_images, [shape_encoded[0], dim])
        #print (encoded_images.shape)
        #encoded_images = self.autoencoder.call_dim_reducer (encoded_images, training=training)

        #t_shape = self._infer_shape(t)
        #print ('t_shape: ', t_shape)
        #t_reshaped = tf.tile(tf.reshape(t, [t_shape[0], 1]), [x_old_shape[0], 1])
        #x_reshaped = tf.repeat (encoded_images, t_shape[0], axis=0)
        #x_reshaped = tf.reshape(x_reshaped, [x_reshaped.shape[0], x_reshaped.shape[1]*x_reshaped.shape[2]])
        #x = tf.concat([x_reshaped, t_reshaped], axis=1)
        
        z = [x]
        print ('x.shape: ', x.shape)
        print (z)
        with tf.GradientTape(persistent=True) as g:
          g.watch (x)
          for ell in range (self.num_resnet_layers):
              print ('ell: ', ell)
              dim = tf.reduce_prod(tf.shape(x)[1:])
              x_flattened = tf.reshape(x, [-1, dim])
              z_flattened = tf.reshape(z[-1], [-1, dim])
              print ('tf.shape(x): ', tf.shape(x))
              with g.stop_recording():
                  if not self.approx_jacobian:
                      jacobian_z_x = self.jacobian(z_flattened[-1], x_flattened, g) 
                  else:
                      if ell == 0:
                          jacobian_z_x = self.jacobian(x_flattened, x_flattened, g) 
                      else:
                          jacobian_z_x +=  self.jacobian(phi_curr_flattened, x_flattened, g) 
                  #print (jacobian_z_x.shape)
              phi_curr = self.call_phi (ell, x, training=training)
              print ('phi_curr.shape: ', phi_curr.shape)
              phi_curr_flattened = tf.reshape(phi_curr, [-1, tf.reduce_prod(phi_curr.shape[1:])])
              delta = jacobian_z_x @  tf.expand_dims(phi_curr_flattened, -1)  
              delta = tf.reshape(delta, z[-1].shape)
              #print(delta.shape)
              z.append(z[-1] +  delta)

        for ell in range (self.num_resnet_layers+1): 
        #    print (ell)  
        #    z[ell] = z[ell][:, :-1]
        #    z[ell] = self.autoencoder.call_dim_reconstructor (z[ell], training=training)
        #    z_ell_shape = self._infer_shape(z[ell])
        #    z[ell] = tf.reshape (z[ell], [z_ell_shape[0], shape_encoded[1], shape_encoded[2], shape_encoded[3]])
        #    z[ell] = self.autoencoder.call_decoder (z[ell],training=training)
            z[ell] = z[ell] [:, :, :, 0]
        
        print ('call finished')
        return z

In [None]:
# Activate InImNet (True) or ResNet (False) during training/testing
inim_on_training = True
inim_on_testing = True

# Plotting options
save_plots = True
view_plots = False
initial_data_plot = False

In [None]:
import numpy as np
import math 
# Logging initiation
logger = logging.getLogger()
logger.setLevel(logging.INFO)  # DEBUG, INFO, WARNING, ERROR, or CRITICAL
logger.addHandler(logging.StreamHandler())


In [None]:
import matplotlib.pyplot as plt
import h5py

print ('Training data')
with h5py.File('data/bouncing_ball_data/training.hkl', 'r') as f:
    print (f.keys())
    trajectories_train = np.array(f['data_0'])
    trajectories_train = np.reshape (trajectories_train,\
                                     [trajectories_train.shape[0], trajectories_train.shape[1], 32, 32])
    #trajectories_train = trajectories_train[:, :, ::2, ::2]
    print (trajectories_train.shape)
plt.imshow (trajectories_train[0, 0, :, :])
plt.show()

print ('Testing data')
with h5py.File('data/bouncing_ball_data/test.hkl', 'r') as f:
    print (f.keys())
    trajectories_test = np.array(f['data_0'])
    trajectories_test = np.reshape (trajectories_test,\
                                     [trajectories_test.shape[0], trajectories_test.shape[1], 32, 32])
    
    #trajectories_test = trajectories_test[:, :, ::2, ::2]
    print (trajectories_test.shape)
plt.imshow (trajectories_test[0, 0, :, :])
plt.show()

print ('Validation data')
with h5py.File('data/bouncing_ball_data/val.hkl', 'r') as f:
    print (f.keys())
    trajectories_val = np.array(f['data_0'])
    trajectories_val = np.reshape (trajectories_val,\
                                     [trajectories_val.shape[0], trajectories_val.shape[1], 32, 32])
    
    #trajectories_val = trajectories_val[:, :, ::2, ::2]
    print (trajectories_val.shape)
plt.imshow (trajectories_val[0, 0, :, :])
plt.show()

#test_data = h5F.load('data/bouncing_ball_data/test.hkl')
#val_data = hkl.load('data/bouncing_ball_data/val.hkl')

In [None]:
import matplotlib.animation
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  
plt.ioff()

fig, ax = plt.subplots()
x= np.linspace(0,10,100)
def animate(t):
    plt.cla()
    img = np.array(trajectories_train[0, t, :, :])
    plt.imshow(img)

matplotlib.animation.FuncAnimation(fig, animate, frames=trajectories_train.shape[1])
        
     #print (im['label'])

In [None]:
#num_internal_layers = 3
#batch_size = 1
#autoenc_lr = 0.0001
#losses = []
#t = np.linspace(0, 1, ntotal).astype(np.float32)

#autoencoder =  InputAutoencoder(autoenc_lr = autoenc_lr )

#for epoch in range(num_epochs):
#    indices = np.random.choice (trajectories_train.shape[0], size=batch_size)
#    images = trajectories_train[indices, :, :, :]    
#    images = tf.reshape (images, [images.shape[0] * images.shape[1], images.shape[2], images.shape[3]])
#    images = tf.expand_dims (images, -1)
#    loss = autoencoder.optimise_autoencoder (images, tf.keras.losses.MSE)
#    if epoch % 1000 == 0:
#        print (epoch)
#        print (loss)
#        pred_train = autoencoder.call_decoder(autoencoder.call_encoder (images))
#        print ('pred_train.shape: ', pred_train.shape)
#        fig, axes = plt.subplots(5, images.shape[0]//5)
#        print ('axes.shape[0]: ', axes.shape[0])
#        for i in  range(axes.shape[0]):
#            for j in  range(axes.shape[1]):
#                axes[i][j].get_xaxis().set_visible(False)
#                axes[i][j].get_yaxis().set_visible(False)
#                axes[i][j].imshow(images[i*axes.shape[0]+j, :, :, 0])
#        plt.show()
#        print ('animation ended')    

#        print ('animation started:l1')
#        fig, axes = plt.subplots(5, pred_train.shape[0]//5)
#        for i in  range(axes.shape[0]):
#            for j in  range(axes.shape[1]):
#                axes[i][j].get_xaxis().set_visible(False)
#                axes[i][j].get_yaxis().set_visible(False)
#                axes[i][j].imshow(pred_train[i*axes.shape[0]+j, :, :, 0])
#        plt.show()
#        print ('animation ended: l1')


In [None]:
#NUM_IMAGES_CONTEXT = 3
#indices = np.random.choice (trajectories_train.shape[0], size=batch_size)
#images = trajectories_train[indices, :, :, :]
#x = images [:, :NUM_IMAGES_CONTEXT, :, :]
#x_old_shape = x.shape
#x = tf.reshape (x, [x.shape[0] * x.shape[1], x.shape[2], x.shape[3]])
#x = tf.expand_dims (x, -1)
#encoded_images = autoencoder.call_encoder (x)
#shape_encoded = encoded_images.shape
#dim = tf.reduce_prod(tf.shape(encoded_images)[1:])
#encoded_images = tf.reshape(encoded_images, [-1, dim])
#encoded_images = tf.reshape (encoded_images, [x_old_shape[0], x_old_shape[1], encoded_images.shape[1]])
#encoded_images = tf.reshape (encoded_images, [encoded_images.shape[0], encoded_images.shape[1]*encoded_images.shape[2]])
#print (encoded_images.shape)

#decoded_images = tf.reshape (encoded_images[:, :dim], [x_old_shape[0], shape_encoded[1], shape_encoded[2], shape_encoded[3]])
#decoded_images = autoencoder.call_decoder (decoded_images)
#plt.imshow(decoded_images[0, :, :, 0])


In [None]:
ell_max = 1
ntotal = 50
double_mlp_on = True
triple_mlp_on = True
bias_on = True
inflation_factor = 2
test_activation = tf.keras.activations.relu
num_internal_layers = 3
batch_size = 2
NUM_IMAGES_TOTAL = 20
NUM_IMAGES_EVAL = 13
NUM_IMAGES_CONTEXT = 3
DIM_LATENT = 16 * 64
INNER_DIM_LATENT=50
DROPOUT_VALUE=0.3
lr_inim = 0.001
lr_autoencoder = 0.0004
num_epochs = 500
t = np.linspace(0, 1, NUM_IMAGES_TOTAL).astype(np.float32)
t_eval = np.linspace(0, 1, NUM_IMAGES_TOTAL).astype(np.float32) [:NUM_IMAGES_EVAL]
print ('t: ', t)
print ('t_eval: ', t_eval)

try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(resolver)
    # This is the TPU initialization code that has to be at the beginning.
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))
    strategy = tf.distribute.TPUStrategy(resolver)
except:
    strategy = None
autodiff_inimnet =  AutoDiffInImNet(dim=INNER_DIM_LATENT, represent_dim_out=DIM_LATENT, 
                            num_resnet_layers=ell_max,
                            activation=test_activation,
                            num_internal_layers=num_internal_layers,
                            bias_on=bias_on,
                            mult=inflation_factor,
                            dropout=DROPOUT_VALUE,
                            lr_inim = tf.keras.optimizers.schedules.ExponentialDecay (lr_inim, 30 * trajectories_train.shape[0] // batch_size, 0.5, staircase=True),#tf.keras.optimizers.schedules.CosineDecay(lr_inim, num_epochs * trajectories_train.shape[0] // batch_size),
                            lr_autoencoder = lr_autoencoder,
                            weight_regularisation_alpha = 0.0#0.00001
                            )
losses = []

#images = trajectories_train[indices, :NUM_IMAGES_TOTAL, :, :]
#x = images [:, :NUM_IMAGES_CONTEXT, :, :]
dataset = tf.data.Dataset.from_tensor_slices((trajectories_train))
dataset = dataset.shuffle(512).batch(batch_size)
iter_dataset = iter(dataset)
if strategy is not None:
    dataset = strategy.experimental_distribute_dataset(dataset)

In [None]:
dataset_test = tf.data.Dataset.from_tensor_slices((trajectories_test))
dataset_test = dataset_test.batch(batch_size)
iter_dataset_test = iter(dataset_test)
if strategy is not None:
  dataset_test = strategy.experimental_distribute_dataset(dataset_test)

dataset_valid = tf.data.Dataset.from_tensor_slices((trajectories_val))
dataset_valid = dataset_valid.batch(batch_size)
iter_dataset_valid = iter(dataset_valid)
if strategy is not None:
  dataset_valid = strategy.experimental_distribute_dataset(dataset_valid)

In [None]:
UPDATE_FREQ = 1
SHOW_PLOTS = True
loss_train_x = []
loss_train_y = []
loss_valid_x = []
loss_valid_y = []
loss_test_x = []
loss_test_y = []
losses_per_epoch = []

@tf.function
def distribute_train_step(data):
    def replica_fn (d):
        d_im = d[:, :NUM_IMAGES_TOTAL, :, :]
        d_x = d[:, :NUM_IMAGES_CONTEXT, :, :]
        print ('d_x.shape: ', d_x.shape)
        print ('d_im.shape: ', d_im.shape)
        return autodiff_inimnet.optimise (d_x, d_im, t, tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE))
                      
    if strategy is not None:         
        per_replica_result = [strategy.run(replica_fn, args=(data,))]
        print ('per_replica_result: ', per_replica_result)
        return strategy.gather(per_replica_result, axis=0)
    else:
        results = replica_fn (data)
        return [results]

@tf.function
def distribute_test_step (data):
    def replica_fn_test (d):
        d_im = d[:, :NUM_IMAGES_EVAL, :, :]
        d_x = d_im[:, :NUM_IMAGES_CONTEXT, :, :]
        print ('d_x.shape: ', d_x.shape)
        print ('d_im.shape: ', d_im.shape)
        return d_im, autodiff_inimnet (d_x, t_eval, training=False)
    if strategy is not None: 
        return strategy.gather(strategy.run(replica_fn_test, args=(data,)), axis=0)
    else:
        return replica_fn_test(data)
              
@tf.function
def distribute_test_step_loss (data, async_exec=True):
    def replica_fn_test_loss (d):
        d_im = d[:, NUM_IMAGES_CONTEXT:NUM_IMAGES_EVAL, :, :]
        d_x = d[:, :NUM_IMAGES_CONTEXT, :, :]
        print ('d_x.shape: ', d_x.shape)
        print ('d_im.shape: ', d_im.shape)
        result = autodiff_inimnet (d_x, t_eval[NUM_IMAGES_CONTEXT:], training=False)
        #d_im = tf.reshape (d_im, [d_im.shape[0] * d_im.shape[1], d_im.shape[2], d_im.shape[3]])
        result = [tf.reshape (r, [ tf.shape(d_im)[0], d_im.shape[1], r.shape[1], r.shape[2]]) for r in result]
                        
        return [tf.math.reduce_mean(tf.square(r - d_im), axis=[2, 3]) for r in result]
    if strategy is not None: 
        if async_exec:
            per_replica_result = strategy.run(replica_fn_test_loss, args=(data,))
            results = [strategy.gather(single_result, axis=0) for single_result in per_replica_result]
            return [tf.reduce_sum(r, axis=0) for r in results]
        else:
            results = replica_fn_test_loss (strategy.gather(data, axis=0))
            return [tf.reduce_sum(r, axis=0) for r in results]
    else:
        results = replica_fn_test_loss (data)
        return [tf.reduce_sum(r, axis=0) for r in results]

for epoch in range(num_epochs):
    import time

    start_time = time.time()
    running_losses = []
    for x in dataset:
        curr_loss = distribute_train_step(x)
        losses.extend(curr_loss)
        running_losses.extend(curr_loss)
        x_last = x
    print ('training finished')
    losses_per_epoch.append(np.array(tf.reduce_mean(running_losses)))

    if epoch % UPDATE_FREQ == 0:
        print (epoch)
    if epoch % UPDATE_FREQ == 0:
        print (epoch)
        curr_loss_train = None
        curr_loss_test = None
        curr_loss_valid = None
         
        for x_train in dataset:
            if curr_loss_train is None:
                curr_loss_train = np.array(distribute_test_step_loss (x_train))
            else:
                curr_loss_train += np.array(distribute_test_step_loss (x_train))
        curr_loss_train /= trajectories_train.shape[0]
       
        for x_test in dataset_test:          
            curr_loss_value = distribute_test_step_loss (x_test)
            curr_loss_value = np.array(curr_loss_value)
            if curr_loss_test is None:
                curr_loss_test = curr_loss_value
            else:
                if not len(curr_loss_value.shape)  == len(curr_loss_test.shape):
                    curr_loss_value = np.array(distribute_test_step_loss (x_test, False))
                else: 
                    x_last_test = x_test
                curr_loss_test += curr_loss_value
                
        curr_loss_test /= trajectories_test.shape[0]

        for x_valid in dataset_valid:          
            curr_loss_value = distribute_test_step_loss (x_valid)
            curr_loss_value = np.array(curr_loss_value)
            if curr_loss_valid is None:
                curr_loss_valid = curr_loss_value
            else:
                if not len(curr_loss_value.shape)  == len(curr_loss_valid.shape):
                    curr_loss_value = np.array(distribute_test_step_loss (x_valid, False))
                curr_loss_valid += curr_loss_value
                
        curr_loss_valid /= trajectories_val.shape[0]

        loss_train_x.append(epoch)
        loss_train_y.append(curr_loss_train)
        loss_test_x.append(epoch)
        loss_test_y.append(curr_loss_test)
        loss_valid_x.append(epoch)
        loss_valid_y.append(curr_loss_valid)

        end_time = time.time()
        print('Time per epoch: ', end_time - start_time)
        if not SHOW_PLOTS:
            continue
        for layer_id in range (1, curr_loss_train.shape[0]):
            plt.plot(loss_train_x, np.mean(np.array(loss_train_y)[:, layer_id, :], axis=1), label='train_loss (l' + str(layer_id)+')')
            plt.plot(loss_test_x, np.mean(np.array(loss_test_y)[:, layer_id, :], axis=1), label='test_loss (l' + str(layer_id)+')')        
            plt.plot(loss_valid_x, np.mean(np.array(loss_valid_y)[:, layer_id, :], axis=1), label='valid_loss(l' + str(layer_id)+')')
            plt.legend()

        plt.show()

        for layer_id in range (1, curr_loss_train.shape[0]):
            plt.plot(np.array(loss_train_y)[-1, layer_id, :], label='train_loss (l' + str(layer_id)+', '+str(np.mean(np.array(loss_train_y)[-1, layer_id, :]))+')')
            plt.plot(np.array(loss_test_y)[-1, layer_id, :], label='test_loss (l' + str(layer_id)+', '+str(np.mean(np.array(loss_test_y)[-1, layer_id, :]))+')')        
            plt.plot(np.array(loss_valid_y)[-1, layer_id, :], label='valid_loss(l' + str(layer_id)+', '+str(np.mean(np.array(loss_valid_y)[-1, layer_id, :]))+')')
            plt.legend() 

        plt.show()
        print ('Loss (frame_id), training:')
        for layer_id in range (1, curr_loss_train.shape[0]):
            print ('layer_id: ', layer_id)
            print  (np.array(loss_train_y)[-1, layer_id, :])
        print ('Loss (frame_id), validation:')
        for layer_id in range (1, curr_loss_train.shape[0]):
            print ('layer_id: ', layer_id)
            print  (np.array(loss_valid_y)[-1, layer_id, :])
        print ('Loss (frame_id), testing:')
        for layer_id in range (1, curr_loss_train.shape[0]):
            print ('layer_id: ', layer_id)
            print  (np.array(loss_test_y)[-1, layer_id, :])

        plt.plot (losses_per_epoch)
        plt.show()

        
        d_im, pred_train  =  distribute_test_step (x_last)
        print ('Data sample (training)')
        print ('GT:')
        fig, axes = plt.subplots(1, d_im.shape[1])
        
        #print ('axes.shape[0]: ', axes.shape[0])
        for i in  range(1):
            for j in  range(axes.shape[0]):
                axes[j].get_xaxis().set_visible(False)
                axes[j].get_yaxis().set_visible(False)
                axes[j].imshow(d_im[0, i*axes.shape[0]+j, :, :])
        plt.show()

        print ('Prediction: ')
        for layer_id in range(len(pred_train)): 
            print ('Layer ' + (str(layer_id)))
            fig, axes = plt.subplots(1, d_im.shape[1])
            for i in  range(1):
                for j in  range(axes.shape[0]):
                    axes[j].get_xaxis().set_visible(False)
                    axes[j].get_yaxis().set_visible(False)
                    axes[j].imshow(pred_train[layer_id][i*axes.shape[0]+j, :, :])
            plt.show()

        d_im_test, pred_test  =  distribute_test_step (x_last_test)
        print ('Data sample (testing)')
        print ('GT:')
        fig, axes = plt.subplots(1, d_im_test.shape[1])
        print ('axes.shape[0]: ', axes.shape[0])
        for i in  range(1):
            for j in  range(axes.shape[0]):
                axes[j].get_xaxis().set_visible(False)
                axes[j].get_yaxis().set_visible(False)
                axes[j].imshow(d_im_test[0, i*axes.shape[0]+j, :, :])
        plt.show()

        print ('Prediction: ')
        for layer_id in range(len(pred_test)): 
            print ('Layer ' + (str(layer_id)))
            fig, axes = plt.subplots(1, d_im_test.shape[1])
            for i in  range(1):
                for j in  range(axes.shape[0]):
                    axes[j].get_xaxis().set_visible(False)
                    axes[j].get_yaxis().set_visible(False)
                    axes[j].imshow(pred_test[layer_id][i*axes.shape[0]+j, :, :])
            plt.show()
      
  
plt.plot(np.array(losses))

In [None]:
plt.figure(figsize=(10, 10))
cumsum_value = np.array(losses)#np.cumsum (np.array(losses), axis = 1)
interp = 10
cumsum_value = cumsum_value[:cumsum_value.shape[0]-cumsum_value.shape[0]%interp]
x_cumsum = np.array (range (cumsum_value.shape[0]))
x_downsampled = np.array([np.min (arr, axis=0) for arr in np.split (x_cumsum, interp, axis=0)])
cumsum_value = np.array([np.mean (arr, axis=0) for arr in np.split (cumsum_value, interp, axis=0)])
for i in range (cumsum_value.shape[1]):
    plt.plot(x_downsampled, cumsum_value[:, i], label = 'Layer' + str(i+1))

plt.legend()
plt.show()

In [None]:
x_test = next(iter_dataset_test)


if strategy is not None:
    d_im, pred_test  =  distribute_test_step (x_test)
else:
    d_im = x_test[:, :NUM_IMAGES_TOTAL, :, :]
    d_x = x_test[:, :NUM_IMAGES_CONTEXT, :, :]
    pred_train = autodiff_inimnet (d_x, t) [-1]
#pred_test = autodiff_inimnet (d_x, t)[-1]

print ('pred_train.shape: ', pred_test)
print ('animation started')
fig, axes = plt.subplots(5, d_im.shape[1]//5)
print ('axes.shape[0]: ', axes.shape[0])
for i in  range(axes.shape[0]):
    for j in  range(axes.shape[1]):
        axes[i][j].get_xaxis().set_visible(False)
        axes[i][j].get_yaxis().set_visible(False)
        axes[i][j].imshow(d_im[0, i*axes.shape[1]+j, :, :])
plt.show()
print ('animation ended')

print ('animation started:l1')
fig, axes = plt.subplots(5, d_im.shape[1]//5)
for i in  range(axes.shape[0]):
    for j in  range(axes.shape[1]):
        axes[i][j].get_xaxis().set_visible(False)
        axes[i][j].get_yaxis().set_visible(False)
        axes[i][j].imshow(pred_test[i*axes.shape[1]+j, :, :])
plt.show()
print ('animation ended: l1')



In [None]:
print (running_losses)