In [1]:
from MIOFlow.models import make_model
import torch

In [2]:
batch_size = 4
feature_dim = 5
output_dim = 5

x = torch.randn(batch_size, feature_dim)
# m0 = torch.ones(batch_size, 1)
# xm = torch.cat([x, m0], dim=1)
t = torch.linspace(0, 1, 7)

model = make_model(feature_dims=feature_dim, which='sde_growth_rate', method='euler')

In [3]:
output_dim

5

In [4]:
model.gunc

ToyODE(
  (seq): Sequential(
    (0): Linear(in_features=9, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=6, bias=True)
  )
)

In [5]:
xt, mt = model(x, t)

In [6]:
xt.shape, mt.shape

(torch.Size([4, 5]), torch.Size([4]))

In [7]:
xtseq, mtseq = model(x, t, return_whole_sequence=True)

In [8]:
xtseq.shape

torch.Size([7, 4, 5])

In [9]:
mtseq.shape

torch.Size([7, 4])

In [10]:
torch.nn.functional.softmax(mt, dim=-1) * mt.shape[-1]

tensor([1.1376, 1.1493, 0.7103, 1.0029], grad_fn=<MulBackward0>)

In [11]:
torch.exp(mt)/torch.exp(mt).sum(dim=0)

tensor([0.2844, 0.2873, 0.1776, 0.2507], grad_fn=<DivBackward0>)

In [12]:
torch.nn.functional.softmax(mtseq, dim=-1) * mtseq.shape[-1]

tensor([[1.0000, 1.0000, 1.0000, 1.0000],
        [0.9930, 0.9549, 1.0697, 0.9824],
        [1.0720, 1.0270, 0.8799, 1.0211],
        [1.1404, 1.0398, 0.7630, 1.0568],
        [1.1235, 1.1386, 0.6304, 1.1075],
        [1.0831, 1.1595, 0.6577, 1.0997],
        [1.1062, 1.1357, 0.6268, 1.1312]], grad_fn=<MulBackward0>)

In [13]:
xtseq.shape

torch.Size([7, 4, 5])

In [14]:
mtseq.unsqueeze(2).shape

torch.Size([7, 4, 1])

In [15]:
xmseq = torch.cat([xtseq, mtseq.unsqueeze(2)], dim=-1)
xmseq.shape

torch.Size([7, 4, 6])

In [16]:
dxdts = torch.stack([model.func(t[i], xmseq[i]) for i in range(len(t))])

In [17]:
dxdts.shape

torch.Size([7, 4, 6])

In [18]:
mt.unsqueeze(1).shape

torch.Size([4, 1])

In [19]:
model.func

ToyODE(
  (seq): Sequential(
    (0): Linear(in_features=9, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=6, bias=True)
  )
)

In [20]:
xt.shape

torch.Size([4, 5])

In [21]:
mt.shape

torch.Size([4])

In [22]:
xm = torch.cat([xt, mt.unsqueeze(-1)], dim=-1)
dxdt = model.func(t[-1], xm)

In [23]:
dxdt[...,:-1].shape

torch.Size([4, 5])

In [24]:
mtseq.shape

torch.Size([7, 4])

In [25]:
xtseq.shape

torch.Size([7, 4, 5])

In [26]:
t.shape

torch.Size([7])

In [27]:
t[[0,1,2],...]

tensor([0.0000, 0.1667, 0.3333])

In [28]:
mt.shape

torch.Size([4])

In [29]:
(dxdt * mt.unsqueeze(-1)).shape

torch.Size([4, 6])

In [30]:
dxdts.shape

torch.Size([7, 4, 6])

In [31]:
mtseq.shape

torch.Size([7, 4])

In [32]:
(dxdts * mtseq.unsqueeze(-1)).shape

torch.Size([7, 4, 6])

In [33]:
torch.nn.functional.softmax(mtseq, dim=-1).shape

torch.Size([7, 4])

In [34]:
mtseq.shape[-1]

4

In [35]:
torch.nn.functional.softmax(mt, dim=-1).shape

torch.Size([4])

In [36]:
mt.shape[-1]

4

In [37]:
def func_end_0(t, x):
    xm = model.func(t, x)
# mt[[0,2,4,6,8]] = 0.
model.func(t[-1], xm).shape

torch.Size([4, 6])

In [38]:
xm[...,-1].shape

torch.Size([4])

In [39]:
xm.shape

torch.Size([4, 6])

In [40]:
xmseq.shape

torch.Size([7, 4, 6])

In [41]:
xm[[0,2],-1] = 0.

In [42]:
xm

tensor([[-0.4173, -0.0048, -0.4574, -1.4883,  1.0279,  0.0000],
        [-0.3578, -0.1944,  1.4519, -0.2590,  0.7232,  0.1228],
        [ 2.6331, -0.0379,  1.4302,  3.1891,  2.3535,  0.0000],
        [-0.7105, -0.5907, -0.3341, -0.3830,  0.7719, -0.0135]],
       grad_fn=<CopySlices>)

In [43]:
dxmdt = model.func(t[-1], xm)

In [44]:
dxmdt

tensor([[-0.3312,  0.2495, -0.0454,  0.2732,  0.2381,  0.0000],
        [-0.1854,  0.1795,  0.4145,  0.0551,  0.0860, -0.0133],
        [ 0.2033,  0.2764,  0.9451,  0.1849,  0.8203,  0.0000],
        [-0.1668,  0.1031,  0.0243,  0.0821,  0.1631,  0.0705]],
       grad_fn=<CopySlices>)