# Experiment 1: FitzHugh-Nagumo Model

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Import the necessary libraries
import time 

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.utils.benchmark as benchmark

import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.numerics import odeint

import numpy as np
import lightning as pl
import matplotlib.pyplot as plt


In [None]:
# Intra-library imports
from models import *
import utils.ode as ode

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set up random seed for reproducibility
torch.manual_seed(0)

In [None]:
## Define the Residue for the ODE

# Instead of defining an ODE on the interval [0,10], 
# we define it on the interval [0,1] and speed up 
# the dynamics by a factor of 10. This is done 
# to aid PINNs in learning the dynamics.
speed = 10.0
eq = ode.FitzHughNagumo(speed = speed, mu = 1.0)
eq_name = eq.name
order = eq.order

# Time and initial conditions.
# We use 100 points and skew them slightly
# to the left to aid PINNs with causality.
nt = 100
t = torch.linspace(0, 1, nt+1, requires_grad=False)**1.1
x0 = torch.tensor([2.0,2-8/3], requires_grad = False)

In [None]:
# Baseline solution
f = lambda t, x: eq.f(x)
_ , sol = odeint(f, x0, t, solver ='dopri5')
_ , sol_extrapolated = odeint(f, x0, 2*t, solver ='dopri5')

# Plot the solution
plt.plot(speed*t,sol, label = 'X')
plt.title(eq_name)

In [None]:
# Build (trivial) datasets
train = data.TensorDataset(x0)
trainloader = data.DataLoader(train, batch_size=len(train), shuffle=False)

In [None]:
# Number of experiments to run
experiments = 5

**Learner**

In [None]:
# Train with Pytorch Lightning

class PINNLearner(pl.LightningModule):
    def __init__(self, model:nn.Module, *,
                  trainloader: data.DataLoader = trainloader,
                  lr = 0.002):
        super().__init__()
        self.model = model
        self.trainloader = trainloader
        self.lr = lr
        
        self.eq = eq
        self.register_buffer('t', t)

    def training_step(self, batch, batch_idx):
        # Calculate the ODE loss
        x0 = batch[0]
        t = self.t
        model = self.model

        x = model(x0,t)
        x_dot = model.diff(x0,t)
        loss = nn.MSELoss()(x_dot, eq.f(x))
        
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger = False)
        return {'loss': loss}   
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr,amsgrad=False,betas=(0.95,0.99),eps=1e-08,weight_decay=0)

    def train_dataloader(self):
        return self.trainloader

In [None]:
def run_experiment(learner):
    trainer = pl.Trainer(max_epochs=2000)
    model = learner.model
    
    start_time = time.time()
    trainer.fit(learner)
    end_time = time.time()
    fit_time = end_time - start_time

    
    trajectory = model(x0, t).squeeze(1).detach()
    loss_ground = torch.mean((trajectory - sol)**2)

    trajectory_extrapolated = model(x0, 2*t).squeeze(1).detach()
    loss_extra = torch.mean((trajectory_extrapolated - sol_extrapolated)**2)

    return fit_time, loss_ground, loss_extra


## MLP

In [None]:
fit_times = []
losses_ground = []
losses_extrapolation = []

width = 32

for i in range(experiments):
        print(f'Experiment {i+1}')

        f = CustomMLP(3,width,width,width,2,fourier_feature=True)
        model = SemiFlow(f)
        learner = PINNLearner(model, trainloader=trainloader)
        
        fit_time, loss_ground, loss_extra = run_experiment(learner)
        fit_times.append(fit_time)
        losses_ground.append(loss_ground)
        losses_extrapolation.append(loss_extra)


mean_ground = np.mean(losses_ground)
std_ground = np.std(losses_ground)
mean_extra = np.mean(losses_extrapolation)
std_extra = np.std(losses_extrapolation)
mean_fit_time = np.mean(fit_times)
std_fit_time = np.std(fit_times)

mlp_results = {'mean_ground': mean_ground, 'std_ground': std_ground,
                'mean_extra': mean_extra, 'std_extra': std_extra,
                'mean_fit_time': mean_fit_time, 'std_fit_time': std_fit_time,
                'model': model}

mlp_results

## Neural Conjugate Flows

In [None]:
# Augment the system with a copy of x0
twin_times = 1+1
x0_twinned = torch.cat([x0]*twin_times, dim=-1)
train_twinned = data.TensorDataset(x0_twinned)
trainloader_twinned = data.DataLoader(train_twinned, batch_size=len(train_twinned), shuffle=False)

In [None]:
# Initialize the matrix
M0 = ncf_matrix_init(eq,x0,pad_mode='twin',pad_times=twin_times)
M0 = (M0-M0.T)/2

In [None]:
# Implement special learner for duplicated system

class NCFLearner(pl.LightningModule):
    def __init__(self, model:nn.Module, *,
                  trainloader: data.DataLoader = trainloader_twinned,
                  weight = 1.0,
                  avg_weight = 0.0,
                  lr = 0.002):
        super().__init__()
        self.model = model
        self.trainloader = trainloader
        self.weight = weight
        self.avg_weight = avg_weight
        self.lr = lr
        
        self.eq = eq
        self.register_buffer('t', t)

    def training_step(self, batch, batch_idx):
        # Have the option to consider the average loss 
        # between the two copies of the system
        x0 = batch[0]
        t = self.t
        model = self.model
        mse = nn.MSELoss()

        x = model(x0,t)
        x_dot = model.diff(x0,t)
        x1,x2 = x.chunk(2,dim=-1)
        x1_dot,x2_dot = x_dot.chunk(2,dim=-1)
        x_mean = (x1 + x2)/2
        x_dot_mean = (x1_dot + x2_dot)/2
        loss =  self.weight*(mse(x1_dot, eq.f(x1)) + mse(x2_dot, eq.f(x2)))/2 + self.avg_weight*mse(x_dot_mean, eq.f(x_mean))

        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger = False)
        return {'loss': loss}   
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr,amsgrad=False,betas=(0.95,0.99),eps=1e-08,weight_decay=0)

    def train_dataloader(self):
        return self.trainloader

In [None]:
def ncf_experiment(learner):
    trainer = pl.Trainer(max_epochs=2000)
    model = learner.model
    
    start_time = time.time()
    trainer.fit(learner)
    end_time = time.time()
    fit_time = end_time - start_time
    
    trajectory = model(x0_twinned, t).squeeze(1).detach()
    traj1, traj2 = trajectory.chunk(2, dim=-1)
    trajectory = (traj1 + traj2)/2
    loss_ground = torch.mean((trajectory - sol)**2)

    trajectory_extrapolated = model(x0_twinned, 2*t).squeeze(1).detach()
    traj1, traj2 = trajectory_extrapolated.chunk(2, dim=-1)
    trajectory_extrapolated = (traj1 + traj2)/2
    loss_extra = torch.mean((trajectory_extrapolated - sol_extrapolated)**2)

    return fit_time, loss_ground, loss_extra


### Neural Conjugate (No Topology)

In [None]:
fit_times = []
losses_ground = []
losses_extrapolation = []

twin_order = order*twin_times
width = 32

for i in range(experiments):
        print(f'Experiment {i+1}')

        Psi = LinearFlow(twin_order**2, M0 = M0, omega_zero=0.075)
        s1 = CustomMLP(order,width,width,order)
        coupling1 = AdditiveCouplingLayer(s1,initial_step_size=0.0075)
        s2 = CustomMLP(order,width,width,order)
        coupling2 = AdditiveCouplingLayer(s2, initial_step_size=0.0075,orientation='skew')        
        model = NeuralConjugate([coupling1,coupling2], Psi, pad='no')
        learner = NCFLearner(model, trainloader=trainloader_twinned,lr = 0.0022)

        
        fit_time, loss_ground, loss_extra = ncf_experiment(learner)
        fit_times.append(fit_time)
        losses_ground.append(loss_ground)
        losses_extrapolation.append(loss_extra)


mean_ground = np.mean(losses_ground)
std_ground = np.std(losses_ground)
mean_extra = np.mean(losses_extrapolation)
std_extra = np.std(losses_extrapolation)
mean_fit_time = np.mean(fit_times)
std_fit_time = np.std(fit_times)

ncf_results = {'mean_ground': mean_ground, 'std_ground': std_ground,
                'mean_extra': mean_extra, 'std_extra': std_extra,
                'mean_fit_time': mean_fit_time, 'std_fit_time': std_fit_time,
                'model': model}

ncf_results

### Neural Conjugate (Enforced Topology)

In [None]:
fit_times = []
losses_ground = []
losses_extrapolation = []

twin_order = order*twin_times
width = 32

for i in range(experiments):
        print(f'Experiment {i+1}')

        Psi = LinearFlow(twin_order**2, M0 = M0, omega_zero=0.1,lie_algebra='skew_symmetric')
        s1 = CustomMLP(order,width,width,order)
        coupling1 = AdditiveCouplingLayer(s1,initial_step_size=0.1)
        s2 = CustomMLP(order,width,width,order)
        coupling2 = AdditiveCouplingLayer(s2, initial_step_size=0.1,orientation='skew')        
        model = NeuralConjugate([coupling1,coupling2], Psi, pad='no')
        learner = NCFLearner(model, trainloader=trainloader_twinned,lr = 0.002)
        
        fit_time, loss_ground, loss_extra = ncf_experiment(learner)
        fit_times.append(fit_time)
        losses_ground.append(loss_ground)
        losses_extrapolation.append(loss_extra)


mean_ground = np.mean(losses_ground)
std_ground = np.std(losses_ground)
mean_extra = np.mean(losses_extrapolation)
std_extra = np.std(losses_extrapolation)
mean_fit_time = np.mean(fit_times)
std_fit_time = np.std(fit_times)

ncf_T_results = {'mean_ground': mean_ground, 'std_ground': std_ground,
                'mean_extra': mean_extra, 'std_extra': std_extra,
                'mean_fit_time': mean_fit_time, 'std_fit_time': std_fit_time,
                'model': model}

ncf_T_results

## Plots

In [None]:
fig, ax = plt.subplots()

mlp_extrapolated = mlp_results["model"](x0, 2*t).squeeze(1).detach()
ncf_extrapolated = ncf_results["model"](x0_twinned, 2*t).squeeze(1).detach()
ncf_T_extrapolated = ncf_T_results["model"](x0_twinned, 2*t).squeeze(1).detach()

plt.plot(2*speed*t,sol_extrapolated[...,0], label = 'Baseline')
plt.plot(2*speed*t,mlp_extrapolated[...,0], label = 'MLP', marker = 'o', linestyle='--')
plt.plot(2*speed*t,ncf_extrapolated[...,0], label = 'NCF', marker = '+', linestyle='-.')
plt.plot(2*speed*t,ncf_T_extrapolated[...,0], label = 'NCF_T', marker = 'x', linestyle=':')

lims = [-3,8]
ax.set_ylim(lims)
ax.set_xlim([0,20])
ax.set_aspect(3/4)
ax.vlines(10.0,*lims, linestyles='dashed', colors='gray')
plt.title("FitzHugh-Nagumo Neuron Model")
plt.legend(["Baseline","MLP","NCF","NCF-T", "Training Limit"])
plt.xlabel("Time (ms)")
plt.ylabel("Neuron Potential (mV)")