In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import tensorflow as tf
print(tf.__version__)

## nn/utils/math.py

In [None]:
import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

## nn/utils/viz.py

In [None]:
import os
import numpy as np

def gallery(array, ncols=3):
    nindex, height, width, intensity = array.shape

    bordered = 0.5*np.ones([nindex, height+2, width+2, intensity])
    for i in range(nindex):
        bordered[i,1:-1,1:-1,:] = array[i]

    array = bordered
    nindex, height, width, intensity = array.shape

    nrows = nindex//ncols
    assert nindex == nrows*ncols
    # want result.shape = (height*nrows, width*ncols, intensity)
    result = (array.reshape(nrows, ncols, height, width, intensity)
              .swapaxes(1,2)
              .reshape(height*nrows, width*ncols, intensity))
    return result

def gif(filename, array, fps=10, scale=1.0):
    from moviepy.editor import ImageSequenceClip
    """Creates a gif given a stack of images using moviepy
    Notes
    -----
    works with current Github version of moviepy (not the pip version)
    https://github.com/Zulko/moviepy/commit/d4c9c37bc88261d8ed8b5d9b7c317d13b2cdf62e
    Usage
    -----
    >>> X = randn(100, 64, 64)
    >>> gif('test.gif', X)
    Parameters
    ----------
    filename : string
        The filename of the gif to write to
    array : array_like
        A numpy array that contains a sequence of images
    fps : int
        frames per second (default: 10)
    scale : float
        how much to rescale each image by (default: 1.0)
    """

    # ensure that the file has the .gif extension
    fname, _ = os.path.splitext(filename)
    filename = fname + '.gif'

    # copy into the color dimension if the images are black and white
    if array.ndim == 3:
        array = array[..., np.newaxis] * np.ones(3)

    # make the moviepy clip
    clip = ImageSequenceClip(list(array), fps=fps).resize(scale)
    clip.write_gif(filename, fps=fps)
    return clip

## nn/utils/misc.py

In [None]:
import os
import time
import inspect
import numpy as np
import zipfile

def log_metrics(logger, prefix, metrics):
    metrics_string = " ".join([k+"=%s"%metrics[k] for k in sorted(metrics.keys())])
    string = prefix + " " + metrics_string
    logger.info(string)

def classes_in_module(module):
    classes = {}
    for name, obj in inspect.getmembers(module):
        if inspect.isclass(obj):
            if obj.__module__ == module.__name__:
                classes[name] = obj
    return classes

def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])

def zipdir(path, save_dir):
    zipf = zipfile.ZipFile(os.path.join(save_dir, 'code.zip'), 'w', zipfile.ZIP_DEFLATED)

    # ziph is zipfile handle
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.split(".")[-1] == "py":
                zipf.write(os.path.join(root, file),
                           os.path.relpath(os.path.join(root, file), os.path.join(path, '..')))

    zipf.close()

## nn/network/base.py

In [None]:
import os
import sys
import shutil
import logging
import numpy as np
import tensorflow as tf

logger = logging.getLogger("tf")
root_path = "/kaggle/working"

OPTIMIZERS = {
    "adam": tf.keras.optimizers.Adam,
    "rmsprop": tf.keras.optimizers.RMSprop,
    "momentum": lambda lr: tf.keras.optimizers.SGD(lr, momentum=0.9),
    "sgd": tf.keras.optimizers.SGD
}
# OPTIMIZERS = {
#     "adam": tf.compat.v1.train.AdamOptimizer,
#     "rmsprop": tf.compat.v1.train.RMSPropOptimizer,
#     "momentum": lambda x: tf.compat.v1.train.MomentumOptimizer(x, 0.9),
#     "sgd": tf.compat.v1.train.GradientDescentOptimizer
# }

import os
import sys
import shutil
import logging
import numpy as np
import tensorflow as tf
from tensorflow import keras  # Import Keras

logger = logging.getLogger("tf")
root_path = "/kaggle/working"


class BaseNet(keras.Model):
    def __init__(self):
        super(BaseNet, self).__init__()

        self.train_metrics = {}
        self.eval_metrics = {}

        # Extra functions (consider if these fit better as custom layers)
        self.extra_train_fns = []
        self.extra_valid_fns = []
        self.extra_test_fns = []

    def call(self, inputs):
        """Defines the forward pass, replacing 'feedforward' """   
        raise NotImplementedError  

    def compute_loss(self, inputs, targets): 
        """Calculates the loss, replacing the old function"""
        raise NotImplementedError 

    # ... (Adapted or removed existing methods considering Keras functionality) ...

    def initialize_graph(self, save_dir, use_ckpt, ckpt_dir=""):
        self.save_dir = save_dir
        
        if os.path.exists(save_dir) and not use_ckpt:
            logger.info("Folder exists, deleting...")
            shutil.rmtree(save_dir)
        
        os.makedirs(save_dir, exist_ok=True)

        if use_ckpt and ckpt_dir:
            self.load_weights(os.path.join(ckpt_dir, "model.ckpt"))  # Adjust as needed
        elif use_ckpt:
            self.load_weights(os.path.join(save_dir, "model.ckpt"))  # Fallback to save_dir if ckpt_dir not provided

    def save_model(self, epoch=None):
        """Saves the model in the TensorFlow SavedModel format or HDF5 format"""
        file_name = "model"
        if epoch is not None:
            file_name += f"_epoch_{epoch}"
        self.save(os.path.join(self.save_dir, file_name))

    def build_optimizer(self, base_lr, optimizer="adam", anneal_lr=True):
        self.base_lr = base_lr
        self.anneal_lr = anneal_lr
        if optimizer == 'adam':
            self.optimizer = tf.keras.optimizers.Adam(learning_rate=base_lr)
        elif optimizer == 'rmsprop':
            self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_lr)
        elif optimizer == 'momentum':
            self.optimizer = lambda lr: tf.keras.optimizers.SGD(lr, momentum=0.9),
        elif optimizer == 'sgd':
            self.optimizer = tf.keras.optimizers.SGD
        # ... (Add other optimizers if needed) ...

    def train_step(self, data):  # Replace 'train' for Keras compatibility
        x, y = data
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)  # Model's forward pass
            loss = self.compute_loss(x, y) 

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # Update metrics
        for metric_name, metric_fn in self.train_metrics.items():
            self.compiled_metrics.update_state(y, predictions) # Assume compiled_metrics exists 

        return {m.name: m.result() for m in self.metrics} 

    def test_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)  # Model's forward pass
            loss = self.compute_loss(x, y) 

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # Update metrics
        for metric_name, metric_fn in self.train_metrics.items():
            self.compiled_metrics.update_state(y, predictions) # Assume compiled_metrics exists 

        return {m.name: m.result() for m in self.metrics}


## nn/network/blocks.py

In [None]:
import numpy as np
import tensorflow as tf

""" Useful subnetwork components """

import tensorflow as tf
from tensorflow import keras  


def unet(inp, base_channels, out_channels, upsamp=True):
    # First block
    h = keras.layers.Conv2D(base_channels, 3, activation='relu', padding='same')(inp)
    h1 = keras.layers.Conv2D(base_channels, 3, activation='relu', padding='same')(h)

    # Downsampling blocks
    blocks = [] # Store downsampling blocks for skip connections
    for num_filters in [base_channels * 2, base_channels * 4, base_channels * 8]:
        h = keras.layers.MaxPooling2D(2)(h1)  
        h = keras.layers.Conv2D(num_filters, 3, activation='relu', padding='same')(h)
        h = keras.layers.Conv2D(num_filters, 3, activation='relu', padding='same')(h)
        blocks.append(h)

    # Bottleneck
    h = keras.layers.Conv2D(base_channels * 8, 3, activation='relu', padding='same')(blocks[-1])
    h4 = keras.layers.Conv2D(base_channels * 8, 3, activation='relu', padding='same')(h)  

    # Upsampling blocks
    for i, block in enumerate(reversed(blocks)): 
        if upsamp:
            h = keras.layers.UpSampling2D()(h4)
        else:
            h = keras.layers.Conv2DTranspose(base_channels * 4 // (2 ** i), 3, strides=2, activation='relu', padding='same')(h4)

        h = keras.layers.concatenate([h, block])  # Skip connection
        num_filters = base_channels * 4 // (2 ** i)  # Reduce filters as we upsample
        h = keras.layers.Conv2D(num_filters, 3, activation='relu', padding='same')(h)
        h = keras.layers.Conv2D(num_filters, 3, activation='relu', padding='same')(h)

    # Final output
    h = keras.layers.Conv2D(out_channels, 1, activation=None, padding='same')(h)
    return h 


def shallow_unet(inp, base_channels, out_channels, upsamp=True):
    # First block
    h = keras.layers.Conv2D(base_channels, 3, activation='relu', padding='same')(inp)
    h1 = keras.layers.Conv2D(base_channels, 3, activation='relu', padding='same')(h)

    # Downsampling
    h = keras.layers.MaxPooling2D(2)(h1)  
    h = keras.layers.Conv2D(base_channels * 2, 3, activation='relu', padding='same')(h)
    h2 = keras.layers.Conv2D(base_channels * 2, 3, activation='relu', padding='same')(h)

    # Upsampling
    if upsamp:
        h = keras.layers.UpSampling2D()(h2)
    else:
        h = keras.layers.Conv2DTranspose(base_channels * 2, 3, strides=2, activation='relu', padding='same')(h2)

    h = keras.layers.concatenate([h, h2])  # Skip connection
    h = keras.layers.Conv2D(base_channels * 2, 3, activation='relu', padding='same')(h)
    h = keras.layers.Conv2D(base_channels * 2, 3, activation='relu', padding='same')(h)

    # Final block
    if upsamp:
        h = keras.layers.UpSampling2D()(h)
    else:
        h = keras.layers.Conv2DTranspose(base_channels, 3, strides=2, activation='relu', padding='same')(h)

    h = keras.layers.concatenate([h, h1])  # Skip connection
    h = keras.layers.Conv2D(base_channels, 3, activation='relu', padding='same')(h)
    h = keras.layers.Conv2D(base_channels, 3, activation='relu', padding='same')(h)

    # Output layer
    h = keras.layers.Conv2D(out_channels, 1, activation=None, padding='same')(h)
    return h

    
def variable_from_network(shape):
    var = keras.layers.Dense(200, activation='tanh')(tf.ones([1,10]))
    var = keras.layers.Dense(np.prod(shape), activation=None)(var)
    var = tf.reshape(var, shape)
    return var 

## nn/network/cells.py

In [None]:
from tensorflow import keras

class ODECell(keras.layers.Layer):
    """Base class for ODE cells. Provides common infrastructure."""

    def __init__(self, units, **kwargs):
        super(CustomODECell, self).__init__(**kwargs)
        self._units = units

    @property
    def state_size(self):
        # Assuming state consists of (positions, velocities)
        return self._units, self._units

    def zero_state(self, batch_size, dtype):
        positions = tf.zeros([batch_size, self._units], dtype=dtype)
        velocities = tf.zeros([batch_size, self._units], dtype=dtype)
        return positions, velocities
    
class BouncingODECell(ODECell):
    def __init__(self, units, **kwargs):
        super(BouncingODECell, self).__init__(units, **kwargs)
        self.dt = self.add_weight(name="dt_x", shape=[], initializer='ones', trainable=False)  # Use weight

    def call(self, inputs, states):
        positions, velocities = states

        for _ in range(5): 
            positions += self.dt / 5 * velocities

            for j in range(2):
                # Simplified with NumPy-like broadcasting
                velocities = tf.where(positions + 2 > 32, -velocities, velocities)
                velocities = tf.where(positions - 2 < 0, -velocities, velocities)
                positions = tf.where(positions + 2 > 32, 32 - (positions + 2 - 32) - 2, positions)
                positions = tf.where(positions - 2 < 0, -(positions - 2) + 2, positions)

        return positions, velocities  # Return new state

class SpringODECell(ODECell):
    def __init__(self, units, **kwargs):
        super(SpringODECell, self).__init__(units, **kwargs)
        self.dt = self.add_weight(name="dt_x", shape=[], initializer='ones', trainable=False)
        self.k = self.add_weight(name="log_k", shape=[], trainable=True)  
        self.equil = self.add_weight(name="log_l", shape=[],  trainable=True) 

    def call(self, inputs, states):
        positions, velocities = states

        for _ in range(5):
            norm = tf.sqrt(tf.reduce_sum(tf.square(positions[:, :2] - positions[:, 2:]), axis=-1, keepdims=True))
            direction = (positions[:, :2] - positions[:, 2:]) / (norm + 1e-4)
            F = tf.exp(self.k) * (norm - 2 * tf.exp(self.equil)) * direction 

            velocities[:, :2] -= self.dt / 5 * F
            velocities[:, 2:] += self.dt / 5 * F  

            positions = positions + self.dt/5 * velocities 

        return positions, velocities 

class GravityODECell(ODECell):
    def __init__(self, units, **kwargs):
        super(GravityODECell, self).__init__(units, **kwargs)
        self.dt = self.add_weight(name="dt_x", shape=[], initializer='ones', trainable=False)
        self.g = self.add_weight(name="log_g", shape=[], trainable=True)
        self.m = self.add_weight(name="log_m", shape=[], trainable=False)  
        self.A = tf.exp(self.g) * tf.exp(2 * self.m)

    def call(self, inputs, states):
        positions, velocities = states

        for _ in range(5):
            # Compute relative vectors
            rel_vecs = [positions[:, i:i+2] - positions[:, j:j+2] 
                        for i, j in [(0, 2), (2, 4), (4, 0)]] 

            # Compute norms
            norms = [tf.sqrt(tf.reduce_sum(tf.square(vec), axis=-1, keepdims=True)) for vec in rel_vecs]

            # Calculate forces (with clipping for numerical stability)
            forces = []
            for vec, norm in zip(rel_vecs, norms):
                f = vec / tf.pow(tf.clip_by_value(norm, 1, 170), 3)  
                forces.append(f)  

            forces = [forces[0] - forces[2], forces[1] - forces[0], forces[2] - forces[1]]  # Net force on each object
            forces = [-self.A * f for f in forces] 

            # Update velocities and positions
            velocities += self.dt / 5 * tf.concat(forces, axis=1)
            positions += self.dt / 5 * velocities

        return positions, velocities 

## nn/network/stn.py

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def _interpolate(im, x, y, out_size):
    """Bilinear interpolation using TensorFlow operations."""
    num_batch, height, width, channels = im.shape.as_list()

    x = tf.cast(x, 'float32') + 1  # Shift to [0, 2] range
    y = tf.cast(y, 'float32') + 1

    max_x = tf.cast(width - 1, 'int32')
    max_y = tf.cast(height - 1, 'int32')

    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1

    x0 = tf.clip_by_value(x0, 0, max_x)
    x1 = tf.clip_by_value(x1, 0, max_x)
    y0 = tf.clip_by_value(y0, 0, max_y)
    y1 = tf.clip_by_value(y1, 0, max_y)

    base = tf.repeat(
        tf.range(num_batch) * height * width, 
        out_size[0] * out_size[1]
    )
    base_y0 = base + y0 * width
    base_y1 = base + y1 * width
    idx_a = base_y0 + x0
    idx_b = base_y1 + x0
    idx_c = base_y0 + x1
    idx_d = base_y1 + x1

    im_flat = tf.reshape(im, [-1, channels])
    Ia = tf.gather(im_flat, idx_a)
    Ib = tf.gather(im_flat, idx_b)
    Ic = tf.gather(im_flat, idx_c)
    Id = tf.gather(im_flat, idx_d)

    wa = (x1 - x) * (y1 - y)
    wb = (x1 - x) * (y - y0)
    wc = (x - x0) * (y1 - y)
    wd = (x - x0) * (y - y0)

    output = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
    return tf.reshape(output, [num_batch, out_size[0], out_size[1], channels])


def _meshgrid(height, width):
    x_linspace = tf.linspace(-1.0, 1.0, width)
    y_linspace = tf.linspace(-1.0, 1.0, height)
    x_coordinates, y_coordinates = tf.meshgrid(x_linspace, y_linspace)
    return tf.stack([x_coordinates, y_coordinates, tf.ones_like(x_coordinates)], axis=-1)


def _transform(theta, input_dim, out_size):
    num_batch = tf.shape(input_dim)[0]
    out_height, out_width = out_size
    grid = _meshgrid(out_height, out_width)
    grid = tf.expand_dims(grid, axis=0)
    grid = tf.tile(grid, [num_batch, 1, 1, 1])
    grid = tf.reshape(grid, [num_batch, -1, 3])  # Flatten middle dimensions 
    transformed_grid = tf.matmul(theta, grid, transpose_b=True)
    x_s = tf.slice(transformed_grid, [0, 0, 0], [-1, -1, 1])
    y_s = tf.slice(transformed_grid, [0, 0, 1], [-1, -1, 1])
    return _interpolate(input_dim, tf.squeeze(x_s), tf.squeeze(y_s), out_size)


class SpatialTransformer(keras.layers.Layer):
    def __init__(self, out_size, **kwargs):
        self.out_size = out_size
        super(SpatialTransformer, self).__init__(**kwargs)

    def call(self, inputs):
        inputs, theta = inputs
        return _transform(theta, inputs, self.out_size)

## nn/network/physics_models.py

In [None]:
import os
import shutil
import logging
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from pprint import pprint
import inspect

import matplotlib.pyplot as plt
import matplotlib.cm as cm
plt.switch_backend('agg')

logger = logging.getLogger("tf")

CELLS = {
    "bouncing_ode_cell": BouncingODECell,
    "spring_ode_cell": SpringODECell,
    "gravity_ode_cell": GravityODECell,
    "lstm": tf.keras.layers.LSTMCell
}

# total number of latent units for each datasets
# coord_units = num_objects*num_dimensions*2
COORD_UNITS = {
    "bouncing_balls": 8,
    "spring_color": 8,
    "spring_color_half": 8,
    "3bp_color": 12,
    "mnist_spring_color": 8
}

class ConvEncoder(keras.layers.Layer):
    def __init__(self, input_shape, n_objs, unet_type='unet', **kwargs):
        super(ConvEncoder, self).__init__(**kwargs)
        self.conv_input_shape = input_shape
        self.n_objs = n_objs
        self.unet_type = unet_type

        if unet_type == 'unet':
            self.downsampler = unet(tf.keras.layers.Input(input_shape), 16, self.n_objs + 1, upsamp=True)
        elif unet_type == 'shallow_unet':
            self.downsampler = shallow_unet(tf.keras.layers.Input(input_shape), 8, self.n_objs + 1, upsamp=True)
        else:
            raise ValueError(f"Unsupported unet_type: {unet_type}")

        self.flatten = keras.layers.Flatten()
        self.dense_layers = [
            keras.layers.Dense(200, activation='relu'),
            keras.layers.Dense(200, activation='relu'),
            keras.layers.Dense(2, activation=None)
        ]

    def call(self, inputs):
        h = self.downsampler(inputs)

        # Object Mask Generation
        h = keras.layers.Softmax(axis=-1)(h)
        self.enc_masks = h
        masked_objs = [h[..., i:i+1] * inputs for i in range(self.n_objs)]

        # Encode Each Object
        encoded_objs = []
        for obj in masked_objs:
            if self.conv_input_shape[0] >= 40:
                obj = keras.layers.AveragePooling2D(2, 2)(obj)

            obj = self.flatten(obj)
            for layer in self.dense_layers:
                obj = layer(obj)
            encoded_objs.append(obj)

        # Combine Encodings
        h = keras.layers.concatenate(encoded_objs, axis=1)
        h = tf.tanh(h) * (self.conv_input_shape[0] / 2) + (self.conv_input_shape[0] / 2)
        return h  

class ConvSTDecoder(keras.layers.Layer):
    def __init__(self, input_shape, n_objs, template_size, **kwargs):
        super(ConvSTDecoder, self).__init__(**kwargs)
        self.input_shape = input_shape
        self.n_objs = n_objs
        self.conv_input_shape = input_shape
        self.template_size = template_size

        # Network for creating templates, contents, and background
        self.network = tf.keras.Sequential([
            layers.Dense(200, activation='tanh'),
            layers.Dense(np.prod([n_objs + 1] + input_shape), activation=None)  # +1 for background
        ])

        # Spatial Transformer Layer
        self.stn_layer = SpacialTransformer(self.conv_input_shape[:2])  

        # Parameter for spatial transformations
        self.logsigma = self.add_weight(name='logsigma', shape=[], 
                                        initializer='zeros', trainable=True)

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]

        # Generate Templates, Contents, and Background
        templates_contents = self.network(tf.ones([1, 10]))  # Dummy input
        templates_contents = tf.reshape(templates_contents, 
                                        [self.n_objs + 1, *self.input_shape])

        template = templates_contents[0] + 5  # Use the first object as the template
        template = tf.tile(template[None, ...], [batch_size, 1, 1, 1])
        contents = tf.nn.sigmoid(templates_contents[1:-1])  # Object contents
        background_content = tf.nn.sigmoid(templates_contents[-1:])  # Background content

        # Spatial Transformation on Objects
        transformed_objs = []
        for obj_content, loc in zip(tf.split(contents, self.n_objs, axis=0), 
                                    tf.split(inputs, self.n_objs, axis=1)):
            theta = self._get_theta(loc)
            transformed_obj = self.stn_layer([obj_content, theta])
            transformed_objs.append(transformed_obj)

        # Combine with Background
        contents = transformed_objs + [background_content]
        masks = self._generate_masks(transformed_objs + [background_content]) 
        output = tf.add_n([m * c for m, c in zip(masks, contents)])

        return output

    def _get_theta(self, loc):
        """Calculates transformation parameters."""
        sigma = tf.exp(self.logsigma)
        theta0 = tf.tile([[sigma]], [tf.shape(loc)[0]])
        theta1 = tf.zeros_like(theta0)
        theta2 = (self.conv_input_shape[0]/2 - loc[:,0]) / self.template_size * sigma
        theta3 = tf.zeros_like(theta0)
        theta4 = tf.tile([[sigma]], [tf.shape(loc)[0]])
        theta5 = (self.conv_input_shape[0]/2 - loc[:,1]) / self.template_size * sigma
        return tf.stack([theta0, theta1, theta2, theta3, theta4, theta5], axis=1)

    def _generate_masks(self, contents):
        """Generates attention-like masks for blending contents and background."""
        all_contents = tf.concat(contents, axis=-1)
        masks = tf.nn.softmax(all_contents - 5, axis=-1)
        return tf.split(masks, self.n_objs + 1, axis=-1) 

class PhysicsNet(BaseNet):
    def __init__(self,
                 task="spring_color",
                 recurrent_units=128,
                 lstm_layers=1,
                 cell_type="",
                 seq_len=20,
                 input_steps=3,
                 pred_steps=5,
                 autoencoder_loss=3.0,
                 alt_vel=False,
                 color=False,
                 input_size=36*36,
                 encoder_type="conv_encoder",
                 decoder_type="conv_st_decoder"):
        
        super(PhysicsNet, self).__init__(**kwargs)
        
        assert task in COORD_UNITS
        self.task = task
        
        # Only used when using black-box dynamics (baselines)
        self.recurrent_units = recurrent_units
        self.lstm_layers = lstm_layers
        self.lstm = tf.keras.layers.LSTMCell(self.recurrent_units)
        
        self.cell_type = cell_type
        self.cell = CELLS[self.cell_type]
        self.coord_units = coord_units 
        self.input_shape = input_shape 
        self.encoder = ConvEncoder()
        self.decoder = ConvSTDecoder()
        self.ode_cell = self.cell(self.coord_units // 2)
        
        assert seq_len > input_steps + pred_steps
        assert input_steps >= 1
        assert pred_steps >= 1
        
        self.seq_len = seq_len
        self.input_steps = input_steps
        self.pred_steps = pred_steps
        self.extrap_steps = self.seq_len-self.input_steps-self.pred_steps

        self.alt_vel = alt_vel
        self.autoencoder_loss = autoencoder_loss

        self.coord_units = COORD_UNITS[self.task]
        self.n_objs = self.coord_units//4

        self.extra_valid_fns.append((self.visualize_sequence,[],{}))
        self.extra_test_fns.append((self.visualize_sequence,[],{}))
        
    def build_optimizer(self, base_lr, optimizer="rmsprop", anneal_lr=True):
        # Uncomment lines below to have different learning rates for physics and vision components
        self.base_lr = base_lr
        self.anneal_lr = anneal_lr
        if optimizer == 'adam':
            self.optimizer = tf.keras.optimizers.Adam(learning_rate=base_lr)
        elif optimizer == 'rmsprop':
            self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_lr)
        elif optimizer == 'momentum':
            self.optimizer = lambda lr: tf.keras.optimizers.SGD(lr, momentum=0.9),
        elif optimizer == 'sgd':
            self.optimizer = tf.keras.optimizers.SGD
            
        self.lr = tf.Variable(base_lr, trainable=False, name="base_lr")
        self.optimizer = OPTIMIZERS[optimizer](self.lr)
        #self.dyn_optimizer = OPTIMIZERS[optimizer](1e-3)

        update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            gvs = self.optimizer.compute_gradients(self.loss, var_list=tf.compat.v1.trainable_variables())
            gvs = [(tf.clip_by_value(grad, -1.0, 1.0), var) for grad, var in gvs if grad is not None]
            self.train_op = self.optimizer.apply_gradients(gvs)
        
    
    def call(self, inputs):
        enc_pos = self.encoder(inputs)
        # ... (Rollout ODE and decode)
#         lstms = [tf.keras.layers.LSTMCell(self.recurrent_units) for i in range(self.lstm_layers)]
#         states = [lstm.zero_state(tf.shape(self.input)[0], dtype=tf.float32) for lstm in lstms]
#         rollout_cell = self.cell(self.coord_units//2)
        
        
        x = self.lstm(enc_pos)
        
        
        # decode the input and pred frames
        recons_out = self.decoder(enc_pos)

        self.recons_out = tf.reshape(recons_out, 
                                     [tf.shape(self.input)[0], self.input_steps+self.pred_steps]+self.input_shape)
        self.enc_pos = tf.reshape(enc_pos, 
                                  [tf.shape(self.input)[0], self.input_steps+self.pred_steps, self.coord_units//2])

        if self.input_steps > 1:
            vel = self.vel_encoder(self.enc_pos[:,:self.input_steps], scope=tvs)
        else:
            vel = tf.zeros([tf.shape(self.input)[0], self.coord_units//2])

        pos = self.enc_pos[:,self.input_steps-1]
        output_seq = []
        pos_vel_seq = []
        pos_vel_seq.append(tf.concat([pos, vel], axis=1))

        # rollout ODE and decoder
        for t in range(self.pred_steps+self.extrap_steps):
            # rollout
            pos, vel = rollout_cell(pos, vel)

            # decode
            out = self.decoder(pos, scope=tvs)

            pos_vel_seq.append(tf.concat([pos, vel], axis=1))
            output_seq.append(out)

        current_scope = tf.compat.v1.get_default_graph().get_name_scope()
        self.network_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, 
                                              scope=current_scope)
        logger.info(self.network_vars)
        
        output_seq = tf.stack(output_seq)
        pos_vel_seq = tf.stack(pos_vel_seq)
        output_seq = tf.transpose(output_seq, (1,0,2,3,4))
        self.pos_vel_seq = tf.transpose(pos_vel_seq, (1,0,2))
        return output_seq
    
    def get_batch(self, batch_size, iterator):
        batch_x, _ = iterator.next_batch(batch_size)
        batch_len = batch_x.shape[1]
        feed_dict = {self.input: batch_x}
        return feed_dict, (batch_x, None)

    def compute_loss(self):

        # Compute reconstruction loss
        recons_target = self.input[:,:self.input_steps+self.pred_steps]
        recons_loss = tf.square(recons_target-self.recons_out)
        #recons_ce_loss = -(recons_target*tf.log(self.recons_out+1e-7) + (1.0-recons_target)*tf.log(1.0-self.recons_out+1e-7))
        recons_loss = tf.reduce_sum(recons_loss, axis=[2,3,4])

        self.recons_loss = tf.reduce_mean(recons_loss)

        target = self.input[:,self.input_steps:]
        #ce_loss = -(target*tf.log(self.output+1e-7) + (1.0-target)*tf.log(1.0-self.output+1e-7))
        loss = tf.square(target-self.output)
        loss = tf.reduce_sum(loss, axis=[2,3,4])

        # Compute prediction losses. pred_loss is used for training, extrap_loss is used for evaluation
        self.pred_loss = tf.reduce_mean(loss[:,:self.pred_steps])
        self.extrap_loss = tf.reduce_mean(loss[:,self.pred_steps:])

        train_loss = self.pred_loss
        if self.autoencoder_loss > 0.0:
            train_loss += self.autoencoder_loss*self.recons_loss

        eval_losses = [self.pred_loss, self.extrap_loss, self.recons_loss]
        return train_loss, eval_losses
    
    def visualize_sequence(self):
        batch_size = 5

        feed_dict, (batch_x, _) = self.get_batch(batch_size, self.test_iterator)
        fetches = [self.output, self.recons_out]
        if hasattr(self, 'pos_vel_seq'):
            fetches.append(self.pos_vel_seq)

        res = self.sess.run(fetches, feed_dict=feed_dict)
        output_seq = res[0]
        recons_seq = res[1]
        if hasattr(self, 'pos_vel_seq'):
            pos_vel_seq = res[2]
        output_seq = np.concatenate([batch_x[:,:self.input_steps], output_seq], axis=1)
        recons_seq = np.concatenate([recons_seq, np.zeros((batch_size, self.extrap_steps)+recons_seq.shape[2:])], axis=1)

        # Plot a grid with prediction sequences
        for i in range(batch_x.shape[0]):
            #if hasattr(self, 'pos_vel_seq'):
            #    if i == 0 or i == 1:
            #        logger.info(pos_vel_seq[i])

            to_concat = [output_seq[i],batch_x[i],recons_seq[i]]
            total_seq = np.concatenate(to_concat, axis=0) 

            total_seq = total_seq.reshape([total_seq.shape[0], 
                                           self.input_shape[0], 
                                           self.input_shape[1], self.conv_ch])

            result = gallery(total_seq, ncols=batch_x.shape[1])

            norm = plt.Normalize(0.0, 1.0)

            figsize = (result.shape[1]//self.input_shape[1], result.shape[0]//self.input_shape[0])
            fig, ax = plt.subplots(figsize=figsize)
            ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            fig.tight_layout()
            fig.savefig(os.path.join(self.save_dir, "example%d.png"%i))

        # Make a gif from the sequences
        bordered_output_seq = 0.5*np.ones([batch_size, self.seq_len, 
                                          self.conv_input_shape[0]+2, self.conv_input_shape[1]+2, 3])
        bordered_batch_x = 0.5*np.ones([batch_size, self.seq_len, 
                                          self.conv_input_shape[0]+2, self.conv_input_shape[1]+2, 3])
        output_seq = output_seq.reshape([batch_size, self.seq_len]+self.input_shape)
        batch_x = batch_x.reshape([batch_size, self.seq_len]+self.input_shape)
        bordered_output_seq[:,:,1:-1,1:-1] = output_seq
        bordered_batch_x[:,:,1:-1,1:-1] = batch_x
        output_seq = bordered_output_seq
        batch_x = bordered_batch_x
        output_seq = np.concatenate(np.split(output_seq, batch_size, 0), axis=-2).squeeze()
        batch_x = np.concatenate(np.split(batch_x, batch_size, 0), axis=-2).squeeze()
        frames = np.concatenate([output_seq, batch_x], axis=1)

        gif(os.path.join(self.save_dir, "animation%d.gif"%i), 
            frames*255, fps=7, scale=3)

        # Save extra tensors for visualization
        fetches = {"contents": self.contents,
                   "templates": self.template,
                   "background_content": self.background_content,
                   "transf_contents": self.transf_contents,
                   "transf_masks": self.transf_masks,
                   "enc_masks": self.enc_masks,
                   "masked_objs": self.masked_objs}
        results = self.sess.run(fetches, feed_dict=feed_dict)
        np.savez_compressed(os.path.join(self.save_dir, "extra_outputs.npz"), **results)
        contents = results["contents"]
        templates = results["templates"]
        contents = 1/(1+np.exp(-contents))
        templates = 1/(1+np.exp(-(templates-5)))
        if self.conv_ch == 1:
            contents = np.tile(contents, [1,1,1,3])
        templates = np.tile(templates, [1,1,1,3])
        total_seq = np.concatenate([contents, templates], axis=0)
        result = gallery(total_seq, ncols=self.n_objs)
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        fig.tight_layout()
        fig.savefig(os.path.join(self.save_dir, "templates.png"))

        logger.info([(v.name, self.sess.run(v)) for v in tf.compat.v1.trainable_variables() if "ode_cell" in v.name or "sigma" in v.name])
        return


## nn/datasets/generators.py

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from itertools import combinations

# from nn.utils.viz import gallery
# from nn.utils.misc import rgb2gray

def generate_bouncing_ball_dataset(dest,
                                   train_set_size,
                                   valid_set_size,
                                   test_set_size,
                                   seq_len,
                                   box_size):
    np.random.seed(0)

    def verify_collision(x, v):
        if x[0] + v[0] > box_size or x[0] + v[0] < 0.0:
            v[0] = -v[0]
        if x[1] + v[1] > box_size or x[1] + v[1] < 0.0:
            v[1] = -v[1]
        return v

    def generate_trajectory(steps):
        traj = []
        x = np.random.rand(2)*box_size
        speed = np.random.rand()+1
        angle = np.random.rand()*2*np.pi
        v = np.array([speed*np.cos(angle), speed*np.sin(angle)])
        for _ in range(steps):
            traj.append(x)
            v = verify_collision(x, v)
            x = x + v
        return traj

    trajectories = []
    for i in range(train_set_size+valid_set_size+test_set_size):
        trajectories.append(generate_trajectory(seq_len))
    trajectories = np.array(trajectories)

    np.savez_compressed(dest, 
                        train_x=trajectories[:train_set_size],
                        valid_x=trajectories[train_set_size:train_set_size+valid_set_size],
                        test_x=trajectories[train_set_size+valid_set_size:])
    print("Saved to file %s" % dest)


def compute_wall_collision(pos, vel, radius, img_size):
    if pos[1]-radius <= 0:
        vel[1] = -vel[1]
        pos[1] = -(pos[1]-radius)+radius
    if pos[1]+radius >= img_size[1]:
        vel[1] = -vel[1]
        pos[1] = img_size[1]-(pos[1]+radius-img_size[1])-radius  
    if pos[0]-radius <= 0:
        vel[0] = -vel[0]
        pos[0] = -(pos[0]-radius)+radius
    if pos[0]+radius >= img_size[0]:
        vel[0] = -vel[0]
        pos[0] = img_size[0]-(pos[0]+radius-img_size[0])-radius 
    return pos, vel


def verify_wall_collision(pos, vel, radius, img_size):
    if pos[1]-radius <= 0:
        return True
    if pos[1]+radius >= img_size[1]:
        return True 
    if pos[0]-radius <= 0:
        return True
    if pos[0]+radius >= img_size[0]:
        return True
    return False


def verify_object_collision(poss, radius):
    for pos1, pos2 in combinations(poss, 2):
        if np.linalg.norm(pos1-pos2) <= radius:
            return True
    return False


def generate_falling_ball_dataset(dest,
                                  train_set_size,
                                  valid_set_size,
                                  test_set_size,
                                  seq_len,
                                  img_size=None,
                                  radius=3,
                                  dt=0.15,
                                  g=9.8,
                                  ode_steps=10):

    from skimage.draw import circle
    from nn.utils.viz import gallery
    import matplotlib.cm as cm
    if img_size is None:
        img_size = [32,32]

    def generate_sequence():
        seq = []
        # sample initial position, with v=0
        pos = np.random.rand(2)
        pos[0] = radius+(img_size[0]-2*radius)*pos[0]
        pos[1] = radius + (img_size[1]-2*radius)/2*pos[1]
        vel = np.array([0.0,0.0])

        for i in range(seq_len):
            assert pos[1]+radius < img_size[1]

            frame = np.zeros(img_size+[1], dtype=np.int8)
            rr, cc = circle(int(pos[1]), int(pos[0]), radius)
            frame[rr, cc, 0] = 255

            seq.append(frame)

            # rollout physics
            for _ in range(ode_steps):
                vel[1] = vel[1] + dt/ode_steps*g
                pos[1] = pos[1] + dt/ode_steps*vel[1]    

        return seq
    
    sequences = []
    for i in range(train_set_size+valid_set_size+test_set_size):
        if i % 100 == 0:
            print(i)
        sequences.append(generate_sequence())
    sequences = np.array(sequences, dtype=np.uint8)

    np.savez_compressed(dest, 
                        train_x=sequences[:train_set_size],
                        valid_x=sequences[train_set_size:train_set_size+valid_set_size],
                        test_x=sequences[train_set_size+valid_set_size:])
    print("Saved to file %s" % dest)

    # Save 10 samples
    result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1])

    norm = plt.Normalize(0.0, 1.0)
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(dest.split(".")[0]+"_samples.jpg")


def generate_falling_bouncing_ball_dataset(dest,
                                  train_set_size,
                                  valid_set_size,
                                  test_set_size,
                                  seq_len,
                                  img_size=None,
                                  radius=3,
                                  dt=0.30,
                                  g=9.8,
                                  vx0_max=0.0,
                                  vy0_max=0.0,
                                  cifar_background=False,
                                  ode_steps=10):

    if cifar_background:
        import tensorflow as tf
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    from skimage.draw import circle
    from skimage.transform import resize

    if img_size is None:
        img_size = [32,32]
    scale = 10
    scaled_img_size = [img_size[0]*scale, img_size[1]*scale]

    def generate_sequence():
        seq = []
        # sample initial position, with v=0
        pos = np.random.rand(2)
        pos[0] = radius + (img_size[0]-2*radius)*pos[0]
        if g == 0.0:
            pos[1] = radius + (img_size[1]-2*radius)*pos[1]
        else:
            pos[1] = radius + (img_size[1]-2*radius)/2*pos[1]
        angle = np.random.rand()*2*np.pi
        vel = np.array([np.cos(angle)*vx0_max, 
                        np.sin(angle)*vy0_max])

        if cifar_background:
            cifar_img = x_train[np.random.randint(50000)]

        for i in range(seq_len):
            if cifar_background:
                frame = cifar_img
                frame = rgb2gray(frame)/255
                frame = resize(frame, scaled_img_size)
                frame = np.clip(frame-0.2, 0.0, 1.0) # darken image a bit
            else:
                frame = np.zeros(scaled_img_size, dtype=np.float32)

            rr, cc = circle(int(pos[1]*scale), int(pos[0]*scale), radius*scale, scaled_img_size)
            frame[rr, cc] = 1.0
            frame = resize(frame, img_size, anti_aliasing=True)
            frame = (frame[:,:,None]*255).astype(np.uint8)

            seq.append(frame)

            # rollout physics
            for _ in range(ode_steps):
                vel[1] = vel[1] + dt/ode_steps*g
                pos[1] = pos[1] + dt/ode_steps*vel[1]

                pos[0] = pos[0] + dt/ode_steps*vel[0]

                # verify wall collisions
                pos, vel = compute_wall_collision(pos, vel, radius, img_size)
        return seq
    
    sequences = []
    for i in range(train_set_size+valid_set_size+test_set_size):
        if i % 100 == 0:
            print(i)
        sequences.append(generate_sequence())
    sequences = np.array(sequences, dtype=np.uint8)

    np.savez_compressed(dest, 
                        train_x=sequences[:train_set_size],
                        valid_x=sequences[train_set_size:train_set_size+valid_set_size],
                        test_x=sequences[train_set_size+valid_set_size:])
    print("Saved to file %s" % dest)

    # Save 10 samples
    result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1])

    norm = plt.Normalize(0.0, 1.0)
    fig, ax = plt.subplots(figsize=(sequences.shape[1], 10))
    ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(dest.split(".")[0]+"_samples.jpg")


def generate_spring_balls_dataset(dest,
                                  train_set_size,
                                  valid_set_size,
                                  test_set_size,
                                  seq_len,
                                  img_size=None,
                                  radius=3,
                                  dt=0.3,
                                  k=3,
                                  equil=5,
                                  vx0_max=0.0,
                                  vy0_max=0.0,
                                  color=False,
                                  cifar_background=False,
                                  ode_steps=10):

    if cifar_background:
        import tensorflow as tf
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    from skimage.draw import circle
    from skimage.transform import resize

    if img_size is None:
        img_size = [32,32]
    scale = 10
    scaled_img_size = [img_size[0]*scale, img_size[1]*scale]

    def generate_sequence():
        # sample initial position of the center of mass, then sample
        # position of each object relative to that.

        collision = True
        while collision == True:
            seq = []

            cm_pos = np.random.rand(2)
            cm_pos[0] = radius+equil + (img_size[0]-2*(radius+equil))*cm_pos[0]
            cm_pos[1] = radius+equil + (img_size[1]-2*(radius+equil))*cm_pos[1]

            angle = np.random.rand()*2*np.pi
            # calculate position of both objects
            r = np.random.rand()+0.5
            poss = [[np.cos(angle)*equil*r+cm_pos[0], np.sin(angle)*equil*r+cm_pos[1]],
                   [np.cos(angle+np.pi)*equil*r+cm_pos[0], np.sin(angle+np.pi)*equil*r+cm_pos[1]]]
            poss = np.array(poss)
            angles = np.random.rand(2)*2*np.pi
            vels = [[np.cos(angles[0])*vx0_max, np.sin(angles[0])*vy0_max],
                   [np.cos(angles[1])*vx0_max, np.sin(angles[1])*vy0_max]]
            vels = np.array(vels)

            if cifar_background:
                cifar_img = x_train[np.random.randint(50000)]

            for i in range(seq_len):
                if cifar_background:
                    frame = cifar_img
                    frame = rgb2gray(frame)/255
                    frame = resize(frame, scaled_img_size)
                    frame = np.clip(frame-0.2, 0.0, 1.0) # darken image a bit
                else:
                    if color:
                        frame = np.zeros(scaled_img_size+[3], dtype=np.float32)
                    else:
                        frame = np.zeros(scaled_img_size+[1], dtype=np.float32)


                for j, pos in enumerate(poss):
                    rr, cc = circle(int(pos[1]*scale), int(pos[0]*scale), radius*scale, scaled_img_size)
                    if color:
                        frame[rr, cc, 2-j] = 1.0 
                    else:
                        frame[rr, cc, 0] = 1.0 

                frame = resize(frame, img_size, anti_aliasing=True)
                frame = (frame*255).astype(np.uint8)

                seq.append(frame)

                # rollout physics
                for _ in range(ode_steps):
                    norm = np.linalg.norm(poss[0]-poss[1])
                    direction = (poss[0]-poss[1])/norm
                    F = k*(norm-2*equil)*direction
                    vels[0] = vels[0] - dt/ode_steps*F
                    vels[1] = vels[1] + dt/ode_steps*F
                    poss = poss + dt/ode_steps*vels

                    collision = verify_wall_collision(poss[0], vels[0], radius, img_size) or \
                                verify_wall_collision(poss[1], vels[1], radius, img_size)
                    if collision:
                        break
                    #poss[0], vels[0] = compute_wall_collision(poss[0], vels[0], radius, img_size)
                    #poss[1], vels[1] = compute_wall_collision(poss[1], vels[1], radius, img_size)
                if collision:
                    break

        return seq
    
    sequences = []
    for i in range(train_set_size+valid_set_size+test_set_size):
        if i % 100 == 0:
            print(i)
        sequences.append(generate_sequence())
    sequences = np.array(sequences, dtype=np.uint8)

    np.savez_compressed(dest, 
                        train_x=sequences[:train_set_size],
                        valid_x=sequences[train_set_size:train_set_size+valid_set_size],
                        test_x=sequences[train_set_size+valid_set_size:])
    print("Saved to file %s" % dest)

    # Save 10 samples
    result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1])

    norm = plt.Normalize(0.0, 1.0)
    fig, ax = plt.subplots(figsize=(sequences.shape[1], 10))
    ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(dest.split(".")[0]+"_samples.jpg")


def generate_spring_mnist_dataset(dest,
                                  train_set_size,
                                  valid_set_size,
                                  test_set_size,
                                  seq_len,
                                  img_size=None,
                                  radius=3,
                                  dt=0.3,
                                  k=3,
                                  equil=5,
                                  vx0_max=0.0,
                                  vy0_max=0.0,
                                  color=False,
                                  cifar_background=False,
                                  ode_steps=10):

    # A single CIFAR image is used for background
    # Only 2 mnist digits are used
    import tensorflow as tf
    from skimage.draw import circle
    from skimage.transform import resize

    scale = 5
    if img_size is None:
        img_size = [32,32]    
    scaled_img_size = [img_size[0]*scale, img_size[1]*scale]

    if cifar_background:
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
        cifar_img = x_train[1]
        
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    digits = x_train[0:2, 3:-3, 3:-3]/255
    digits = [resize(d, [22*scale, 22*scale]) for d in digits]
    radius = 11

    def generate_sequence():
        # sample initial position of the center of mass, then sample
        # position of each object relative to that.

        collision = True
        while collision == True:
            seq = []

            cm_pos = np.random.rand(2)
            cm_pos[0] = radius+equil + (img_size[0]-2*(radius+equil))*cm_pos[0]
            cm_pos[1] = radius+equil + (img_size[1]-2*(radius+equil))*cm_pos[1]

            angle = np.random.rand()*2*np.pi
            # calculate position of both objects
            r = np.random.rand()+0.5
            poss = [[np.cos(angle)*equil*r+cm_pos[0], np.sin(angle)*equil*r+cm_pos[1]],
                   [np.cos(angle+np.pi)*equil*r+cm_pos[0], np.sin(angle+np.pi)*equil*r+cm_pos[1]]]
            poss = np.array(poss)
            angles = np.random.rand(2)*2*np.pi
            vels = [[np.cos(angles[0])*vx0_max, np.sin(angles[0])*vy0_max],
                   [np.cos(angles[1])*vx0_max, np.sin(angles[1])*vy0_max]]
            vels = np.array(vels)

            for i in range(seq_len):
                if cifar_background:
                    frame = cifar_img
                    if not color:
                        frame = rgb2gray(frame)
                        frame = frame[:,:,None]
                    frame = frame/255
                    frame = resize(frame, scaled_img_size)
                    frame = np.clip(frame-0.2, 0.0, 1.0) # darken image a bit
                else:
                    if color:
                        frame = np.zeros(scaled_img_size+[3], dtype=np.float32)
                    else:
                        frame = np.zeros(scaled_img_size+[1], dtype=np.float32)


                for j, pos in enumerate(poss):
                    rr, cc = circle(int(pos[1]*scale), int(pos[0]*scale), radius*scale, scaled_img_size)
                    frame_coords = np.array([[max(0, (pos[1]-radius)*scale), min(scaled_img_size[1], (pos[1]+radius)*scale)],
                                             [max(0, (pos[0]-radius)*scale), min(scaled_img_size[0], (pos[0]+radius)*scale)]])
                    digit_coords = np.array([[max(0, (radius-pos[1])*scale), min(2*radius*scale, scaled_img_size[1]-(pos[1]-radius)*scale)],
                                             [max(0, (radius-pos[0])*scale), min(2*radius*scale, scaled_img_size[0]-(pos[0]-radius)*scale)]])
                    frame_coords = np.round(frame_coords).astype(np.int32)
                    digit_coords = np.round(digit_coords).astype(np.int32)
                    
                    digit_slice = digits[j][digit_coords[0,0]:digit_coords[0,1], 
                                            digit_coords[1,0]:digit_coords[1,1]]
                    if color:
                        for l in range(3):
                            frame_slice = frame[frame_coords[0,0]:frame_coords[0,1], 
                                                frame_coords[1,0]:frame_coords[1,1], l]
                            c = 1.0 if l == j else 0.0
                            frame[frame_coords[0,0]:frame_coords[0,1], 
                                  frame_coords[1,0]:frame_coords[1,1], l] = digit_slice*c + (1-digit_slice)*frame_slice

                    else:
                        frame_slice = frame[frame_coords[0,0]:frame_coords[0,1], 
                                            frame_coords[1,0]:frame_coords[1,1], 0]
                        frame[frame_coords[0,0]:frame_coords[0,1], 
                              frame_coords[1,0]:frame_coords[1,1], 0] = digit_slice + (1-digit_slice)*frame_slice

                frame = resize(frame, img_size, anti_aliasing=True)
                frame = (frame*255).astype(np.uint8)

                seq.append(frame)

                # rollout physics
                for _ in range(ode_steps):
                    norm = np.linalg.norm(poss[0]-poss[1])
                    direction = (poss[0]-poss[1])/norm
                    F = k*(norm-2*equil)*direction
                    vels[0] = vels[0] - dt/ode_steps*F
                    vels[1] = vels[1] + dt/ode_steps*F
                    poss = poss + dt/ode_steps*vels

                    collision = verify_wall_collision(poss[0], vels[0], 2, img_size) or \
                                verify_wall_collision(poss[1], vels[1], 2, img_size)
                    if collision:
                        break
                    #poss[0], vels[0] = compute_wall_collision(poss[0], vels[0], radius, img_size)
                    #poss[1], vels[1] = compute_wall_collision(poss[1], vels[1], radius, img_size)
                if collision:
                    break

        return seq
    
    sequences = []
    for i in range(train_set_size+valid_set_size+test_set_size):
        if i % 100 == 0:
            print(i)
        sequences.append(generate_sequence())
    sequences = np.array(sequences, dtype=np.uint8)

    np.savez_compressed(dest, 
                        train_x=sequences[:train_set_size],
                        valid_x=sequences[train_set_size:train_set_size+valid_set_size],
                        test_x=sequences[train_set_size+valid_set_size:])
    print("Saved to file %s" % dest)

    # Save 10 samples
    result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1])

    norm = plt.Normalize(0.0, 1.0)
    fig, ax = plt.subplots(figsize=(sequences.shape[1], 10))
    ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(dest.split(".")[0]+"_samples.jpg")


def generate_3_body_problem_dataset(dest,
                                  train_set_size,
                                  valid_set_size,
                                  test_set_size,
                                  seq_len,
                                  img_size=None,
                                  radius=3,
                                  dt=0.3,
                                  g=9.8,
                                  m=1.0,
                                  vx0_max=0.0,
                                  vy0_max=0.0,
                                  color=False,
                                  cifar_background=False,
                                  ode_steps=10):

    if cifar_background:
        import tensorflow as tf
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    from skimage.draw import circle
    from skimage.transform import resize

    if img_size is None:
        img_size = [32,32]
    scale = 10
    scaled_img_size = [img_size[0]*scale, img_size[1]*scale]

    def generate_sequence():
        # sample initial position of the center of mass, then sample
        # position of each object relative to that.

        collision = True
        while collision == True:
            seq = []

            cm_pos = np.random.rand(2)
            cm_pos = np.array(img_size)/2
            angle1 = np.random.rand()*2*np.pi
            angle2 = angle1 + 2*np.pi/3+(np.random.rand()-0.5)/2
            angle3 = angle1 + 4*np.pi/3+(np.random.rand()-0.5)/2

            angles = [angle1, angle2, angle3]
            # calculate position of both objects
            r = (np.random.rand()/2+0.75)*img_size[0]/4
            poss = [[np.cos(angle)*r+cm_pos[0], np.sin(angle)*r+cm_pos[1]] for angle in angles]
            poss = np.array(poss)
            
            #angles = np.random.rand(3)*2*np.pi
            #vels = [[np.cos(angle)*vx0_max, np.sin(angle)*vy0_max] for angle in angles]
            #vels = np.array(vels)
            r = np.random.randint(0,2)*2-1
            angles = [angle+r*np.pi/2 for angle in angles]
            noise = np.random.rand(2)-0.5
            vels = [[np.cos(angle)*vx0_max+noise[0], np.sin(angle)*vy0_max+noise[1]] for angle in angles]
            vels = np.array(vels)

            if cifar_background:
                cifar_img = x_train[np.random.randint(50000)]

            for i in range(seq_len):
                if cifar_background:
                    frame = cifar_img
                    frame = rgb2gray(frame)/255
                    frame = resize(frame, scaled_img_size)
                    frame = np.clip(frame-0.2, 0.0, 1.0) # darken image a bit
                else:
                    if color:
                        frame = np.zeros(scaled_img_size+[3], dtype=np.float32)
                    else:
                        frame = np.zeros(scaled_img_size+[1], dtype=np.float32)

                for j, pos in enumerate(poss):
                    rr, cc = circle(int(pos[1]*scale), int(pos[0]*scale), radius*scale, scaled_img_size)
                    if color:
                        frame[rr, cc, 2-j] = 1.0 
                    else:
                        frame[rr, cc, 0] = 1.0 

                frame = resize(frame, img_size, anti_aliasing=True)
                frame = (frame*255).astype(np.uint8)

                seq.append(frame)

                # rollout physics
                for _ in range(ode_steps):
                    norm01 = np.linalg.norm(poss[0]-poss[1])
                    norm12 = np.linalg.norm(poss[1]-poss[2])
                    norm20 = np.linalg.norm(poss[2]-poss[0])
                    vec01 = (poss[0]-poss[1])
                    vec12 = (poss[1]-poss[2])
                    vec20 = (poss[2]-poss[0])

                    # Compute force vectors
                    F = [vec01/norm01**3-vec20/norm20**3,
                         vec12/norm12**3-vec01/norm01**3,
                         vec20/norm20**3-vec12/norm12**3]
                    F = np.array(F)
                    F = -g*m*m*F

                    vels = vels + dt/ode_steps*F
                    poss = poss + dt/ode_steps*vels

                    collision = any([verify_wall_collision(pos, vel, radius, img_size) for pos, vel in zip(poss, vels)]) or \
                                verify_object_collision(poss, radius+1)
                    if collision:
                        break

                if collision:
                    break

        return seq
    
    sequences = []
    for i in range(train_set_size+valid_set_size+test_set_size):
        if i % 100 == 0:
            print(i)
        sequences.append(generate_sequence())
    sequences = np.array(sequences, dtype=np.uint8)

    np.savez_compressed(dest, 
                        train_x=sequences[:train_set_size],
                        valid_x=sequences[train_set_size:train_set_size+valid_set_size],
                        test_x=sequences[train_set_size+valid_set_size:])
    print("Saved to file %s" % dest)

    # Save 10 samples
    result = gallery(np.concatenate(sequences[:10]/255), ncols=sequences.shape[1])

    norm = plt.Normalize(0.0, 1.0)
    fig, ax = plt.subplots(figsize=(sequences.shape[1], 10))
    ax.imshow(np.squeeze(result), interpolation='nearest', cmap=cm.Greys_r, norm=norm)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    fig.tight_layout()
    fig.savefig(dest.split(".")[0]+"_samples.jpg")


## nn/datasets/iterators.py

In [None]:
import os
import time
import numpy as np
import tensorflow as tf

class DataIterator:

    def __init__(self, X, Y=None):
        self.X = X
        self.Y = Y

        self.num_examples = self.X.shape[0]
        self.epochs_completed = 0
        self.indices = np.arange(self.num_examples)
        self.reset_iteration()

    def reset_iteration(self):
        np.random.shuffle(self.indices)
        self.start_idx = 0

    def get_epoch(self):
        return self.epochs_completed

    def reset_epoch(self):
        self.reset_iteration()
        self.epochs_completed = 0

    def next_batch(self, batch_size, data_type="train", shuffle=True):#
        assert data_type in ["train", "val", "test"], \
            "data_type must be 'train', 'val', or 'test'."

        idx = self.indices[self.start_idx:self.start_idx + batch_size]

        batch_x = self.X[idx]
        batch_y = self.Y[idx] if self.Y is not None else self.Y
        self.start_idx += batch_size

        if self.start_idx + batch_size > self.num_examples:
            self.reset_iteration()
            self.epochs_completed += 1

        return (batch_x, batch_y)
    
    def sample_random_batch(self, batch_size):
        start_idx = np.random.randint(0, self.num_examples - batch_size)
        batch_x = self.X[self.start_idx:self.start_idx + batch_size]
        batch_y = self.Y[self.start_idx:self.start_idx + batch_size] if self.Y is not None else self.Y
        
        return (batch_x, batch_y)


def get_iterators(file, conv=False, datapoints=0):
    data = np.load(file)
    if conv:
        img_shape = data["train_x"][0,0].shape
    else:
        img_shape = data["train_x"][0,0].flatten().shape
    train_it = DataIterator(X=data["train_x"].reshape(data["train_x"].shape[:2]+img_shape)/255)
    valid_it = DataIterator(X=data["valid_x"].reshape(data["valid_x"].shape[:2]+img_shape)/255)
    test_it = DataIterator(X=data["test_x"].reshape(data["test_x"].shape[:2]+img_shape)/255)
    return train_it, valid_it, test_it


## runners/run_base.py

In [None]:
import os
import logging
import tensorflow as tf

# tf.compat.v1.app.flags.DEFINE_integer("epochs", 10, "Epochs to train.")
# tf.compat.v1.app.flags.DEFINE_integer("batch_size", 100, "Training batch size")
# tf.compat.v1.app.flags.DEFINE_string("save_dir", "experiments/spring_color/", "Directory to save checkpoint and logs.")
# tf.compat.v1.app.flags.DEFINE_bool("use_ckpt", False, "Whether to start from scratch of start from checkpoint.")
# tf.compat.v1.app.flags.DEFINE_string("ckpt_dir", "", "Checkpoint dir to use.")
# tf.compat.v1.app.flags.DEFINE_float("base_lr", 3e-4, "Base learning rate.")
# tf.compat.v1.app.flags.DEFINE_bool("anneal_lr", True, "Whether to anneal lr after 0.75 of total epochs.")
# tf.compat.v1.app.flags.DEFINE_string("optimizer", "rmsprop", "Optimizer to use.")
# tf.compat.v1.app.flags.DEFINE_integer("save_every_n_epochs", 5, "Epochs between checkpoint saves.")
# tf.compat.v1.app.flags.DEFINE_integer("eval_every_n_epochs", 1, "Epochs between validation run.")
# tf.compat.v1.app.flags.DEFINE_integer("print_interval", 10, "Print train metrics every n mini-batches.")
# tf.compat.v1.app.flags.DEFINE_bool("debug", False, "If true, eval is not ran before training.")
# tf.compat.v1.app.flags.DEFINE_bool("test_mode", False, "If true, only run test set.")

logger = logging.getLogger("tf")
logger.setLevel(logging.DEBUG)
# create console handler
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)

## runners/run_physics.py

In [None]:
# Configuration Constants
EPOCHS = 10
BATCH_SIZE = 100
SAVE_DIR = "experiments/spring_color/"
USE_CKPT = False
CKPT_DIR = ""
BASE_LR = 3e-4
ANNEAL_LR = True
OPTIMIZER = "rmsprop"
SAVE_EVERY_N_EPOCHS = 5
EVAL_EVERY_N_EPOCHS = 1
PRINT_INTERVAL = 10
DEBUG = False
TEST_MODE = False

TASK = "spring_color"
MODEL = "PhysicsNet"
RECURRENT_UNITS = 100
LSTM_LAYERS = 1
CELL_TYPE = ""
ENCODER_TYPE = "conv_encoder"
DECODER_TYPE = "conv_st_decoder"

AUTOENCODER_LOSS = 3.0
ALT_VEL = False
COLOR = True
DATAPOINTS = 0

In [None]:
import os
import logging
import inspect
import tensorflow as tf
# from nn.network import physics_models
# from nn.utils.misc import classes_in_module
# from nn.datasets.iterators import get_iterators
# import runners.run_base

In [None]:
# tf.compat.v1.app.flags.DEFINE_string("task", "spring_color", "Type of task.")
# tf.compat.v1.app.flags.DEFINE_string("model", "PhysicsNet", "Model to use.")
# tf.compat.v1.app.flags.DEFINE_integer("recurrent_units", 100, "Number of units for each lstm, if using black-box dynamics.")
# tf.compat.v1.app.flags.DEFINE_integer("lstm_layers", 1, "Number of lstm cells to use, if using black-box dynamics")
# tf.compat.v1.app.flags.DEFINE_string("cell_type", "", "Type of pendulum to use.")
# tf.compat.v1.app.flags.DEFINE_string("encoder_type", "conv_encoder", "Type of encoder to use.")
# tf.compat.v1.app.flags.DEFINE_string("decoder_type", "conv_st_decoder", "Type of decoder to use.")

# tf.compat.v1.app.flags.DEFINE_float("autoencoder_loss", 3.0, "Autoencoder loss weighing.")
# tf.compat.v1.app.flags.DEFINE_bool("alt_vel", False, "Whether to use linear velocity computation.")
# tf.compat.v1.app.flags.DEFINE_bool("color", True, "Whether images are rbg or grayscale.")
# tf.compat.v1.app.flags.DEFINE_integer("datapoints", 0, "How many datapoints from the dataset to use. \
#                                               Useful for measuring data efficiency. default=0 uses all data.")

# FLAGS = tf.compat.v1.app.flags.FLAGS

In [None]:
Model = PhysicsNet

data_file, test_data_file, cell_type, seq_len, test_seq_len, input_steps, pred_steps, input_size = {
    "bouncing_balls": (
        "bouncing/color_bounce_vx8_vy8_sl12_r2.npz", 
        "bouncing/color_bounce_vx8_vy8_sl30_r2.npz", 
        "bouncing_ode_cell",
        12, 30, 4, 6, 32*32),
    "spring_color": (
        "spring_color/color_spring_vx8_vy8_sl12_r2_k4_e6.npz", 
        "spring_color/color_spring_vx8_vy8_sl30_r2_k4_e6.npz",
        "spring_ode_cell",
        12, 30, 4, 6, 32*32),
    "spring_color_half": (
        "spring_color_half/color_spring_vx4_vy4_sl12_r2_k4_e6_halfpane.npz", 
        "spring_color_half/color_spring_vx4_vy4_sl30_r2_k4_e6_halfpane.npz", 
        "spring_ode_cell",
        12, 30, 4, 6, 32*32),
    "3bp_color": (
        "3bp_color/color_3bp_vx2_vy2_sl20_r2_g60_m1_dt05.npz", 
        "3bp_color/color_3bp_vx2_vy2_sl40_r2_g60_m1_dt05.npz", 
        "gravity_ode_cell",
        20, 40, 4, 12, 36*36),
    "mnist_spring_color": (
        "mnist_spring_color/color_mnist_spring_vx8_vy8_sl12_r2_k2_e12.npz", 
        "mnist_spring_color/color_mnist_spring_vx8_vy8_sl30_r2_k2_e12.npz", 
        "spring_ode_cell",
        12, 30, 3, 7, 64*64)
}[TASK]

def run_physics():
    if not TEST_MODE:
        network = Model(TASK, RECURRENT_UNITS, LSTM_LAYERS, cell_type, 
                        seq_len, input_steps, pred_steps,
                       AUTOENCODER_LOSS, ALT_VEL, COLOR, 
                       input_size, ENCODER_TYPE, DECODER_TYPE)

        network.build_graph()
        network.build_optimizer(BASE_LR, OPTIMIZER, ANNEAL_LR)
        network.initialize_graph(SAVE_DIR, USE_CKPT, CKPT_DIR)

        data_iterators = get_iterators(
                              os.path.join(
                                  os.path.dirname(os.path.realpath(__file__)), 
                                  "/kaggle/input/physics-as-inverse-graphics/%s"%data_file), conv=True, datapoints=FLAGS.datapoints)
        network.get_data(data_iterators)
        network.train(EPOCHS, BATCH_SIZE, SAVE_EVERY_N_EPOCHS, EVAL_EVERY_N_EPOCHS,
                    PRINT_INTERVAL, DEBUG)

        tf.compat.v1.reset_default_graph()

    network = Model(TASK, RECURRENT_UNITS, LSTM_LAYERS, cell_type, 
                    test_seq_len, input_steps, pred_steps,
                   AUTOENCODER_LOSS, ALT_VEL, COLOR, 
                   input_size, ENCODER_TYPE, DECODER_TYPE)

    network.build_graph()
    network.build_optimizer(BASE_LR, OPTIMIZER, ANNEAL_LR)
    network.initialize_graph(SAVE_DIR, True, CKPT_DIR)

    data_iterators = get_iterators(
                          os.path.join(
                              os.path.dirname(os.path.realpath(__file__)), 
                              "/kaggle/input/physics-as-inverse-graphics/%s"%test_data_file), conv=True, datapoints=FLAGS.datapoints)
    network.get_data(data_iterators)
    network.train(0, BATCH_SIZE, SAVE_EVERY_N_EPOCHS, EVAL_EVERY_N_EPOCHS,
                PRINT_INTERVAL, DEBUG)

In [41]:
# tf.compat.v1.disable_eager_execution()