# dy/dx = y

In [None]:
from torchdyn.core import NeuralODE
from torchdyn.datasets import *
from torchdyn import *

%load_ext autoreload
%autoreload 2

# quick run for automated notebook validation
dry_run = False

## Generating Data

In [None]:
x = torch.Tensor([[a] for a in torch.linspace(-20, 20, 81)])
y = torch.exp(x.flatten())
#print(x)
#print(y)

In [None]:
import torch
import torch.utils.data as data

device = torch.device("cpu")

x_train = torch.Tensor(x).to(device)
y_train = torch.Tensor(y.to(device))
train = data.TensorDataset(x_train, y_train)
trainloader = data.DataLoader(train, batch_size=len(x), shuffle=True)

## Learner

In [None]:
import torch.nn as nn
import pytorch_lightning as pl

class Exponential_Learner(pl.LightningModule):
    def __init__(self, t_span:torch.Tensor, model:nn.Module):
        super().__init__()
        self.model, self.t_span = model, t_span

    # one entire forward pass of the NN
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch      
        t_eval, y_hat = self.model(x, self.t_span)
        y_hat = y_hat[-1] # select last point of solution trajectory
        loss = nn.MSELoss()(y_hat, y)
        return {'loss': loss}   
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.01)

    def train_dataloader(self):
        return trainloader

## Train

In [None]:
f = nn.Sequential(
        nn.Linear(1, 16),
        nn.ELU(),
        nn.Linear(16, 1)
        nn.ELU()
        nn.Linear(1, 16
        nn.ELU(),
        nn.Linear(16, 1)
    )

model = NeuralODE(f, sensitivity='adjoint', solver='rk4', solver_adjoint='dopri5', atol_adjoint=1e-4, rtol_adjoint=1e-4).to(device)

In [None]:
t_span = torch.linspace(0,1,100)
learn = Exponential_Learner(t_span, model)
if dry_run: trainer = pl.Trainer(min_epochs=1, max_epochs=1)
else: trainer = pl.Trainer(min_epochs=200, max_epochs=300)
trainer.fit(learn)

## Plot Results

In [None]:
t_eval, trajectory = model(x_train, t_span)
trajectory = trajectory.detach().cpu()

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(5,5))
ax0 = fig.add_subplot(111)
for i in range(81):
    ax0.plot(t_span, trajectory[:,i], color='black', alpha=.1);

In [None]:
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
ax.plot(trajectory[0, :], trajectory[-1, :], color='black');

In [None]:
print(len(trajectory[0]))
for i in range(len(trajectory[0])):
    print("x: " + str(trajectory[0][i].item()) + " y_guess: " + str(trajectory[-1][i].item()) + " y: " + str(y[i]))