## 1. Set up the experiment

### 1-1. Import modules

In [1]:
import  os, time
from    pathlib             import  Path
from    tqdm.notebook       import  tqdm
import  pickle
import  yaml

import  numpy       as  np
import  torch
from    torch       import  nn
from    torch.utils.data            import  TensorDataset, DataLoader

from    custom_modules.utils                import  get_time_str
from    custom_modules.utils                import  GridGenerator, npzReader, GaussianNormalizer
from    custom_modules.pytorch.neuralop     import  DeepONetStructured  as  DeepONet
from    custom_modules.pytorch.torch_utils  import  count_parameters


time_str = get_time_str()

### 1-2. Load the configurations

In [2]:
time_str = get_time_str()
with open("config_train.yaml") as f:
    config      = yaml.load(f, Loader = yaml.FullLoader)
    _exp        = config['experiment']
    _data       = config['pde_dataset']
    _deeponet   = config['deeponet']

### 1-3. Set the experiment

In [3]:
# NOTE Training and data preprocess


BATCH_SIZE      = _exp['batch_size']
NUM_EPOCHS      = _exp['num_epochs']
LEARNING_RATE   = _exp['learning_rate']
TRAIN_SIZE      = _exp['train_size']
VAL_SIZE        = _exp['val_size']
DEVICE          = torch.device(f"cuda:{_exp['cuda_index']}")


RESOLUTION      = _data['resolution']
TRAIN_PATH      = Path(_data['path'])
__RANDOM_CHOICE = np.random.choice(1024, TRAIN_SIZE + VAL_SIZE, replace = False)
TRAIN_MASK      = __RANDOM_CHOICE[:TRAIN_SIZE]
VAL_MASK        = __RANDOM_CHOICE[-VAL_SIZE:]


DOWNSAMPLE      = _data['downsample']
GRID            = (RESOLUTION - 1) // DOWNSAMPLE + 1
grid            = torch.stack(
                        torch.meshgrid(
                            torch.linspace(0, 1, GRID),
                            torch.linspace(0, 1, GRID),
                            indexing = 'ij'
                        ),
                        dim = -1
                    )   # Shape: (GRID, GRID, dim_domain = 2)
grid            = grid.reshape(-1, 2).to(DEVICE)

## 2. Preprocess data

### 2-1. Instantiate the storages

In [4]:
train_data: dict[str, torch.Tensor] = {
    'coeff':    None,
    'sol':      None,
}
val_data: dict[str, torch.Tensor] = {
    'coeff':    None,
    'sol':      None,
}


normalizer: dict[str, GaussianNormalizer] = {
    'coeff':    None,
    'sol':      None,
}

### 2-2. Load the train data

In [5]:
# Train data
reader = npzReader(TRAIN_PATH)
for cnt, k in tqdm(enumerate(train_data.keys()), desc = "Preprocessing the train data"):   
    # Step 1. Load data
    train_data[k] = torch.from_numpy(reader.get_field(k)[TRAIN_MASK, ::DOWNSAMPLE, ::DOWNSAMPLE])
    train_data[k] = train_data[k].type(torch.float)
    train_data[k] = train_data[k].reshape(TRAIN_SIZE, -1).to(DEVICE)
    
    # Step 2. Normalize data
    normalizer[k] = GaussianNormalizer(train_data[k])
    normalizer[k].to(DEVICE)
    train_data[k] = normalizer[k].encode(train_data[k])


# Validation data
for cnt, k in tqdm(enumerate(val_data.keys()), desc = "Preprocessing the validation data"):
    # Step 1. Load data
    val_data[k] = torch.from_numpy(reader.get_field(k)[VAL_MASK, ::DOWNSAMPLE, ::DOWNSAMPLE])
    val_data[k] = val_data[k].type(torch.float)
    val_data[k] = val_data[k].reshape(VAL_SIZE, -1).to(DEVICE)
    
    # Step 2. Normalize data (NOTE: Uses the normalizers for the train dataset)
    val_data[k] = normalizer[k].encode(val_data[k])

Preprocessing the train data: 0it [00:00, ?it/s]

Preprocessing the validation data: 0it [00:00, ?it/s]

### 2-3. Instantiate dataloaders

In [6]:
train_dataset = TensorDataset(train_data['coeff'], train_data['sol'])
val_dataset   = TensorDataset(  val_data['coeff'], val_data['sol'])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader   = torch.utils.data.DataLoader(  val_dataset,  batch_size = BATCH_SIZE, shuffle = True) 

## 3. Train the model

### 3-1. Initialize the model and instantiate the loss function and the optimizer

In [7]:
deeponet = DeepONet(**_deeponet).to(DEVICE)
print(f"The number of the parameters in the model\n>>> {count_parameters(deeponet)}")
print(deeponet)

for p in deeponet.parameters():
    if p.ndim == 1:
        torch.nn.init.zeros_(p)
    else:
        torch.nn.init.xavier_uniform_(p)

criterion = torch.nn.MSELoss(reduction = 'mean')
optimizer = torch.optim.Adam(params = deeponet.parameters(), lr = _exp['learning_rate'])

The number of the parameters in the model
>>> 903361
DeepONet(
    structured,
    branch=MLP(layer=(961, 512, 512, 256), bias=True, activation=relu),
    trunk =MLP(layer=(2, 64, 256), bias=True, activation=relu),
)


### 3-2. Train the model

In [8]:
train_history = {
    'train_loss':   [],
    'train_error':  [],
    'val_loss':     [],
    'val_error':    [],
    'train_time':   0.0,
}
normalizer['sol'].to(DEVICE)

elapsed_time = time.time()
for epoch in tqdm(range(1, NUM_EPOCHS + 1)):
    # NOTE: Train
    deeponet.train()
    _train_time = time.time()
    train_epoch_loss:  torch.Tensor = 0
    train_epoch_error: torch.Tensor = 0
    for data, target in train_loader:
        num_data = len(data)
        
        train_pred = deeponet.forward((data, grid))
        train_loss = criterion.forward(train_pred, target)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        train_epoch_loss = train_epoch_loss + train_loss * num_data
        train_pred  = normalizer['sol'].decode(train_pred)
        target      = normalizer['sol'].decode(target)
        train_epoch_error = train_epoch_error + (
            torch.linalg.norm(train_pred - target) / (1e-8 + torch.linalg.norm(target))
        ) * num_data
    _train_time = time.time() - _train_time
    train_history['train_time'] += _train_time
    train_epoch_loss    = train_epoch_loss / TRAIN_SIZE
    train_epoch_error   = train_epoch_error / TRAIN_SIZE
    train_history['train_loss'].append(train_epoch_loss.item())
    train_history['train_error'].append(train_epoch_error.item())
    
    
    # NOTE: Validation
    deeponet.eval()
    val_epoch_loss:     torch.Tensor = 0
    val_epoch_error:    torch.Tensor = 0
    with torch.no_grad():
        for data, target in val_loader:
            num_data = len(data)
            
            val_pred = deeponet.forward((data, grid))
            val_loss = criterion.forward(val_pred, target)
            
            val_epoch_loss      = val_epoch_loss + val_loss * num_data
            val_pred = normalizer['sol'].decode(val_pred)
            target   = normalizer['sol'].decode(target)
            val_epoch_error     = val_epoch_error + (
                                        torch.linalg.norm(val_pred - target) / (1e-8 + torch.linalg.norm(target))
                                    ) * num_data
    val_epoch_loss      = val_epoch_loss / VAL_SIZE
    val_epoch_error     = val_epoch_error / VAL_SIZE
    train_history['val_loss'].append(val_epoch_loss.item())
    train_history['val_error'].append(val_epoch_error.item())
    
    if epoch % 10 == 0 or epoch == 1:
        print(f"[ Epoch {epoch} / {NUM_EPOCHS} ]")
        for k in train_history.keys():
            if k == "train_time":
                continue
            print(f"* {k:15s}: {train_history[k][-1]:.4e}")
    
elapsed_time = time.time() - elapsed_time
print(f"Elapsed time: {int(elapsed_time)} seconds")

  0%|          | 0/300 [00:00<?, ?it/s]

[ Epoch 1 / 300 ]
* train_loss     : 1.4981e+00
* train_error    : 6.4362e-01
* val_loss       : 6.5141e-01
* val_error      : 4.7844e-01
[ Epoch 10 / 300 ]
* train_loss     : 8.6739e-02
* train_error    : 1.7397e-01
* val_loss       : 9.1232e-02
* val_error      : 1.7883e-01
[ Epoch 20 / 300 ]
* train_loss     : 4.7496e-02
* train_error    : 1.2803e-01
* val_loss       : 3.4796e-02
* val_error      : 1.1052e-01
[ Epoch 30 / 300 ]
* train_loss     : 2.8584e-02
* train_error    : 1.0041e-01
* val_loss       : 4.3579e-02
* val_error      : 1.2382e-01
[ Epoch 40 / 300 ]
* train_loss     : 2.1931e-02
* train_error    : 8.8019e-02
* val_loss       : 3.3846e-02
* val_error      : 1.0896e-01
[ Epoch 50 / 300 ]
* train_loss     : 1.9585e-02
* train_error    : 8.3181e-02
* val_loss       : 2.3270e-02
* val_error      : 9.0250e-02
[ Epoch 60 / 300 ]
* train_loss     : 2.3275e-02
* train_error    : 9.0095e-02
* val_loss       : 4.4451e-02
* val_error      : 1.2448e-01
[ Epoch 70 / 300 ]
* train_l

### 3-3. Save the model and the train history

In [9]:
deeponet.cpu()

# Save the model
os.makedirs(f"./{time_str}", exist_ok = True)
torch.save(deeponet.state_dict(), f"{time_str}/deeponet_darcy{RESOLUTION}_res{GRID}.pth")

# Save the normalizer, which will also be used in prediction
normalizer['sol'].cpu()
torch.save(normalizer, f"{time_str}/deeponet_darcy{RESOLUTION}_res{GRID}_normalizer.pth")

# Save the history
with open(f"{time_str}/deeponet_darcy{RESOLUTION}_res{GRID}.pickle", "wb") as f:
    pickle.dump(train_history, f)

## End of file