In [None]:
%load_ext autoreload
%autoreload 2 
%reload_ext autoreload
%matplotlib inline

import numpy as np
import scipy.io as io
from pyDOE import lhs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

from complexPyTorch.complexLayers import ComplexLinear

import cplxmodule
from cplxmodule import cplx
from cplxmodule.nn import RealToCplx, CplxToReal, CplxSequential, CplxToCplx
from cplxmodule.nn import CplxLinear, CplxModReLU, CplxAdaptiveModReLU, CplxModulus, CplxAngle

# To access the contents of the parent dir
import sys; sys.path.insert(0, '../')
import os
from scipy.io import loadmat
from utils import *
from models import TorchComplexMLP, ImaginaryDimensionAdder, cplx2tensor, ComplexTorchMLP, complex_mse
from preprocess import *

# Model selection
from sparsereg.model import STRidge
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression, Ridge
from pde_diff import TrainSTRidge, FiniteDiff, print_pde
from RegscorePy.bic import bic

# Fancy optimizers
from lbfgsnew import LBFGSNew
from madgrad import MADGRAD

In [None]:
# torch device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("You're running on", device)

DATA_PATH = '../PDE_FIND_experimental_datasets/harmonic_osc.mat'
data = io.loadmat(DATA_PATH)

t = data['t'].flatten()[:,None]
x = data['x'].flatten()[:,None]
spatial_dim = x.shape[0]
time_dim = t.shape[0]
potential = np.vstack([0.5*np.power(x,2).reshape((1,spatial_dim)) for _ in range(time_dim)])
Exact = data['usol']

# Adjust the diemnsion of Exact and potential (0.5*x**2)
if Exact.shape == (time_dim, spatial_dim): Exact = Exact.T
if potential.shape == (time_dim, spatial_dim): potential = potential.T

Exact_u = np.real(Exact)
Exact_v = np.imag(Exact)

X, T = np.meshgrid(x,t)

X_star = np.hstack((X.flatten()[:,None], T.flatten()[:,None]))
h_star = to_column_vector(Exact)
u_star = to_column_vector(Exact_u)
v_star = to_column_vector(Exact_v)

# Doman bounds
lb = X_star.min(axis=0)
ub = X_star.max(axis=0)

# Converting the grounds to be tensor
X_star = to_tensor(X_star, True)
h_star = to_complex_tensor(h_star, False)

N = 500; include_N_res = 2
idx = np.random.choice(X_star.shape[0], N, replace=False)
# idx = np.arange(N) # Just have an easy dataset for experimenting

lb = to_tensor(lb, False).to(device)
ub = to_tensor(ub, False).to(device)

X_train = to_tensor(X_star[idx, :], True).to(device)
u_train = to_tensor(u_star[idx, :], False).to(device)
v_train = to_tensor(v_star[idx, :], False).to(device)
h_train = torch.complex(u_train, v_train).to(device)

# Unsup data
if include_N_res>0:
    N_res = int(N*include_N_res)
    idx_res = np.array(range(X_star.shape[0]-1))[~idx]
    idx_res = idx_res[:N_res]
    X_res = to_tensor(X_star[idx_res, :], True)
    print(f"Training with {N_res} unsup samples")
    X_train = torch.vstack([X_train, X_res])

# Potential is calculated from x
# Hence, Quadratic features of x are required.
feature_names = ['hf', 'x', 'h_x', 'h_xx', 'h_xxx']

In [None]:
dt = (t[1]-t[0])[0]
dx = (x[2]-x[1])[0]

fd_h_t = np.zeros((spatial_dim, time_dim), dtype=np.complex64)
fd_h_x = np.zeros((spatial_dim, time_dim), dtype=np.complex64)
fd_h_xx = np.zeros((spatial_dim, time_dim), dtype=np.complex64)
fd_h_xxx = np.zeros((spatial_dim, time_dim), dtype=np.complex64)

for i in range(spatial_dim):
    fd_h_t[i,:] = FiniteDiff(Exact[i,:], dt, 1)
for i in range(time_dim):
    fd_h_x[:,i] = FiniteDiff(Exact[:,i], dx, 1)
    fd_h_xx[:,i] = FiniteDiff(Exact[:,i], dx, 2)
    fd_h_xxx[:,i] = FiniteDiff(Exact[:,i], dx, 3)

fd_h_t = to_column_vector(fd_h_t)
fd_h_x = to_column_vector(fd_h_x)
fd_h_xx = to_column_vector(fd_h_xx)
fd_h_xxx = to_column_vector(fd_h_xxx)
V = to_column_vector(potential)

In [None]:
derivatives = cat_numpy(h_star.detach().numpy(), V, fd_h_x, fd_h_xx, fd_h_xxx)
dictionary = {}
for i in range(len(feature_names)): dictionary[feature_names[i]] = get_feature(derivatives, i)

In [None]:
c_poly = ComplexPolynomialFeatures(feature_names, dictionary)
complex_poly_features = c_poly.fit()
complex_poly_features

In [None]:
# This cell is not needed anymore.
# w = TrainSTRidge(complex_poly_features, fd_h_t, 1e-10, 10)
# print("PDE derived using STRidge")
# print_pde(w, c_poly.poly_feature_names)

In [None]:
PRETRAINED_PATH = None

inp_dimension = 2
act = CplxToCplx[torch.tanh]
complex_model = CplxSequential(
                            CplxLinear(100, 100, bias=True),
                            act(),
                            CplxLinear(100, 100, bias=True),
                            act(),
                            CplxLinear(100, 100, bias=True),
                            act(),
                            CplxLinear(100, 100, bias=True),
                            act(),
                            CplxLinear(100, 1, bias=True),
                            )

complex_model = torch.nn.Sequential(
                                    torch.nn.Linear(inp_dimension, 200),
                                    RealToCplx(),
                                    complex_model
                                    )

if PRETRAINED_PATH is not None: complex_model.load_state_dict(cpu_load(PRETRAINED_PATH))

In [None]:
class ComplexNetwork(nn.Module):
    def __init__(self, model, index2features=None, scale=False, lb=None, ub=None):
        super(ComplexNetwork, self).__init__()
        # pls init the self.model before
        self.model = model
        # For tracking, the default tup is for the burgers' equation.
        self.index2features = index2features
        print("Considering", self.index2features)
        self.diff_flag = diff_flag(self.index2features)
        self.uf = None
        self.scale = scale
        self.lb, self.ub = lb, ub
        
    def xavier_init(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    def forward(self, x, t):
        if not self.scale: self.uf = self.model(torch.cat([x, t], dim=-1))
        else: self.uf = self.model(self.neural_net_scale(torch.cat([x, t], dim=-1)))
        return self.uf
    
    def get_selector_data(self, x, t):
        uf = self.forward(x, t)
        u_t = complex_diff(uf, t)
        
        ### PDE Loss calculation ###
        # Without calling grad
        derivatives = []
        for t in self.diff_flag[0]:
            if t=='hf': 
                derivatives.append(cplx2tensor(uf))
            elif t=='x': derivatives.append(0.5*torch.pow(x,2))
        # With calling grad
        for t in self.diff_flag[1]:
            out = uf
            for c in t:
                if c=='x': out = complex_diff(out, x)
                elif c=='t': out = complex_diff(out, t)
            derivatives.append(out)
        
        return torch.cat(derivatives, dim=-1), u_t
    
    def neural_net_scale(self, inp):
        return 2*(inp-self.lb)/(self.ub-self.lb)-1

In [None]:
class ComplexAttentionSelectorNetwork(nn.Module):
    def __init__(self, layers, prob_activation=torch.sigmoid, bn=None, reg_intensity=0.1):
        super(ComplexAttentionSelectorNetwork, self).__init__()
        # Nonlinear model, Training with PDE reg.
        assert len(layers) > 1
        self.linear1 = CplxLinear(layers[0], layers[0], bias=True)
        self.prob_activation = prob_activation
        self.nonlinear_model = ComplexTorchMLP(dimensions=layers, activation_function=CplxToCplx[F.relu](), bn=bn, dropout_rate=0.0)
        self.latest_weighted_features = None
        self.th = 0.1
        self.reg_intensity = reg_intensity
        
    def xavier_init(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
        
    def forward(self, inn):
        feature_importances = self.weighted_features(inn)
        inn = inn*feature_importances
        return self.nonlinear_model(inn)
    
    def weighted_features(self, inn):
        self.latest_weighted_features = self.prob_activation(cplx2tensor(self.linear1(inn)).abs())
        self.latest_weighted_features = self.latest_weighted_features.mean(dim=0)
        return self.latest_weighted_features
    
    def loss(self, X_input, y_input):
        ut_approx = self.forward(X_input)
        mse_loss = complex_mse(ut_approx, y_input)
        reg_term = F.relu(self.latest_weighted_features-self.th)
        return mse_loss+self.reg_intensity*(torch.norm(reg_term, p=0)+(torch.tensor([1.0, 1.0, 2.0, 3.0, 4.0])*reg_term).sum())

# Only the SemiSupModel has changed to work with the finite difference guidance
class SemiSupModel(nn.Module):
    def __init__(self, network, selector, normalize_derivative_features=False, mini=None, maxi=None, uncert=False):
        super(SemiSupModel, self).__init__()
        self.network = network
        self.selector = selector
        self.normalize_derivative_features = normalize_derivative_features
        self.mini = mini
        self.maxi = maxi
        self.weights = None
        if uncert: 
            self.weights = torch.tensor([0.0, 0.0])
        
    def forward(self, X_h_train, h_train, include_unsup=True):
        X_selector, y_selector = self.network.get_selector_data(*dimension_slicing(X_h_train))
        
        h_row = h_train.shape[0]
        fd_guidance = complex_mse(self.network.uf[:h_row, :], h_train)
        
        # I am not sure a good way to normalize/scale a complex tensor
        if self.normalize_derivative_features:
            X_selector = (X_selector-self.mini)/(self.maxi-self.mini)
        
        if include_unsup: unsup_loss = self.selector.loss(X_selector, y_selector)
        else: unsup_loss = None
            
        if include_unsup and self.weights is not None:
            return (torch.exp(-self.weights[0])*fd_guidance)+self.weights[0], (torch.exp(-self.weights[1])*unsup_loss)+self.weights[1]
        else:
            return fd_guidance, unsup_loss

In [None]:
lets_pretrain = True

semisup_model = SemiSupModel(
    network=ComplexNetwork(model=complex_model, index2features=feature_names, scale=False, lb=lb, ub=ub),
    selector=ComplexAttentionSelectorNetwork([len(feature_names), 50, 50, 1], prob_activation=F.softmax, bn=True),
    normalize_derivative_features=True,
    mini=torch.tensor(np.abs(derivatives).min(axis=0), dtype=torch.cfloat), # does not matter, will be replaced
    maxi=torch.tensor(np.abs(derivatives).max(axis=0), dtype=torch.cfloat), # does not matter, will be replaced
    uncert=False,
)

#### Pretraining the solver network

In [None]:
if lets_pretrain:
    def pretraining_closure():
        global N, X_h_train, h_train
        if torch.enable_grad(): pretraining_optimizer.zero_grad()
        # Only focusing on first [:N, :] elements
        mse_loss = complex_mse(semisup_model.network(*dimension_slicing(X_train[:N, :])), h_train[:N, :])
        if mse_loss.requires_grad: mse_loss.backward(retain_graph=False)
        return mse_loss
    
    print("Pretraining")
    pretraining_optimizer = LBFGSNew(semisup_model.network.parameters(),
                                     lr=1e-1, max_iter=300,
                                     max_eval=int(300*1.25), history_size=150,
                                     line_search_fn=True, batch_mode=False)

    semisup_model.network.train()    
    for i in range(120):
        pretraining_optimizer.step(pretraining_closure)
            
        if (i%10)==0:
            l = pretraining_closure()
            curr_loss = l.item()
            print("Epoch {}: ".format(i), curr_loss)

            # See how well the model perform on the test set
            semisup_model.network.eval()
            test_performance = complex_mse(semisup_model.network(*dimension_slicing(X_star)).detach(), h_star).item()
            string_test_performance = scientific2string(test_performance)
            print('Test MSE:', string_test_performance)
    
    print("Computing derivatives features")
    semisup_model.eval()
    referenced_derivatives, u_t = semisup_model.network.get_selector_data(*dimension_slicing(X_train))
    semisup_model.mini = torch.min(referenced_derivatives, axis=0)[0].detach().requires_grad_(False)
    semisup_model.maxi = torch.max(referenced_derivatives, axis=0)[0].detach().requires_grad_(False)

#     semisup_model.mini = tmp.min(axis=0)[0].requires_grad_(False)
#     semisup_model.maxi = tmp.max(axis=0)[0].requires_grad_(False)