In [5]:
import torch; import torch.nn as nn
import sys
sys.path.append('../')
from torchdyn.models import *

In [94]:
import torch
import torch.nn as nn

class HybridNeuralDE(nn.Module):
    def __init__(self, jump, flow, readout, s_spans=None, residual=True, reverse=False):
        super().__init__()
        self.flow, self.jump, self.readout = flow, jump, readout
        self.reverse, self.residual = reverse, residual
        self.s_spans = s_spans
        
    def forward(self, x):
        "s_spans: iterable "
        if self.s_spans:
            assert len(self.s_spans) == x.shape[0]-1
            
        h = self._init_hidden(x)
        Y = torch.zeros_like(x)
        
        if self.reverse: x_t = x_t.flip(0)
        for i, x_t in enumerate(x): 
            h = self.jump(x_t, h) + h if self.residual else self.jump(x_t, h)
            if self.s_spans: self.flow.s_span = self.s_spans[i]
            h = self.flow(h)
            Y[i] = h
        return self.readout(Y[-1])
        
    def _init_hidden(self, x):
        # determine shape of hidden `h` based on x
        if not hasattr(self, 'hidden_dim'):
            self.hidden_dim = x.shape[2:] # L, B, [...]  
        return torch.zeros((x.shape[1], *self.hidden_dim)).to(x.device) # B, [...]
    
    def hidden_trajectory(self, x, s_spans):
        h = self._init_hidden(x)
        
        # find global num mesh points across all s_spans
        mesh = sum(map(len, s_spans))
        
        Y = torch.zeros((mesh, *x.shape[1:]))
        if self.reverse: x_t = x_t.flip(0)
        for i, x_t in enumerate(x): 
            h = self.jump(x_t, h) + h if self.residual else self.jump(x_t, h)
            h = self.flow.trajectory(h, s_spans[i])
            print(Y.shape, h.shape)
            Y[i] = h
        return torch.readout(Y)
        
        
class NeuralCDE(nn.Module):
    def __init__(self):
        super().__init__()
#         pass
    
#     def forward(self, x):
#         pass

In [95]:
s1, s2 = torch.linspace(0, 1, 100), torch.linspace(0, 2, 10)

In [96]:
f = nn.Linear(2, 2)

flow = NeuralODE(f)
jump = nn.RNNCell(2, 2)
readout = nn.Linear(2, 2)

net = HybridNeuralDE(jump, flow, readout)

In [97]:
x0 = torch.randn(10, 2, 2)

In [98]:
s_spans = [torch.linspace(0, 1, i+1) for i in range(10)]

In [99]:
net.hidden_trajectory(x0, s_spans)

torch.Size([55, 2, 2]) torch.Size([1, 2, 2])


RuntimeError: Input batch size 2 doesn't match hidden batch size 1