In [None]:
import torch
import os
import scipy.special as sp
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from utils import set_seed
from data_input import get_dataset
from grid_search import BaseSearcher

In [None]:
seed = 42
set_seed(seed)

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
device = torch.device('cuda:2')

f0 = lambda x: sp.ellipj(x[:,[0]], x[:,[1]])[0]
f1 = lambda x: sp.ellipkinc(x[:,[0]], x[:,[1]])
f2 = lambda x: sp.ellipeinc(x[:,[0]], x[:,[1]])
f3 = lambda x: sp.jv(x[:,[0]], x[:,[1]])
f4 = lambda x: sp.yv(x[:,[0]], x[:,[1]])
f5 = lambda x: sp.kv(x[:,[0]], x[:,[1]])
f6 = lambda x: sp.iv(x[:,[0]], x[:,[1]])
f7 = lambda x: sp.lpmv(0, x[:,[0]], x[:,[1]])
f8 = lambda x: sp.lpmv(1, x[:,[0]], x[:,[1]])
f9 = lambda x: sp.lpmv(2, x[:,[0]], x[:,[1]])
f10 = lambda x: sp.sph_harm(0, 1, x[:,[0]], x[:,[1]])
f11 = lambda x: sp.sph_harm(1, 1, x[:,[0]], x[:,[1]])
f12 = lambda x: sp.sph_harm(0, 2, x[:,[0]], x[:,[1]])
f13 = lambda x: sp.sph_harm(1, 2, x[:,[0]], x[:,[1]])
f14 = lambda x: sp.sph_harm(2, 2, x[:,[0]], x[:,[1]])

func_set = [globals()[f'f{i}'] for i in range(15)]

In [None]:
size_list = [
    [2,4,1],[2,8,1],[2,16,1],[2,32,1],
    [2,4,4,1],[2,8,8,1],[2,16,16,1],[2,32,32,1],
    [2,4,4,4,1],[2,8,8,8,1],[2,16,16,16,1],[2,32,32,32,1],
    [2,4,4,4,4,1],[2,8,8,8,8,1],[2,16,16,16,16,1],[2,32,32,32,32,1],
]    
lr_list = [0.3,0.2,0.1,0.08,0.05,0.03,0.01,0.008,0.005,0.001]

In [None]:
for i_f, f in enumerate(func_set):
    trainset, valset, testset = get_dataset('function',f=f,n_var=2,seed=seed)

    train_loader = DataLoader(trainset, batch_size=4096, shuffle=True)
    val_loader = DataLoader(valset, batch_size=1024, shuffle=False)
    test_loader = DataLoader(testset, batch_size=1024, shuffle=False)

    gs = BaseSearcher(device=device,save_dir=f'save/special_func/func_{i_f}')
    gs.init_logs()
    gs.grid_search(size_list,lr_list,train_loader,val_loader,test_loader,repu_order=3,optim='lbfgs',max_iter=500,epoch_list=[100,100,500],scheduler='lam')