In [1]:
import numpy as np
import os
import sys
from pysr import PySRRegressor
from sympy import *
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

In [2]:
COUNTER_FILENAME = os.path.expanduser("../bigan/model_counter.txt")
global MODEL_ID

# Adjust printing view dimensions
np.set_printoptions(threshold=sys.maxsize, linewidth=300)

def retrieve_model_id():
    global MODEL_ID
    try:
        with open(COUNTER_FILENAME, 'r') as f:
            count = f.read()
            MODEL_ID = count
    except FileNotFoundError:
        print('New counter file.')
        with open(COUNTER_FILENAME, 'w') as f:
            count = '0'
            f.write(count)
            MODEL_ID = count

            
retrieve_model_id()

LOAD_PATH =  r'..\bigan\data\activations\ae_{}'.format(MODEL_ID)
params = np.load(r'{}\params.npy'.format(LOAD_PATH))
activs = np.load(r'{}\activs.npy'.format(LOAD_PATH))
losses = np.load(r'{}\losses.npy'.format(LOAD_PATH))

# Remove inactive neurons
activs = np.delete(activs, 1, 1)
activs = np.delete(activs, 2, 1)

X = params
y = activs

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=14)

In [4]:
model = PySRRegressor(
    model_selection="best",  # Result is mix of simplicity+accuracy
    niterations=40,
    binary_operators=[
        "+", 
        "-", 
        "*", 
        "/", 
    ],
    unary_operators=[
        "cos",
        "exp",
        "sin",
        "log",
    ],
    extra_sympy_mappings={
        "inv": lambda x: 1 / x,
        "sqrt_abs": lambda x: sqrt(Abs(x)),
        "sq_mult": lambda x, y: (x**2) * y,
        "sq": lambda x: x**2,
        "quartic": lambda x: x**4,
        "pi": lambda x: pi * x,
    },
    loss="loss(x, y) = (x - y)^2",
)

In [5]:
model.fit(X_train, y_train)



In [6]:
y_hat = model.predict(X_test)
mean_squared_error(y_test, y_hat)

5.236920305722326

In [11]:
model.latex()

['x_{2}',
 '1.86 e^{\\left(- 5.34 x_{2} + \\cos{\\left(\\log{\\left(\\left|{x_{3}}\\right| \\right)} \\right)}\\right) \\log{\\left(\\left|{x_{1}}\\right| \\right)}}']