In [None]:
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import torch

import fit.sine_pde_dense as T
import numpy as np

from IPython.display import HTML
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt

%matplotlib inline
torch.set_printoptions(precision=4, linewidth=500, sci_mode=True)

### Data Fitting Example
### 1. Fit PDE with time varying source term

Learns a PDE with time-space invariant coeffcient and a time-space varying source term. This makes the model flexible and easier to learn.

In [None]:
method, dataset = T.create_model(time_varying_source=True)

In [None]:
T.train(method, dataset, epochs=100)

In [None]:
#end = T.method.model.end
coord_dims = method.model.coord_dims
target = dataset.y
damp = dataset.damp

In [None]:
func_list = method.func_list
y_list = method.y_list
#n_iter =500

In [None]:
fig, ax = plt.subplots(1,2)
cax0 = ax[0].pcolormesh(y_list[-1].reshape(*coord_dims), cmap='RdBu', shading='gouraud')
cax1= ax[1].pcolormesh(func_list[-1].reshape(*coord_dims), cmap='RdBu', shading='gouraud')

def animate(i):
   cax0.set_array(y_list[i].reshape(*coord_dims).flatten())
   cax1.set_array(func_list[i].reshape(*coord_dims).flatten())

   ax[0].axis('off')
   ax[1].axis('off')

anim = FuncAnimation(fig, animate, interval=100, frames=len(func_list))
HTML(anim.to_html5_video())

### 2. Fit without Source Term

Fits a PDE with time-space invariant coeffcient and no source term.

In [None]:
del method, dataset
method, dataset = T.create_model(time_varying_source=False)

In [None]:
method.model.time_varying_source

In [None]:
T.train(method, dataset, epochs=100)

In [None]:
func_list = method.func_list
y_list = method.y_list
#n_iter =500

In [None]:
fig, ax = plt.subplots(1,2)
cax0 = ax[0].pcolormesh(y_list[-1].reshape(*coord_dims), cmap='RdBu', shading='gouraud')
cax1= ax[1].pcolormesh(func_list[-1].reshape(*coord_dims), cmap='RdBu', shading='gouraud')

def animate(i):
   cax0.set_array(y_list[i].reshape(*coord_dims).flatten())
   cax1.set_array(func_list[i].reshape(*coord_dims).flatten())

   ax[0].axis('off')
   ax[1].axis('off')

anim = FuncAnimation(fig, animate, interval=100, frames=len(func_list))
HTML(anim.to_html5_video())