# Localization

Use "localization" to learn a Cahn-Hilliard model.


## Learn a Cahn-Hilliard

Square domain periodic boundary conditions

## $$ \dot{\phi} = \nabla^2 \left( \phi^3 - \phi \right) - \gamma \nabla^4 \phi $$

## What are we trying to do?

Create a mapping from $t_0$ to $t_{10}$ without doing all the steps. We want to do the following.

## $$ \phi[s](t=t_0) \rightarrow \phi[s](t=t_{10})$$

## Localization

Use regression for each local state.

## Create Samples

In [None]:
%matplotlib inline

import pymks
import matplotlib.pyplot as plt
import numpy as np
from pymks.datasets import make_cahn_hilliard

In [None]:
n_steps = 10
size = (151, 151)
X, y = make_cahn_hilliard(n_samples=10, size=size, dt=1., n_steps=n_steps)

In [None]:
print(X.shape)
print(y.shape)

In [None]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(X[0])
plt.colorbar()

In [None]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y[0])
plt.colorbar()

## Parallel

In [None]:
from dask import compute, delayed
import dask.multiprocessing

def make_data(seed):
    np.random.seed(seed)
    return make_cahn_hilliard(n_samples=10, size=size, dt=1., n_steps=n_steps)

funcs = [delayed(make_data)(seed) for seed in range(30)]
    
out = compute(*funcs, scheduler="threads")

In [None]:
np.array(out).shape

In [None]:
X = np.array(out)[:, 0].reshape((300,) + size)
y = np.array(out)[:, 1].reshape((300,) + size)

## Learning

In [None]:
from pymks import MKSLocalizationModel
from pymks.bases import PrimitiveBasis

basis = PrimitiveBasis(n_states=5, domain=[-1, 1])
model = MKSLocalizationModel(basis=basis)

In [None]:
model.fit(X[:-1], y[:-1])

In [None]:
y_pred = model.predict(X[-1:])

In [None]:
# NBVAL_IGNORE_OUTPUT
plt.imshow(y_pred[0])
plt.colorbar()

In [None]:
# NBVAL_IGNORE_OUTPUT
plt.imshow(y[-1])
plt.colorbar()

## Train Test Split

In [None]:
from sklearn.model_selection import train_test_split
from sklearn import metrics

X_train, X_test, y_train, y_test = train_test_split(X, y)

In [None]:
basis = PrimitiveBasis(n_states=5, domain=[-1, 1])
model = MKSLocalizationModel(basis=basis)

In [None]:
model.fit(X_train, y_train)

In [None]:
y_pred = model.predict(X_test)

In [None]:
# NBVAL_IGNORE_OUTPUT

metrics.mean_squared_error(y_pred.flatten(), y_test.flatten())

In [None]:
# NBVAL_IGNORE_OUTPUT

print(y_pred[0][0][:10])
print(y_test[0][0][:10])

In [None]:
y_pred.shape

In [None]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y_pred[0])
plt.colorbar()

In [None]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y_test[0])
plt.colorbar()

## Scale Up

In [None]:
X_big, y_big = make_cahn_hilliard(n_samples=1, size=(1000, 1000), dt=1., n_steps=n_steps)

In [None]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y_big[0])
plt.colorbar()

In [None]:
basis = PrimitiveBasis(n_states=5, domain=[-1, 1])
model = MKSLocalizationModel(basis=basis)
model.fit(X, y)
model.resize_coeff(y_big[0].shape)

In [None]:
y_big_pred = model.predict(X_big)

In [None]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y_big_pred[0])
plt.colorbar()

In [None]:
# NBVAL_IGNORE_OUTPUT

metrics.mean_squared_error(y_big_pred.flatten(), y_big.flatten())

In [None]:
# NBVAL_IGNORE_OUTPUT

%timeit make_cahn_hilliard(n_samples=1, size=(1000, 1000), dt=1., n_steps=n_steps)

In [None]:
# NBVAL_IGNORE_OUTPUT

%timeit model.predict(X_big)

## Multiple Steps

In [None]:
X2, y2 = make_cahn_hilliard(n_samples=1, size=size, dt=1., n_steps=2 * n_steps)

basis = PrimitiveBasis(n_states=10, domain=[-1, 1])
model = MKSLocalizationModel(basis=basis)
model.fit(X, y)

In [None]:
tmp = model.predict(X2)
y2_pred = model.predict(tmp)

In [None]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y2[0])
plt.colorbar()

In [None]:
# NBVAL_IGNORE_OUTPUT

plt.imshow(y2_pred[0])
plt.colorbar()

In [None]:
# NBVAL_IGNORE_OUTPUT

metrics.mean_squared_error(y2_pred.flatten(), y2.flatten())

## Cross Validation

In [None]:
from pymks.bases import LegendreBasis
from sklearn.model_selection import GridSearchCV
from sklearn import metrics
mse = metrics.mean_squared_error

prim_basis = PrimitiveBasis(2, [-1, 1])
leg_basis = LegendreBasis(2, [-1, 1])

params_to_tune = {'n_states': [2, 3, 5, 8, 13],
                  'basis': [prim_basis, leg_basis]}
model = MKSLocalizationModel(prim_basis)
score_func = metrics.make_scorer(lambda x, y: -mse(x.flatten(), y.flatten()))
gscv = GridSearchCV(model, params_to_tune, cv=5, scoring=score_func)

In [None]:
?GridSearchCV

In [None]:
# NBVAL_SKIP

gscv.fit(X_train, y_train)

In [None]:
# NBVAL_SKIP

gscv.best_estimator_

In [None]:
# NBVAL_SKIP

gscv.score(X_test, y_test)

In [None]:
# NBVAL_SKIP

gscv.cv_results_