## Setup

In [None]:
# %%capture
# ! pip install -q equinox optax

In [None]:
# ! git remote remove origin
# ! git init .
# ! git remote add origin https://github.com/arudikov/PNO
# ! git pull origin main

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import optax
import itertools
import sympy as sp
import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt

from datasets.Elliptic import solvers
from architectures.DilResNet import DilatedResNet
from architectures.FNO import FNO
from architectures.DilResNet import make_step as make_step_Dil
from architectures.FNO import make_step as make_step_FNO
from tqdm import tqdm
from IPython import display
from jax import config, random, grad, jit, vmap
from jax.lax import scan
from functools import partial

%matplotlib inline
%config InlineBackend.figure_format='retina'

## Useful Code: `batch generator`, `compute_loss`, `train_on_epoch`

In [None]:
def batch_generator(x, y, batch_size, key, shuffle=True):
    N_samples = len(x)
    list_of_indeces = jnp.linspace(0, N_samples-1, N_samples, dtype=jnp.int64)

    if shuffle:
        random.shuffle(key, list_of_indeces)

    list_x = x[list_of_indeces]
    list_y = y[list_of_indeces]

    n_batches = N_samples // batch_size
    if N_samples % batch_size != 0:
        n_batches += 1

    for k in range(n_batches):
        this_batch_size = batch_size

        if k == n_batches - 1:
            if N_samples % batch_size > 0:
                this_batch_size = N_samples % batch_size

        x = jnp.array(list_x[k * batch_size : k * batch_size + this_batch_size])
        y = jnp.array(list_y[k * batch_size : k * batch_size + this_batch_size])

        yield x, y

In [None]:
def compute_loss(model, input, target):
    output = model(input)
    diff_norm = jnp.linalg.norm((output - target).reshape(input.shape[0], -1), axis=1)
    y_norm = jnp.linalg.norm(target.reshape(input.shape[0], -1), axis=1)
    return jnp.mean(diff_norm / y_norm)

compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)   

In [None]:
def train_on_epoch(train_generator, model, make_step, optimizer, opt_state, n_iter):
    epoch_loss = []
    for it, (batch_of_x, batch_of_y) in enumerate(train_generator):
        batch_loss, model, opt_state = make_step(model, batch_of_x, batch_of_y, optimizer, opt_state)  
        epoch_loss.append(batch_loss.item())
        
        n_iter += 1
        
    return epoch_loss, model, opt_state, n_iter

## Generate Dataset for `train` and `validations` parts

In [None]:
def dataset(N_points, key):
    coeffs = []
    for key in iter(random.split(key, 3)):
        coeffs.append(random.normal(key, (5,), dtype=jnp.complex128))

    a = lambda x: -(jnp.real(0.5 * jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[0][k] for k in range(coeffs[0].size)], 0), 0))**2 + 0.5)
    d = lambda x: jnp.zeros_like(x)
    c = lambda x: jnp.real(0.2 * jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[1][k] for k in range(coeffs[1].size)], 0), 0))
    f = lambda x: jnp.real(jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[2][k] for k in range(coeffs[2].size)], 0), 0))

    F = [a, d, c, f]
    BCs = [0, 0]
    
    x = jnp.linspace(0, 1, N_points)
    features = jnp.vstack((-a(x), c(x), f(x)))
    solution = solvers.solve_BVP(N_points, F, BCs)

    return features, solution

In [None]:
def train_dataset(N_samples, N_points=100):
    features, targets = [], []
    
    for key in tqdm(iter(random.split(random.PRNGKey(42), N_samples))):
        feature, solution = dataset(N_points, key)
        features.append(feature)
        targets.append(solution)

    return [jnp.array(features), jnp.array(targets)]

In [None]:
def validation_dataset(N_samples, N_points=100):
    features, targets = [], []
    for it in tqdm(random.randint(random.PRNGKey(10), (1,N_samples), 0,10000)[0]):
        coeffs = []
        for key in iter(random.split(random.PRNGKey(it), 3)):
            coeffs.append(random.normal(key, (5,), dtype=jnp.complex128))

        a = lambda x: -(jnp.real(0.5 * jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[0][k] for k in range(coeffs[0].size)], 0), 0))**2 + 0.5)
        d = lambda x: jnp.zeros_like(x)
        c = lambda x: jnp.real(0.2 * jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[1][k] for k in range(coeffs[1].size)], 0), 0))
        f = lambda x: jnp.real(jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[2][k] for k in range(coeffs[2].size)], 0), 0))

        F = [a, d, c, f]
        BCs = [0, 0]
        
        x = jnp.linspace(0, 1, N_points)
        feature = jnp.vstack((-a(x), c(x), f(x)))
        solution = solvers.solve_BVP(N_points, F, BCs)

        features.append(feature)
        targets.append(solution)
    
    return [jnp.array(features), jnp.array(targets)]

In [None]:
def generator(key, N_points=100):
    coeffs = []
    for key in iter(random.split(key, 3)):
        coeffs.append(random.normal(key, (5,), dtype=jnp.complex128))

    a = lambda x: -(jnp.real(0.5 * jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[0][k] for k in range(coeffs[0].size)], 0), 0))**2 + 0.5)
    d = lambda x: jnp.zeros_like(x)
    c = lambda x: jnp.real(0.2 * jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[1][k] for k in range(coeffs[1].size)], 0), 0))
    f = lambda x: jnp.real(jnp.sum(jnp.stack([jnp.exp(1j * 2 * jnp.pi * x * k) * coeffs[2][k] for k in range(coeffs[2].size)], 0), 0))

    x = jnp.linspace(0, 1, N_points)
    features = jnp.vstack((-a(x), c(x), f(x)))

    return features

## Define basic parameters of model for training

In [None]:
class Model:
    def __init__(self, model_name, params_of_model, params_of_learning):
        if model_name == 'DilResNet':
            self.model = DilatedResNet(key = random.PRNGKey(42), **params_of_model)
            self.make_step = make_step_Dil

        elif model_name == 'FNO':
            self.model = FNO(**params_of_model, key=random.PRNGKey(42))
            self.make_step = make_step_FNO

        self.batch_size = params_of_learning['batch_size']
        self.lr = params_of_learning['learning_rate']
        self.count_of_epoch = params_of_learning['epochs']
        self.fine_epoch = params_of_learning['finetuning_epochs']
        self.fine_lr = params_of_learning['finetuning_lr']

    def trainer(self, dataset, plot=True):
        loss = compute_loss(self.model, dataset[0], dataset[1])
        self.history = [loss.item(), ]

        c = dataset[0].shape[0] // self.batch_size
        dict_lr = {50 * c : 0.5, 100 * c : 0.5, 150 * c : 0.5, 200 * c : 0.5, 
                  250 * c : 0.5, 300 * c : 0.5, 350 * c : 0.5, 400 * c : 0.5, 
                  450 * c : 0.5, 500 * c : 0.5, 550 * c : 0.5, 600 * c : 0.5, 
                  650 * c : 0.5, 700 * c : 0.5, 750 * c : 0.5, 800 * c : 0.5, 
                  850 * c : 0.5, 900 * c : 0.5, 950 * c : 0.5, 1000 * c : 0.5}

        sc = optax.piecewise_constant_schedule(self.lr, dict_lr)
        optimizer = optax.experimental.split_real_and_imaginary(optax.adamw(sc, weight_decay=1e-2))
        opt_state = optimizer.init(eqx.filter(self.model, eqx.is_array))

        iterations = tqdm(range(self.count_of_epoch), desc='epoch')
        iterations.set_postfix({'train epoch loss': jnp.nan})
        n_iter = 0

        for it in iterations:
            key = random.PRNGKey(it)
            generator = batch_generator(x=dataset[0], y=dataset[1], batch_size = self.batch_size, key = key, shuffle = True)
            epoch_loss, self.model, opt_state, n_iter = train_on_epoch(train_generator = generator, 
                                                                  model = self.model, 
                                                                  make_step = self.make_step,                                          
                                                                  optimizer = optimizer,
                                                                  opt_state = opt_state,
                                                                  n_iter = n_iter)
            
            iterations.set_postfix({'train epoch loss': epoch_loss})
            self.history.append(jnp.array(epoch_loss).mean())
            
            if plot:
                display.clear_output(wait=True)
                fig = plt.figure(figsize=(10, 5))
                plt.title(r'Loss')
                plt.yscale("log")
                plt.plot(self.history, color='red', label='train')
                plt.legend()
                plt.tight_layout()
                plt.show()



    def fine_trainer(self, dataset_fine, dataset, plot=True):
        dataset = [jnp.concatenate([dataset_fine[0], dataset[0]],axis=0), jnp.concatenate([dataset_fine[1], dataset[1]],axis=0)]
        loss = compute_loss(self.model, dataset[0], dataset[1])
        self.history.append(loss.item())

        optimizer = optax.experimental.split_real_and_imaginary(optax.adamw(self.fine_lr, weight_decay=1e-2))
        opt_state = optimizer.init(eqx.filter(self.model, eqx.is_array))

        iterations = tqdm(range(self.fine_epoch), desc='epoch')
        iterations.set_postfix({'train epoch loss': jnp.nan})
        n_iter = 0

        for it in iterations:
            key = random.PRNGKey(it)
            generator = batch_generator(x=dataset_fine[0], y=dataset_fine[1], batch_size = self.batch_size, key = key, shuffle = True)
            epoch_loss, self.model, opt_state, n_iter = train_on_epoch(train_generator = generator, 
                                                                  model = self.model, 
                                                                  make_step = self.make_step,                                          
                                                                  optimizer = optimizer,
                                                                  opt_state = opt_state,
                                                                  n_iter = n_iter)
            
            iterations.set_postfix({'train epoch loss': epoch_loss})
            loss = compute_loss(self.model, dataset[0], dataset[1])
            self.history.append(loss.item())
            
            if plot:
                display.clear_output(wait=True)
                fig = plt.figure(figsize=(10, 5))
                plt.title(r'Loss')
                plt.yscale("log")
                plt.plot(self.history, color='red', label='train')
                plt.legend()
                plt.tight_layout()
                plt.show()


## Upper bound, Derivatives

In [None]:
@jit
def derivative(a, h):
    '''
    find derivative of a 1D functions given on uniform grid x
    a.shape = (N_x)
    h = grid spacing
    '''
    d_a = (jnp.roll(a, -1, axis=0) - jnp.roll(a, 1, axis=0)) / (2*h)
    d_a = d_a.at[0].set((-3*a[0]/2 + 2*a[1] - a[2]/2)/h) # 1/2	−2	3/2
    d_a = d_a.at[-1].set((a[-3]/2 - 2*a[-2] + 3*a[-1]/2)/h) # 1/2	−2	3/2
    return d_a

In [None]:
learning_rate = 1e-4
optim = optax.adam(learning_rate)

def u_update_step(carry, i):
    y, u, a, b, f, C_f, opt_state = carry
    d_y = grad(upper_bound)(y, u, a, b, f, C_f)
    y_update, opt_state = optim.update(d_y, opt_state, y)
    y = y + y_update
    return [y, u, a, f, C_f, opt_state], upper_bound(y, u, a, b, f, C_f)

def u_optimize(y, u, a, b, f, C_f, opt_state, N_sweeps):
    carry = [y, u, a, b, f, C_f, opt_state]
    i = jnp.arange(N_sweeps)
    carry, loss = scan(u_update_step, carry, i)
    return carry[0], loss

In [None]:
def upper_bound(y, u, a, b, f, C_f, N_points=100):
    dy = derivative(y, 1 / N_points)
    du = derivative(u, 1 / N_points)
    return jnp.sqrt(jnp.trapz((y - a * du)**2 / a, dx=1/N_points)) + C_f * jnp.sqrt(jnp.trapz((f + dy - b * u)**2, dx=1/N_points))

def estimate_upper_bound(model, input, N_sweeps, lr=1e-4, N_points=100):
    output = model(input.reshape(1,3,-1))
    u = output[0, :]
    a, b, f = input[0, :], input[1, :], input[2, :]
    C_f = jnp.max(a, axis=1) / jnp.pi

    optim = optax.adam(lr)

    opt_state = optim.init(a*derivative(u, 1 / N_points))
    y, history = u_optimize(a*derivative(u, 1 / N_points), u, a, b, f, C_f, opt_state, N_sweeps)

    u_bound = upper_bound(y, u, a, b, f, C_f, opt_state, N_sweeps)
    return u_bound

## Fine tuning

In [None]:
def finetuning(iterations, N_new_samples, model_name, params_of_model, params_of_learning, fine_params, N_points=100):
    dataset = train_dataset(fine_params['N_samples_train'])
    val_dataset = validation_dataset(fine_params['N_samples_val'])

    model = Model(model_name, params_of_model, params_of_learning)
    val_loss = []
    
    model.trainer(dataset)
    val_loss.append(compute_loss(model.model, val_dataset[0], val_dataset[1]))

    count = 0
    new_dataset = [jnp.zeros((fine_params['tune_step'], 3, N_points)), jnp.zeros((fine_params['tune_step'], N_points))]

    for it in range(iterations):
        features = generator(random.PRNGKey(it))
        u_bound = estimate_upper_bound(model.model, features, 1000)
        if u_bound < 1e-2:
            continue

        else:
            x = jnp.linspace(0, 1, N_points)
            F = [features[:,0], jnp.zeros_like(x), features[:,1], features[:,2]]
            BCs = [0, 0]
            
            solution = solvers.solve_BVP(N_points, F, BCs)
            new_dataset[0][count] = features
            new_dataset[1][count] = solution

            count += 1

        if count // fine_params['tune_step'] == 0:
            model.fine_trainer(new_dataset, dataset)
            dataset[0] = jnp.concatenate([dataset[0],new_dataset[0]],axis=0)
            dataset[1] = jnp.concatenate([dataset[1],new_dataset[1]],axis=0)
            val_loss.append(compute_loss(model.model, val_dataset[0], val_dataset[1]))

        if dataset[0].shape[0] >= (N_new_samples  + fine_params['N_samples_train']):
            break
    return val_loss

## Models

In [None]:
##############################################
# FNO
##############################################
N_layers_FNO = 5
n_modes = 12
encoder_shapes = [3, 64]
decoder_shapes = [64, 128, 1]
FNO_shapes = [encoder_shapes[-1], ] * N_layers_FNO
spatial_shapes = [n_modes] * N_layers_FNO

model_params_FNO = {
    'encoder_shapes': encoder_shapes,
    'decoder_shapes': decoder_shapes,
    'FNO_shapes': FNO_shapes,
    'spatial_shapes': spatial_shapes,
}

##############################################
# DilResNet
##############################################
channels = [3, 32, 1]
n_cells = 7

model_params_DilResNet = {
    'channels': channels,
    'n_cells': n_cells
}

##############################################
# Params of learning
##############################################
train_params = {
    'batch_size': 32, 
    'learning_rate': 1e-3, 
    'epochs': 50,
    'finetuning_epochs': 10,
    'finetuning_lr': 1e-3
}


##############################################
# Params of finetuning
##############################################
fine_params = {
    'N_samples_train': 50, 
    'N_samples_val': 100, 
    'tune_step': 10,
}

In [None]:
finetuning(200, 100, 'DilResNet', model_params_DilResNet, train_params, fine_params)

50it [00:49,  1.00it/s]
 99%|█████████▉| 99/100 [01:38<00:01,  1.01s/it]