## Euler Integrator using ResNet architecture

In [1]:
import torch
import torch.nn as nn

In [14]:
class EulerBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.l1 = nn.Linear(dim, dim)
        self.l2 = nn.Linear(dim, dim)

        self.relu = nn.ReLU()

    def forward(self, x):

        out = self.l1(x)
        out = self.relu(out)

        out = self.l2(out)

        return out

In [17]:
class EulerIntegrator(nn.Module):
    def __init__(self, dim, n, step):
        super().__init__()

        self.block = EulerBlock(dim)
        self.n = n
        self.h = step

    def forward(self, x):

        for _ in range(self.n):
            x = x + self.h + self.block(x)

        return x

In [18]:
x = torch.randn(32, 64)

model = EulerIntegrator(dim=64, n=5, step=0.05)
out = model(x)
print(out.shape)  # (32, 64)

torch.Size([32, 64])
