In [1]:
from MIOFlow.models import make_model
import torch

In [2]:
batch_size = 32
feature_dim = 10
output_dim = 10

x = torch.randn(batch_size, feature_dim)
# m0 = torch.ones(batch_size, 1)
# xm = torch.cat([x, m0], dim=1)
t = torch.linspace(0, 1, 17)

model = make_model(feature_dims=feature_dim, output_dims=output_dim, which='ode_growth_rate')

In [3]:
xt, mt = model(x, t)

In [4]:
xt.shape, mt.shape

(torch.Size([32, 10]), torch.Size([32]))

In [5]:
xtseq, mtseq = model(x, t, return_whole_sequence=True)

In [6]:
xtseq.shape

torch.Size([17, 32, 10])

In [7]:
mtseq.shape

torch.Size([17, 32])

In [8]:
torch.nn.functional.softmax(mt, dim=-1) * mt.shape[-1]

tensor([0.9443, 1.0173, 0.7501, 0.5969, 1.0634, 0.7985, 1.0667, 1.3186, 0.9306,
        1.1925, 1.0128, 1.3801, 0.9110, 0.8183, 1.1681, 0.9912, 1.1354, 1.1225,
        1.2475, 0.7734, 1.0399, 0.7493, 0.8734, 1.1611, 0.9538, 0.7815, 0.9653,
        1.1884, 1.0213, 0.9292, 1.0074, 1.0902], grad_fn=<MulBackward0>)

In [9]:
torch.exp(mt)/torch.exp(mt).sum(dim=0)

tensor([0.0295, 0.0318, 0.0234, 0.0187, 0.0332, 0.0250, 0.0333, 0.0412, 0.0291,
        0.0373, 0.0316, 0.0431, 0.0285, 0.0256, 0.0365, 0.0310, 0.0355, 0.0351,
        0.0390, 0.0242, 0.0325, 0.0234, 0.0273, 0.0363, 0.0298, 0.0244, 0.0302,
        0.0371, 0.0319, 0.0290, 0.0315, 0.0341], grad_fn=<DivBackward0>)

In [10]:
torch.nn.functional.softmax(mtseq, dim=-1) * mtseq.shape[-1]

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [0.9943, 1.0019, 0.9835, 0.9709, 1.0050, 0.9877, 1.0044, 1.0180, 0.9963,
         1.0132, 1.0026, 1.0214, 0.9941, 0.9887, 1.0112, 0.9984, 1.0074, 1.0069,
         1.0149, 0.9871, 1.0027, 0.9836, 0.9919, 1.0141, 0.9986, 0.9865, 1.0007,
         1.0101, 1.0015, 0.9941, 1.0016, 1.0068],
        [0.9891, 1.0036, 0.9671, 0.9421, 1.0098, 0.9753, 1.0088, 1.0363, 0.9925,
         1.0264, 1.0050, 1.0430, 0.9883, 0.9773, 1.0223, 0.9969, 1.0150, 1.0138,
         1.0300, 0.9739, 1.0053, 0.9672, 0.9838, 1.0278, 0.9970, 0.9730, 1.0011,
         1.0204, 1.0029, 0.9885, 1.0031, 1.0134],
        [0.9842, 1.0052, 0.9508, 0.9138, 1.0145, 0.9628, 1.0131, 1.0548, 0.9886,
         1.0394, 1.0072, 1.0650, 0.9826,

In [22]:
xtseq.shape

torch.Size([17, 32, 10])

In [26]:
mtseq.unsqueeze(2).shape

torch.Size([17, 32, 1])

In [27]:
xmseq = torch.cat([xtseq, mtseq.unsqueeze(2)], dim=-1)
xmseq.shape

torch.Size([17, 32, 11])

In [53]:
dxdts = torch.stack([model.func(t[i], xmseq[i]) for i in range(len(t))])

In [55]:
dxdts.shape

torch.Size([17, 32, 11])

In [19]:
mt.unsqueeze(1).shape

torch.Size([32, 1])

In [21]:
model.func

ToyODE(
  (seq): Sequential(
    (0): Linear(in_features=14, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=11, bias=True)
  )
)

In [30]:
xt.shape

torch.Size([32, 10])

In [32]:
mt.shape

torch.Size([32])

In [39]:
xm = torch.cat([xt, mt.unsqueeze(-1)], dim=-1)
dxdt = model.func(t[-1], xm)

In [40]:
dxdt[...,:-1].shape

torch.Size([32, 10])

In [41]:
mtseq.shape

torch.Size([17, 32])

In [43]:
xtseq.shape

torch.Size([17, 32, 10])

In [45]:
t.shape

torch.Size([17])

In [46]:
t[[0,1,2],...]

tensor([0.0000, 0.0625, 0.1250])

In [47]:
mt.shape

torch.Size([32])

In [52]:
(dxdt * mt.unsqueeze(-1)).shape

torch.Size([32, 11])

In [56]:
dxdts.shape

torch.Size([17, 32, 11])

In [57]:
mtseq.shape

torch.Size([17, 32])

In [58]:
(dxdts * mtseq.unsqueeze(-1)).shape

torch.Size([17, 32, 11])

In [59]:
torch.nn.functional.softmax(mtseq, dim=-1).shape

torch.Size([17, 32])

In [60]:
mtseq.shape[-1]

32

In [61]:
torch.nn.functional.softmax(mt, dim=-1).shape

torch.Size([32])

In [62]:
mt.shape[-1]

32