In [1]:
from multitaskpinn.utils.logger import Logger
from multitaskpinn.training.convergence import Convergence


# %% General imports
import numpy as np
import torch

# DeepMoD stuff
from multitaskpinn import DeepMoD
from multitaskpinn.model.func_approx import NN
from multitaskpinn.model.library import Library1D
from multitaskpinn.model.constraint import LeastSquares
from multitaskpinn.model.sparse_estimators import Threshold
from multitaskpinn.training.sparsity_scheduler import TrainTestPeriodic
from multitaskpinn.training import train_multitask

from multitaskpinn.data import Dataset
from multitaskpinn.data.kdv import DoubleSoliton

%load_ext autoreload
%autoreload 2

In [2]:
# Settings
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device='cpu'
# Settings for reproducibility
np.random.seed(42)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [92]:
def train_mixed_multitask(model: DeepMoD,
          data: torch.Tensor,
          target: torch.Tensor,
          optimizer,
          sparsity_scheduler,
          split = 0.8,
          exp_ID: str = None,
          log_dir: str = None,
          max_iterations: int = 10000,
          write_iterations: int = 25,
          **convergence_kwargs) -> None:
    """Trains the DeepMoD model. This function automatically splits the data set in a train and test set. 

    Args:
        model (DeepMoD):  A DeepMoD object.
        data (torch.Tensor):  Tensor of shape (n_samples x (n_spatial + 1)) containing the coordinates, first column should be the time coordinate.
        target (torch.Tensor): Tensor of shape (n_samples x n_features) containing the target data.
        optimizer ([type]):  Pytorch optimizer.
        sparsity_scheduler ([type]):  Decides when to update the sparsity mask.
        split (float, optional):  Fraction of the train set, by default 0.8.
        exp_ID (str, optional): Unique ID to identify tensorboard file. Not used if log_dir is given, see pytorch documentation.
        log_dir (str, optional): Directory where tensorboard file is written, by default None.
        max_iterations (int, optional): [description]. Max number of epochs , by default 10000.
        write_iterations (int, optional): [description]. Sets how often data is written to tensorboard and checks train loss , by default 25.
    """
    logger = Logger(exp_ID, log_dir)
    sparsity_scheduler.path = logger.log_dir # write checkpoint to same folder as tb output.
    
    # Splitting data, assumes data is already randomized
    n_train = int(split * data.shape[0])
    n_test = data.shape[0] - n_train
    data_train, data_test = torch.split(data, [n_train, n_test], dim=0)
    target_train, target_test = torch.split(target, [n_train, n_test], dim=0)
    
    n_samples = n_train
    # Training
    convergence = Convergence(**convergence_kwargs)
    for iteration in torch.arange(0, max_iterations):
        # ================== Training Model ============================
        prediction, time_derivs, thetas = model(data_train)


        MSE = torch.mean((prediction - target_train)**2, dim=0)  # loss per output
        Reg = torch.stack([torch.mean((dt - theta @ coeff_vector)**2)
                           for dt, theta, coeff_vector in zip(time_derivs, thetas, model.constraint_coeffs(scaled=False, sparse=True))])
        
        #loss = torch.sum(torch.log(prediction - target_train + 1e-8))
        #loss = torch.sum(torch.sum(torch.log((prediction - target_train)**2), dim=0) + torch.stack([torch.sum(torch.log((dt - theta @ coeff_vector)**2))
                           #for dt, theta, coeff_vector in zip(time_derivs, thetas, model.constraint_coeffs(scaled=False, sparse=True))]))
        loss = torch.sum(torch.log(MSE) + torch.log(Reg))
        loss = torch.sum(torch.log(MSE))
        #print(torch.log(MSE), torch.log(Reg), torch.log(10**8 * Reg * MSE))
        # Optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if iteration % write_iterations == 0:
            # ================== Validation costs ================
            with torch.no_grad():
                prediction_test = model.func_approx(data_test)[0]
                MSE_test = torch.mean((prediction_test - target_test)**2, dim=0)  # loss per output
         
            # ====================== Logging =======================
            _ = model.sparse_estimator(thetas, time_derivs) # calculating estimator coeffs but not setting mask
            logger(iteration, 
                   loss, MSE, Reg,
                   model.constraint_coeffs(sparse=True, scaled=True), 
                   model.constraint_coeffs(sparse=True, scaled=False),
                   model.estimator_coeffs(),
                   MSE_test=MSE_test)

            # ================== Sparsity update =============
            # Updating sparsity 
            update_sparsity = sparsity_scheduler(iteration, torch.sum(MSE_test), model, optimizer)
            if update_sparsity: 
                model.constraint.sparsity_masks = model.sparse_estimator(thetas, time_derivs)

            # ================= Checking convergence
            l1_norm = torch.sum(torch.abs(torch.cat(model.constraint_coeffs(sparse=True, scaled=True), dim=1)))
            converged = convergence(iteration, l1_norm)
            if converged:
                break
    logger.close(model)

In [93]:
# %% making data

x = np.linspace(-10, 10, 50)
t = np.linspace(0.0, 2.0, 20)
x_grid, t_grid = np.meshgrid(x, t, indexing='ij')

dataset = Dataset(DoubleSoliton, c=[5.0, 1.0], x0=[-5.0, -1.0])
X, y = dataset.create_dataset(x_grid.reshape(-1, 1), t_grid.reshape(-1, 1), n_samples=0, noise=0.20, random=True, normalize=False)
X, y = X.to(device), y.to(device)


In [94]:
sparsity_scheduler = TrainTestPeriodic(periodicity=25, patience=1e8, delta=1e-5) # in terms of write iterations
network = NN(2, [30, 30, 30, 30, 30], 1)
library = Library1D(poly_order=2, diff_order=3) # Library function
estimator = Threshold(0.1) # Sparse estimator 
constraint = LeastSquares() # How to constrain
model = DeepMoD(network, library, estimator, constraint, 1).to(device) # Putting it all in the model
optimizer = torch.optim.Adam(model.parameters(), betas=(0.99, 0.99), amsgrad=True, lr=2e-3) # Defining optimizer

In [95]:
model.constraint.sparsity_masks = [torch.zeros(12, dtype=torch.bool).to(device)]
model.constraint.sparsity_masks[0][[3, 5]] = 1

In [96]:
train_mixed_multitask(model, X, y, optimizer, sparsity_scheduler, exp_ID=f'mixed', split=0.8, write_iterations=50, max_iterations=100000, delta=1e-8, patience=200)

   300  MSE: 2.91e-01  Reg: 5.28e-11  L1: 1.02e+00 

KeyboardInterrupt: 

In [106]:
def train_mixed_multitask(model: DeepMoD,
          data: torch.Tensor,
          target: torch.Tensor,
          optimizer,
          sparsity_scheduler,
          split = 0.8,
          exp_ID: str = None,
          log_dir: str = None,
          max_iterations: int = 10000,
          write_iterations: int = 25,
          **convergence_kwargs) -> None:
    """Trains the DeepMoD model. This function automatically splits the data set in a train and test set. 

    Args:
        model (DeepMoD):  A DeepMoD object.
        data (torch.Tensor):  Tensor of shape (n_samples x (n_spatial + 1)) containing the coordinates, first column should be the time coordinate.
        target (torch.Tensor): Tensor of shape (n_samples x n_features) containing the target data.
        optimizer ([type]):  Pytorch optimizer.
        sparsity_scheduler ([type]):  Decides when to update the sparsity mask.
        split (float, optional):  Fraction of the train set, by default 0.8.
        exp_ID (str, optional): Unique ID to identify tensorboard file. Not used if log_dir is given, see pytorch documentation.
        log_dir (str, optional): Directory where tensorboard file is written, by default None.
        max_iterations (int, optional): [description]. Max number of epochs , by default 10000.
        write_iterations (int, optional): [description]. Sets how often data is written to tensorboard and checks train loss , by default 25.
    """
    logger = Logger(exp_ID, log_dir)
    sparsity_scheduler.path = logger.log_dir # write checkpoint to same folder as tb output.
    
    # Splitting data, assumes data is already randomized
    n_train = int(split * data.shape[0])
    n_test = data.shape[0] - n_train
    data_train, data_test = torch.split(data, [n_train, n_test], dim=0)
    target_train, target_test = torch.split(target, [n_train, n_test], dim=0)

    n_samples = data_train.shape[0]
    model.t.data = -torch.var(target)
    model.b.data = -torch.var(target)
    model.a.data = torch.tensor(0.0)
    
    threshold = 1e4
    # Training
    convergence = Convergence(**convergence_kwargs)
    for iteration in torch.arange(0, max_iterations):
        # ================== Training Model ============================
        prediction, time_derivs, thetas = model(data_train)

        tau_ = torch.exp(model.t).clamp(max=1e8) # we train the log of these things since they're very big
        beta_ = torch.exp(model.b).clamp(max=1e8) # we cap alpha and beta to prevent overflow
        #m = torch.tanh(model.a) #torch.exp(model.a).clamp(max=1e8)
        m = torch.tensor(10**9)
        mse = prediction - target_train
        reg = time_derivs[0] - thetas[0] @ model.constraint_coeffs(scaled=False, sparse=True)[0]
        
        n_samples = mse.shape[0]
        #loss = (tau_ / (1 - m**2) * mse.T @ mse 
        #        + beta_ / (1- m**2) * reg.T @ reg 
        #        - 2 * m * torch.sqrt(beta_ * tau_) * mse.T @ reg
        #        - n_samples * torch.log(tau_ * beta_ / (1 - m**2)))
        loss = (tau_ * mse.T @ mse 
                + beta_  * reg.T @ reg 
                + m * torch.sqrt(mse.T @ mse * reg.T @ reg)
                - n_samples * torch.log(tau_ * beta_))
        
        #loss = (tau_ * mse.T @ mse 
        #        + beta_ * reg.T @ reg 
        #        - n_samples * torch.log(tau_ * beta_))
        
        MSE = torch.mean((prediction - target_train)**2, dim=0)  # loss per output
        Reg = torch.stack([torch.mean((dt - theta @ coeff_vector)**2)
                           for dt, theta, coeff_vector in zip(time_derivs, thetas, model.constraint_coeffs(scaled=False, sparse=True))])
        
        #p_MSE = n_samples * (tau_ * MSE - torch.log(tau_))
        #p_reg = n_samples * (beta_ * Reg - torch.log(beta_))
        #loss = torch.sum(p_MSE + p_reg)

        # Optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if iteration % write_iterations == 0:
            # ================== Validation costs ================
            with torch.no_grad():
                prediction_test = model.func_approx(data_test)[0]
                MSE_test = torch.mean((prediction_test - target_test)**2, dim=0)  # loss per output
         
            # ====================== Logging =======================
            _ = model.sparse_estimator(thetas, time_derivs) # calculating estimator coeffs but not setting mask
            logger(iteration, 
                   loss, MSE, Reg,
                   model.constraint_coeffs(sparse=True, scaled=True), 
                   model.constraint_coeffs(sparse=True, scaled=False),
                   model.estimator_coeffs(),
                   MSE_test=MSE_test,
                   tau=tau_,
                   beta_=beta_,
                   m = m)

            # ================== Sparsity update =============
            # Updating sparsity 
            update_sparsity = sparsity_scheduler(iteration, torch.sum(MSE_test), model, optimizer)
            if update_sparsity: 
                model.constraint.sparsity_masks = model.sparse_estimator(thetas, time_derivs)

            # ================= Checking convergence
            l1_norm = torch.sum(torch.abs(torch.cat(model.constraint_coeffs(sparse=True, scaled=True), dim=1)))
            converged = convergence(iteration, l1_norm)
            if converged:
                break
    logger.close(model)

In [107]:
sparsity_scheduler = TrainTestPeriodic(periodicity=25, patience=1e8, delta=1e-5) # in terms of write iterations
network = NN(2, [30, 30, 30, 30, 30], 1)
library = Library1D(poly_order=2, diff_order=3) # Library function
estimator = Threshold(0.1) # Sparse estimator 
constraint = LeastSquares() # How to constrain
model = DeepMoD(network, library, estimator, constraint, 1).to(device) # Putting it all in the model
optimizer = torch.optim.Adam(model.parameters(), betas=(0.99, 0.99), amsgrad=True, lr=2e-3) # Defining optimizer

In [108]:
model.constraint.sparsity_masks = [torch.zeros(12, dtype=torch.bool).to(device)]
model.constraint.sparsity_masks[0][[3, 5]] = 1

In [109]:
train_mixed_multitask(model, X, y, optimizer, sparsity_scheduler, exp_ID=f'max', split=0.8, write_iterations=50, max_iterations=100000, delta=1e-8, patience=200)

  3300  MSE: 2.94e-01  Reg: 1.10e-13  L1: 1.09e+00 

KeyboardInterrupt: 