In [None]:
import numpy as np
import matplotlib 
import matplotlib.pyplot as plt

# Latex preamble
# matplotlib.rc('text', usetex=True)
# matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

import math 

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import train_test_split

device = torch.device("cpu")#torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
class Approximator(nn.Module):
    
    def __init__(self, nlatent=20, ndim=2):
        super(Approximator, self).__init__()
        
        self.n_repeats = 20
        self.h  = 1/self.n_repeats
        
        self.ndim = ndim
        self.nlatent = nlatent
        
        self.f0 = lambda x : x
        
        
        self.regressor = nn.Sequential(nn.Linear(self.ndim, self.nlatent),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(self.nlatent, self.nlatent),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(self.nlatent, self.ndim)) 
    
    def r_n_repeats(self, n):
        
        self.n_repeats = n
        self.h  = 1/self.n_repeats
    
    def forward(self, x):
        
        x = self.f0(x)
        
        for _ in range(self.n_repeats):
            
            flux = self.regressor(x)

            x = x + self.h*flux
            
        output = x
        return output

    
def project_con(model, gamma = 0.9):
    
    if gamma == np.inf:
        
        return
    
    W_norm, hW_norm, W_dist = find_nn_norm(model)
    
    if hW_norm < gamma:
        
        return
        
    param_dict = model.state_dict()

    i = 0 

    for p in param_dict.keys():

        if 'weight' in p:
            
            c = np.power(W_norm*model.h,W_dist[i])
            cg = np.power(gamma,W_dist[i])
            
            i += 1

            W = param_dict[p]

            param_dict[p] = cg*W/c

    model.load_state_dict(param_dict)
    
    return

def find_nn_norm(f_nn):

    param_dict = f_nn.state_dict()

    W_norm = 1.0
    
    w_dist = []

    for p in param_dict.keys():

        if 'weight' in p:

            W = param_dict[p]

            pnorm = torch.norm(W)
            
            w_dist.append(np.log(pnorm))

            W_norm = W_norm*pnorm
            
    w_dist = w_dist/np.log(W_norm.detach().numpy())
    
    return W_norm, W_norm*f_nn.h, w_dist


def train_func_approx(model,a_func,
                      LR = 1e-2,
                      MAX_EPOCH = 20,
                      BATCH_SIZE = 512,
                      range_lr = [0.,1.],
                      gamma = 0.9):

    X = np.random.rand(10**5,model.ndim)*(range_lr[1]-range_lr[0])+range_lr[0]
    y = a_func(X)

    X_train, X_val, y_train, y_val = map(torch.tensor, 
                                         train_test_split(X, y, 
                                                          test_size=0.2))

    train_dataloader = DataLoader(TensorDataset(X_train.unsqueeze(1), 
                                                y_train.unsqueeze(1)), 
                                  batch_size=BATCH_SIZE,
                                  pin_memory=True, shuffle=True)

    val_dataloader = DataLoader(TensorDataset(X_val.unsqueeze(1), 
                                              y_val.unsqueeze(1)), 
                                batch_size=BATCH_SIZE,
                                pin_memory=True, shuffle=True)

    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.MSELoss(reduction="mean")
    
    project_con(model,gamma = gamma)
    
 
    train_loss_list = list()
    val_loss_list = list()
    for epoch in range(MAX_EPOCH):
        print("epoch %d / %d" % (epoch+1, MAX_EPOCH))
        model.train()
        
        temp_loss_list = list()
        for X_train, y_train in train_dataloader:
            X_train = X_train.type(torch.float32).to(device)
            y_train = y_train.type(torch.float32).to(device)
            
            
            optimizer.zero_grad()

            score = model(X_train)
            
            loss = criterion(input=score, target=y_train)
            loss.backward()

            optimizer.step()
            

            temp_loss_list.append(loss.detach().cpu().numpy())

        # project constraint here (right now every epoch)
        project_con(model,gamma = gamma)
        
        temp_loss_list = list()
        for X_train, y_train in train_dataloader:
            X_train = X_train.type(torch.float32).to(device)
            y_train = y_train.type(torch.float32).to(device)

            score = model(X_train)
            loss = criterion(input=score, target=y_train)

            temp_loss_list.append(loss.detach().cpu().numpy())

        train_loss_list.append(np.average(temp_loss_list))

        # validation
        model.eval()

        temp_loss_list = list()
        for X_val, y_val in val_dataloader:
            X_val = X_val.type(torch.float32).to(device)
            y_val = y_val.type(torch.float32).to(device)

            score = model(X_val)
            loss = criterion(input=score, target=y_val)

            temp_loss_list.append(loss.detach().cpu().numpy())

        val_loss_list.append(np.average(temp_loss_list))
        
        W, hW, W_dist = find_nn_norm(model)

        print("\ttrain loss: %.5f" % train_loss_list[-1])
        print("\tval loss: %.5f" % val_loss_list[-1])
        print("\tflux norm h: %.5f" % hW)
        print("\tflux distribution:", W_dist)

$x^2$ example

In [None]:
f_x2 = lambda x : x**2

f_nn_x2 = Approximator(nlatent=20, ndim=2).to(device)

# doesn't do anything
# f_nn_x2.nlatent = 20
# f_nn_x2.ndim = 2

f_nn_x2.f0 = lambda x : x
f_nn_x2.r_n_repeats(100)  # set number of repeats



print('discretisation info : ')
print('\nn repeats : %i' % f_nn_x2.n_repeats, 
      '\nh         : %4f' % f_nn_x2.h)

In [None]:
print('training f_nn: \n')

train_func_approx(f_nn_x2,f_x2,
                  range_lr = [-1.,1.],
                  MAX_EPOCH = 10,
                  LR = 1e-3, 
                  gamma = 0.95)

In [None]:
plt.plot(x_grid,f_x2(xx_grid)[:,0], c = 'black')
plt.plot(x_grid,f_nn_x2(xx_grid).detach().numpy()[:,0])

Example 1: $g = x^3, f_0 = x$ 

In [None]:
f_x3 = lambda x : x**3

f_nn_x3 = Approximator().to(device)

f_nn_x3.nlatent = 20
f_nn_x3.ndim = 2

x_grid = torch.reshape(torch.linspace(-1,1,50),[50,1])
xx_grid = torch.cat([x_grid for _ in range(f_nn_x3.ndim)],1)

f_nn_x3.f0 = lambda x : x
f_nn_x3.r_n_repeats(100)  # number of repeats of ResNet block

print('discretisation info : ')
print('\nn repeats : %i' % f_nn_x3.n_repeats, 
      '\nh         : %4f' % f_nn_x3.h)

In [None]:
print('training f_nn: \n')

train_func_approx(f_nn_x3,f_x3,
                  range_lr = [-1.,1.],
                  MAX_EPOCH = 10,
                  LR = 1e-3, 
                  gamma = 0.95)
print("Training done!")

In [None]:
# matplotlib.rc('text', usetex=True)
# matplotlib.rcParams['text.latex.preamble'] = [r'\boldmath', r'\amsmath']

def plot_precond_func_approx(x, y_f, y_f_nn, y_f0, file_name, precon_legend='$R^{\mathrm{pre}}(v)$ (preconditioner)'): 
    # compute error as l1 distance
    error = np.abs(y_f - y_f_nn) 

    plt.figure(figsize=(6, 6))
    plt.axhline(y=0., color='black', linestyle='--', linewidth=.5)
    plt.plot(x, y_f, c = 'black', label = r'$w(v) = v^2$ (ground-truth)')
    plt.plot(x, y_f_nn, c = 'blue', label = r'$R(v)$ (approximation)')   # TODO fix \boldsymbol{ for theta
    plt.plot(x, y_f0, '-.', c = 'green', label = r'' + precon_legend)
    # plt.plot(x, error, '--', c = 'red', linewidth=1, label = r'$|y - f_{\theta}|$ ($L^1$-error)')  # TODO fix \boldsymbol{ for theta
    # show legend
    plt.legend(loc='lower right', fontsize=15)
    plt.xlim([-1,1])
    plt.ylim([-1,1])
    # set x ticks at -1 and 1
    plt.xticks([-1,0,1])
    # set y ticks at -1 and 1
    plt.yticks([-1,0,1])

    plt.xlabel(r'$v$', fontsize=15)
    plt.ylabel(r'$w$', fontsize=15)

    # plt.plot(x_grid,f_x3(xx_grid)[:,0], c = 'black', label = '$f(x)$')
    # plt.plot(x_grid,f_nn_x3(xx_grid).detach().numpy()[:,0], c = 'blue', label = '$f_{\boldsymbol{\theta}}$')
    # plt.
    # x title in latex format
    # plt.x_title('$\mathbb{X}$')
    # plt.xtitle('$g(\mathbb{X})$')   # TODO in different colors

    # save pdf with high resolution
    plt.savefig('Figs_output/{}.pdf'.format(file_name), bbox_inches='tight', dpi=1000)
    plt.show()

In [None]:
x = x_grid
y_f = f_x3(xx_grid)[:,0]
y_f_nn = f_nn_x3(xx_grid).detach().numpy()[:,0]
y_f0 = f_nn_x3.f0(x_grid).detach().numpy()[:,0]
file_name = 'g_x^3_f0_x'

plot_precond_func_approx(x=x, y_f=y_f, y_f_nn=y_f_nn, y_f0=y_f0, file_name=file_name)

Example 2: $g = x^3, f_0 = \vert x \vert$ 

In [None]:
f_x3 = lambda x : x**3

f_nn_x3 = Approximator().to(device)

f_nn_x3.nlatent = 20
f_nn_x3.ndim = 2

x_grid = torch.reshape(torch.linspace(-1,1,50),[50,1])
xx_grid = torch.cat([x_grid for _ in range(f_nn_x3.ndim)],1)

f_nn_x3.f0 = lambda x : np.abs(x)
f_nn_x3.r_n_repeats(100)  # number of repeats of ResNet block

print('discretisation info : ')
print('\nn repeats : %i' % f_nn_x3.n_repeats, 
      '\nh         : %4f' % f_nn_x3.h)

print('training f_nn: \n')

train_func_approx(f_nn_x3,f_x3,
                  range_lr = [-1.,1.],
                  MAX_EPOCH = 10,
                  LR = 1e-3, 
                  gamma = 0.95)
print("Training done!")

In [None]:
x = x_grid
y_f = f_x3(xx_grid)[:,0]
y_f_nn = f_nn_x3(xx_grid).detach().numpy()[:,0]
y_f0 = f_nn_x3.f0(x_grid).detach().numpy()[:,0]
file_name = 'g_x^3_f0_abs(x)'

plot_precond_func_approx(x=x, y_f=y_f, y_f_nn=y_f_nn, y_f0=y_f0, file_name=file_name)

Example 3: $g = x^2, f_0 = x$ 

In [None]:
f_x3 = lambda x : x**2

f_nn_x3 = Approximator().to(device)

f_nn_x3.nlatent = 20
f_nn_x3.ndim = 2

x_grid = torch.reshape(torch.linspace(-1,1,50),[50,1])
xx_grid = torch.cat([x_grid for _ in range(f_nn_x3.ndim)],1)

f_nn_x3.f0 = lambda x : x
f_nn_x3.r_n_repeats(100)  # number of repeats of ResNet block

print('discretisation info : ')
print('\nn repeats : %i' % f_nn_x3.n_repeats, 
      '\nh         : %4f' % f_nn_x3.h)

print('training f_nn: \n')

train_func_approx(f_nn_x3,f_x3,
                  range_lr = [-1.,1.],
                  MAX_EPOCH = 10,
                  LR = 1e-3, 
                  gamma = 0.95)
print("Training done!")

In [None]:
x = x_grid
y_f = f_x3(xx_grid)[:,0]
y_f_nn = f_nn_x3(xx_grid).detach().numpy()[:,0]
y_f0 = f_nn_x3.f0(x_grid).detach().numpy()[:,0]
file_name = 'g_x^2_f0_x'

plot_precond_func_approx(x=x, y_f=y_f, y_f_nn=y_f_nn, y_f0=y_f0, file_name=file_name, precon_legend="$R^{\mathrm{pre}}(v) = v$ (preconditioner)")

Example 4: $g = x^2, f_0 = \vert x \vert$ 

In [None]:
f_x3 = lambda x : x**2

f_nn_x3 = Approximator().to(device)

f_nn_x3.nlatent = 20
f_nn_x3.ndim = 2

x_grid = torch.reshape(torch.linspace(-1,1,50),[50,1])
xx_grid = torch.cat([x_grid for _ in range(f_nn_x3.ndim)],1)

f_nn_x3.f0 = lambda x : np.abs(x)
f_nn_x3.r_n_repeats(100)  # number of repeats of ResNet block

print('discretisation info : ')
print('\nn repeats : %i' % f_nn_x3.n_repeats, 
      '\nh         : %4f' % f_nn_x3.h)

print('training f_nn: \n')

train_func_approx(f_nn_x3,f_x3,
                  range_lr = [-1.,1.],
                  MAX_EPOCH = 10,
                  LR = 1e-3, 
                  gamma = 0.95)
print("Training done!")

In [None]:
x = x_grid
y_f = f_x3(xx_grid)[:,0]
y_f_nn = f_nn_x3(xx_grid).detach().numpy()[:,0]
y_f0 = f_nn_x3.f0(x_grid).detach().numpy()[:,0]
file_name = 'g_x^2_f0_abs(x)'

plot_precond_func_approx(x=x, y_f=y_f, y_f_nn=y_f_nn, y_f0=y_f0, file_name=file_name, precon_legend="$R^{\mathrm{pre}}(v) = \mid v \mid$ (preconditioner)")

$x^2$ example with $|x|$ preconditioning 

In [None]:
f_x2 = lambda x : x**2

f_nn_x2 = Approximator().to(device)

f_nn_x2.nlatent = 20
f_nn_x2.ndim = 2

f_nn_x2.f0 = lambda x : np.abs(x)
f_nn_x2.r_n_repeats(100)

print('discretisation info : ')
print('\nn repeats : %i' % f_nn_x3.n_repeats, 
      '\nh         : %4f' % f_nn_x3.h)

In [None]:
print('training f_nn: \n')

train_func_approx(f_nn_x2,f_x2,
                  range_lr = [-1.,1.],
                  MAX_EPOCH = 10,
                  LR = 1e-3, 
                  gamma = 0.95)
print("Training done!")

In [None]:
plt.plot(x_grid,f_x2(xx_grid)[:,0], c = 'black')
plt.plot(x_grid,f_nn_x2(xx_grid).detach().numpy()[:,0])

$x^3$ example with $|x|$ preconditioning 

In [None]:
f_x3 = lambda x : x**3

f_nn_x3 = Approximator().to(device)

f_nn_x3.nlatent = 20
f_nn_x3.ndim = 2

f_nn_x3.f0 = lambda x : np.abs(x)
f_nn_x3.r_n_repeats(100)

print('discretisation info : ')
print('\nn repeats : %i' % f_nn_x3.n_repeats, 
      '\nh         : %4f' % f_nn_x3.h)

In [None]:
print('training f_nn: \n')

train_func_approx(f_nn_x3,f_x3,
                  range_lr = [-1.,1.],
                  MAX_EPOCH = 10,
                  LR = 1e-3, 
                  gamma = 0.95)
print("Training done!")

In [None]:
plt.plot(x_grid,f_x3(xx_grid)[:,0], c = 'black')
plt.plot(x_grid,f_nn_x3(xx_grid).detach().numpy()[:,0])

# old crap

In [None]:
f = lambda x : x**3

x_grid = torch.reshape(torch.linspace(-1,1,50),[50,1])

plt.plot(x_grid,x_grid)
plt.plot(x_grid,f(x_grid))


In [None]:
f_nn = Approximator().to(device)

f_nn.nlatent = 20
f_nn.ndim = 2

f_nn.f0 = lambda x : x**5

f_nn.r_n_repeats(100)

print('discretisation info : ')
print('\nn repeats : %i' % f_nn.n_repeats, 
      '\nh         : %4f' % f_nn.h)

In [None]:
print('training f_nn: \n')

train_func_approx(f_nn,f,
                  range_lr = [-1.,1.],
                  MAX_EPOCH = 10,
                  LR = 1e-3, 
                  gamma = 0.95)

In [None]:
plt.plot(x_grid,f(xx_grid)[:,0], c = 'black')
plt.plot(x_grid,f_nn(xx_grid).detach().numpy()[:,0])

In [None]:
step = lambda x : f_nn.regressor(x) + x

xx_grid = torch.cat([x_grid for _ in range(f_nn.ndim)],1)

x = xx_grid

fig,ax = plt.subplots(3,1,figsize = (20,20))


ax[0].plot(x_grid,f(xx_grid)[:,0], c = 'black')


for i in range(f_nn.ndim):

    ax[0].plot(x_grid,f_nn(xx_grid).detach().numpy()[:,0])

    ax[0].plot(x_grid,f_nn.regressor(xx_grid).detach().numpy()[:,0])


ax[1].plot(x_grid,f(xx_grid)[:,0], c = 'black')

for _ in range(f_nn.n_repeats):
    
    flux = f_nn.regressor(x)
    
    x = flux + x
    
    ax[1].plot(x_grid, x.detach().numpy()[:,0])

    ax[2].plot(x_grid, flux.detach().numpy()[:,0])


plt.show(fig)    

In [None]:
plt.plot(x_grid,f(xx_grid)[:,0], c = 'black')
plt.plot(x_grid,f_nn(xx_grid).detach().numpy()[:,0])

In [None]:
plt.plot(x_grid,f(xx_grid)[:,0], c = 'black')
plt.plot(x_grid,f_nn(xx_grid).detach().numpy()[:,0])

In [None]:
x0 = torch.ones_like(xx_grid[0])*0.5

y = x0

Y = [x0]

for _ in range(f_nn.n_repeats):
    
    Y.append(y + f_nn.h*f_nn.regressor(y))
    
    y = Y[-1]
    
print(y, f_nn(x0))

Z = [y]

z = y

for _ in range(f_nn.n_repeats):
    
    y = Z[-1]

    for i in range(10):

        z = y - f_nn.h*f_nn.regressor(z)

    Z.append(z)
    
Y.reverse() 
    
print(torch.cat((torch.stack(Y),torch.stack(Z)),axis = 1 ).detach())

print('\n initial error : %20f' % torch.norm(z-x0))

In [None]:
def res_inverse(y,flux, n = 2, it = 5):

    h = 1/n
    
    z = [y for _ in range(n+1)]
    
    for _ in range(it):
    
        for i in reversed(range(n)):

            z[i] = z[i+1] - h*flux(z[i])
             
    return z[0]

recon = res_inverse(f_nn(xx_grid),f_nn.regressor, n = f_nn.n_repeats, it = 5)
recon = recon.detach()

fig,ax = plt.subplots(f_nn.ndim + 1,1,figsize = (20,20))

for i in range(f_nn.ndim):

    ax[i].plot(x_grid,x_grid)

    ax[i].plot(x_grid,recon[:,i])
    
error = torch.norm( xx_grid - recon, dim = 1).detach()

ax[-1].plot(error )

plt.show(fig)