### Simple example for performing symbolic regression for a set of points

In [1]:
from nesymres.architectures.model import Model
from nesymres.utils import load_metadata_hdf5
from nesymres.dclasses import FitParams, NNEquation, BFGSParams
from pathlib import Path
from functools import partial
import torch
from sympy import lambdify
import json
import numpy as np

In [2]:
## Load equation configuration and architecture configuration
import omegaconf
with open('100M/eq_setting.json', 'r') as json_file:
  eq_setting = json.load(json_file)

cfg = omegaconf.OmegaConf.load("100M/config.yaml")

In [3]:
## Set up BFGS load rom the hydra config yaml
bfgs = BFGSParams(
        activated= cfg.inference.bfgs.activated,
        n_restarts=cfg.inference.bfgs.n_restarts,
        add_coefficients_if_not_existing=cfg.inference.bfgs.add_coefficients_if_not_existing,
        normalization_o=cfg.inference.bfgs.normalization_o,
        idx_remove=cfg.inference.bfgs.idx_remove,
        normalization_type=cfg.inference.bfgs.normalization_type,
        stop_time=cfg.inference.bfgs.stop_time,
    )

In [4]:
params_fit = FitParams(word2id=eq_setting["word2id"], 
                            id2word={int(k): v for k,v in eq_setting["id2word"].items()}, 
                            una_ops=eq_setting["una_ops"], 
                            bin_ops=eq_setting["bin_ops"], 
                            total_variables=list(eq_setting["total_variables"]),  
                            total_coefficients=list(eq_setting["total_coefficients"]),
                            rewrite_functions=list(eq_setting["rewrite_functions"]),
                            bfgs=bfgs,
                            beam_size=cfg.inference.beam_size #This parameter is a tradeoff between accuracy and fitting time
                            )

In [5]:
weights_path = "../weights/100M.ckpt"

In [6]:
## Load architecture, set into eval mode, and pass the config parameters
model = Model.load_from_checkpoint(weights_path, cfg=cfg.architecture)
model.eval()
if torch.cuda.is_available(): 
  model.cuda()

  rank_zero_warn(
Lightning automatically upgraded your loaded checkpoint from v1.3.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../weights/100M.ckpt`


In [7]:
fitfunc = partial(model.fitfunc,cfg_params=params_fit)

In [8]:
# Create points from an equation
number_of_points = 500
n_variables = 1

#To get best results make sure that your support inside the max and mix support
max_supp = cfg.dataset_train.fun_support["max"] 
min_supp = cfg.dataset_train.fun_support["min"]
X = torch.rand(number_of_points,len(list(eq_setting["total_variables"])))*(max_supp-min_supp)+min_supp
X[:,n_variables:] = 0
target_eq = "x_1*sin(x_1)" #Use x_1,x_2 and x_3 as independent variables
X_dict = {x:X[:,idx].cpu() for idx, x in enumerate(eq_setting["total_variables"])} 
y = lambdify(",".join(eq_setting["total_variables"]), target_eq)(**X_dict)

In [9]:
print("X shape: ", X.shape)
print("y shape: ", y.shape)

X shape:  torch.Size([500, 3])
y shape:  torch.Size([500])


In [10]:
X

tensor([[ 7.6582,  0.0000,  0.0000],
        [ 7.7311,  0.0000,  0.0000],
        [ 5.7335,  0.0000,  0.0000],
        ...,
        [-6.0958,  0.0000,  0.0000],
        [-5.1852,  0.0000,  0.0000],
        [-1.1132,  0.0000,  0.0000]])

In [11]:
output = fitfunc(X,y) 

Memory footprint of the encoder: 4.096e-05GB 



  X = torch.tensor(X,device=self.device).unsqueeze(0)
  y = torch.tensor(y,device=self.device).unsqueeze(0)


Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...
Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...


In [12]:
output

{'all_bfgs_preds': ['((x_1)*(sin(x_1)))', '((x_1)*((cos(x_1))*(tan(x_1))))'],
 'all_bfgs_loss': [0.0, 6.911328e-14],
 'best_bfgs_preds': ['((x_1)*(sin(x_1)))'],
 'best_bfgs_loss': [0.0]}