## Function Fitting Task


In [1]:
import numpy as np
import torch
from sklearn.linear_model import LinearRegression
import sympy
import yaml
from sympy.utilities.lambdify import lambdify
import re

# sigmoid = sympy.Function('sigmoid')
# name: (torch implementation, sympy implementation)

# singularity protection functions
f_inv = lambda x, y_th: ((x_th := 1/y_th), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x) * (torch.abs(x) >= x_th))
f_inv2 = lambda x, y_th: ((x_th := 1/y_th**(1/2)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**2) * (torch.abs(x) >= x_th))
f_inv3 = lambda x, y_th: ((x_th := 1/y_th**(1/3)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**3) * (torch.abs(x) >= x_th))
f_inv4 = lambda x, y_th: ((x_th := 1/y_th**(1/4)), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**4) * (torch.abs(x) >= x_th))
f_inv5 = lambda x, y_th: ((x_th := 1/y_th**(1/5)), y_th/x_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(1/x**5) * (torch.abs(x) >= x_th))
f_sqrt = lambda x, y_th: ((x_th := 1/y_th**2), x_th/y_th*x * (torch.abs(x) < x_th) + torch.nan_to_num(torch.sqrt(torch.abs(x))*torch.sign(x)) * (torch.abs(x) >= x_th))
f_power1d5 = lambda x, y_th: torch.abs(x)**1.5
f_invsqrt = lambda x, y_th: ((x_th := 1/y_th**2), y_th * (torch.abs(x) < x_th) + torch.nan_to_num(1/torch.sqrt(torch.abs(x))) * (torch.abs(x) >= x_th))
f_log = lambda x, y_th: ((x_th := torch.e**(-y_th)), - y_th * (torch.abs(x) < x_th) + torch.nan_to_num(torch.log(torch.abs(x))) * (torch.abs(x) >= x_th))
f_tan = lambda x, y_th: ((clip := x % torch.pi), (delta := torch.pi/2-torch.arctan(y_th)), - y_th/delta * (clip - torch.pi/2) * (torch.abs(clip - torch.pi/2) < delta) + torch.nan_to_num(torch.tan(clip)) * (torch.abs(clip - torch.pi/2) >= delta))
f_arctanh = lambda x, y_th: ((delta := 1-torch.tanh(y_th) + 1e-4), y_th * torch.sign(x) * (torch.abs(x) > 1 - delta) + torch.nan_to_num(torch.arctanh(x)) * (torch.abs(x) <= 1 - delta))
f_arcsin = lambda x, y_th: ((), torch.pi/2 * torch.sign(x) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arcsin(x)) * (torch.abs(x) <= 1))
f_arccos = lambda x, y_th: ((), torch.pi/2 * (1-torch.sign(x)) * (torch.abs(x) > 1) + torch.nan_to_num(torch.arccos(x)) * (torch.abs(x) <= 1))
f_exp = lambda x, y_th: ((x_th := torch.log(y_th)), y_th * (x > x_th) + torch.exp(x) * (x <= x_th))

SYMBOLIC_LIB = {'x': (lambda x: x, lambda x: x, 1, lambda x, y_th: ((), x)),
                 'x^2': (lambda x: x**2, lambda x: x**2, 2, lambda x, y_th: ((), x**2)),
                 'x^3': (lambda x: x**3, lambda x: x**3, 3, lambda x, y_th: ((), x**3)),
                 'x^4': (lambda x: x**4, lambda x: x**4, 3, lambda x, y_th: ((), x**4)),
                 'x^5': (lambda x: x**5, lambda x: x**5, 3, lambda x, y_th: ((), x**5)),
                 '1/x': (lambda x: 1/x, lambda x: 1/x, 2, f_inv),
                 '1/x^2': (lambda x: 1/x**2, lambda x: 1/x**2, 2, f_inv2),
                 '1/x^3': (lambda x: 1/x**3, lambda x: 1/x**3, 3, f_inv3),
                 '1/x^4': (lambda x: 1/x**4, lambda x: 1/x**4, 4, f_inv4),
                 '1/x^5': (lambda x: 1/x**5, lambda x: 1/x**5, 5, f_inv5),
                 'sqrt': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
                 'x^0.5': (lambda x: torch.sqrt(x), lambda x: sympy.sqrt(x), 2, f_sqrt),
                 'x^1.5': (lambda x: torch.sqrt(x)**3, lambda x: sympy.sqrt(x)**3, 4, f_power1d5),
                 '1/sqrt(x)': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt),
                 '1/x^0.5': (lambda x: 1/torch.sqrt(x), lambda x: 1/sympy.sqrt(x), 2, f_invsqrt),
                 'exp': (lambda x: torch.exp(x), lambda x: sympy.exp(x), 2, f_exp),
                 'log': (lambda x: torch.log(x), lambda x: sympy.log(x), 2, f_log),
                 'abs': (lambda x: torch.abs(x), lambda x: sympy.Abs(x), 3, lambda x, y_th: ((), torch.abs(x))),
                 'sin': (lambda x: torch.sin(x), lambda x: sympy.sin(x), 2, lambda x, y_th: ((), torch.sin(x))),
                 'cos': (lambda x: torch.cos(x), lambda x: sympy.cos(x), 2, lambda x, y_th: ((), torch.cos(x))),
                 'tan': (lambda x: torch.tan(x), lambda x: sympy.tan(x), 3, f_tan),
                 'tanh': (lambda x: torch.tanh(x), lambda x: sympy.tanh(x), 3, lambda x, y_th: ((), torch.tanh(x))),
                 'sgn': (lambda x: torch.sign(x), lambda x: sympy.sign(x), 3, lambda x, y_th: ((), torch.sign(x))),
                 'arcsin': (lambda x: torch.arcsin(x), lambda x: sympy.asin(x), 4, f_arcsin),
                 'arccos': (lambda x: torch.arccos(x), lambda x: sympy.acos(x), 4, f_arccos),
                 'arctan': (lambda x: torch.arctan(x), lambda x: sympy.atan(x), 4, lambda x, y_th: ((), torch.arctan(x))),
                 'arctanh': (lambda x: torch.arctanh(x), lambda x: sympy.atanh(x), 4, f_arctanh),
                 '0': (lambda x: x*0, lambda x: x*0, 0, lambda x, y_th: ((), x*0)),
                 'gaussian': (lambda x: torch.exp(-x**2), lambda x: sympy.exp(-x**2), 3, lambda x, y_th: ((), torch.exp(-x**2))),
                 #'cosh': (lambda x: torch.cosh(x), lambda x: sympy.cosh(x), 5),
                 #'sigmoid': (lambda x: torch.sigmoid(x), sympy.Function('sigmoid'), 4),
                 #'relu': (lambda x: torch.relu(x), relu),
}

def create_dataset(f, 
                   n_var=2, 
                   f_mode = 'col',
                   ranges = [-1,1],
                   train_num=1000, 
                   test_num=1000,
                   normalize_input=False,
                   normalize_label=False,
                   device='cpu',
                   seed=0):
    '''
    create dataset
    
    Args:
    -----
        f : function
            the symbolic formula used to create the synthetic dataset
        ranges : list or np.array; shape (2,) or (n_var, 2)
            the range of input variables. Default: [-1,1].
        train_num : int
            the number of training samples. Default: 1000.
        test_num : int
            the number of test samples. Default: 1000.
        normalize_input : bool
            If True, apply normalization to inputs. Default: False.
        normalize_label : bool
            If True, apply normalization to labels. Default: False.
        device : str
            device. Default: 'cpu'.
        seed : int
            random seed. Default: 0.
        
    Returns:
    --------
        dataset : dic
            Train/test inputs/labels are dataset['train_input'], dataset['train_label'],
                        dataset['test_input'], dataset['test_label']
         
    Example
    -------
    >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
    >>> dataset = create_dataset(f, n_var=2, train_num=100)
    >>> dataset['train_input'].shape
    torch.Size([100, 2])
    '''

    np.random.seed(seed)
    torch.manual_seed(seed)

    if len(np.array(ranges).shape) == 1:
        ranges = np.array(ranges * n_var).reshape(n_var,2)
    else:
        ranges = np.array(ranges)
        
    
    train_input = torch.zeros(train_num, n_var)
    test_input = torch.zeros(test_num, n_var)
    for i in range(n_var):
        train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0]
        test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0]
                
    if f_mode == 'col':
        train_label = f(train_input)
        test_label = f(test_input)
    elif f_mode == 'row':
        train_label = f(train_input.T)
        test_label = f(test_input.T)
    else:
        print(f'f_mode {f_mode} not recognized')
        
    # if has only 1 dimension
    if len(train_label.shape) == 1:
        train_label = train_label.unsqueeze(dim=1)
        test_label = test_label.unsqueeze(dim=1)
        
    def normalize(data, mean, std):
            return (data-mean)/std
            
    if normalize_input == True:
        mean_input = torch.mean(train_input, dim=0, keepdim=True)
        std_input = torch.std(train_input, dim=0, keepdim=True)
        train_input = normalize(train_input, mean_input, std_input)
        test_input = normalize(test_input, mean_input, std_input)
        
    if normalize_label == True:
        mean_label = torch.mean(train_label, dim=0, keepdim=True)
        std_label = torch.std(train_label, dim=0, keepdim=True)
        train_label = normalize(train_label, mean_label, std_label)
        test_label = normalize(test_label, mean_label, std_label)

    dataset = {}
    dataset['train_input'] = train_input.to(device)
    dataset['test_input'] = test_input.to(device)

    dataset['train_label'] = train_label.to(device)
    dataset['test_label'] = test_label.to(device)

    return dataset



def fit_params(x, y, fun, a_range=(-10,10), b_range=(-10,10), grid_number=101, iteration=3, verbose=True, device='cpu'):
    '''
    fit a, b, c, d such that
    
    .. math::
        |y-(cf(ax+b)+d)|^2
        
    is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.
    
    Args:
    -----
        x : 1D array
            x values
        y : 1D array
            y values
        fun : function
            symbolic function
        a_range : tuple
            sweeping range of a
        b_range : tuple
            sweeping range of b
        grid_num : int
            number of steps along a and b
        iteration : int
            number of zooming in
        verbose : bool
            print extra information if True
        device : str
            device
        
    Returns:
    --------
        a_best : float
            best fitted a
        b_best : float
            best fitted b
        c_best : float
            best fitted c
        d_best : float
            best fitted d
        r2_best : float
            best r2 (coefficient of determination)
    
    Example
    -------
    >>> num = 100
    >>> x = torch.linspace(-1,1,steps=num)
    >>> noises = torch.normal(0,1,(num,)) * 0.02
    >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
    >>> fit_params(x, y, torch.sin)
    r2 is 0.9999727010726929
    (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
    '''
    # fit a, b, c, d such that y=c*fun(a*x+b)+d; both x and y are 1D array.
    # sweep a and b, choose the best fitted model   
    for _ in range(iteration):
        a_ = torch.linspace(a_range[0], a_range[1], steps=grid_number, device=device)
        b_ = torch.linspace(b_range[0], b_range[1], steps=grid_number, device=device)
        a_grid, b_grid = torch.meshgrid(a_, b_, indexing='ij')
        post_fun = fun(a_grid[None,:,:] * x[:,None,None] + b_grid[None,:,:])
        x_mean = torch.mean(post_fun, dim=[0], keepdim=True)
        y_mean = torch.mean(y, dim=[0], keepdim=True)
        numerator = torch.sum((post_fun - x_mean)*(y-y_mean)[:,None,None], dim=0)**2
        denominator = torch.sum((post_fun - x_mean)**2, dim=0)*torch.sum((y - y_mean)[:,None,None]**2, dim=0)
        r2 = numerator/(denominator+1e-4)
        r2 = torch.nan_to_num(r2)
        
        
        best_id = torch.argmax(r2)
        a_id, b_id = torch.div(best_id, grid_number, rounding_mode='floor'), best_id % grid_number
        
        
        if a_id == 0 or a_id == grid_number - 1 or b_id == 0 or b_id == grid_number - 1:
            if _ == 0 and verbose==True:
                print('Best value at boundary.')
            if a_id == 0:
                a_range = [a_[0], a_[1]]
            if a_id == grid_number - 1:
                a_range = [a_[-2], a_[-1]]
            if b_id == 0:
                b_range = [b_[0], b_[1]]
            if b_id == grid_number - 1:
                b_range = [b_[-2], b_[-1]]
            
        else:
            a_range = [a_[a_id-1], a_[a_id+1]]
            b_range = [b_[b_id-1], b_[b_id+1]]
            
    a_best = a_[a_id]
    b_best = b_[b_id]
    post_fun = fun(a_best * x + b_best)
    r2_best = r2[a_id, b_id]
    
    if verbose == True:
        print(f"r2 is {r2_best}")
        if r2_best < 0.9:
            print(f'r2 is not very high, please double check if you are choosing the correct symbolic function.')

    post_fun = torch.nan_to_num(post_fun)
    reg = LinearRegression().fit(post_fun[:,None].detach().cpu().numpy(), y.detach().cpu().numpy())
    c_best = torch.from_numpy(reg.coef_)[0].to(device)
    d_best = torch.from_numpy(np.array(reg.intercept_)).to(device)
    return torch.stack([a_best, b_best, c_best, d_best]), r2_best


def sparse_mask(in_dim, out_dim):
    '''
    get sparse mask
    '''
    in_coord = torch.arange(in_dim) * 1/in_dim + 1/(2*in_dim)
    out_coord = torch.arange(out_dim) * 1/out_dim + 1/(2*out_dim)

    dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:])
    in_nearest = torch.argmin(dist_mat, dim=0)
    in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1,0)
    out_nearest = torch.argmin(dist_mat, dim=1)
    out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1,0)
    all_connection = torch.cat([in_connection, out_connection], dim=0)
    mask = torch.zeros(in_dim, out_dim)
    mask[all_connection[:,0], all_connection[:,1]] = 1.
    
    return mask


def add_symbolic(name, fun, c=1, fun_singularity=None):
    '''
    add a symbolic function to library
    
    Args:
    -----
        name : str
            name of the function
        fun : fun
            torch function or lambda function
    
    Returns:
    --------
        None
    
    Example
    -------
    >>> print(SYMBOLIC_LIB['Bessel'])
    KeyError: 'Bessel'
    >>> add_symbolic('Bessel', torch.special.bessel_j0)
    >>> print(SYMBOLIC_LIB['Bessel'])
    (<built-in function special_bessel_j0>, Bessel)
    '''
    exec(f"globals()['{name}'] = sympy.Function('{name}')")
    if fun_singularity==None:
        fun_singularity = fun
    SYMBOLIC_LIB[name] = (fun, globals()[name], c, fun_singularity)
    
  
def ex_round(ex1, n_digit):
    '''
    rounding the numbers in an expression to certain floating points
    
    Args:
    -----
        ex1 : sympy expression
        n_digit : int
        
    Returns:
    --------
        ex2 : sympy expression
    
    Example
    -------
    >>> from kan.utils import *
    >>> from sympy import *
    >>> input_vars = a, b = symbols('a b')
    >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
    >>> ex_round(expression, 2)
    '''
    ex2 = ex1
    for a in sympy.preorder_traversal(ex1):
        if isinstance(a, sympy.Float):
            ex2 = ex2.subs(a, round(a, n_digit))
    return ex2


def augment_input(orig_vars, aux_vars, x):
    '''
    augment inputs
    
    Args:
    -----
        orig_vars : list of sympy symbols
        aux_vars : list of auxiliary symbols
        x : inputs
        
    Returns:
    --------
        augmented inputs
    
    Example
    -------
    >>> from kan.utils import *
    >>> from sympy import *
    >>> orig_vars = a, b = symbols('a b')
    >>> aux_vars = [a + b, a * b]
    >>> x = torch.rand(100, 2)
    >>> augment_input(orig_vars, aux_vars, x).shape
    '''
    # if x is a tensor
    if isinstance(x, torch.Tensor):
        
        aux_values = torch.tensor([]).to(x.device)

        for aux_var in aux_vars:
            func = lambdify(orig_vars, aux_var,'numpy') # returns a numpy-ready function
            aux_value = torch.from_numpy(func(*[x[:,[i]].numpy() for i in range(len(orig_vars))]))
            aux_values = torch.cat([aux_values, aux_value], dim=1)
            
        x = torch.cat([aux_values, x], dim=1)

    # if x is a dataset
    elif isinstance(x, dict):
        x['train_input'] = augment_input(orig_vars, aux_vars, x['train_input'])
        x['test_input'] = augment_input(orig_vars, aux_vars, x['test_input'])
        
    return x


def batch_jacobian(func, x, create_graph=False, mode='scalar'):
    '''
    jacobian
    
    Args:
    -----
        func : function or model
        x : inputs
        create_graph : bool
        
    Returns:
    --------
        jacobian
    
    Example
    -------
    >>> from kan.utils import batch_jacobian
    >>> x = torch.normal(0,1,size=(100,2))
    >>> model = lambda x: x[:,[0]] + x[:,[1]]
    >>> batch_jacobian(model, x)
    '''
    # x in shape (Batch, Length)
    def _func_sum(x):
        return func(x).sum(dim=0)
    if mode == 'scalar':
        return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
    elif mode == 'vector':
        return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)

def batch_hessian(model, x, create_graph=False):
    '''
    hessian
    
    Args:
    -----
        func : function or model
        x : inputs
        create_graph : bool
        
    Returns:
    --------
        jacobian
    
    Example
    -------
    >>> from kan.utils import batch_hessian
    >>> x = torch.normal(0,1,size=(100,2))
    >>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2
    >>> batch_hessian(model, x)
    '''
    # x in shape (Batch, Length)
    jac = lambda x: batch_jacobian(model, x, create_graph=True)
    def _jac_sum(x):
        return jac(x).sum(dim=0)
    return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)


def create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu'):
    '''
    create dataset from data
    
    Args:
    -----
        inputs : 2D torch.float
        labels : 2D torch.float
        train_ratio : float
            the ratio of training fraction
        device : str
        
    Returns:
    --------
        dataset (dictionary)
    
    Example
    -------
    >>> from kan.utils import create_dataset_from_data
    >>> x = torch.normal(0,1,size=(100,2))
    >>> y = torch.normal(0,1,size=(100,1))
    >>> dataset = create_dataset_from_data(x, y)
    >>> dataset['train_input'].shape
    '''
    num = inputs.shape[0]
    train_id = np.random.choice(num, int(num*train_ratio), replace=False)
    test_id = list(set(np.arange(num)) - set(train_id))
    dataset = {}
    dataset['train_input'] = inputs[train_id].detach().to(device)
    dataset['test_input'] = inputs[test_id].detach().to(device)
    dataset['train_label'] = labels[train_id].detach().to(device)
    dataset['test_label'] = labels[test_id].detach().to(device)
    
    return dataset


def get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0., lamb_l1=1., lamb_entropy=0.):
    '''
    compute the jacobian/hessian of loss wrt to model parameters
    
    Args:
    -----
        inputs : 2D torch.float
        labels : 2D torch.float
        derivative : str
            'jacobian' or 'hessian'
        device : str
        
    Returns:
    --------
        jacobian or hessian
    '''
    def get_mapping(model):

        mapping = {}
        name = 'model1'

        keys = list(model.state_dict().keys())
        for key in keys:

            y = re.findall(".[0-9]+", key)
            if len(y) > 0:
                y = y[0][1:]
                x = re.split(".[0-9]+", key)
                mapping[key] = name + '.' + x[0] + '[' + y + ']' + x[1]


            y = re.findall("_[0-9]+", key)
            if len(y) > 0:
                y = y[0][1:]
                x = re.split(".[0-9]+", key)
                mapping[key] = name + '.' + x[0] + '[' + y + ']'

        return mapping

    
    #model1 = copy.deepcopy(model)
    model1 = model.copy()
    mapping = get_mapping(model)
   
    # collect keys and shapes
    keys = list(model.state_dict().keys())
    shapes = []

    for params in model.parameters():
        shapes.append(params.shape)


    # turn a flattened vector to model params
    def param2statedict(p, keys, shapes):

        new_state_dict = {}

        start = 0
        n_group = len(keys)
        for i in range(n_group):
            shape = shapes[i]
            n_params = torch.prod(torch.tensor(shape))
            new_state_dict[keys[i]] = p[start:start+n_params].reshape(shape)
            start += n_params

        return new_state_dict
    
    def differentiable_load_state_dict(mapping, state_dict, model1):

        for key in keys:
            if mapping[key][-1] != ']':
                exec(f"del {mapping[key]}")
            exec(f"{mapping[key]} = state_dict[key]")
            

    # input: p, output: output
    def get_param2loss_fun(inputs, labels):

        def param2loss_fun(p):

            p = p[0]
            state_dict = param2statedict(p, keys, shapes)
            # this step is non-differentiable
            #model.load_state_dict(state_dict)
            differentiable_load_state_dict(mapping, state_dict, model1)
            if loss_mode == 'pred':
                pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True)
                loss = pred_loss
            elif loss_mode == 'reg':
                reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1)
                loss = reg_loss
            elif loss_mode == 'all':
                pred_loss = torch.mean((model1(inputs) - labels)**2, dim=(0,1), keepdim=True)
                reg_loss = model1.get_reg(reg_metric=reg_metric, lamb_l1=lamb_l1, lamb_entropy=lamb_entropy) * torch.ones(1,1)
                loss = pred_loss + lamb * reg_loss
            return loss

        return param2loss_fun
    
    fun = get_param2loss_fun(inputs, labels)    
    p = model2param(model)[None,:]
    if derivative == 'hessian':
        result = batch_hessian(fun, p)
    elif derivative == 'jacobian':
        result = batch_jacobian(fun, p)
    return result

def model2param(model):
    '''
    turn model parameters into a flattened vector
    '''
    p = torch.tensor([]).to(model.device)
    for params in model.parameters():
        p = torch.cat([p, params.reshape(-1,)], dim=0)
    return p

from sympy import *
import torch


def get_feynman_dataset(name):
    
    global symbols
    
    tpi = torch.tensor(torch.pi)
    
    if name == 'test':
        symbol = x, y = symbols('x, y')
        expr = (x+y) * sin(exp(2*y))
        f = lambda x: (x[:,[0]] + x[:,[1]])*torch.sin(torch.exp(2*x[:,[1]]))
        ranges = [-1,1]
    
    if name == 'I.6.20a' or name == 1:
        symbol = theta = symbols('theta')
        symbol = [symbol]
        expr = exp(-theta**2/2)/sqrt(2*pi)
        f = lambda x: torch.exp(-x[:,[0]]**2/2)/torch.sqrt(2*tpi)
        ranges = [[-3,3]]
    
    if name == 'I.6.20' or name == 2:
        symbol = theta, sigma = symbols('theta sigma')
        expr = exp(-theta**2/(2*sigma**2))/sqrt(2*pi*sigma**2)
        f = lambda x: torch.exp(-x[:,[0]]**2/(2*x[:,[1]]**2))/torch.sqrt(2*tpi*x[:,[1]]**2)
        ranges = [[-1,1],[0.5,2]]
        
    if name == 'I.6.20b' or name == 3:
        symbol = theta, theta1, sigma = symbols('theta theta1 sigma')
        expr = exp(-(theta-theta1)**2/(2*sigma**2))/sqrt(2*pi*sigma**2)
        f = lambda x: torch.exp(-(x[:,[0]]-x[:,[1]])**2/(2*x[:,[2]]**2))/torch.sqrt(2*tpi*x[:,[2]]**2)
        ranges = [[-1.5,1.5],[-1.5,1.5],[0.5,2]]
    
    if name == 'I.8.4' or name == 4:
        symbol = x1, x2, y1, y2 = symbols('x1 x2 y1 y2')
        expr = sqrt((x2-x1)**2+(y2-y1)**2)
        f = lambda x: torch.sqrt((x[:,[1]]-x[:,[0]])**2+(x[:,[3]]-x[:,[2]])**2)
        ranges = [[-1,1],[-1,1],[-1,1],[-1,1]]
        
    if name == 'I.9.18' or name == 5:
        symbol = G, m1, m2, x1, x2, y1, y2, z1, z2 = symbols('G m1 m2 x1 x2 y1 y2 z1 z2')
        expr = G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/((x[:,[3]]-x[:,[4]])**2+(x[:,[5]]-x[:,[6]])**2+(x[:,[7]]-x[:,[8]])**2)
        ranges = [[-1,1],[-1,1],[-1,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1],[-1,-0.5],[0.5,1]]
        
    if name == 'I.10.7' or name == 6:
        symbol = m0, v, c = symbols('m0 v c')
        expr = m0/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[0,1],[1,2]]
        
    if name == 'I.11.19' or name == 7:
        symbol = x1, y1, x2, y2, x3, y3 = symbols('x1 y1 x2 y2 x3 y3')
        expr = x1*y1 + x2*y2 + x3*y3
        f = lambda x: x[:,[0]]*x[:,[1]] + x[:,[2]]*x[:,[3]] + x[:,[4]]*x[:,[5]]
        ranges = [-1,1]
    
    if name == 'I.12.1' or name == 8:
        symbol = mu, Nn = symbols('mu N_n')
        expr = mu * Nn
        f = lambda x: x[:,[0]]*x[:,[1]]
        ranges = [-1,1]
        
    if name == 'I.12.2' or name == 9:
        symbol = q1, q2, eps, r = symbols('q1 q2 epsilon r')
        expr = q1*q2/(4*pi*eps*r**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]]*x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]]
        
    if name == 'I.12.4' or name == 10:
        symbol = q1, eps, r = symbols('q1 epsilon r')
        expr = q1/(4*pi*eps*r**2)
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]**2)
        ranges = [[-1,1],[0.5,2],[0.5,2]]
        
    if name == 'I.12.5' or name == 11:
        symbol = q2, Ef = symbols('q2, E_f')
        expr = q2*Ef
        f = lambda x: x[:,[0]]*x[:,[1]]
        ranges = [-1,1]
        
    if name == 'I.12.11' or name == 12:
        symbol = q, Ef, B, v, theta = symbols('q E_f B v theta')
        expr = q*(Ef + B*v*sin(theta))
        f = lambda x: x[:,[0]]*(x[:,[1]]+x[:,[2]]*x[:,[3]]*torch.sin(x[:,[4]]))
        ranges = [[-1,1],[-1,1],[-1,1],[-1,1],[0,2*tpi]]
        
    if name == 'I.13.4' or name == 13:
        symbol = m, v, u, w = symbols('m u v w')
        expr = 1/2*m*(v**2+u**2+w**2)
        f = lambda x: 1/2*x[:,[0]]*(x[:,[1]]**2+x[:,[2]]**2+x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[-1,1],[-1,1]]
        
    if name == 'I.13.12' or name == 14:
        symbol = G, m1, m2, r1, r2 = symbols('G m1 m2 r1 r2')
        expr = G*m1*m2*(1/r2-1/r1)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*(1/x[:,[4]]-1/x[:,[3]])
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'I.14.3' or name == 15:
        symbol = m, g, z = symbols('m g z')
        expr = m*g*z
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]
        ranges = [[0,1],[0,1],[-1,1]]
        
    if name == 'I.14.4' or name == 16:
        symbol = ks, x = symbols('k_s x')
        expr = 1/2*ks*x**2
        f = lambda x: 1/2*x[:,[0]]*x[:,[1]]**2
        ranges = [[0,1],[-1,1]]
        
    if name == 'I.15.3x' or name == 17:
        symbol = x, u, t, c = symbols('x u t c')
        expr = (x-u*t)/sqrt(1-u**2/c**2)
        f = lambda x: (x[:,[0]] - x[:,[1]]*x[:,[2]])/torch.sqrt(1-x[:,[1]]**2/x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[-1,1],[1,2]]
        
    if name == 'I.15.3t' or name == 18:
        symbol = t, u, x, c = symbols('t u x c')
        expr = (t-u*x/c**2)/sqrt(1-u**2/c**2)
        f = lambda x: (x[:,[0]] - x[:,[1]]*x[:,[2]]/x[:,[3]]**2)/torch.sqrt(1-x[:,[1]]**2/x[:,[3]]**2)
        ranges = [[-1,1],[-1,1],[-1,1],[1,2]]
        
    if name == 'I.15.10' or name == 19:
        symbol = m0, v, c = symbols('m0 v c')
        expr = m0*v/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[-1,1],[-0.9,0.9],[1.1,2]]
        
    if name == 'I.16.6' or name == 20:
        symbol = u, v, c = symbols('u v c')
        expr = (u+v)/(1+u*v/c**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/(1+x[:,[0]]*x[:,[1]]/x[:,[2]]**2)
        ranges = [[-0.8,0.8],[-0.8,0.8],[1,2]]
        
    if name == 'I.18.4' or name == 21:
        symbol = m1, r1, m2, r2 = symbols('m1 r1 m2 r2')
        expr = (m1*r1+m2*r2)/(m1+m2)
        f = lambda x: (x[:,[0]]*x[:,[1]]+x[:,[2]]*x[:,[3]])/(x[:,[0]]+x[:,[2]])
        ranges = [[0.5,1],[-1,1],[0.5,1],[-1,1]]
        
    if name == 'I.18.4' or name == 22:
        symbol = r, F, theta = symbols('r F theta')
        expr = r*F*sin(theta)
        f = lambda x: x[:,[0]]*x[:,[1]]*torch.sin(x[:,[2]])
        ranges = [[-1,1],[-1,1],[0,2*tpi]]
        
    if name == 'I.18.16' or name == 23:
        symbol = m, r, v, theta = symbols('m r v theta')
        expr = m*r*v*sin(theta)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*torch.sin(x[:,[3]])
        ranges = [[-1,1],[-1,1],[-1,1],[0,2*tpi]]
        
    if name == 'I.24.6' or name == 24:
        symbol = m, omega, omega0, x = symbols('m omega omega_0 x')
        expr = 1/4*m*(omega**2+omega0**2)*x**2
        f = lambda x: 1/4*x[:,[0]]*(x[:,[1]]**2+x[:,[2]]**2)*x[:,[3]]**2
        ranges = [[0,1],[-1,1],[-1,1],[-1,1]]
        
    if name == 'I.25.13' or name == 25:
        symbol = q, C = symbols('q C')
        expr = q/C
        f = lambda x: x[:,[0]]/x[:,[1]]
        ranges = [[-1,1],[0.5,2]]
        
    if name == 'I.26.2' or name == 26:
        symbol = n, theta2 = symbols('n theta2')
        expr = asin(n*sin(theta2))
        f = lambda x: torch.arcsin(x[:,[0]]*torch.sin(x[:,[1]]))
        ranges = [[0,0.99],[0,2*tpi]]
        
    if name == 'I.27.6' or name == 27:
        symbol = d1, d2, n = symbols('d1 d2 n')
        expr = 1/(1/d1+n/d2)
        f = lambda x: 1/(1/x[:,[0]]+x[:,[2]]/x[:,[1]])
        ranges = [[0.5,2],[1,2],[0.5,2]]
    
    if name == 'I.29.4' or name == 28:
        symbol = omega, c = symbols('omega c')
        expr = omega/c
        f = lambda x: x[:,[0]]/x[:,[1]]
        ranges = [[0,1],[0.5,2]]
        
    if name == 'I.29.16' or name == 29:
        symbol = x1, x2, theta1, theta2 = symbols('x1 x2 theta1 theta2')
        expr = sqrt(x1**2+x2**2-2*x1*x2*cos(theta1-theta2))
        f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2-2*x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]]-x[:,[3]]))
        ranges = [[-1,1],[-1,1],[0,2*tpi],[0,2*tpi]]
        
    if name == 'I.30.3' or name == 30:
        symbol = I0, n, theta = symbols('I_0 n theta')
        expr = I0 * sin(n*theta/2)**2 / sin(theta/2) ** 2
        f = lambda x: x[:,[0]] * torch.sin(x[:,[1]]*x[:,[2]]/2)**2 / torch.sin(x[:,[2]]/2)**2
        ranges = [[0,1],[0,4],[0.4*tpi,1.6*tpi]]
        
    if name == 'I.30.5' or name == 31:
        symbol = lamb, n, d = symbols('lambda n d')
        expr = asin(lamb/(n*d))
        f = lambda x: torch.arcsin(x[:,[0]]/(x[:,[1]]*x[:,[2]]))
        ranges = [[-1,1],[1,1.5],[1,1.5]]
        
    if name == 'I.32.5' or name == 32:
        symbol = q, a, eps, c = symbols('q a epsilon c')
        expr = q**2*a**2/(eps*c**3)
        f = lambda x: x[:,[0]]**2*x[:,[1]]**2/(x[:,[2]]*x[:,[3]]**3)
        ranges = [[-1,1],[-1,1],[0.5,2],[0.5,2]]
        
    if name == 'I.32.17' or name == 33:
        symbol = eps, c, Ef, r, omega, omega0 = symbols('epsilon c E_f r omega omega_0')
        expr = nsimplify((1/2*eps*c*Ef**2)*(8*pi*r**2/3)*(omega**4/(omega**2-omega0**2)**2))
        f = lambda x: (1/2*x[:,[0]]*x[:,[1]]*x[:,[2]]**2)*(8*tpi*x[:,[3]]**2/3)*(x[:,[4]]**4/(x[:,[4]]**2-x[:,[5]]**2)**2)
        ranges = [[0,1],[0,1],[-1,1],[0,1],[0,1],[1,2]]
        
    if name == 'I.34.8' or name == 34:
        symbol = q, V, B, p = symbols('q V B p')
        expr = q*V*B/p
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[-1,1],[-1,1],[-1,1],[0.5,2]]
        
    if name == 'I.34.10' or name == 35:
        symbol = omega0, v, c = symbols('omega_0 v c')
        expr = omega0/(1-v/c)
        f = lambda x: x[:,[0]]/(1-x[:,[1]]/x[:,[2]])
        ranges = [[0,1],[0,0.9],[1.1,2]]
        
    if name == 'I.34.14' or name == 36:
        symbol = omega0, v, c = symbols('omega_0 v c')
        expr = omega0 * (1+v/c)/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]*(1+x[:,[1]]/x[:,[2]])/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[-0.9,0.9],[1.1,2]]
        
    if name == 'I.34.27' or name == 37:
        symbol = hbar, omega = symbols('hbar omega')
        expr = hbar * omega
        f = lambda x: x[:,[0]]*x[:,[1]]
        ranges = [[-1,1],[-1,1]]
        
    if name == 'I.37.4' or name == 38:
        symbol = I1, I2, delta = symbols('I_1 I_2 delta')
        expr = I1 + I2 + 2*sqrt(I1*I2)*cos(delta)
        f = lambda x: x[:,[0]] + x[:,[1]] + 2*torch.sqrt(x[:,[0]]*x[:,[1]])*torch.cos(x[:,[2]])
        ranges = [[0.1,1],[0.1,1],[0,2*tpi]]
        
    if name == 'I.38.12' or name == 39:
        symbol = eps, hbar, m, q = symbols('epsilon hbar m q')
        expr = 4*pi*eps*hbar**2/(m*q**2)
        f = lambda x: 4*tpi*x[:,[0]]*x[:,[1]]**2/(x[:,[2]]*x[:,[3]]**2)
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'I.39.10' or name == 40:
        symbol = pF, V = symbols('p_F V')
        expr = 3/2 * pF * V
        f = lambda x: 3/2 * x[:,[0]] * x[:,[1]]
        ranges = [[0,1],[0,1]]
        
    if name == 'I.39.11' or name == 41:
        symbol = gamma, pF, V = symbols('gamma p_F V')
        expr = pF * V/(gamma - 1)
        f = lambda x: 1/(x[:,[0]]-1) * x[:,[1]] * x[:,[2]]
        ranges = [[1.5,3],[0,1],[0,1]]
        
    if name == 'I.39.22' or name == 42:
        symbol = n, kb, T, V = symbols('n k_b T V')
        expr = n*kb*T/V
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'I.40.1' or name == 43:
        symbol = n0, m, g, x, kb, T = symbols('n_0 m g x k_b T')
        expr = n0 * exp(-m*g*x/(kb*T))
        f = lambda x: x[:,[0]] * torch.exp(-x[:,[1]]*x[:,[2]]*x[:,[3]]/(x[:,[4]]*x[:,[5]]))
        ranges = [[0,1],[-1,1],[-1,1],[-1,1],[1,2],[1,2]]
        
    if name == 'I.41.16' or name == 44:
        symbol = hbar, omega, c, kb, T = symbols('hbar omega c k_b T')
        expr = hbar * omega**3/(pi**2*c**2*(exp(hbar*omega/(kb*T))-1))
        f = lambda x: x[:,[0]]*x[:,[1]]**3/(tpi**2*x[:,[2]]**2*(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[3]]*x[:,[4]]))-1))
        ranges = [[0.5,1],[0.5,1],[0.5,2],[0.5,2],[0.5,2]]
        
    if name == 'I.43.16' or name == 45:
        symbol = mu, q, Ve, d = symbols('mu q V_e d')
        expr = mu*q*Ve/d
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'I.43.31' or name == 46:
        symbol = mu, kb, T = symbols('mu k_b T')
        expr = mu*kb*T
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]
        ranges = [[0,1],[0,1],[0,1]]
    
    if name == 'I.43.43' or name == 47:
        symbol = gamma, kb, v, A = symbols('gamma k_b v A')
        expr = kb*v/A/(gamma-1)
        f = lambda x: 1/(x[:,[0]]-1)*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[1.5,3],[0,1],[0,1],[0.5,2]]
        
    if name == 'I.44.4' or name == 48:
        symbol = n, kb, T, V1, V2 = symbols('n k_b T V_1 V_2')
        expr = n*kb*T*log(V2/V1)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*torch.log(x[:,[4]]/x[:,[3]])
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'I.47.23' or name == 49:
        symbol = gamma, p, rho = symbols('gamma p rho')
        expr = sqrt(gamma*p/rho)
        f = lambda x: torch.sqrt(x[:,[0]]*x[:,[1]]/x[:,[2]])
        ranges = [[0.1,1],[0.1,1],[0.5,2]]
        
    if name == 'I.48.20' or name == 50:
        symbol = m, v, c = symbols('m v c')
        expr = m*c**2/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]*x[:,[2]]**2/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[-0.9,0.9],[1.1,2]]
        
    if name == 'I.50.26' or name == 51:
        symbol = x1, alpha, omega, t = symbols('x_1 alpha omega t')
        expr = x1*(cos(omega*t)+alpha*cos(omega*t)**2)
        f = lambda x: x[:,[0]]*(torch.cos(x[:,[2]]*x[:,[3]])+x[:,[1]]*torch.cos(x[:,[2]]*x[:,[3]])**2)
        ranges = [[0,1],[0,1],[0,2*tpi],[0,1]]
        
    if name == 'II.2.42' or name == 52:
        symbol = kappa, T1, T2, A, d = symbols('kappa T_1 T_2 A d')
        expr = kappa*(T2-T1)*A/d
        f = lambda x: x[:,[0]]*(x[:,[2]]-x[:,[1]])*x[:,[3]]/x[:,[4]]
        ranges = [[0,1],[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'II.3.24' or name == 53:
        symbol = P, r = symbols('P r')
        expr = P/(4*pi*r**2)
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]**2)
        ranges = [[0,1],[0.5,2]]
        
    if name == 'II.4.23' or name == 54:
        symbol = q, eps, r = symbols('q epsilon r')
        expr = q/(4*pi*eps*r)
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]])
        ranges = [[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.6.11' or name == 55:
        symbol = eps, pd, theta, r = symbols('epsilon p_d theta r')
        expr = 1/(4*pi*eps)*pd*cos(theta)/r**2
        f = lambda x: 1/(4*tpi*x[:,[0]])*x[:,[1]]*torch.cos(x[:,[2]])/x[:,[3]]**2
        ranges = [[0.5,2],[0,1],[0,2*tpi],[0.5,2]]
        
    if name == 'II.6.15a' or name == 56:
        symbol = eps, pd, z, x, y, r = symbols('epsilon p_d z x y r')
        expr = 3/(4*pi*eps)*pd*z/r**5*sqrt(x**2+y**2)
        f = lambda x: 3/(4*tpi*x[:,[0]])*x[:,[1]]*x[:,[2]]/x[:,[5]]**5*torch.sqrt(x[:,[3]]**2+x[:,[4]]**2)
        ranges = [[0.5,2],[0,1],[0,1],[0,1],[0,1],[0.5,2]]
    
    if name == 'II.6.15b' or name == 57:
        symbol = eps, pd, r, theta = symbols('epsilon p_d r theta')
        expr = 3/(4*pi*eps)*pd/r**3*cos(theta)*sin(theta)
        f = lambda x: 3/(4*tpi*x[:,[0]])*x[:,[1]]/x[:,[2]]**3*torch.cos(x[:,[3]])*torch.sin(x[:,[3]])
        ranges = [[0.5,2],[0,1],[0.5,2],[0,2*tpi]]
        
    if name == 'II.8.7' or name == 58:
        symbol = q, eps, d = symbols('q epsilon d')
        expr = 3/5*q**2/(4*pi*eps*d)
        f = lambda x: 3/5*x[:,[0]]**2/(4*tpi*x[:,[1]]*x[:,[2]])
        ranges = [[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.8.31' or name == 59:
        symbol = eps, Ef = symbols('epsilon E_f')
        expr = 1/2*eps*Ef**2
        f = lambda x: 1/2*x[:,[0]]*x[:,[1]]**2
        ranges = [[0,1],[0,1]]
        
    if name == 'I.10.9' or name == 60:
        symbol = sigma, eps, chi = symbols('sigma epsilon chi')
        expr = sigma/eps/(1+chi)
        f = lambda x: x[:,[0]]/x[:,[1]]/(1+x[:,[2]])
        ranges = [[0,1],[0.5,2],[0,1]]
        
    if name == 'II.11.3' or name == 61:
        symbol = q, Ef, m, omega0, omega = symbols('q E_f m omega_o omega')
        expr = q*Ef/(m*(omega0**2-omega**2))
        f = lambda x: x[:,[0]]*x[:,[1]]/(x[:,[2]]*(x[:,[3]]**2-x[:,[4]]**2))
        ranges = [[0,1],[0,1],[0.5,2],[1.5,3],[0,1]]
        
    if name == 'II.11.17' or name == 62:
        symbol = n0, pd, Ef, theta, kb, T = symbols('n_0 p_d E_f theta k_b T')
        expr = n0*(1+pd*Ef*cos(theta)/(kb*T))
        f = lambda x: x[:,[0]]*(1+x[:,[1]]*x[:,[2]]*torch.cos(x[:,[3]])/(x[:,[4]]*x[:,[5]]))
        ranges = [[0,1],[-1,1],[-1,1],[0,2*tpi],[0.5,2],[0.5,2]]
        
        
    if name == 'II.11.20' or name == 63:
        symbol = n, pd, Ef, kb, T = symbols('n p_d E_f k_b T')
        expr = n*pd**2*Ef/(3*kb*T)
        f = lambda x: x[:,[0]]*x[:,[1]]**2*x[:,[2]]/(3*x[:,[3]]*x[:,[4]])
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.11.27' or name == 64:
        symbol = n, alpha, eps, Ef = symbols('n alpha epsilon E_f')
        expr = n*alpha/(1-n*alpha/3)*eps*Ef
        f = lambda x: x[:,[0]]*x[:,[1]]/(1-x[:,[0]]*x[:,[1]]/3)*x[:,[2]]*x[:,[3]]
        ranges = [[0,1],[0,2],[0,1],[0,1]]
        
    if name == 'II.11.28' or name == 65:
        symbol = n, alpha = symbols('n alpha')
        expr = 1 + n*alpha/(1-n*alpha/3)
        f = lambda x: 1 + x[:,[0]]*x[:,[1]]/(1-x[:,[0]]*x[:,[1]]/3)
        ranges = [[0,1],[0,2]]
        
    if name == 'II.13.17' or name == 66:
        symbol = eps, c, l, r = symbols('epsilon c l r')
        expr = 1/(4*pi*eps*c**2)*(2*l/r)
        f = lambda x: 1/(4*tpi*x[:,[0]]*x[:,[1]]**2)*(2*x[:,[2]]/x[:,[3]])
        ranges = [[0.5,2],[0.5,2],[0,1],[0.5,2]]
        
    if name == 'II.13.23' or name == 67:
        symbol = rho, v, c = symbols('rho v c')
        expr = rho/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[0,1],[1,2]]
        
    if name == 'II.13.34' or name == 68:
        symbol = rho, v, c = symbols('rho v c')
        expr = rho*v/sqrt(1-v**2/c**2)
        f = lambda x: x[:,[0]]*x[:,[1]]/torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)
        ranges = [[0,1],[0,1],[1,2]]
        
    if name == 'II.15.4' or name == 69:
        symbol = muM, B, theta = symbols('mu_M B theta')
        expr = - muM * B * cos(theta)
        f = lambda x: - x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]])
        ranges = [[0,1],[0,1],[0,2*tpi]]
        
    if name == 'II.15.5' or name == 70:
        symbol = pd, Ef, theta = symbols('p_d E_f theta')
        expr = - pd * Ef * cos(theta)
        f = lambda x: - x[:,[0]]*x[:,[1]]*torch.cos(x[:,[2]])
        ranges = [[0,1],[0,1],[0,2*tpi]]
        
    if name == 'II.21.32' or name == 71:
        symbol = q, eps, r, v, c = symbols('q epsilon r v c')
        expr = q/(4*pi*eps*r*(1-v/c))
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,[2]]*(1-x[:,[3]]/x[:,[4]]))
        ranges = [[0,1],[0.5,2],[0.5,2],[0,1],[1,2]]
        
    if name == 'II.24.17' or name == 72:
        symbol = omega, c, d = symbols('omega c d')
        expr = sqrt(omega**2/c**2-pi**2/d**2)
        f = lambda x: torch.sqrt(x[:,[0]]**2/x[:,[1]]**2-tpi**2/x[:,[2]]**2)
        ranges = [[1,1.5],[0.75,1],[1*tpi,1.5*tpi]]
        
    if name == 'II.27.16' or name == 73:
        symbol = eps, c, Ef = symbols('epsilon c E_f')
        expr = eps * c * Ef**2
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]**2
        ranges = [[0,1],[0,1],[-1,1]]
        
    if name == 'II.27.18' or name == 74:
        symbol = eps, Ef = symbols('epsilon E_f')
        expr = eps * Ef**2
        f = lambda x: x[:,[0]]*x[:,[1]]**2
        ranges = [[0,1],[-1,1]]
        
    if name == 'II.34.2a' or name == 75:
        symbol = q, v, r = symbols('q v r')
        expr = q*v/(2*pi*r)
        f = lambda x: x[:,[0]]*x[:,[1]]/(2*tpi*x[:,[2]])
        ranges = [[0,1],[0,1],[0.5,2]]
        
    if name == 'II.34.2' or name == 76:
        symbol = q, v, r = symbols('q v r')
        expr = q*v*r/2
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/2
        ranges = [[0,1],[0,1],[0,1]]
        
    if name == 'II.34.11' or name == 77:
        symbol = g, q, B, m = symbols('g q B m')
        expr = g*q*B/(2*m)
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/(2*x[:,[3]])
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'II.34.29a' or name == 78:
        symbol = q, h, m = symbols('q h m')
        expr = q*h/(4*pi*m)
        f = lambda x: x[:,[0]]*x[:,[1]]/(4*tpi*x[:,[2]])
        ranges = [[0,1],[0,1],[0.5,2]]
        
    if name == 'II.34.29b' or name == 79:
        symbol = g, mu, B, J, hbar = symbols('g mu B J hbar')
        expr = g*mu*B*J/hbar
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]*x[:,[3]]/x[:,[4]]
        ranges = [[0,1],[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'II.35.18' or name == 80:
        symbol = n0, mu, B, kb, T = symbols('n0 mu B k_b T')
        expr = n0/(exp(mu*B/(kb*T))+exp(-mu*B/(kb*T)))
        f = lambda x: x[:,[0]]/(torch.exp(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))+torch.exp(-x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]])))
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.35.21' or name == 81:
        symbol = n, mu, B, kb, T = symbols('n mu B k_b T')
        expr = n*mu*tanh(mu*B/(kb*T))
        f = lambda x: x[:,[0]]*x[:,[1]]*torch.tanh(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.36.38' or name == 82:
        symbol = mu, B, kb, T, alpha, M, eps, c = symbols('mu B k_b T alpha M epsilon c')
        expr = mu*B/(kb*T) + mu*alpha*M/(eps*c**2*kb*T)
        f = lambda x: x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]) + x[:,[0]]*x[:,[4]]*x[:,[5]]/(x[:,[6]]*x[:,[7]]**2*x[:,[2]]*x[:,[3]])
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'II.37.1' or name == 83:
        symbol = mu, chi, B = symbols('mu chi B')
        expr = mu*(1+chi)*B
        f = lambda x: x[:,[0]]*(1+x[:,[1]])*x[:,[2]]
        ranges = [[0,1],[0,1],[0,1]]
        
    if name == 'II.38.3' or name == 84:
        symbol = Y, A, x, d = symbols('Y A x d')
        expr = Y*A*x/d
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'II.38.14' or name == 85:
        symbol = Y, sigma = symbols('Y sigma')
        expr = Y/(2*(1+sigma))
        f = lambda x: x[:,[0]]/(2*(1+x[:,[1]]))
        ranges = [[0,1],[0,1]]
        
    if name == 'III.4.32' or name == 86:
        symbol = hbar, omega, kb, T = symbols('hbar omega k_b T')
        expr = 1/(exp(hbar*omega/(kb*T))-1)
        f = lambda x: 1/(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]))-1)
        ranges = [[0.5,1],[0.5,1],[0.5,2],[0.5,2]]
        
    if name == 'III.4.33' or name == 87:
        symbol = hbar, omega, kb, T = symbols('hbar omega k_b T')
        expr = hbar*omega/(exp(hbar*omega/(kb*T))-1)
        f = lambda x: x[:,[0]]*x[:,[1]]/(torch.exp(x[:,[0]]*x[:,[1]]/(x[:,[2]]*x[:,[3]]))-1)
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'III.7.38' or name == 88:
        symbol = mu, B, hbar = symbols('mu B hbar')
        expr = 2*mu*B/hbar
        f = lambda x: 2*x[:,[0]]*x[:,[1]]/x[:,[2]]
        ranges = [[0,1],[0,1],[0.5,2]]
        
    if name == 'III.8.54' or name == 89:
        symbol = E, t, hbar = symbols('E t hbar')
        expr = sin(E*t/hbar)**2
        f = lambda x: torch.sin(x[:,[0]]*x[:,[1]]/x[:,[2]])**2
        ranges = [[0,2*tpi],[0,1],[0.5,2]]
        
    if name == 'III.9.52' or name == 90:
        symbol = pd, Ef, t, hbar, omega, omega0 = symbols('p_d E_f t hbar omega omega_0')
        expr = pd*Ef*t/hbar*sin((omega-omega0)*t/2)**2/((omega-omega0)*t/2)**2
        f = lambda x: x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]*torch.sin((x[:,[4]]-x[:,[5]])*x[:,[2]]/2)**2/((x[:,[4]]-x[:,[5]])*x[:,[2]]/2)**2
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0,tpi],[0,tpi]]
        
    if name == 'III.10.19' or name == 91:
        symbol = mu, Bx, By, Bz = symbols('mu B_x B_y B_z')
        expr = mu*sqrt(Bx**2+By**2+Bz**2)
        f = lambda x: x[:,[0]]*torch.sqrt(x[:,[1]]**2+x[:,[2]]**2+x[:,[3]]**2)
        ranges = [[0,1],[0,1],[0,1],[0,1]]
        
    if name == 'III.12.43' or name == 92:
        symbol = n, hbar = symbols('n hbar')
        expr = n * hbar
        f = lambda x: x[:,[0]]*x[:,[1]]
        ranges = [[0,1],[0,1]]
        
    if name == 'III.13.18' or name == 93:
        symbol = E, d, k, hbar = symbols('E d k hbar')
        expr = 2*E*d**2*k/hbar
        f = lambda x: 2*x[:,[0]]*x[:,[1]]**2*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'III.14.14' or name == 94:
        symbol = I0, q, Ve, kb, T = symbols('I_0 q V_e k_b T')
        expr = I0 * (exp(q*Ve/(kb*T))-1)
        f = lambda x: x[:,[0]]*(torch.exp(x[:,[1]]*x[:,[2]]/(x[:,[3]]*x[:,[4]]))-1)
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2]]
        
    if name == 'III.15.12' or name == 95:
        symbol = U, k, d = symbols('U k d')
        expr = 2*U*(1-cos(k*d))
        f = lambda x: 2*x[:,[0]]*(1-torch.cos(x[:,[1]]*x[:,[2]]))
        ranges = [[0,1],[0,2*tpi],[0,1]]
        
    if name == 'III.15.14' or name == 96:
        symbol = hbar, E, d = symbols('hbar E d')
        expr = hbar**2/(2*E*d**2)
        f = lambda x: x[:,[0]]**2/(2*x[:,[1]]*x[:,[2]]**2)
        ranges = [[0,1],[0.5,2],[0.5,2]]
        
    if name == 'III.15.27' or name == 97:
        symbol = alpha, n, d = symbols('alpha n d')
        expr = 2*pi*alpha/(n*d)
        f = lambda x: 2*tpi*x[:,[0]]/(x[:,[1]]*x[:,[2]])
        ranges = [[0,1],[0.5,2],[0.5,2]]
        
    if name == 'III.17.37' or name == 98:
        symbol = beta, alpha, theta = symbols('beta alpha theta')
        expr = beta * (1+alpha*cos(theta))
        f = lambda x: x[:,[0]]*(1+x[:,[1]]*torch.cos(x[:,[2]]))
        ranges = [[0,1],[0,1],[0,2*tpi]]
        
    if name == 'III.19.51' or name == 99:
        symbol = m, q, eps, hbar, n = symbols('m q epsilon hbar n')
        expr = - m * q**4/(2*(4*pi*eps)**2*hbar**2)*1/n**2
        f = lambda x: - x[:,[0]]*x[:,[1]]**4/(2*(4*tpi*x[:,[2]])**2*x[:,[3]]**2)*1/x[:,[4]]**2
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2]]
        
    if name == 'III.21.20' or name == 100:
        symbol = rho, q, A, m = symbols('rho q A m')
        expr = - rho*q*A/m
        f = lambda x: - x[:,[0]]*x[:,[1]]*x[:,[2]]/x[:,[3]]
        ranges = [[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'Rutherforld scattering' or name == 101:
        symbol = Z1, Z2, alpha, hbar, c, E, theta = symbols('Z_1 Z_2 alpha hbar c E theta')
        expr = (Z1*Z2*alpha*hbar*c/(4*E*sin(theta/2)**2))**2
        f = lambda x: (x[:,[0]]*x[:,[1]]*x[:,[2]]*x[:,[3]]*x[:,[4]]/(4*x[:,[5]]*torch.sin(x[:,[6]]/2)**2))**2
        ranges = [[0,1],[0,1],[0,1],[0,1],[0,1],[0.5,2],[0.1*tpi,0.9*tpi]]
        
    if name == 'Friedman equation' or name == 102:
        symbol = G, rho, kf, c, af = symbols('G rho k_f c a_f')
        expr = sqrt(8*pi*G/3*rho-kf*c**2/af**2)
        f = lambda x: torch.sqrt(8*tpi*x[:,[0]]/3*x[:,[1]] - x[:,[2]]*x[:,[3]]**2/x[:,[4]]**2)
        ranges = [[1,2],[1,2],[0,1],[0,1],[1,2]]
        
    if name == 'Compton scattering' or name == 103:
        symbol = E, m, c, theta = symbols('E m c theta')
        expr = E/(1+E/(m*c**2)*(1-cos(theta)))
        f = lambda x: x[:,[0]]/(1+x[:,[0]]/(x[:,[1]]*x[:,[2]]**2)*(1-torch.cos(x[:,[3]])))
        ranges = [[0,1],[0.5,2],[0.5,2],[0,2*tpi]]
        
    if name == 'Radiated gravitational wave power' or name == 104:
        symbol = G, c, m1, m2, r = symbols('G c m_1 m_2 r')
        expr = -32/5*G**4/c**5*(m1*m2)**2*(m1+m2)/r**5
        f = lambda x: -32/5*x[:,[0]]**4/x[:,[1]]**5*(x[:,[2]]*x[:,[3]])**2*(x[:,[2]]+x[:,[3]])/x[:,[4]]**5
        ranges = [[0,1],[0.5,2],[0,1],[0,1],[0.5,2]]
        
    if name == 'Relativistic aberration' or name == 105:
        symbol = theta2, v, c = symbols('theta_2 v c')
        expr = acos((cos(theta2)-v/c)/(1-v/c*cos(theta2)))
        f = lambda x: torch.arccos((torch.cos(x[:,[0]])-x[:,[1]]/x[:,[2]])/(1-x[:,[1]]/x[:,[2]]*torch.cos(x[:,[0]])))
        ranges = [[0,tpi],[0,1],[1,2]]
        
    if name == 'N-slit diffraction' or name == 106:
        symbol = I0, alpha, delta, N = symbols('I_0 alpha delta N')
        expr = I0 * (sin(alpha/2)/(alpha/2)*sin(N*delta/2)/sin(delta/2))**2
        f = lambda x: x[:,[0]] * (torch.sin(x[:,[1]]/2)/(x[:,[1]]/2)*torch.sin(x[:,[3]]*x[:,[2]]/2)/torch.sin(x[:,[2]]/2))**2
        ranges = [[0,1],[0.1*tpi,0.9*tpi],[0.1*tpi,0.9*tpi],[0.5,1]]
        
    if name == 'Goldstein 3.16' or name == 107:
        symbol = m, E, U, L, r = symbols('m E U L r')
        expr = sqrt(2/m*(E-U-L**2/(2*m*r**2)))
        f = lambda x: torch.sqrt(2/x[:,[0]]*(x[:,[1]]-x[:,[2]]-x[:,[3]]**2/(2*x[:,[0]]*x[:,[4]]**2)))
        ranges = [[1,2],[2,3],[0,1],[0,1],[1,2]]
        
    if name == 'Goldstein 3.55' or name == 108:
        symbol = m, kG, L, E, theta1, theta2 = symbols('m k_G L E theta_1 theta_2')
        expr = m*kG/L**2*(1+sqrt(1+2*E*L**2/(m*kG**2))*cos(theta1-theta2))
        f = lambda x: x[:,[0]]*x[:,[1]]/x[:,[2]]**2*(1+torch.sqrt(1+2*x[:,[3]]*x[:,[2]]**2/(x[:,[0]]*x[:,[1]]**2))*torch.cos(x[:,[4]]-x[:,[5]]))
        ranges = [[0.5,2],[0.5,2],[0.5,2],[0,1],[0,2*tpi],[0,2*tpi]]
        
    if name == 'Goldstein 3.64 (ellipse)' or name == 109:
        symbol = d, alpha, theta1, theta2 = symbols('d alpha theta_1 theta_2')
        expr = d*(1-alpha**2)/(1+alpha*cos(theta2-theta1))
        f = lambda x: x[:,[0]]*(1-x[:,[1]]**2)/(1+x[:,[1]]*torch.cos(x[:,[2]]-x[:,[3]]))
        ranges = [[0,1],[0,0.9],[0,2*tpi],[0,2*tpi]]
        
    if name == 'Goldstein 3.74 (Kepler)' or name == 110:
        symbol = d, G, m1, m2 = symbols('d G m_1 m_2')
        expr = 2*pi*d**(3/2)/sqrt(G*(m1+m2))
        f = lambda x: 2*tpi*x[:,[0]]**(3/2)/torch.sqrt(x[:,[1]]*(x[:,[2]]+x[:,[3]]))
        ranges = [[0,1],[0.5,2],[0.5,2],[0.5,2]]
        
    if name == 'Goldstein 3.99' or name == 111:
        symbol = eps, E, L, m, Z1, Z2, q = symbols('epsilon E L m Z_1 Z_2 q')
        expr = sqrt(1+2*eps**2*E*L**2/(m*(Z1*Z2*q**2)**2))
        f = lambda x: torch.sqrt(1+2*x[:,[0]]**2*x[:,[1]]*x[:,[2]]**2/(x[:,[3]]*(x[:,[4]]*x[:,[5]]*x[:,[6]]**2)**2))
        ranges = [[0,1],[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2],[0.5,2]]
        
    if name == 'Goldstein 8.56' or name == 112:
        symbol = p, q, A, c, m, Ve = symbols('p q A c m V_e')
        expr = sqrt((p-q*A)**2*c**2+m**2*c**4) + q*Ve
        f = lambda x: torch.sqrt((x[:,[0]]-x[:,[1]]*x[:,[2]])**2*x[:,[3]]**2+x[:,[4]]**2*x[:,[3]]**4) + x[:,[1]]*x[:,[5]]
        ranges = [0,1]
        
    if name == 'Goldstein 12.80' or name == 113:
        symbol = m, p, omega, x, alpha, y = symbols('m p omega x alpha y')
        expr = 1/(2*m)*(p**2+m**2*omega**2*x**2*(1+alpha*y/x))
        f = lambda x: 1/(2*x[:,[0]]) * (x[:,[1]]**2+x[:,[0]]**2*x[:,[2]]**2*x[:,[3]]**2*(1+x[:,[4]]*x[:,[3]]/x[:,[5]]))
        ranges = [[0.5,2],[0,1],[0,1],[0,1],[0,1],[0.5,2]]
        
    if name == 'Jackson 2.11' or name == 114:
        symbol = q, eps, y, Ve, d = symbols('q epsilon y V_e d')
        expr = q/(4*pi*eps*y**2)*(4*pi*eps*Ve*d-q*d*y**3/(y**2-d**2)**2)
        f = lambda x: x[:,[0]]/(4*tpi*x[:,[1]]*x[:,x[:,[2]]]**2)*(4*tpi*x[:,[1]]*x[:,[3]]*x[:,[4]]-x[:,[0]]*x[:,[4]]*x[:,[2]]**3/(x[:,[2]]**2-x[:,[4]]**2)**2)
        ranges = [[0,1],[0.5,2],[1,2],[0,1],[0,1]]
        
    if name == 'Jackson 3.45' or name == 115:
        symbol = q, r, d, alpha = symbols('q r d alpha')
        expr = q/sqrt(r**2+d**2-2*d*r*cos(alpha))
        f = lambda x: x[:,[0]]/torch.sqrt(x[:,[1]]**2+x[:,[2]]**2-2*x[:,[1]]*x[:,[2]]*torch.cos(x[:,[3]]))
        ranges = [[0,1],[0,1],[0,1],[0,2*tpi]]
        
    if name == 'Jackson 4.60' or name == 116:
        symbol = Ef, theta, alpha, d, r = symbols('E_f theta alpha d r')
        expr = Ef * cos(theta) * ((alpha-1)/(alpha+2) * d**3/r**2 - r)
        f = lambda x: x[:,[0]] * torch.cos(x[:,[1]]) * ((x[:,[2]]-1)/(x[:,[2]]+2) * x[:,[3]]**3/x[:,[4]]**2 - x[:,[4]])
        ranges = [[0,1],[0,2*tpi],[0,2],[0,1],[0.5,2]]
        
    if name == 'Jackson 11.38 (Doppler)' or name == 117:
        symbol = omega, v, c, theta = symbols('omega v c theta')
        expr = sqrt(1-v**2/c**2)/(1+v/c*cos(theta))*omega
        f = lambda x: torch.sqrt(1-x[:,[1]]**2/x[:,[2]]**2)/(1+x[:,[1]]/x[:,[2]]*torch.cos(x[:,[3]]))*x[:,[0]]
        ranges = [[0,1],[0,1],[1,2],[0,2*tpi]]
        
    if name == 'Weinberg 15.2.1' or name == 118:
        symbol = G, c, kf, af, H = symbols('G c k_f a_f H')
        expr = 3/(8*pi*G)*(c**2*kf/af**2+H**2)
        f = lambda x: 3/(8*tpi*x[:,[0]])*(x[:,[1]]**2*x[:,[2]]/x[:,[3]]**2+x[:,[4]]**2)
        ranges = [[0.5,2],[0,1],[0,1],[0.5,2],[0,1]]
        
    if name == 'Weinberg 15.2.2' or name == 119:
        symbol = G, c, kf, af, H, alpha = symbols('G c k_f a_f H alpha')
        expr = -1/(8*pi*G)*(c**4*kf/af**2+c**2*H**2*(1-2*alpha))
        f = lambda x: -1/(8*tpi*x[:,[0]])*(x[:,[1]]**4*x[:,[2]]/x[:,[3]]**2 + x[:,[1]]**2*x[:,[4]]**2*(1-2*x[:,[5]]))
        ranges = [[0.5,2],[0,1],[0,1],[0.5,2],[0,1],[0,1]]
        
    if name == 'Schwarz 13.132 (Klein-Nishina)' or name == 120:
        symbol = alpha, hbar, m, c, omega0, omega, theta = symbols('alpha hbar m c omega_0 omega theta')
        expr = pi*alpha**2*hbar**2/m**2/c**2*(omega0/omega)**2*(omega0/omega+omega/omega0-sin(theta)**2)
        f = lambda x: tpi*x[:,[0]]**2*x[:,[1]]**2/x[:,[2]]**2/x[:,[3]]**2*(x[:,[4]]/x[:,[5]])**2*(x[:,[4]]/x[:,[5]]+x[:,[5]]/x[:,[4]]-torch.sin(x[:,[6]])**2)
        ranges = [[0,1],[0,1],[0.5,2],[0.5,2],[0.5,2],[0.5,2],[0,2*tpi]]
        
    return symbol, expr, f, ranges
# f_1 = lambda x: torch.exp(-(x[:, 0]**2)/(2*x[:, 1]**2)) / torch.sqrt(2*torch.pi*x[:, 1]**2)
# f_2 = lambda x: torch.exp((-(x[:, 0] - x[:, 1])**2)/(2*x[:, 2]**2)) / torch.sqrt(2*torch.pi*x[:, 1]**2)

In [None]:
import numpy as np 
import torch 
from src.efficient_kan.group_kan import Knots_KAN
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import mean_squared_error
from tqdm import trange




def trainer(model, train_x, train_y, test_x, test_y, epochs=100, lr=0.5, lamb=1e-3, reg=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    # best_model = None
    # best_loss = np.inf
    
    best_mse = np.inf
    best_rmse = np.inf
    for e in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(train_x)
        loss = criterion(output, train_y)

        if reg:
            l2_norm = 0
            for layer in model.layers:
                f_derive = torch.autograd.grad(loss, layer.spline_lin_c, create_graph=True)[0]
                s_derive = torch.autograd.grad(f_derive.sum(), layer.spline_lin_c, create_graph=True)[0]
                l2_norm += torch.norm(s_derive, p=2)
            loss += lamb * l2_norm
            
        loss.backward()
        optimizer.step()
        # print(loss)
        # if loss.item() < best_loss:
        #     best_loss = loss.item()
        #     torch.save(model, 'best_model_fn_ft.pth')
        
        model.to(device)
        model.eval()
        with torch.no_grad():
            output = model(test_x)
            output = output.detach().cpu().numpy()
            test_y = test_y.detach().cpu()
            mse = mean_squared_error(test_y, output)
            rmse = mean_squared_error(test_y, output, squared=False)
            
            if best_mse > mse:
                best_mse = mse
                print(best_mse)
            if best_rmse > rmse:
                best_rmse = rmse
            
    return best_mse, best_rmse


symbol, expr, f_x, ranges = get_feynman_dataset('I.6.20')

n_var = len(ranges)
print(ranges)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = create_dataset(f_x, n_var=n_var, f_mode='col', device=device)
train_x = dataset['train_input']
train_y = dataset['train_label']
test_x = dataset['test_input']
test_y = dataset['test_label']

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2, device=device)
train_x = dataset['train_input']
train_y = dataset['train_label']
test_x = dataset['test_input']
test_y = dataset['test_label']

n_var = 2

# Knot_KAN

In [None]:

# No layernorm or batchnorm needed for function fitting
mse = []
rmse = []
k_fold = 1
for k in range(k_fold):
# initialize KAN with G=3
    model = Knots_KAN(layers_hidden=[n_var, 2, 1], 
                    grid_size=10, 
                    spline_order=3, 
                    grid_range=[-6, 6], 
                    groups=-1,
                    need_relu=True
                    ).to(device)

    test_mse, test_rmse = trainer(model, train_x, train_y, test_x, test_y, epochs=300, lr=0.03, lamb=1e-5, reg=True)
    mse.append(test_mse)
    rmse.append(test_rmse)


print("Ave MSE:", np.mean(mse))
print("Ave RMSE:", np.mean(rmse))



In [None]:
import torch

grid_size = 100
xs = torch.linspace(-1, 1, grid_size)
ys = torch.linspace(-1, 1, grid_size)
X_grid, Y_grid = torch.meshgrid(xs, ys, indexing='xy')
# X_grid, Y_grid each have shape (50, 50).

# Flatten into a list of points:
points = torch.stack([X_grid.ravel(), Y_grid.ravel()], dim=-1).to(device)
print(points.shape)  # torch.Size([2500, 2])


model.eval()
with torch.no_grad():
    # output = model.layer_norm[0](points)
    output = model.layers[0](points)
    # output = model(points)

# KAN

In [None]:
from src.efficient_kan.kan import KAN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

mse = []
rmse = []
k_fold = 1
for k in range(k_fold):
# initialize KAN with G=3
    model = KAN(layers_hidden=[n_var, 2, 1], 
                    grid_size=10, 
                    spline_order=3, 
                    grid_range=[-4, 4]
                    ).to(device)

    test_mse, test_rmse = trainer(model, train_x, train_y, test_x, test_y, epochs=300, lr=0.05, reg=False)
    mse.append(test_mse)
    rmse.append(test_rmse)
    
print("Ave MSE:", np.mean(mse))
print("Ave RMSE:", np.mean(rmse))

# Rational_KAN

In [None]:
from src.efficient_kan.rational_kan import Rational_KAN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

mse = []
rmse = []
k_fold = 3
for k in range(k_fold):
# initialize KAN with G=3
    model = Rational_KAN(layers_hidden=[n_var, 2, 1], 
                    P_order=5,
                    Q_order=5,
                    groups=-1
                    ).to(device)

    test_mse, test_rmse = trainer(model, train_x, train_y, test_x, test_y, epochs=300, lr=0.05, reg=False)
    mse.append(test_mse)
    rmse.append(test_rmse)
    
print("Ave MSE:", np.mean(mse))
print("Ave RMSE:", np.mean(rmse))

# Fourier-KAN

In [None]:
from src.efficient_kan.fourier_kan import Fourier_KAN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

mse = []
rmse = []
k_fold = 3
for k in range(k_fold):
# initialize KAN with G=3
    model = Fourier_KAN(layers_hidden=[n_var, 2, 1], 
                    grid_size=10,
                    spline_order=3
                    ).to(device)

    test_mse, test_rmse = trainer(model, train_x, train_y, test_x, test_y, epochs=300, lr=0.01, reg=False)
    mse.append(test_mse)
    rmse.append(test_rmse)
    
    
print("Ave MSE:", np.mean(mse))
print("Ave RMSE:", np.mean(rmse))

# RBF-KAN

In [None]:
from src.efficient_kan.rbf_kan import RBF_KAN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

mse = []
rmse = []
k_fold = 3
for k in range(k_fold):
# initialize KAN with G=3
    model = RBF_KAN(layers_hidden=[n_var, 2, 1], 
                    grid_size=10,
                    grid_range=[-4, 4]
                    ).to(device)

    test_mse, test_rmse = trainer(model, train_x, train_y, test_x, test_y, epochs=500, lr=0.02, reg=False)
    mse.append(test_mse)
    rmse.append(test_rmse)
    
    
print("Ave MSE:", np.mean(mse))
print("Ave RMSE:", np.mean(rmse))

# MLP

In [None]:
from src.efficient_kan.group_kan import MLP

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

mse = []
rmse = []
k_fold = 3
for k in range(k_fold):
# initialize KAN with G=3
    model = MLP(layers_hidden=[n_var, 3, 1]).to(device)
    test_mse, test_rmse = trainer(model, train_x, train_y, test_x, test_y, epochs=300, lr=0.01, reg=False)
    mse.append(test_mse)
    rmse.append(test_rmse)
    
print("Ave MSE:", np.mean(mse))
print("Ave RMSE:", np.mean(rmse))