In this notebook we show how to perform the forward pass through the Neural ODE using different regimes to propagate through the Meta ODE block, namely
- Standalone
- Solver sampling (switching/smoothing)
- Solver ensembling
- Model ensembling

In more details, usage of different regimes means
- **Standalone**
    - One solver is used during  inference
    - Applied during training/testing.
     
    
    
- **Solver switching / smoothing**
    - For each batch one solver is chosen from a group of solvers with finite (in switching regime) or infinite (in smoothing regime) number of members.
    - Applied during training.
    
    
- **Solver ensembling**
    - Several solvers are used durung inference.
    - Outputs of ODE Block (obtained with different solvers) are averaged before propagating through the next layer.
    - Applied during training/testing
    
    
- **Model ensembling**
    - Several solvers are used durung inference.
    - Model probabilites obtained via propagation with different solvers are averaged to get the final result.
    - Applied during training/testing
    

In [1]:
import os
os.environ['CUDA_DEVICE_ORDER']="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICE']="0"

In [2]:
from argparse import Namespace
import torch
import torchvision.transforms as transforms
import numpy as np
import itertools
import wandb

import sys

sys.path.append('../../')
import sopa.src.models.odenet_cifar10.layers as cifar10_models
from sopa.src.models.odenet_cifar10.utils import *
from sopa.src.models.odenet_cifar10.data import get_cifar10_test_loader
from sopa.src.models.utils import fix_seeds
from sopa.src.solvers.utils import create_solver, noise_params, create_solver_ensemble_by_noising_params

# Build a model

In [3]:
# Load a checkpoint

# checkpoint_name = './checkpoints/fgsm_random_8_255_seed_102_checkpoint_6125.pth'
checkpoint_name = "./checkpoints/fgsm_random_8_255_smoothing_00125_seed_102_checkpoint_6125.pth"
checkpoint=torch.load(checkpoint_name)
config = Namespace(**checkpoint['wandb_config'])

print(f'Solvers used during training: {config.solvers}')

Solvers used during training: [['rk2', 'u', 8, -1.0, 0.5, -1]]


In [4]:
# Initialize Neural ODE model
from black_box import init_model_metanode

model = init_model_metanode(checkpoint)
model

MetaNODE(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): Identity()
  (layer1): MetaLayer(
    (blocks_res): Sequential(
      (0): PreBasicBlock(
        (bn1): Identity()
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): Identity()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (shortcut): Sequential()
      )
    )
    (blocks_ode): ModuleList(
      (0): MetaODEBlock(
        (rhs_func): PreBasicBlock2(
          (bn1): Identity()
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (shortcut): Sequential()
        )
      )
    )
  )
  (layer2): MetaLayer(
    (blocks_res): Sequential(
      (0): PreBasicBlock(
        (bn1): Identity()
     

# Build a data loader

In [5]:
data_root="./data"
test_loader = get_cifar10_test_loader(batch_size=32,
                                      data_root=data_root,
                                      num_workers=1,
                                      pin_memory=False,
                                      shuffle=False,
                                      download=True)
len(test_loader)

Files already downloaded and verified


312

# Evaluate the model

In [6]:
def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)

def accuracy(model,
             dataset_loader,
             device, solvers=None,
             solver_options=None,
             data_noise_std=None,
             solver_noise_params=Namespace(**{'noise_type': None,
                                              'noise_sigma': 0.0125,
                                              'noise_prob': 0.9})
            ):
    model.eval()
    model.to(device)
    total_correct = 0

    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)
        target_class = np.argmax(y, axis=1)
        
        ### Noise base solver parameters
        if solver_noise_params.noise_type is not None:
            for i in range(len(solvers)):
                solvers[i].u, solvers[i].v = noise_params(solvers[i].u0,
                                                          solvers[i].v0,
                                                          std=solver_noise_params.noise_sigma,
                                                          bernoulli_p=solver_noise_params.noise_prob,
                                                          noise_type=solver_noise_params.noise_type)
                solvers[i].build_ButcherTableau()

        with torch.no_grad():
            # Add noise to the input data:
            if (data_noise_std is not None) and (data_noise_std > 1e-12):
                x = x + data_noise_std * torch.randn_like(x)
                
            if solvers is not None:
                out = model(x, solvers, solver_options).cpu().detach().numpy()
            else:
                out = model(x).cpu().detach().numpy()
            predicted_class = np.argmax(out, axis=1)
            total_correct += np.sum(predicted_class == target_class)
            
        ### Denoise best solver parameters
        if solver_noise_params.noise_type is not None:
            for i in range(len(solvers)):
                solvers[i].u, solvers[i].v = solvers[i].u0, solvers[i].v0
                solvers[i].build_ButcherTableau()

    total = len(dataset_loader) * dataset_loader.batch_size
    torch.cuda.empty_cache()
    return total_correct / total

# Standalone 
- Use one solver during  inference


### How to define solver configuration

- Each solver is represented with *(method, parameterization, n_steps, step_size, u0, v0)*.

- If the solver has only one parameter *u0*, set *v0* to *-1*.

- *n_steps* and *step_size* are exclusive parameters, only one of them can be != -1.

- If *n_steps = step_size = -1*, automatic time grid_constructor is used.

For example, 

`--solvers 'rk2,u,8,-1,0.5,-1' ` defines 8-step 2-nd order Runge-Kutta method with Butcher tableau computed using *u=0.5*.

`--solver_mode 'standalone'` specifies that we propagate through Meta ODE block using *standalone* regime (i.e., one pre-defined solver).

In [7]:
device='cuda'
dtype=torch.float32

# Create a solver
val_solvers = [create_solver(method='rk2',
                             parameterization='u',
                             n_steps=8,
                             step_size=-1,
                             u0=0.5,
                             v0=-1,
                             dtype=dtype,
                             device=device)]
# Freeze solver params
for solver in val_solvers:
    solver.freeze_params()

val_solver_options = Namespace(**{'solver_mode': 'standalone'})

In [8]:
# Compute standard accuracy
accuracy_test = accuracy(model, test_loader, device=device,
                         solvers=val_solvers,
                         solver_options=val_solver_options)
accuracy_test

0.8283253205128205

# Solver smoothing

For each batch one solver is chosen from a group of solvers with infinite number of members.

We initialize a base solver and probibility distribution of noise, which is added to the base solver parameters before propagating through the batch. We can also specify a probability of applying noising procedure. 

For example, 

`--solvers 'rk2,u,8,-1,0.5,-1' --solver_mode 'standalone'`  means we use 8-step 2-nd order Runge-Kutta method (with Butcher tableau computed using *u=0.5*) as a base solver.

`--noise_type 'normal' --noise_sigma 0.0125 --noise_prob 0.9` means that for each batch  we sample RK2 solver parameter from  *N(0.5, 0.0125)* with probability 0.9, and use a base solver (RK2 with *u=0.5*) with probability 0.1.

In [9]:
# Compute standard accuracy
accuracy_test = accuracy(model, test_loader, device=device,
                         solvers=val_solvers,
                         solver_options=val_solver_options,
                         solver_noise_params=Namespace(**{'noise_type': 'normal',
                                              'noise_sigma': 0.0125,
                                              'noise_prob': .9}))

accuracy_test

0.8283253205128205

# Solver switching
For each batch one solver is chosen from a group of solvers with finite  number of members.

For example, 

`--solvers 'rk2,u,8,-1,0.5,-1;rk4,uv,8,-1,0.3,0.6' ` defines two solvers: 8-step 2-nd order Runge-Kutta method with Butcher tableau computed using *u=0.5*, and 8-step 4-th order Runge-Kutta method with Butcher tableau computed using *u=0.3, v=0.6*.

`--solver_mode 'switch'` specifies that we propagate through Meta ODE block using  *switching* regime (i.e., *sampling* regime with finite number of solvers). 

`--switch_probs '0.6,0.4'` means that we  choose first solver with probability *0.6*, and second solver with probability *0.4* to propagate through the Meta ODE block.

In [10]:
device='cuda'
dtype=torch.float32
val_solvers = [create_solver(method='rk2', parameterization='u', n_steps=8, step_size=-1, u0=0.5, v0=-1,
                             dtype=dtype, device=device),
              create_solver(method='rk4', parameterization='uv', n_steps=8, step_size=-1, u0=0.3, v0=0.6,
                             dtype=dtype, device=device),]
for solver in val_solvers:
    solver.freeze_params()
    
val_solver_options = Namespace(**{'solver_mode': 'switch', 'switch_probs': [0.6, 0.4]})

In [11]:
accuracy_test = accuracy(model, test_loader, device=device,
                         solvers=val_solvers, solver_options=val_solver_options)
accuracy_test

0.8286258012820513

# Solver ensembling
- Use several solvers durung inference.

- Outputs of ODE Block (obtained with different solvers) are averaged before propagating through the next layer.

`--solvers 'rk2,u,8,-1,0.5,-1;rk4,uv,8,-1,0.3,0.6' ` defines two solvers: 8-step 2-nd order Runge-Kutta method with Butcher tableau computed using *u=0.5*, and 8-step 4-th order Runge-Kutta method with Butcher tableau computed using *u=0.3, v=0.6*.

`--solver_mode 'ensemble'` specifies that we propagate through Meta ODE block using *ensembling* regime. 

`--ensemble_weights '0.6,0.4'` means that we  propagate through the Meta ODE block using two solvers, and we average theirs outputs with weights *0.6* and *0.4* before passing to the next model's block.

`--ensemble_prob 1.` means that for each batch we use ensembling regime for Meta ODE block with propability 1.

In [12]:
val_solvers = [create_solver(method='rk2', parameterization='u', n_steps=8, step_size=-1, u0=0.5, v0=-1,
                             dtype=dtype, device=device),
              create_solver(method='rk4', parameterization='uv', n_steps=8, step_size=-1, u0=0.3, v0=0.6,
                             dtype=dtype, device=device)]
for solver in val_solvers:
    solver.freeze_params()
    
val_solver_options = Namespace(**{'solver_mode': 'ensemble',
                                  'ensemble_prob':1, 'ensemble_weights': [0.6, 0.4]})

In [13]:
accuracy_test = accuracy(model, test_loader, device=device,
                         solvers=val_solvers, solver_options=val_solver_options)
accuracy_test

0.828125

# Model Ensembling
- Use several solvers durung inference.

- Model probabilites obtained via propagation with different solvers are averaged to get the final result.

In [17]:
def accuracy_ensemble(models, dataset_loader, device, solvers_solver_options_arr=None, data_noise_std=None):
    for model in models:
        model.eval()
    total_correct = 0

    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)
        target_class = np.argmax(y, axis=1)

        with torch.no_grad():
            # Add noise to the input:
            if (data_noise_std is not None) and (data_noise_std > 1e-12):
                x = x + data_noise_std * torch.randn_like(x)

            probs_ensemble = 0

            if solvers_solver_options_arr is not None:
                for n, (model, solvers_solver_options) in enumerate(
                        itertools.zip_longest(models, solvers_solver_options_arr, fillvalue=models[0])):
                    logits = model(x, **solvers_solver_options)
                    probs = nn.Softmax(dim=1)(logits).cpu().detach().numpy()
                    probs_ensemble = probs_ensemble + probs

            else:
                for n, model in enumerate(models):
                    logits = model(x)
                    probs = nn.Softmax()(logits).cpu().detach().numpy()
                    probs_ensemble = probs_ensemble + probs

            probs_ensemble /= (n + 1)

            predicted_class = np.argmax(probs_ensemble, axis=1)
            total_correct += np.sum(predicted_class == target_class)

    total = len(dataset_loader) * dataset_loader.batch_size
    return total_correct / total

In [15]:
val_solvers = [create_solver(method='rk2',
                             parameterization='u',
                             n_steps=8,
                             step_size=-1,
                             u0=0.5,
                             v0=-1,
                             dtype=dtype,
                             device=device)]
for solver in val_solvers:
    solver.freeze_params()

val_solver_options = Namespace(**{'solver_mode': 'standalone'})

ensemble_size = 2
solver_ensemble = create_solver_ensemble_by_noising_params(val_solvers[0],
                                                           ensemble_size=ensemble_size,
                                                           kwargs_noise={'std': 0.0125,
                                                                         'bernoulli_p': 1.,
                                                                         'noise_type': 'normal'})

solvers_solver_options_arr = [{'solvers': [solver], 'solver_options': val_solver_options}
                              for solver in solver_ensemble]

tensor([0.5217], device='cuda:0') None


In [16]:
accuracy_test = accuracy_ensemble([model], test_loader, device=device,
                                  solvers_solver_options_arr=solvers_solver_options_arr,)
accuracy_test

0.8283253205128205