In [1]:
import yaml
import torch
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go

class DotDict(dict):     
    """dot.notation access to dictionary attributes"""      
    def __getattr__(*args):         
        val = dict.get(*args)        
        return DotDict(val) if type(val) is dict else val              
    __setattr__ = dict.__setitem__     
    __delattr__ = dict.__delitem__


with open("./config/config.yaml", 'r') as stream:
    config = yaml.safe_load(stream)

config = DotDict(config)

## Comparison num_expressions in RL

In [None]:
nums = [str(i) for i in range(1, 6)]
data = {}
interesting_data = ['scores', 'max_scores']

for num in nums:
    a = torch.load('outputs/rl_'+num+'_000/model_200000.pt', map_location='cpu')
    for col in interesting_data:
        data[col] = data.get(col, []) + a[col]
    data['num_expressions'] = data.get('num_expressions', []) + [num for _ in range(len(a['scores']))]
    data['Episode'] = data.get('Episode', []) + [i*100 for i in range(1, len(a['scores']) + 1)]

df = pd.DataFrame(data)
# Do a MA to see something
window_size = 50
for col in interesting_data:
    df[col] = df.groupby('num_expressions')[col].transform(lambda s: s.rolling(window_size).mean())
fig = px.line(df, x='Episode', y="scores", color='num_expressions')
fig.show()

In [None]:
a = torch.load('outputs/rl_3_000/model_200000.pt', map_location='cpu')

In [None]:
df = pd.DataFrame({'scores': a['scores']})
df['scores'].rolling(10).mean().head(20)

In [2]:
from scripts.dclasses import Dataset
from scripts.language import Language
from pytorch_lightning.utilities.seed import seed_everything

seed_everything(5)

n_functions = 500
import time
language = Language(config.Language)
ini = time.time()
data = Dataset(n_functions, language)
print(time.time() - ini)

Global seed set to 100
  stack.append(function.function(first_operand))


0.32672619819641113


In [10]:
import torch

Xys = []

for row in data:
    X = torch.from_numpy(row['X'])
    y = torch.from_numpy(row['y']).unsqueeze(1)
    
    Xys.append(torch.cat((X, y), dim=1))
    
Xys = torch.stack(Xys)

torch.save(Xys, '../NeuralSymbolicRegressionThatScales-main/tensor2.pt')

In [None]:
a = 0
import time
ini = time.time()
for row in data:
    a += len(row['Target Expression'].traversal)
    
a/100000

In [None]:
from scripts.expression import Expression
from scripts.model import ETIN_model
import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



'''

HACER QUE TE CREE LA EXPRESIÓN DE OTRA MANERA:

    . En cambio de coger un valor probabilísticamente, coger el máximo con cierta probabilidad
      y con otra probabilidad coger un valor aleatorio del vector según probabilidades.

'''

# path = '/home/gridsan/amorenas/ETIN3/outputs/rl/model_80000.pt'
path = None

def nrmse(y_pred, y_true):
    std_y = np.std(y_true)
    nrmse = np.sqrt(np.mean((y_pred - y_true)**2)) / std_y
    return nrmse, 5 / (1 + nrmse)

if path is None:
    etin_model = ETIN_model(config.Model, language.info_for_model)
else:
    etin_model = ETIN_model.load_from_checkpoint(path, cfg=config.Model, info_for_model=language.info_for_model)

etin_model.to(device)
errors = []
rewards = []
for i, row in enumerate(data):
    new_expr = Expression(language, model=etin_model, prev_info=row)
    if i == 5:
        print(row['Target Expression'].to_sympy())
        print(new_expr.to_sympy())
        a = bbb
    y_pred = new_expr.evaluate(row['X'])
    if (np.isnan(y_pred).any() or np.abs(y_pred).max() > 1e5 or np.abs(y_pred).min() < 1e-2):
        continue
    error, reward = nrmse(y_pred, row['y'])
    errors.append(error)
    rewards.append(reward)

print(np.mean(errors), np.mean(rewards))

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Generate data on commute times.
rewards_series = pd.Series(rewards)

rewards_series.plot.hist(grid=True, bins=20, rwidth=0.9,
                         color='#607c8e')
plt.title('Rewards for 500 equations')
plt.xlabel('Counts')
plt.ylabel('Rewards')
plt.grid(axis='y', alpha=0.75)