In [1]:
import torch
import torch.nn as nn

# from drdmannturb.calibration import CalibrationProblem
# from drdmannturb.data_generator import OnePointSpectraDataGenerator
from drdmannturb.parameters import (
    LossParameters,
    NNParameters,
    PhysicalParameters,
    ProblemParameters,
)
from drdmannturb.spectra_fitting import CalibrationProblem, OnePointSpectraDataGenerator

device = "cuda" if torch.cuda.is_available() else "cpu"

# v2: torch.set_default_device('cuda:0')
if torch.cuda.is_available():
    torch.set_default_tensor_type("torch.cuda.FloatTensor")

L = 0.59


In [2]:


Gamma = 3.9
sigma = 3.4

domain = torch.logspace(-1, 2, 20)

pb = CalibrationProblem(
    nn_params=NNParameters(
        nlayers=2,
        hidden_layer_sizes=[10, 10],
        activations=[nn.GELU(), nn.GELU()],
    ),
    prob_params=ProblemParameters(nepochs=5),
    loss_params=LossParameters(alpha_pen2=1.0, alpha_pen1=1.0e-5, beta_reg=2e-4),
    phys_params=PhysicalParameters(L=L, Gamma=Gamma, sigma=sigma, domain=domain),
    device=device,
)

# %%
k1_data_pts = domain
DataPoints = [(k1, 1) for k1 in k1_data_pts]

# %%
Data = OnePointSpectraDataGenerator(data_points=DataPoints).Data

In [3]:
pb.eval(k1_data_pts)
optimal_parameters = pb.calibrate(data=Data)

mse loss: 0.12731095138740986
Initial loss: 0.12731095138740986
mse loss: 0.12731095138740986
0.12731095138740986
mse loss: 0.08091976340885287
0.08091976340885287
mse loss: 0.07251920502269152
0.07251920502269152
mse loss: 0.06574797974696969
0.06574797974696969
mse loss: 0.060267845810795286
0.060267845810795286
mse loss: 0.05583177680577787
0.05583177680577787
mse loss: 0.05224008069123877
0.05224008069123877
mse loss: 0.04933131257123198
0.04933131257123198
mse loss: 0.046974947333490454
0.046974947333490454
mse loss: 0.04506545472750759
0.04506545472750759
mse loss: 0.04351750741293436
0.04351750741293436
mse loss: 0.04226210666423032
0.04226210666423032
mse loss: 0.03692712735239805
0.03692712735239805
mse loss: 0.03692557245615881
0.03692557245615881
mse loss: 0.03691178713102499
0.03691178713102499
mse loss: 0.03679463441357401
0.03679463441357401
mse loss: 0.03673811736975942
0.03673811736975942
mse loss: 0.03637067952496032
0.03637067952496032
mse loss: 0.03629228007007742
0.