In [1]:
import torch
from torch import nn
from torch.func import functional_call, grad
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import random_split
from torchvision.datasets import *
from torchvision import transforms
from torchdiffeq import odeint
from scipy.integrate import odeint as odeint_scipy
import pytorch_lightning as pl
import os
import shutil
import numpy as np
import pandas as pd
from threading import Thread
import matplotlib.pyplot as plt

from gc_module import ContNet

print(torch.cuda.is_available())
torch.zeros(1).cuda()
torch.set_float32_matmul_precision('high')

gen_dataset = True

True


In [2]:
def dburgers(t, u, x, dx, mu2, nu):
    w = 0.5*np.square(u)
    dwdx = np.concatenate(([(w[1] - w[0])/dx], (w[2:] - w[:-2])/(2*dx), [(w[-1] - w[-2])/dx]))
    dudx2 = np.concatenate(([(w[0] - 2*w[1] + w[2])/(dx**2)], (w[2:] - 2*w[1:-1] + w[:-2])/(dx**2), [(w[-1] - 2*w[-2] + w[-3])/(dx**2)]))
    dudt = 0.02*np.exp(mu2*x) - dwdx + nu*dudx2
    dudt[0] = 0
    return dudt

def run_burgers(mu1, mu2, nu=0.5):
    nx = 128
    x = np.linspace(0, 30, nx)
    dx = x[1] - x[0]

    nt = 5
    T = 5.0
    t = np.linspace(0, T, nt)

    u0 = np.ones(nx)
    u0[0] = mu1

    odefunc = lambda u_, t_: dburgers(t_, u_, x, dx, mu2, nu)
    
    ut = odeint_scipy(odefunc, u0, t)

    return ut

if gen_dataset:
    if os.path.exists('burgers_data'):
        shutil.rmtree('burgers_data')
    if os.path.exists('burgers_params.csv'):
        os.remove('burgers_params.csv')
    os.mkdir('burgers_data')
    
    param_list = []
    mu1 = np.linspace(4.25, 5.5, 10)
    mu2 = np.linspace(0.015, 0.03, 8)
    for mu1i in mu1:
        for mu2i in mu2:
            param_list.append([mu1i, mu2i])
            ut = run_burgers(mu1i, mu2i)
            #x = np.linspace(0, 30, 128)
            #plt.plot(x, ut[0])
            #plt.plot(x, ut[-1])
            df = pd.DataFrame(ut)
            df.to_csv(f'burgers_data/{mu1i:0.3f}_{mu2i:0.3f}.csv', header=False, index=False)
    pd.DataFrame(param_list, columns=['mu1', 'mu2']).to_csv('burgers_params.csv', index=False)

In [3]:
class BurgersDataset(Dataset):
    def __init__(self, param_file, states_dir, transform=None):
        self.param_list = pd.read_csv(param_file)
        self.states_dir = states_dir
        self.transform = transform

    def __len__(self):
        return len(self.param_list)

    def __getitem__(self, idx):
        mu = self.param_list.iloc[idx, :].to_numpy()
        param_str = '_'.join([f'{p:0.3f}' for p in mu]) + '.csv'
        state_path = os.path.join(self.states_dir, param_str)
        ut = np.loadtxt(state_path, delimiter=",", dtype=float)
        return torch.from_numpy(mu).float(), torch.from_numpy(ut).float()

In [4]:

class BurgersNODE(ContNet):
    def __init__(self, loglambda0: float, cont_lr: float, cont_reg: float, warmup_epochs: int):
        super().__init__(loglambda0, cont_lr, cont_reg, warmup_epochs)
        self.lossfunc = F.mse_loss

        self.dataset = BurgersDataset("burgers_params.csv", "burgers_data", transform=transforms.ToTensor())
        size_train = int(len(self.dataset)*0.9)
        self.data_train, self.data_val = random_split(self.dataset, [size_train, len(self.dataset) - size_train])

        self.T = 5.0
        self.nt = 5
        self.t_span = torch.linspace(0, self.T, self.nt, device=self.device)
        self.nx = 128

        activ = nn.Tanh()
        self.net = nn.Sequential(
                                nn.Linear(self.nx + 2, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, 256), activ,
                                nn.Linear(256, self.nx)
                                )
        
        for layer in self.net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)

    def configure_optimizers(self):
        # include logcontvar in optimizer
        optimizer = torch.optim.Adam([{'params': self.net.parameters()},
                                      {'params': (self.logcontvar,), 'lr': self.cont_lr}],
                                     lr=3e-4)
        return optimizer
    
    def f(self, t, x, mu):
        xdot = self.net(torch.cat((x, mu), dim=1))
        return xdot

    def forward(self, mu):
        x0 = torch.ones((len(mu), self.nx), device=self.device)
        x0[:, 0] = mu[:, 0]
        odefunc = lambda t_, x_: self.f(t_, x_, mu)
        xt = odeint(odefunc, x0, self.t_span.to(self.device))
        xt = xt.transpose(0, 1)
        return xt

    def training_step(self, batch, batch_idx):
        self.log('param_norm', sum(p.pow(2.0).sum() for p in self.parameters()))
        
        opt = self.optimizers()
        opt.zero_grad()

        # add gaussian noise to parameters
        rand_samp, ref_params = self.perturb_params()
        
        # compute loss
        x, y = batch
        y_pred = self.forward(x)
        loss = self.lossfunc(y_pred, y)
        self.manual_backward(loss)

        # compute contvar gradient
        self.contvar_grad(rand_samp, loss)
        
        # reload reference parameters
        self.load_state_dict(ref_params)

        opt.step()
        self.log('train_loss', loss)
        #return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.forward(x)
        loss = self.lossfunc(y_pred, y)
        self.log('val_loss', loss)
        #return loss
    
    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=4, num_workers=4, persistent_workers=True)
        
    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=4, num_workers=4, persistent_workers=True)


In [5]:
def run_case(hyperparams: list) -> None:
    mymodel = BurgersNODE(*hyperparams)

    # training
    epochs = 100
    logger = pl.loggers.tensorboard.TensorBoardLogger('.', name=f'l30_dcont_burgers_{epochs}_' + '_'.join([f'{p}' for p in hyperparams]))
    trainer = pl.Trainer(max_epochs=epochs, accelerator='gpu', logger=logger, log_every_n_steps=3)
    trainer.fit(mymodel)

def sweep_hyperparams(hyperparam_list: list, n_runs: int, hyperparams: list = []) -> None:
    if len(hyperparams) == len(hyperparam_list):
        for i in range(n_runs):
            print('-'*80)
            print(f'Running case with hyperparams {hyperparams}')
            print('-'*80)

            run_case(hyperparams)

    else:
        for hyperparam_i in hyperparam_list[len(hyperparams)]:
            new_hyperparams = hyperparams + [hyperparam_i]
            sweep_hyperparams(hyperparam_list, n_runs, new_hyperparams)
    

In [6]:
def main() -> None:
    # model
    loglambda0 = [float('nan'), -12.0]
    cont_lr = [1e-2]
    cont_reg = [1e-3]
    warmup_epochs = [9]
    n_runs = 1

    hyperparam_list = [loglambda0, cont_lr, cont_reg, warmup_epochs]
    sweep_hyperparams(hyperparam_list, n_runs)

if __name__ == "__main__":
    main()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


--------------------------------------------------------------------------------
Running case with hyperparams [nan, 0.01, 0.001, 9]
--------------------------------------------------------------------------------


Missing logger folder: ./l30_dcont_burgers_100_nan_0.01_0.001_9
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | net  | Sequential | 1.9 M 
------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.634     Total estimated model params size (MB)


Epoch 27:  67%|██████▋   | 12/18 [04:27<02:13, 22.29s/it, v_num=0]         

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


--------------------------------------------------------------------------------
Running case with hyperparams [-12.0, 0.01, 0.001, 9]
--------------------------------------------------------------------------------



  | Name | Type       | Params
------------------------------------
0 | net  | Sequential | 1.9 M 
------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.634     Total estimated model params size (MB)


Epoch 0:   0%|          | 0/18 [00:00<?, ?it/s]                            