In [None]:
from phi.flow import *
from matplotlib import pyplot as plt
import numpy as np
import random

In [None]:
step_count = 16  # how many solver steps to perform
domain = Domain([64, 64])
dt = 1.0
data_path = 'shape-transitions'
test_range = range(100)
val_range = range(100, 200)
train_range = range(200, 1000)

In [None]:
supervised_checkpoints = {'OP%d' % n: 'ckpts/supervised/OP%d_1000' % n for n in [2, 4, 8, 16]}
supervised_checkpoints['CFE'] = 'ckpts/CFE/CFE_1000'

In [None]:
from src.control.control_training import ControlTraining
from src.control.pde.incompressible_flow import IncompressibleFluidPDE
from src.control.sequences import StaggeredSequence

staggered_app = ControlTraining(step_count, IncompressibleFluidPDE(domain, dt),
                                datapath=data_path, val_range=val_range, train_range=train_range,
                                trace_to_channel=lambda _: 'density',
                                obs_loss_frames=[step_count], trainable_networks=['CFE', 'OP2', 'OP4', 'OP8', 'OP16'],
                                sequence_class=StaggeredSequence, learning_rate=5e-4).prepare()
staggered_app.load_checkpoints(supervised_checkpoints)
for i in range(1000):
    staggered_app.progress()  # run staggered Optimization for one batch
staggered_checkpoint = staggered_app.save_model()

In [None]:
states = staggered_app.infer_all_frames(test_range)

In [None]:
import pylab

batches = [5]

pylab.subplots(len(batches), 10, sharey='row', sharex='col', figsize=(14, 6))
pylab.tight_layout(w_pad=0)

# solutions
for i, batch in enumerate(batches):
    for t in range(9):
        pylab.subplot(len(batches), 10, t + 1 + i * 10)
        pylab.title('t=%d' % (t * 2))
        pylab.imshow(states[t * 2].density.data[batch, ..., 0], origin='lower')

# add targets
testset = BatchReader(Dataset.load(staggered_app.data_path, test_range), staggered_app._channel_struct)[test_range]
for i, batch in enumerate(batches):
    pylab.subplot(len(batches), 10, i * 10 + 10)
    pylab.title('target')
    pylab.imshow(testset[1][i, ..., 0], origin='lower')


In [None]:
errors = []
for batch in enumerate(test_range):
    initial = np.mean( np.abs( states[0].density.data[batch, ..., 0] - testset[1][batch,...,0] ))
    solution = np.mean( np.abs( states[16].density.data[batch, ..., 0] - testset[1][batch,...,0] ))
    errors.append( solution/initial )
print("Relative MAE: "+format(np.mean(errors)))