## 1. Set up the experiment

### 1-1. Import modules

In [1]:
import  os, time
from    pathlib             import  Path
from    tqdm.notebook       import  tqdm
import  pickle

import  numpy                       as  np
import  torch
from    torch                       import  nn, optim
from    torch_geometric.data        import  Data
from    torch_geometric.loader      import  DataLoader

import  yaml

from    custom_modules.utils                import  GaussianNormalizer, npzReader
from    custom_modules.utils                import  GridGenerator
from    custom_modules.pytorch.neuralop     import  GNOLite
from    custom_modules.pytorch.ref_neuralop import  KernelNN
from    custom_modules.pytorch.torch_utils  import  count_parameters

### 1-2. Load the configurations

In [2]:
with open("config_train.yaml") as f:
    config = yaml.load(f, Loader = yaml.FullLoader)
    _exp   = config['experiment']
    _data  = config['pde_dataset']
    _graph = config['graph']
    _gno   = config['gno']

### 1-3. Set the experiment

In [3]:
# NOTE Training and data preprocess


BATCH_SIZE      = _exp['batch_size']
NUM_EPOCHS      = _exp['num_epochs']
LEARNING_RATE   = _exp['learning_rate']
TRAIN_SIZE      = _exp['train_size']
VAL_SIZE        = _exp['val_size']


RESOLUTION      = _data['resolution']
TRAIN_PATH      = Path(_data['path'])
__RANDOM_CHOICE = np.random.choice(1024, TRAIN_SIZE + VAL_SIZE, replace = False)
TRAIN_MASK      = __RANDOM_CHOICE[:TRAIN_SIZE]
VAL_MASK        = __RANDOM_CHOICE[-VAL_SIZE:]


DOWNSAMPLE      = _data['downsample']
GRID            = (RESOLUTION - 1) // DOWNSAMPLE + 1
NUM_NODES       = GRID ** 2


RADIUS_TRAIN    = _graph['radius']
SAMPLE_SIZE     = _graph['sample_size']
NUM_SAMPLING    = _graph['num_sampling']

In [4]:
### NOTE Model instantiation


GNO_SOURCE  = ("custom", "reference")
GNO_INDEX   = _gno['gno_index']

GNO_LIST= (
    GNOLite(
                in_channels         = _gno['in_channels'],
                hidden_channels     = _gno['hidden_channels'],
                out_channels        = _gno['out_channels'],
                edge_channels       = _gno['edge_channels'],
                n_layers            = _gno['n_layers'],
                lift_layer          = _gno['lift_layer'],
                kernel_layer        = _gno['kernel_layer'],
                project_layer       = _gno['project_layer'],
                activation          = _gno['activation'],
            ).cuda(),
    KernelNN(
                in_width        = _gno['in_channels'],
                ker_in          = _gno['edge_channels'],
                width           = _gno['hidden_channels'],
                ker_width       = _gno['kernel_layer'][0],  # `KernelNN` uses a 2-layer MLP
                depth           = _gno['n_layers'],
            ).cuda()
)

print(f"The number of the parameters in the custom GNO\n>>> {   count_parameters(GNO_LIST[0])}")
print(f"The number of the parameters in the reference GNO\n>>> {count_parameters(GNO_LIST[1])}")

gno: GNOLite | KernelNN = GNO_LIST[GNO_INDEX]
print(gno)

The number of the parameters in the custom GNO
>>> 332065
The number of the parameters in the reference GNO
>>> 332065
GraphNeuralOperatorLite(
    lift:       MLP(layer=(6, 32), bias=True, activation=relu),
    hidden:     GraphKernelLayer(node_dim=32, kernel_layer=MLP(layer=(6, 256, 256, 1024), bias=True, activation=relu))
                x 6)
                )
    projection: MLP(layer=(32, 1), bias=True, activation=relu),
)


## 2. Preprocess data

### 2-1. Instantiate the storages

In [5]:
train_data: dict[str, torch.Tensor]= {
    'coeff':    None,
    'Kcoeff':   None,
    'Kcoeff_x': None,
    'Kcoeff_y': None,
    'sol':      None,
}
val_data: dict[str, torch.Tensor]= {
    'coeff':    None,
    'Kcoeff':   None,
    'Kcoeff_x': None,
    'Kcoeff_y': None,
    'sol':      None,
}


normalizer: dict[str, GaussianNormalizer] = {
    'coeff':    None,
    'Kcoeff':   None,
    'Kcoeff_x': None,
    'Kcoeff_y': None,
    'sol':      None,
}

### 2-2. Load the train data

In [6]:
# Train data
reader = npzReader(TRAIN_PATH)
for cnt, k in tqdm(enumerate(train_data.keys()), desc = "Preprocessing the train data"):
    # Step 1. Load data
    train_data[k] = torch.from_numpy(reader.get_field(k)[TRAIN_MASK, ::DOWNSAMPLE, ::DOWNSAMPLE])
    train_data[k] = train_data[k].flatten(-1)
    train_data[k] = train_data[k].type(torch.float)
    
    # Step 2. Normalize data
    normalizer[k] = GaussianNormalizer(train_data[k])
    train_data[k] = normalizer[k].encode(train_data[k])


# Validation data
for cnt, k in tqdm(enumerate(val_data.keys()), desc = "Preprocessing the validation data"):
    # Step 1. Load data
    val_data[k] = torch.from_numpy(reader.get_field(k)[VAL_MASK, ::DOWNSAMPLE, ::DOWNSAMPLE])
    val_data[k] = val_data[k].flatten(-1)
    val_data[k] = val_data[k].type(torch.float)
    
    # Step 2. Normalize data (NOTE: Uses the normalizers for the train dataset)
    val_data[k] = normalizer[k].encode(val_data[k])

Preprocessing the train data: 0it [00:00, ?it/s]

Preprocessing the validation data: 0it [00:00, ?it/s]

### 2-3. Construct graphs

In [7]:
# NOTE Generate a grid to set the node and edge attributes


grid_generator  = GridGenerator([[0., 1.], [0., 1.]], [GRID, GRID], radius = RADIUS_TRAIN)
grid_full_info  = grid_generator.full_information()
node_index  = grid_full_info['node_index']
edge_index  = grid_full_info['edge_index']
grid        = grid_full_info['grid']

In [8]:
# NOTE Construct graphs


list_train_data, list_test_data = [], []


for idx in tqdm(range(TRAIN_SIZE)):
    _coeff      = train_data[ 'coeff'  ][idx].reshape(NUM_NODES, -1)
    _Kcoeff     = train_data[ 'Kcoeff' ][idx].reshape(NUM_NODES, -1)
    _Kcoeff_x   = train_data['Kcoeff_x'][idx].reshape(NUM_NODES, -1)
    _Kcoeff_y   = train_data['Kcoeff_y'][idx].reshape(NUM_NODES, -1)
    # Define the node feature
    _x = torch.hstack([grid, _coeff, _Kcoeff, _Kcoeff_x, _Kcoeff_y])
    # Define the node target
    _y = train_data['sol'][idx].reshape(NUM_NODES, -1)
    # Define the edge feature
    _edge_attr = torch.hstack(
        [
            grid[edge_index[0]],
            grid[edge_index[1]],
            _coeff[edge_index[0]],
            _coeff[edge_index[1]]
        ]
    )
    
    # Append the new graph
    list_train_data.append(
        Data(
            x = _x,
            y = _y,
            edge_index  = edge_index,
            edge_attr   = _edge_attr,
        )
    )


for idx in tqdm(range(VAL_SIZE)):
    _coeff      = val_data[ 'coeff'  ][idx].reshape(NUM_NODES, -1)
    _Kcoeff     = val_data[ 'Kcoeff' ][idx].reshape(NUM_NODES, -1)
    _Kcoeff_x   = val_data['Kcoeff_x'][idx].reshape(NUM_NODES, -1)
    _Kcoeff_y   = val_data['Kcoeff_y'][idx].reshape(NUM_NODES, -1)
    # Define the node feature
    _x = torch.hstack([grid, _coeff, _Kcoeff, _Kcoeff_x, _Kcoeff_y])
    # Define the node target
    _y = val_data['sol'][idx].reshape(NUM_NODES, -1)
    # Define the edge feature
    _edge_attr = torch.hstack(
        [
            grid[edge_index[0]],
            grid[edge_index[1]],
            _coeff[edge_index[0]],
            _coeff[edge_index[1]]
        ]
    )
    
    # Append the new graph
    list_test_data.append(
        Data(
            x = _x,
            y = _y,
            edge_index  = edge_index,
            edge_attr   = _edge_attr,
        )
    )

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

### 2-4. Instantiate dataloaders

In [9]:
train_loader = DataLoader(list_train_data, batch_size = BATCH_SIZE, shuffle = True)
test_loader  = DataLoader(list_test_data,  batch_size = BATCH_SIZE, shuffle = True) 

## 3. Train the model

### 3-1. Initialize the model and instantiate the loss function and the optimizer

In [10]:
for p in gno.parameters():
    if p.ndim == 1:
        nn.init.zeros_(p)
    else:
        nn.init.xavier_uniform_(p)

criterion = nn.MSELoss(reduction = 'mean')
optimizer = optim.Adam(params = gno.parameters(), lr = 1e-3)

### 3-2. Train the model

In [11]:
train_history = {
    'train_loss':   [],
    'train_error':  [],
    'val_loss':     [],
    'val_error':    [],
    'train_time':   0.0,
}
normalizer['sol'].cuda()

elapsed_time = time.time()
for epoch in tqdm(range(1, NUM_EPOCHS + 1)):
    # NOTE: Train
    gno.train()
    _train_time = time.time()
    train_epoch_loss:  torch.Tensor = 0
    train_epoch_error: torch.Tensor = 0
    for batch in train_loader:
        batch: Data = batch.cuda()
        
        train_pred = gno.forward(batch.x, batch.edge_index, batch.edge_attr)
        train_loss = criterion.forward(train_pred, batch.y)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        train_epoch_loss = train_epoch_loss + (
            train_loss
        ) * len(batch)
        train_pred  = normalizer['sol'].decode(train_pred)
        batch.y     = normalizer['sol'].decode(batch.y)
        train_epoch_error = train_epoch_error + (
            torch.linalg.norm(train_pred - batch.y) / (1e-8 + torch.linalg.norm(batch.y))
        ) * len(batch)
    _train_time = time.time() - _train_time
    train_history['train_time'] += _train_time
    train_epoch_loss    = train_epoch_loss / TRAIN_SIZE
    train_epoch_error   = train_epoch_error / TRAIN_SIZE
    train_history['train_loss'].append(train_epoch_loss.item())
    train_history['train_error'].append(train_epoch_error.item())
    
    
    # NOTE: Validation
    gno.eval()
    val_epoch_loss:     torch.Tensor = 0
    val_epoch_error:    torch.Tensor = 0
    with torch.no_grad():
        for batch in test_loader:
            batch: Data = batch.cuda()
            
            val_pred = gno.forward(batch.x, batch.edge_index, batch.edge_attr)
            val_loss = criterion.forward(val_pred, batch.y)
            
            val_epoch_loss      = val_epoch_loss + val_loss * len(batch)
            val_pred = normalizer['sol'].decode(val_pred)
            batch.y  = normalizer['sol'].decode(batch.y)
            val_epoch_error     = val_epoch_error + (
                                        torch.linalg.norm(val_pred - batch.y) / (1e-8 + torch.linalg.norm(batch.y))
                                    ) * len(batch)
    val_epoch_loss      = val_epoch_loss / VAL_SIZE
    val_epoch_error     = val_epoch_error / VAL_SIZE
    train_history['val_loss'].append(val_epoch_loss.item())
    train_history['val_error'].append(val_epoch_error.item())
    
    if epoch % 10 == 0 or epoch == 1:
        print(f"[ Epoch {epoch} / {NUM_EPOCHS} ]")
        for k in train_history.keys():
            if k == "train_time":
                continue
            print(f"* {k:15s}: {train_history[k][-1]:.4e}")
    
elapsed_time = time.time() - elapsed_time
print(f"Elapsed time: {int(elapsed_time)} seconds")

  0%|          | 0/200 [00:00<?, ?it/s]

[ Epoch 1 / 200 ]
* train_loss     : 8.0489e-01
* train_error    : 5.6419e-01
* val_loss       : 5.2429e-01
* val_error      : 4.5898e-01
[ Epoch 10 / 200 ]
* train_loss     : 3.8638e-02
* train_error    : 1.2276e-01
* val_loss       : 2.6116e-02
* val_error      : 1.0252e-01
[ Epoch 20 / 200 ]
* train_loss     : 3.6145e-02
* train_error    : 1.1825e-01
* val_loss       : 3.0003e-02
* val_error      : 1.0942e-01
[ Epoch 30 / 200 ]
* train_loss     : 1.6634e-02
* train_error    : 8.0478e-02
* val_loss       : 1.2970e-02
* val_error      : 7.1363e-02
[ Epoch 40 / 200 ]
* train_loss     : 2.8071e-02
* train_error    : 1.0349e-01
* val_loss       : 2.4498e-02
* val_error      : 9.9493e-02
[ Epoch 50 / 200 ]
* train_loss     : 1.7405e-02
* train_error    : 8.1709e-02
* val_loss       : 2.0814e-02
* val_error      : 9.1615e-02
[ Epoch 60 / 200 ]
* train_loss     : 1.2323e-02
* train_error    : 6.8496e-02
* val_loss       : 2.3154e-02
* val_error      : 9.6266e-02
[ Epoch 70 / 200 ]
* train_l

### 3-3. Save the model and the train history

In [12]:
gno.cpu()

# Save the model
gno_src = GNO_SOURCE[GNO_INDEX]
os.makedirs(GNO_SOURCE[GNO_INDEX], exist_ok = True)
torch.save(gno.state_dict(), f"{gno_src}/gno_darcy{RESOLUTION}_res{GRID}.pth")

# Save the normalizer, which will also be used in prediction
normalizer['sol'].cpu()
torch.save(normalizer, f"{gno_src}/gno_darcy{RESOLUTION}_res{GRID}_normalizer.pth")

# Save the history
with open(f"{gno_src}/gno_darcy{RESOLUTION}_res{GRID}.pickle", "wb") as f:
    pickle.dump(train_history, f)

## End of file