In [1]:
import sys
sys.path.append('../')
from kan import KAN

In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import sympy as sp
from mpmath import mp, mpf

In [3]:
data = pd.read_csv("./Mushroom_Dataset/mushroom_cleaned.csv")
data.head()

Unnamed: 0,cap-diameter,cap-shape,gill-attachment,gill-color,stem-height,stem-width,stem-color,season,class
0,1372,2,2,10,3.807467,1545,11,1.804273,1
1,1461,2,2,10,3.807467,1557,11,1.804273,1
2,1371,2,2,10,3.612496,1566,11,1.804273,1
3,1261,6,2,10,3.787572,1566,11,1.804273,1
4,1305,6,2,10,3.711971,1464,11,0.943195,1


In [4]:
features = data.drop('class', axis=1)
target = data['class'].astype(int)

scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)

X_train, X_test, y_train, y_test = train_test_split(features_scaled,
                                                    target,
                                                    test_size=0.2,
                                                    random_state=42,
                                                    shuffle=True,
                                                    stratify=target)

In [5]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
device = 'cuda'

print(device)

cuda


In [6]:
dataset = {}
dataset['train_input'] = torch.from_numpy(X_train).to(device)
dataset['test_input'] = torch.from_numpy(X_test).to(device)
dataset['train_label'] = torch.from_numpy(np.array(y_train)).long().to(device)
dataset['test_label'] = torch.from_numpy(np.array(y_test)).long().to(device)

In [7]:
print(dataset['train_input'].shape)
print(dataset['train_label'].shape)
print(dataset['test_input'].shape)
print(dataset['test_label'].shape)

torch.Size([43228, 8])
torch.Size([43228])
torch.Size([10807, 8])
torch.Size([10807])


In [8]:
model = KAN(width=[8,24,2], grid=6, k=3, device=device)

def train_acc():
    return torch.mean((torch.argmax(model(dataset['train_input']), dim=1) == dataset['train_label']).float())

def test_acc():
    return torch.mean((torch.argmax(model(dataset['test_input']), dim=1) == dataset['test_label']).float())

model.load_ckpt('test_2_mushroom_softmax.ckpt')
res_tmp = model(dataset['test_input'])

In [9]:
model.fix_symbolic(0,0,1, 'custom_29')
model.fix_symbolic(0,0,2, 'custom_30')
model.fix_symbolic(0,0,5, 'custom_20')
model.fix_symbolic(0,0,8, 'custom_30')
model.fix_symbolic(0,0,9, 'custom_8')
model.fix_symbolic(0,0,10, 'custom_29')
model.fix_symbolic(0,0,14, 'custom_30')
model.fix_symbolic(0,0,15, 'custom_28')
model.fix_symbolic(0,0,17, 'custom_20')
model.fix_symbolic(0,0,18, 'custom_23')
model.fix_symbolic(0,0,19, 'custom_22')
model.fix_symbolic(0,0,22, 'custom_19')
model.fix_symbolic(0,1,0, 'custom_11')
model.fix_symbolic(0,1,2, 'custom_11')
model.fix_symbolic(0,1,3, 'custom_7')
model.fix_symbolic(0,1,5, 'custom_2')
model.fix_symbolic(0,1,9, 'custom_2')
model.fix_symbolic(0,1,10, 'custom_12')
model.fix_symbolic(0,1,11, 'custom_19')
model.fix_symbolic(0,1,18, 'custom_7')
model.fix_symbolic(0,1,20, 'custom_2')
model.fix_symbolic(0,1,22, 'custom_12')
model.fix_symbolic(0,2,1, 'custom_11')
model.fix_symbolic(0,2,3, 'custom_11')
model.fix_symbolic(0,2,4, 'custom_28')
model.fix_symbolic(0,2,5, 'custom_8')
model.fix_symbolic(0,2,6, 'custom_28')
model.fix_symbolic(0,2,8, 'custom_16')
model.fix_symbolic(0,2,9, 'custom_2')
model.fix_symbolic(0,2,10, 'custom_16')
model.fix_symbolic(0,2,12, 'custom_8')
model.fix_symbolic(0,2,14, 'custom_1')
model.fix_symbolic(0,2,15, 'custom_1')
model.fix_symbolic(0,2,16, 'custom_5')
model.fix_symbolic(0,2,17, 'custom_22')
model.fix_symbolic(0,2,18, 'custom_16')
model.fix_symbolic(0,2,20, 'custom_25')
model.fix_symbolic(0,2,21, 'custom_30')
model.fix_symbolic(0,2,23, 'custom_8')
model.fix_symbolic(0,3,2, 'custom_11')
model.fix_symbolic(0,3,3, 'custom_5')
model.fix_symbolic(0,3,7, 'custom_23')
model.fix_symbolic(0,3,14, 'custom_4')
model.fix_symbolic(0,4,0, 'custom_14')
model.fix_symbolic(0,4,2, 'custom_29')
model.fix_symbolic(0,4,4, 'custom_33')
model.fix_symbolic(0,4,5, 'custom_14')
model.fix_symbolic(0,4,6, 'custom_11')
model.fix_symbolic(0,4,7, 'custom_0')
model.fix_symbolic(0,4,10, 'custom_33')
model.fix_symbolic(0,4,13, 'custom_23')
model.fix_symbolic(0,4,20, 'custom_25')
model.fix_symbolic(0,4,22, 'custom_13')
model.fix_symbolic(0,4,23, 'custom_2')
model.fix_symbolic(0,5,1, 'custom_16')
model.fix_symbolic(0,5,6, 'custom_16')
model.fix_symbolic(0,5,9, 'custom_13')
model.fix_symbolic(0,5,15, 'custom_20')
model.fix_symbolic(0,6,0, 'custom_25')
model.fix_symbolic(0,6,1, 'custom_29')
model.fix_symbolic(0,6,2, 'custom_2')
model.fix_symbolic(0,6,5, 'custom_9')
model.fix_symbolic(0,6,8, 'custom_10')
model.fix_symbolic(0,6,9, 'custom_25')
model.fix_symbolic(0,6,12, 'custom_7')
model.fix_symbolic(0,6,13, 'custom_17')
model.fix_symbolic(0,6,16, 'custom_2')
model.fix_symbolic(0,7,0, 'custom_27')
model.fix_symbolic(0,7,1, 'custom_22')
model.fix_symbolic(0,7,2, 'custom_9')
model.fix_symbolic(0,7,3, 'custom_0')
model.fix_symbolic(0,7,4, 'custom_20')
model.fix_symbolic(0,7,5, 'custom_9')
model.fix_symbolic(0,7,6, 'custom_26')
model.fix_symbolic(0,7,7, 'custom_9')
model.fix_symbolic(0,7,8, 'custom_14')
model.fix_symbolic(0,7,9, 'custom_25')
model.fix_symbolic(0,7,10, 'custom_27')
model.fix_symbolic(0,7,11, 'custom_22')
model.fix_symbolic(0,7,12, 'custom_33')
model.fix_symbolic(0,7,13, 'custom_22')
model.fix_symbolic(0,7,14, 'custom_26')
model.fix_symbolic(0,7,15, 'custom_25')
model.fix_symbolic(0,7,16, 'custom_33')
model.fix_symbolic(0,7,17, 'custom_11')
model.fix_symbolic(0,7,18, 'custom_26')
model.fix_symbolic(0,7,19, 'custom_22')
model.fix_symbolic(0,7,20, 'custom_33')
model.fix_symbolic(0,7,21, 'custom_25')
model.fix_symbolic(0,7,22, 'custom_33')
model.fix_symbolic(0,7,23, 'custom_25')
model.fix_symbolic(1,1,0,'custom_9')
model.fix_symbolic(1,1,1,'custom_9')
model.fix_symbolic(1,2,0,'relu')
model.fix_symbolic(1,2,1,'relu')
model.fix_symbolic(1,4,0,'relu')
model.fix_symbolic(1,4,1,'relu')
model.fix_symbolic(1,5,1,'relu')
model.fix_symbolic(1,8, 0,'custom_9')
model.fix_symbolic(1,11,0,'custom_9')
model.fix_symbolic(1,11,1,'custom_9')
model.fix_symbolic(1,12,0,'custom_23')
model.fix_symbolic(1,12,1,'custom_12')
model.fix_symbolic(1,13,0,'custom_1')
model.fix_symbolic(1,13,1,'custom_1')
model.fix_symbolic(1,17,0,'relu')
model.fix_symbolic(1,17,1,'relu')
model.fix_symbolic(1,19,0,'custom_0')
model.fix_symbolic(1,19,1,'custom_12')

r2 is 0.9978988819633795
r2 is 0.998432459530966
r2 is 0.9973281040917942
r2 is 0.9990624982177914
r2 is 0.9970524487546712
r2 is 0.9986642243651711
r2 is 0.9993906324435142
r2 is 0.9977897861593943
r2 is 0.9972172724509125
r2 is 0.9995755294712914
r2 is 0.9986855433823527
r2 is 0.9989097218341008
r2 is 0.9981109834031752
r2 is 0.9983163314427845
r2 is 0.9985970044085454
r2 is 0.9990914893514359
r2 is 0.9980629803495124
r2 is 0.9999564072759848
r2 is 0.9993292942395082
r2 is 0.9972573100208288
r2 is 0.999043734342271
r2 is 0.9992435602960943
r2 is 0.9997826732933714
r2 is 0.9994450012880779
r2 is 0.999705814452179
r2 is 0.9997416306013897
r2 is 0.9985111032993822
r2 is 0.9987898921679864
r2 is 0.9973945391117428
r2 is 0.999674254829953
r2 is 0.9970826568814255
r2 is 0.9994592950671074
r2 is 0.9997114359423743
r2 is 0.9992796253116333
r2 is 0.9979550035049141
r2 is 0.9996544627900671
r2 is 0.9972373098913699
r2 is 0.9991570741081759
r2 is 0.9999170636551356
r2 is 0.9975714707834771
r2 i

tensor(1.0000, device='cuda:0', dtype=torch.float64, grad_fn=<SelectBackward0>)

In [10]:
model.fix_symbolic(0,0,0,'custom_40')
model.fix_symbolic(0,0,3,'custom_41')
model.fix_symbolic(0,0,4,'custom_42')
model.fix_symbolic(0,0,6,'custom_43')
model.fix_symbolic(0,0,7,'custom_44')
model.fix_symbolic(0,0,11,'custom_45')
model.fix_symbolic(0,0,12,'custom_46')
model.fix_symbolic(0,0,13,'custom_47')
model.fix_symbolic(0,0,16,'custom_48')
model.fix_symbolic(0,0,20,'custom_49')
model.fix_symbolic(0,0,21,'custom_50')
model.fix_symbolic(0,0,23,'custom_51')
model.fix_symbolic(0,1,1,'custom_52')
model.fix_symbolic(0,1,4,'custom_53')
model.fix_symbolic(0,1,6,'custom_54')
model.fix_symbolic(0,1,7,'custom_55')
model.fix_symbolic(0,1,8,'custom_56')
model.fix_symbolic(0,1,12,'custom_57')
model.fix_symbolic(0,1,13,'custom_58')
model.fix_symbolic(0,1,14,'custom_59')
model.fix_symbolic(0,1,15,'custom_60')
model.fix_symbolic(0,1,16,'custom_61')
model.fix_symbolic(0,1,17,'custom_62')
model.fix_symbolic(0,1,19,'custom_63')
model.fix_symbolic(0,1,21,'custom_64')
model.fix_symbolic(0,1,23,'custom_65')
model.fix_symbolic(0,2,0,'custom_66')
model.fix_symbolic(0,2,2,'custom_67')
model.fix_symbolic(0,2,7,'custom_68')
model.fix_symbolic(0,2,11,'custom_69')
model.fix_symbolic(0,2,13,'custom_70')
model.fix_symbolic(0,2,19,'custom_71')
model.fix_symbolic(0,2,22,'custom_72')
model.fix_symbolic(0,3,0,'custom_73')
model.fix_symbolic(0,3,1,'custom_74')
model.fix_symbolic(0,3,4,'custom_75')
model.fix_symbolic(0,3,5,'custom_76')
model.fix_symbolic(0,3,6,'custom_77')
model.fix_symbolic(0,3,8,'custom_78')
model.fix_symbolic(0,3,9,'custom_79')
model.fix_symbolic(0,3,10,'custom_80')
model.fix_symbolic(0,3,11,'custom_81')
model.fix_symbolic(0,3,12,'custom_82')
model.fix_symbolic(0,3,13,'custom_83')
model.fix_symbolic(0,3,15,'custom_84')
model.fix_symbolic(0,3,16,'custom_85')
model.fix_symbolic(0,3,17,'custom_86')
model.fix_symbolic(0,3,18,'custom_87')
model.fix_symbolic(0,3,19,'custom_88')
model.fix_symbolic(0,3,20,'custom_89')
model.fix_symbolic(0,3,21,'custom_90')
model.fix_symbolic(0,3,22,'custom_91')
model.fix_symbolic(0,3,23,'custom_92')
model.fix_symbolic(0,4,1,'custom_93')
model.fix_symbolic(0,4,3,'custom_94')
model.fix_symbolic(0,4,8,'custom_95')
model.fix_symbolic(0,4,9,'custom_96')
model.fix_symbolic(0,4,11,'custom_97')
model.fix_symbolic(0,4,12,'custom_98')
model.fix_symbolic(0,4,14,'custom_99')
model.fix_symbolic(0,4,15,'custom_100')
model.fix_symbolic(0,4,16,'custom_101')
model.fix_symbolic(0,4,17,'custom_102')
model.fix_symbolic(0,4,18,'custom_103')
model.fix_symbolic(0,4,19,'custom_104')
model.fix_symbolic(0,4,21,'custom_105')
model.fix_symbolic(0,5,0,'custom_106')
model.fix_symbolic(0,5,2,'custom_107')
model.fix_symbolic(0,5,3,'custom_108')
model.fix_symbolic(0,5,4,'custom_109')
model.fix_symbolic(0,5,5,'custom_110')
model.fix_symbolic(0,5,7,'custom_111')
model.fix_symbolic(0,5,8,'custom_112')
model.fix_symbolic(0,5,10,'custom_113')
model.fix_symbolic(0,5,11,'custom_114')
model.fix_symbolic(0,5,12,'custom_115')
model.fix_symbolic(0,5,13,'custom_116')
model.fix_symbolic(0,5,14,'custom_117')
model.fix_symbolic(0,5,16,'custom_118')
model.fix_symbolic(0,5,17,'custom_119')
model.fix_symbolic(0,5,18,'custom_120')
model.fix_symbolic(0,5,19,'custom_121')
model.fix_symbolic(0,5,20,'custom_122')
model.fix_symbolic(0,5,21,'custom_123')
model.fix_symbolic(0,5,22,'custom_124')
model.fix_symbolic(0,5,23,'custom_125')
model.fix_symbolic(0,6,3,'custom_126')
model.fix_symbolic(0,6,4,'custom_127')
model.fix_symbolic(0,6,6,'custom_128')
model.fix_symbolic(0,6,7,'custom_129')
model.fix_symbolic(0,6,10,'custom_130')
model.fix_symbolic(0,6,11,'custom_131')
model.fix_symbolic(0,6,14,'custom_132')
model.fix_symbolic(0,6,15,'custom_133')
model.fix_symbolic(0,6,17,'custom_134')
model.fix_symbolic(0,6,18,'custom_135')
model.fix_symbolic(0,6,19,'custom_136')
model.fix_symbolic(0,6,20,'custom_137')
model.fix_symbolic(0,6,21,'custom_138')
model.fix_symbolic(0,6,22,'custom_139')
model.fix_symbolic(0,6,23,'custom_140')

model.fix_symbolic(1,0,0, 'custom_141')
model.fix_symbolic(1,0,1, 'custom_142')
model.fix_symbolic(1,3,0, 'custom_143')
model.fix_symbolic(1,3,1, 'custom_144')
model.fix_symbolic(1,5,0, 'custom_145')
model.fix_symbolic(1,6,0, 'custom_146')
model.fix_symbolic(1,6,1, 'custom_147')
model.fix_symbolic(1,7,0, 'custom_148') #
model.fix_symbolic(1,7,1, 'custom_149')
model.fix_symbolic(1,9,0, 'custom_151')
model.fix_symbolic(1,9,1, 'custom_152') #
model.fix_symbolic(1,10,0, 'custom_153')
model.fix_symbolic(1,10,1, 'custom_154')
model.fix_symbolic(1,14,0, 'custom_155')
model.fix_symbolic(1,14,1, 'custom_156')
model.fix_symbolic(1,15,0, 'custom_157')
model.fix_symbolic(1,15,1, 'custom_158')
model.fix_symbolic(1,16,0, 'custom_159')
model.fix_symbolic(1,16,1, 'custom_160')
model.fix_symbolic(1,18,0, 'custom_161')
model.fix_symbolic(1,18,1, 'custom_162')
model.fix_symbolic(1,20,0, 'custom_163')
model.fix_symbolic(1,20,1, 'custom_164')
model.fix_symbolic(1,21,0, 'custom_165')
model.fix_symbolic(1,21,1, 'custom_166')
model.fix_symbolic(1,23,0, 'custom_167')
model.fix_symbolic(1,23,1, 'custom_168')
model.fix_symbolic(1,22,0, 'custom_169')
model.fix_symbolic(1,22,1, 'custom_170')
model.fix_symbolic(1,8,1, 'custom_171')

r2 is 0.9994362561906731
r2 is 0.99945587427955
r2 is 0.9986159695074722
r2 is 0.9996144917983509
r2 is 0.9951597599841365
r2 is 0.9924087003334096
r2 is 0.9981780890551967
r2 is 0.9995068595633683
r2 is 0.9988651802531182
r2 is 0.9995768193516956
r2 is 0.9984353496161567
r2 is 0.9993802282362686
r2 is 0.9999627933305972
r2 is 0.9896512475618544
r2 is 0.9999269369354454
r2 is 0.9950758473500906
r2 is 0.9996664467404263
r2 is 0.9998599338449821
r2 is 0.9998567933917639
r2 is 0.9990998390585545
r2 is 0.9971703706453549
r2 is 0.9994688080905196
r2 is 0.9785359268648605
r2 is 0.9998728027535531
r2 is 0.9995754102977994
r2 is 0.9996921436813841
r2 is 0.9998792730969842
r2 is 0.999859867867338
r2 is 0.9990547097740944
r2 is 0.9995353532397754
r2 is 0.9999704358402911
r2 is 0.9991045361343868
r2 is 0.999226263956735
r2 is 0.994583236020367
r2 is 0.9955324533118488
r2 is 0.9995123027075384
r2 is 0.9996392525393514
r2 is 0.9997375773661292
r2 is 0.9996892765559474
r2 is 0.9983782640109677
r2 is

tensor(0.9994, device='cuda:0', dtype=torch.float64, grad_fn=<SelectBackward0>)

In [11]:
lib = ['x','x^2','x^3','1/x','sqrt','1/sqrt(x)','exp','log','abs','sin','tan','tanh','sigmoid','relu','sgn','arcsin','arctan','arctanh','0','gaussian','cosh']
model.auto_symbolic(lib=lib)

skipping (0,0,0) since already symbolic
skipping (0,0,1) since already symbolic
skipping (0,0,2) since already symbolic
skipping (0,0,3) since already symbolic
skipping (0,0,4) since already symbolic
skipping (0,0,5) since already symbolic
skipping (0,0,6) since already symbolic
skipping (0,0,7) since already symbolic
skipping (0,0,8) since already symbolic
skipping (0,0,9) since already symbolic
skipping (0,0,10) since already symbolic
skipping (0,0,11) since already symbolic
skipping (0,0,12) since already symbolic
skipping (0,0,13) since already symbolic
skipping (0,0,14) since already symbolic
skipping (0,0,15) since already symbolic
skipping (0,0,16) since already symbolic
skipping (0,0,17) since already symbolic
skipping (0,0,18) since already symbolic
skipping (0,0,19) since already symbolic
skipping (0,0,20) since already symbolic
skipping (0,0,21) since already symbolic
skipping (0,0,22) since already symbolic
skipping (0,0,23) since already symbolic
skipping (0,1,0) since alr

In [12]:
print("Train accuracy:", train_acc(), " - Test accuracy:", test_acc())

Train accuracy: tensor(0.5617, device='cuda:0')  - Test accuracy: tensor(0.5578, device='cuda:0')


In [None]:
formula1, formula2 = model.symbolic_formula(floating_digit=5)[0]
# model.symbolic_formula()

In [18]:
print('Formula1:', formula1)
print('Formula2:', formula2)

Formula1: 0.05628*x_6 + 1.21015*(-0.77528*x_1 - 4.67178)*(0.00761*x_1 + 0.00982*sin(2.21502*x_1 + 21.564) + 0.07689) + 0.01039*(7.48008*x_3 - 4.336)*sin(6.84274*x_3 - 3.96655) - 2.06238*(0.97432*x_7 + 4.79598)*(-0.00957*x_7 - 0.00982*sin(2.78369*x_7 + 5.48591) - 0.0161) + 0.99432*((0.03453*x_2 + 5.4958*(0.40504*x_1 - 2.3949)*(-0.00398*x_1 - 0.00982*sin(1.15722*x_1 - 15.05883) + 0.05452) + 0.00155*(1.6848*x_5 + 5.5371)*(2.29672*x_5 + sin(2.10563*x_5 + 7.11602) + 10.9772) + 0.31237*(0.77302*atan(0.97736*x_6 + 0.0144) - 0.43434)*cos(sin(sin(cos(cos(0.97736*x_6 + 0.0144)) - atan(sinh(0.97736*x_6 + 0.0144)))))**8 + 1.55958*sin(0.98177*x_8 - 2.0337) + 0.01714*sin(2.576*x_8 - 8.0) + 0.0143*sin(sin(2.84696*x_2 + 7.40896)) - 0.07296*sin(sin(5.30904006958008*x_3 - 5.63788))**20 - 0.03117*cos(3.02103*x_2 + 7.86195) + 0.31719*tan(0.26512*cos(2.98424*sin(tanh(sin(1.01088*x_4 + 0.87961) + 0.2256)))) + 0.07296*tan(sinh(tanh(asinh(5.30904006958008*x_3 - 6.87744) - 0.08959)**7)) + 0.3155*tanh(0.56492*t

In [19]:
def import_formula_from_txt(path):
    with open(path, 'r') as f:
        import_expr = f.read()
    loaded_expr = sp.sympify(import_expr)
    return loaded_expr

def export_formula_to_txt(formula, path):
    export_expr = sp.srepr(formula)
    with open(path, 'w') as f:
        f.write(export_expr)

In [20]:
export_formula_to_txt(formula1, './formula1_2class_mushroom.txt')
export_formula_to_txt(formula2, './formula2_2class_mushroom.txt')

loaded_formula1 = import_formula_from_txt('./formula1_2class_mushroom.txt')
loaded_formula2 = import_formula_from_txt('./formula2_2class_mushroom.txt')

print('Formula1:', loaded_formula1)
print('Formula2:', loaded_formula2)

Formula1: 0.05628*x_6 + 1.21015*(-0.77528*x_1 - 4.67178)*(0.00761*x_1 + 0.00982*sin(2.21502*x_1 + 21.564) + 0.07689) + 0.01039*(7.48008*x_3 - 4.336)*sin(6.84274*x_3 - 3.96655) - 2.06238*(0.97432*x_7 + 4.79598)*(-0.00957*x_7 - 0.00982*sin(2.78369*x_7 + 5.48591) - 0.0161) + 0.99432*((0.03453*x_2 + 5.4958*(0.40504*x_1 - 2.3949)*(-0.00398*x_1 - 0.00982*sin(1.15722*x_1 - 15.05883) + 0.05452) + 0.00155*(1.6848*x_5 + 5.5371)*(2.29672*x_5 + sin(2.10563*x_5 + 7.11602) + 10.9772) + 0.31237*(0.77302*atan(0.97736*x_6 + 0.0144) - 0.43434)*cos(sin(sin(cos(cos(0.97736*x_6 + 0.0144)) - atan(sinh(0.97736*x_6 + 0.0144)))))**8 + 1.55958*sin(0.98177*x_8 - 2.0337) + 0.01714*sin(2.576*x_8 - 8.0) + 0.0143*sin(sin(2.84696*x_2 + 7.40896)) - 0.07296*sin(sin(5.30904006958008*x_3 - 5.63788))**20 - 0.03117*cos(3.02103*x_2 + 7.86195) + 0.31719*tan(0.26512*cos(2.98424*sin(tanh(sin(1.01088*x_4 + 0.87961) + 0.2256)))) + 0.07296*tan(sinh(tanh(asinh(5.30904006958008*x_3 - 6.87744) - 0.08959)**7)) + 0.3155*tanh(0.56492*t

In [21]:
variables = loaded_formula1.free_symbols
print("Variables in the expression:", variables)

Variables in the expression: {x_6, x_4, x_1, x_8, x_5, x_3, x_7, x_2}


In [24]:
def acc_new(f_sympy_1, f_sympy_2, X, y):
    batch = X.shape[0]
    correct = 0
    res = []

    x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8 = sp.symbols('x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8')

    f_math_1 = sp.lambdify((x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8), f_sympy_1, modules=['mpmath'])
    f_math_2 = sp.lambdify((x_1, x_2, x_3, x_4, x_5, x_6, x_7, x_8), f_sympy_2, modules=['mpmath'])
 
    mp.dps = 30
    
    for i in range(batch):
        # if(i%100 == 0):
        #     print("==========", i, " ==========")

        logit1 = f_math_1(mpf(str(X[i,0].item())), mpf(str(X[i,1].item())), mpf(str(X[i,2].item())), mpf(str(X[i,3].item())), mpf(str(X[i,4].item())), mpf(str(X[i,5].item())), mpf(str(X[i,6].item())), mpf(str(X[i,7].item())))
        logit2 = f_math_2(mpf(str(X[i,0].item())), mpf(str(X[i,1].item())), mpf(str(X[i,2].item())), mpf(str(X[i,3].item())), mpf(str(X[i,4].item())), mpf(str(X[i,5].item())), mpf(str(X[i,6].item())), mpf(str(X[i,7].item())))

        correct += np.argmax([logit1, logit2]) == y[i]

    return correct/batch

In [25]:
acc = acc_new(loaded_formula1, loaded_formula2, dataset['test_input'], dataset['test_label'])

print('Test acc of the formula:', acc)

Test acc of the formula: tensor(0.5485, device='cuda:0')
