## Cell Delay Prediction by Augmented Neural ODEs

In [96]:
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

device = torch.device('cpu')

### Create an ODE function

In [97]:
from anode.models import ODENet
from anode.training import Trainer

data_dim = 6
hidden_dim = 16

### Prepare Dataset
用一个向量表示数据。

In [101]:
class CellData():
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        vector = self.data[index]
        labels = vector[:2]
        inputs = vector[2:]
        return inputs, labels
    
    def __len__(self):
        return len(self.data)
    
data = torch.tensor([
    [1.0, 2.0, 3.0, 4.0, 5.0, 0, 0, 1],
    [6.0, 7.0, 8.0, 9.0, 10.0, 0, 1, 0],
    [11.0, 12.0, 13.0, 14.0, 15.0, 0, 1, 1]
])

dataset = CellData(data)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

for input, label in dataloader:
    print("模型输入:", input)
    print("标签:", label)

模型输入: tensor([[ 8.,  9., 10.,  0.,  1.,  0.],
        [13., 14., 15.,  0.,  1.,  1.]])
标签: tensor([[ 6.,  7.],
        [11., 12.]])
模型输入: tensor([[3., 4., 5., 0., 0., 1.]])
标签: tensor([[1., 2.]])


### Model Definition

In [102]:
model = ODENet(device, data_dim, hidden_dim, time_dependent=False, non_linearity='relu')

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

### Training

In [103]:
num_epochs = 20

trainer = Trainer(model, optimizer, device)
trainer.train(dataloader, num_epochs)


Iteration 0/2
Loss: 0.959
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 1: 0.588

Iteration 0/2
Loss: 0.416
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 2: 0.702

Iteration 0/2
Loss: 0.790
NFE: 20
BNFE: 0
Total NFE: 20
Epoch 3: 0.464

Iteration 0/2
Loss: 0.728
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 4: 0.427

Iteration 0/2
Loss: 0.452
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 5: 0.491

Iteration 0/2
Loss: 0.644
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 6: 0.391


  return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)



Iteration 0/2
Loss: 0.501
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 7: 0.422

Iteration 0/2
Loss: 0.525
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 8: 0.402

Iteration 0/2
Loss: 0.552
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 9: 0.388

Iteration 0/2
Loss: 0.571
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 10: 0.401

Iteration 0/2
Loss: 0.205
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 11: 0.590

Iteration 0/2
Loss: 0.566
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 12: 0.415

Iteration 0/2
Loss: 0.207
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 13: 0.592

Iteration 0/2
Loss: 0.562
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 14: 0.409

Iteration 0/2
Loss: 0.602
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 15: 0.387

Iteration 0/2
Loss: 0.557
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 16: 0.392

Iteration 0/2
Loss: 0.570
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 17: 0.387

Iteration 0/2
Loss: 0.557
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 18: 0.380

Iteration 0/2
Loss: 0.545
NFE: 14
BNFE: 0
Total NFE: 14
Epoch 19: 0.389

Iteration 0/2
Loss: 0.558
NFE: 14
BNFE: 0
Total NFE: 