## Breaking Down Input Function: `input_fn`

This notebook helps us to understand how the data was uploaded to create input_fn callback.

First, import the nedeed libraries

In [1]:
import sys
sys.path.append('../../')
import os
import tensorflow as tf
import functools
import json
from learning_to_simulate import reading_utils
from learning_to_simulate import train
tf.compat.v1.enable_eager_execution()

Define function to read metada

In [2]:
def _read_metadata(data_path):
  with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:
    return json.loads(fp.read())

Load the tfrecord and json with the metada.

After this cell, the dataset contains tuples `(context, features)`
```
    context['particle_type'] => tf size: [n_particles]
    features['position']     => tf: [steps,n_particles, positions]
```


In [3]:
info_dir = "/home/zoso/Documents/deepmind-research/information"
data_path = os.path.join(info_dir,'datasets/WaterDropSample/')

metadata = _read_metadata(data_path)


# Create a tf.data.Dataset from the TFRecord.
ds = tf.data.TFRecordDataset([os.path.join(data_path, 'train.tfrecord')])
ds = ds.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata))

for (context, features) in ds.take(2):
    print("particle type: ",context['particle_type'].shape)
    print("position: ", features['position'].shape)

particle type:  (678,)
position:  (1001, 678, 2)
particle type:  (355,)
position:  (1001, 355, 2)


### `mode: one_step`

Executing the next set leads us to a ds which contains `element` 
```
element['particle_type'] => tf : [n_particles]
element['position']      => rf : [7, n_particles, positions]
```

In [4]:
ds1 = ds

# So we can calculate the last 5 velocities.
INPUT_SEQUENCE_LENGTH = 6
batch_size = 2

# Splits an entire trajectory into chunks of 7 steps.
# Previous 5 velocities, current velocity and target.
# It is like a batch of 7 position steps
split_with_window = functools.partial(
    reading_utils.split_trajectory,
    window_length=INPUT_SEQUENCE_LENGTH + 1)
ds1 = ds1.flat_map(split_with_window)

for elem in ds1.take(1):
    print("particle type: ", elem['particle_type'].shape)
    print("position: ", elem['position'].shape)
print("-------------------")

particle type:  (678,)
position:  (7, 678, 2)
-------------------


Executing the next set leads us to a ds which contains tuples `(features,labels)` 
```
features['particle_type']           => tf: [n_particles]
features['position']                => tf: [n_particles,6,positions]
features['n_particles_per_example'] => tf: [1] value: [n_particles]
labels                              => tf: [n_particles]

```

In [5]:
ds1 = ds1.map(train.prepare_inputs)
for (features, labels) in ds1.take(1):
    print("particle type: ",features['particle_type'].shape)
    print("position: ", features['position'].shape)
    print("n_particles_per_example: ",features['n_particles_per_example'])
    print("labels: ",labels.shape) # the target position
    
print("-------------------")

particle type:  (678,)
position:  (678, 6, 2)
n_particles_per_example:  tf.Tensor([678], shape=(1,), dtype=int32)
labels:  (678, 2)
-------------------


Executing the next set leads us to a ds which contains tuples `(features,labels)` 
```
features['particle_type']          => tf: [batch_size*n_particles]
features['position']               => tf: [batch_size*n_particles,6,positions]
features['n_particle_per_example'] => tf: [1] value: batch_size * [n_particles] 
labels                             => tf: [batch_size*n_particles,positions]
```

In [6]:
ds2 = train.batch_concat(ds1, batch_size)
for features, labels in ds2.take(1):
    print("particle type: ",features['particle_type'].shape)
    print("position: ", features['position'].shape)
    print("n_particles_per_example: ",features['n_particles_per_example'])
    print("labels: ",labels.shape) # the target position

particle type:  (1356,)
position:  (1356, 6, 2)
n_particles_per_example:  tf.Tensor([678 678], shape=(2,), dtype=int32)
labels:  (1356, 2)


### `mode: one_step_train`

This point must be executed before the last cell, and just allow us to shuffle the dataset

In [7]:
ds3 = ds1.repeat()
ds3 = ds3.shuffle(512)

for (context, features) in ds3.take(1):
    print("particle type: ",context['particle_type'].shape)
    print("position: ", context['position'].shape)
    print("n_particles_per_example: ",context['n_particles_per_example'])
    print("features: ",features.shape) # the target position
    

particle type:  (678,)
position:  (678, 6, 2)
n_particles_per_example:  tf.Tensor([678], shape=(1,), dtype=int32)
features:  (678, 2)


Executing the next set leads us to a ds which contains tuples `(features,labels)` 
```
features['particle_type']          => tf: [batch_size*n_particles]
features['position']               => tf: [batch_size*n_particles,6,positions]
features['n_particle_per_example'] => tf: [1] value: batch_size * [n_particles] 
labels                             => tf: [batch_size*n_particles,positions]
```

In [8]:
ds3 = train.batch_concat(ds3, batch_size)
for features, labels in ds3.take(1):
    print("particle type: ",features['particle_type'].shape)
    print("position: ", features['position'].shape)
    print("n_particles_per_example: ",features['n_particles_per_example'])
    print("labels: ",labels.shape) # the target position

particle type:  (1356,)
position:  (1356, 6, 2)
n_particles_per_example:  tf.Tensor([678 678], shape=(2,), dtype=int32)
labels:  (1356, 2)


### `mode: rollout`

Executing the next set leads us to a ds which contains tuples `(features,labels)` 
```
features['particle_type']          => tf: [n_particles]
features['position']               => tf: [n_particles,steps,positions]
features['key']                    => tf: [1] value: id_example
features['n_particle_per_example'] => tf: [1] value: [n_particles]
features['is_trajectory']          => tf: [1] value: True or False
labels                             => tf: [n_particles, positions]
```

In [9]:
ds4 = ds.map(train.prepare_rollout_inputs)
for features, labels in ds4:
    print("particle_type: ", features['particle_type'].shape)
    print("position: ", features['position'].shape)
    print("key: ", features['key'])
    print("n_particles_per_example: ",features['n_particles_per_example'] )
    print("is_trajectory: ", features["is_trajectory"])
    print("labels: ", labels.shape)
    print("-------------")

particle_type:  (678,)
position:  (678, 1000, 2)
key:  tf.Tensor(0, shape=(), dtype=int64)
n_particles_per_example:  tf.Tensor([678], shape=(1,), dtype=int32)
is_trajectory:  tf.Tensor([ True], shape=(1,), dtype=bool)
labels:  (678, 2)
-------------
particle_type:  (355,)
position:  (355, 1000, 2)
key:  tf.Tensor(1, shape=(), dtype=int64)
n_particles_per_example:  tf.Tensor([355], shape=(1,), dtype=int32)
is_trajectory:  tf.Tensor([ True], shape=(1,), dtype=bool)
labels:  (355, 2)
-------------


### `main function`
Here we test the main function which generates the input_fn function calleable. You need to pass the respectivo `mode` and `split`

In [19]:
info_dir = "/home/zoso/Documents/deepmind-research/information"
data_path = os.path.join(info_dir,'datasets/WaterDropSample/')
#batch_size = 1
#mode = 'rollout'
batch_size = 2
mode = 'one_step_train'

input_fn = train.get_input_fn(data_path, batch_size,
                                mode=mode, split='train')

dataset = input_fn()

if 'one_step' in mode:

    for (features, labels) in dataset.take(1):
        print("particle type: ",features['particle_type'].shape)
        print("position: ", features['position'].shape)
        print("n_particles_per_example: ",features['n_particles_per_example'])
        print("labels: ",labels.shape) # the target position
elif mode == 'rollout' and batch_size == 1:
    for features, labels in dataset.take(1):
        print("particle_type: ", features['particle_type'].shape)
        print("position: ", features['position'].shape)
        print("key: ", features['key'])
        print("n_particles_per_example: ",features['n_particles_per_example'] )
        print("is_trajectory: ", features["is_trajectory"])
        print("labels: ", labels.shape)
        print("-------------")
    

particle type:  (1356,)
position:  (1356, 6, 2)
n_particles_per_example:  tf.Tensor([678 678], shape=(2,), dtype=int32)
labels:  (1356, 2)
