Imports

In [1]:
import f90nml
import numpy as np
from pint import UnitRegistry; AssignQuantity = UnitRegistry().Quantity
import os
import reference_solution as refsol
from scipy.fft import rfft
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import icepinn as ip

torch.set_default_dtype(torch.float64)
print(torch.cuda.device_count())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

device = ip.get_device()

cuda
1
cuda


In [2]:
# Read in GI parameters
inputfile = "GI parameters - Reference limit cycle (for testing).nml"
GI=f90nml.read(inputfile)['GI']
nx_crystal = GI['nx_crystal']
L = GI['L']
NBAR = GI['Nbar']
NSTAR = GI['Nstar']

# Define t range (needs to be same as training file)
RUNTIME = 5
#NUM_T_STEPS = RUNTIME + 1
NUM_T_STEPS = RUNTIME*5 + 1

# Define initial conditions
Ntot_init = np.ones(nx_crystal)
Nqll_init = ip.get_Nqll(Ntot_init)

# Define x, t pairs for training
X_QLC = np.linspace(-L,L,nx_crystal)
t_points = np.linspace(0, RUNTIME, NUM_T_STEPS)
x, t = np.meshgrid(X_QLC, t_points)
training_set = torch.tensor(np.column_stack((x.flatten(), t.flatten()))).to(device)

In [4]:
# Define model attributes; instantiate model
model_dimensions = torch.tensor([8, 80]).to(device)
is_sf_PINN = torch.tensor(False)
model1 = ip.IcePINN(
	num_hidden_layers=model_dimensions[0], 
	hidden_layer_size=model_dimensions[0],
	is_sf_PINN=is_sf_PINN.item()).to(device)

# Attach model attributes as buffers so they can be saved and loaded
model1.register_buffer('dimensions', model_dimensions)
model1.register_buffer('is_sf_PINN', is_sf_PINN)

# Initialize model weights with HE initialization

model1.apply(ip.init_HE)

optimizer = torch.optim.AdamW(model1.parameters(), lr=0.001)

# Define learning rate scheduling scheme
scheduler_summed = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=10000
    )

In [16]:
print(training_set.shape)
print(ip.calc_cp_loss(model, training_set, ip.get_misc_params(), 10, 1, 1, False).shape)
print(ip.enforced_model(training_set, model).shape)

MODEL_NAME = "TestPINN_hard_enforced_SF"

torch.Size([8320, 2])
torch.Size([8320, 2])
torch.Size([8320, 2])


Train model

In [5]:
ip.train_IcePINN(
    model=model1, 
    optimizer=optimizer, 
    training_set=training_set, 
    epochs=200_000, 
    name=MODEL_NAME, 
    print_every=1_000, 
    print_gradients=False,
    LR_scheduler=None)

Commencing PINN training on 8320 points for 200000 epochs.
Epoch [1000/200000]: Ntot = 4605.302, Nqll = 843.847, LR = 0.001
Epoch [2000/200000]: Ntot = 4492.581, Nqll = 817.493, LR = 0.001
Epoch [3000/200000]: Ntot = 4448.287, Nqll = 838.971, LR = 0.001
Epoch [4000/200000]: Ntot = 4408.222, Nqll = 848.948, LR = 0.001
Epoch [5000/200000]: Ntot = 4407.769, Nqll = 882.040, LR = 0.001
Epoch [6000/200000]: Ntot = 5110.055, Nqll = 1095.239, LR = 0.001
Epoch [7000/200000]: Ntot = 5140.427, Nqll = 948.118, LR = 0.001
Epoch [8000/200000]: Ntot = 9812.062, Nqll = 1900.371, LR = 0.001
Epoch [9000/200000]: Ntot = 15331.307, Nqll = 3249.935, LR = 0.001
Epoch [10000/200000]: Ntot = 23906.017, Nqll = 6978.736, LR = 0.001
Epoch [11000/200000]: Ntot = 12439.407, Nqll = 2667.248, LR = 0.001
Epoch [12000/200000]: Ntot = 15808.532, Nqll = 3486.519, LR = 0.001
Epoch [13000/200000]: Ntot = 62.375, Nqll = 13.022, LR = 0.001
Epoch [14000/200000]: Ntot = 22.083, Nqll = 1.749, LR = 0.001
Epoch [15000/200000]: N

Train again on smaller LR starting from saved checkpoint (best saved model above)

In [18]:
# Instantiate model2 as best version of model1
model2 = ip.load_IcePINN(MODEL_NAME)
model2.train()
optimizer2 = torch.optim.AdamW(model2.parameters(), lr=0.0001)

mn2 = MODEL_NAME+"_round2"
ip.train_IcePINN(
    model=model2, 
    optimizer=optimizer2, 
    training_set=training_set, 
    epochs=100_000, 
    name=MODEL_NAME+"_round2", 
    print_every=1_000, 
    print_gradients=False,
    LR_scheduler=None)

Commencing PINN training on 8320 points for 100000 epochs.
Epoch [1000/100000]: Ntot = 23.008, Nqll = 9.562, LR = 0.0001
Epoch [2000/100000]: Ntot = 58.882, Nqll = 13.254, LR = 0.0001


KeyboardInterrupt: 