In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from kan import KAN,create_dataset
import sympy as sp
import sys
if "../" not in sys.path:
    sys.path.append("../")
from kharkan.model import KharKAN

In [2]:
#model=MyModelO()
f = lambda x: x[:,[0]]+x[:,[0]]*x[:,[0]]+x[:,[0]] * x[:,[1]]+0.25
dataset = create_dataset(f, n_var=2,ranges = [-1,1])

In [3]:
def remove_non_numeric(datastr:str):
    datastr=str(datastr)
    return ''.join([c for c in datastr if c in '1234567890./'])
def divide_data(datastr:str):
    datastr=remove_non_numeric(datastr)
    if '/' in datastr:
        x, y = datastr.split('/')
        x=float(x)
        y=float(y)
        return x/y
    return float(datastr)
def parse_float(datastr:str):
    newstr=remove_non_numeric(datastr)
    return float(newstr)
data=pd.read_csv('../data/J_v6.tsv', sep='\t', names=['Jintra', 'deltaJ'])
data['Jinter']=20
data['Jintra'] = data['Jintra'].apply(lambda x: divide_data(x))
# normalize each column
data['Jintra'] = data['Jintra']/20
data['deltaJ'] = data['deltaJ']/20
data['Jinter'] = data['Jinter']/20

data['deltaJ/intra'] = data['deltaJ']/data['Jintra']
data['deltaJ*intra'] = data['deltaJ']*data['Jintra']
data['intra/deltaJ'] = data['Jintra']/data['deltaJ']
freqs=pd.read_csv('../data/Freqs_v6.tsv', sep=' ', names=['f1', 'f2', 'f3',"f1_analytical", "f2_analytical", "f3_analytical","diff1", "diff2", "diff3"])
for column in freqs.columns:
    freqs[column] = freqs[column].apply(lambda x: parse_float(x))
freqs=freqs.astype(float)
data['f1']=freqs['f1']
data['f2']=freqs['f2']
data['f3']=freqs['f3']
data["f1_analytical"]=freqs["f1_analytical"]
data["f2_analytical"]=freqs["f2_analytical"]
data["f3_analytical"]=freqs["f3_analytical"]

In [4]:
%matplotlib qt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# labels
ax.set_xlabel('Jintra')
ax.set_ylabel('deltaJ')
ax.set_zlabel('Frequencies')
ax.scatter(data['Jintra'], data['deltaJ'], data['f1'])
ax.scatter(data['Jintra'], data['deltaJ'], data['f2'], c='r')
ax.scatter(data['Jintra'], data['deltaJ'], data['f3'], c='g')

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x1c7f7df03b0>

In [5]:
dataset['train_input']=torch.tensor(data[['Jintra',"deltaJ"]].values, dtype=torch.float32)
dataset['train_label']=torch.tensor(data[['f1', 'f2', 'f3']].values, dtype=torch.float32)
dataset['train_label']=torch.tensor(data[['f1_analytical']].values, dtype=torch.float32)

In [6]:
model=KharKAN((2,3,3,3,1)).cuda()

In [7]:
# training loop
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
losses=[]
model=model.cuda()
for epoch in range(100000):
    optimizer.zero_grad()
    output = model(dataset['train_input'].cuda())
    loss = criterion(output, dataset['train_label'].cuda())
    losses.append(loss.item())
    loss2=loss+model.L1_loss()*loss.item()
    loss2.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss {loss.item()}")
    if epoch % 1000 == 0 and epoch<6000:
        model.simplify()
        #model.anneal(1e-1/(epoch+1000)*1000)
    if epoch==5000:
        model.set_normalize(False)
with torch.no_grad():
    output = model(dataset['train_input'].cuda())

Epoch 0, Loss 29.86164093017578
Epoch 100, Loss 5.255110740661621
Epoch 200, Loss 1.3221807479858398
Epoch 300, Loss 1.2833278179168701
Epoch 400, Loss 1.2204376459121704
Epoch 500, Loss 1.1561205387115479
Epoch 600, Loss 1.0732989311218262
Epoch 700, Loss 0.9497087001800537
Epoch 800, Loss 0.7895612120628357
Epoch 900, Loss 0.7425876259803772
Epoch 1000, Loss 0.5439730286598206
Epoch 1100, Loss 0.46383798122406006
Epoch 1200, Loss 0.4088360369205475
Epoch 1300, Loss 0.3817916214466095
Epoch 1400, Loss 0.3677212595939636
Epoch 1500, Loss 0.3564581871032715
Epoch 1600, Loss 0.3458864986896515
Epoch 1700, Loss 0.3343924283981323
Epoch 1800, Loss 0.32049715518951416
Epoch 1900, Loss 0.3068848252296448
Epoch 2000, Loss 0.2975488007068634
Epoch 2100, Loss 0.2697073221206665
Epoch 2200, Loss 0.23681128025054932
Epoch 2300, Loss 0.1930505335330963
Epoch 2400, Loss 0.16178099811077118
Epoch 2500, Loss 0.14912039041519165
Epoch 2600, Loss 0.14620548486709595
Epoch 2700, Loss 0.14577960968017578

In [8]:
loss = criterion(output, dataset['train_label'].cuda())

In [9]:
with torch.no_grad():
    output = model(dataset['train_input'].cuda())

In [10]:
model.simplify()
model=model.cpu()
with torch.no_grad():
    A=model(dataset['train_input'])
model.simplify()

formula=model.symbolic_formula()
expanded=formula['z_0'].expand()
expanded=expanded.subs({"x_2":"x_0/x_1"})
expanded=expanded.subs({"x_3":"x_1/x_0"})

expanded=expanded.expand()
# remove small numbers from the formula

def clean_expr(expr, eps):
    replacer = lambda x: x if x>=eps else 0
    # expr.xreplace({n : round(n, num_digits) for n in expr.atoms(sp.Number)})
    return expr.xreplace({n : replacer(n) for n in expr.atoms(sp.Number)})
def round_expr(expr, num_digits):
    return expr.xreplace({n : round(n, num_digits) for n in expr.atoms(sp.Number)})
answer=clean_expr(expanded,0.001)
answer=round_expr(answer,2)
answer

0.01*x_0*x_1**2 + 11.81*x_1**3 + 8.41*x_1 + 0.25

In [11]:
str(answer)

'0.01*x_0*x_1**2 + 11.81*x_1**3 + 8.41*x_1 + 0.25'

In [16]:
# plot the 3d surface and compare it with data
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# labels
ax.set_xlabel('Jintra')
ax.set_ylabel('deltaJ')
ax.set_zlabel('Frequencies')
x_0=data["Jintra"]
x_1=data["deltaJ"]
pred_fn=lambda x_0,x_1: 0.01*x_0*x_1**2 + 11.81*x_1**3 + 8.41*x_1 + 0.25
data['f2_predicted']=pred_fn(x_0,x_1)
ax.scatter(data['Jintra'], data['deltaJ'], data['f1_analytical'], c='r')
ax.scatter(data['Jintra'], data['deltaJ'], data['f2_predicted'], c='g')
ax.scatter(data['Jintra'], data['deltaJ'], A.cpu().numpy().flatten(), c='b')
#ax.scatter(data['Jintra'], data['deltaJ'], output.cpu().numpy().flatten(), c='black')

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x1c7fc567860>