In [1]:
from MIOFlow.models import make_model
from MIOFlow.losses import MMD_loss, OT_loss, Density_loss, Local_density_loss, density_specified_OT_loss, EnergyLoss, EnergyLossGrowthRate, EnergyLossSeq, EnergyLossGrowthRateSeq
from MIOFlow.utils import generate_steps, set_seeds, config_criterion
import torch

In [2]:
batch_size = 5
feature_dim = 3
output_dim = 3

x = torch.randn(batch_size, feature_dim)
m0 = torch.ones(batch_size, 1)
m0[[0,3],:] = 0.
xm = torch.cat([x, m0], dim=1)
t = torch.linspace(0, 1, 7)

model = make_model(feature_dims=feature_dim, output_dims=output_dim, which='ode_growth_rate')
energy_weighted = True
energy_detach_m = True
energy_loss_growth_rate = EnergyLossGrowthRate(weighted=energy_weighted, detach_m=energy_detach_m)
energy_loss_growth_rate_seq = EnergyLossGrowthRateSeq(weighted=energy_weighted, detach_m=energy_detach_m)

criterion_name = 'density_specified_ot'
criterion = config_criterion(criterion_name)

hinge_value = 0.01
density_fn = Density_loss(hinge_value)

In [3]:
dxmdt = model.func(t[-1], xm)

In [4]:
xt, mt = model(x, t)

In [5]:
xt

tensor([[-0.4248,  0.2344, -0.2809],
        [ 0.8790,  0.2571, -1.0647],
        [-0.7393,  0.2159, -0.0567],
        [-1.0216, -0.1463,  0.4386],
        [-0.2821,  0.5557,  0.0513]], grad_fn=<SliceBackward0>)

In [6]:
mt = mt.detach()
# mt[[0,3]] = 0.
mt = torch.zeros_like(mt)
mt.requires_grad = True

In [7]:
xtseq, mtseq = model(x, t, return_whole_sequence=True)

In [8]:
otl = criterion(xt, x, mt)
dl = density_fn(xt, x, pre_softmax_weights=mt.detach(), top_k=2)
eloss, emloss = energy_loss_growth_rate(model, xt, mt.detach(), t[-1])
ml = (torch.square(mt.mean(axis=-1) - model.m_init)).mean() 
ml2 = (torch.square(mt - model.m_init)).mean()
loss = otl + dl + eloss + emloss + ml + ml2
print(f"OT loss: {otl}, Density loss: {dl}, Energy loss: {eloss}, Energy loss growth rate: {emloss}, Mean loss: {ml}, Mean loss 2: {ml2}")

OT loss: 0.1780138611793518, Density loss: 0.41248807311058044, Energy loss: 0.03557818382978439, Energy loss growth rate: 0.0, Mean loss: 1.0, Mean loss 2: 1.0


In [9]:
loss.backward()

In [10]:
mt.grad

tensor([-0.8000, -0.8000, -0.8000, -0.8000, -0.8000])

In [11]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Gradient for {name}: {param.grad.max()}")

Gradient for func.seq.0.weight: 0.14733898639678955
Gradient for func.seq.0.bias: 0.14599904417991638
Gradient for func.seq.2.weight: 0.24677564203739166
Gradient for func.seq.2.bias: 0.4118116497993469


In [12]:
x

tensor([[-0.3208,  0.2013, -0.1972],
        [ 1.0577,  0.2151, -1.0301],
        [-0.6896,  0.1817,  0.0713],
        [-1.0455, -0.2010,  0.6102],
        [-0.1015,  0.5418,  0.0678]])

In [13]:
torch.cdist(xt, x) * mt

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], grad_fn=<MulBackward0>)

In [14]:
mt.shape

torch.Size([5])

In [15]:
mt.unsqueeze(0).shape

torch.Size([1, 5])

In [16]:
torch.cdist(xt[:3,:], x) * mt[:3].unsqueeze(-1)

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], grad_fn=<MulBackward0>)

In [18]:
mt.mean(axis=-1)

tensor(0., grad_fn=<MeanBackward1>)

In [19]:
mtseq.shape

torch.Size([7, 5])

In [26]:
threshold = mtseq.mean(dim=-1, keepdim=True)

In [28]:
torch.where(mtseq <= threshold, torch.tensor(0., dtype=mtseq.dtype, device=mtseq.device), mtseq)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0003, 1.0025, 0.0000, 0.0000, 0.0000],
        [1.0018, 1.0044, 0.0000, 0.0000, 0.0000],
        [1.0033, 1.0061, 0.0000, 0.0000, 0.0000],
        [1.0045, 1.0073, 0.0000, 0.0000, 0.0000],
        [1.0053, 1.0083, 0.9915, 0.0000, 0.9914],
        [1.0062, 1.0098, 0.9911, 0.0000, 0.9912]], grad_fn=<WhereBackward0>)

In [32]:
mtseq[mtseq > threshold]

tensor([1.0003, 1.0025, 1.0018, 1.0044, 1.0033, 1.0061, 1.0045, 1.0073, 1.0053,
        1.0083, 0.9915, 0.9914, 1.0062, 1.0098, 0.9911, 0.9912],
       grad_fn=<IndexBackward0>)

In [33]:
mtseq

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0003, 1.0025, 0.9970, 0.9904, 0.9964],
        [1.0018, 1.0044, 0.9946, 0.9818, 0.9944],
        [1.0033, 1.0061, 0.9929, 0.9738, 0.9931],
        [1.0045, 1.0073, 0.9920, 0.9664, 0.9921],
        [1.0053, 1.0083, 0.9915, 0.9594, 0.9914],
        [1.0062, 1.0098, 0.9911, 0.9529, 0.9912]], grad_fn=<ReluBackward0>)

In [34]:
threshold

tensor([[1.0000],
        [0.9973],
        [0.9954],
        [0.9938],
        [0.9924],
        [0.9912],
        [0.9902]], grad_fn=<MeanBackward1>)

In [36]:
mtseq > threshold

tensor([[False, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False,  True],
        [ True,  True,  True, False,  True]])

In [37]:
xt.shape

torch.Size([5, 3])

In [38]:
mt.shape

torch.Size([5])

In [51]:
mt1 = torch.randn_like(mt)
mask = mt1 > (-1 * mt1.mean(dim=-1, keepdim=True))

In [53]:
mt1[mask]

tensor([0.2280, 1.1005])

In [55]:
xt[mask]

tensor([[ 0.8790,  0.2571, -1.0647],
        [-0.2821,  0.5557,  0.0513]], grad_fn=<IndexBackward0>)

In [59]:
tuple = (xt, xt)

In [60]:
fun = lambda a,b,c: a+b+c

In [61]:
fun(*tuple, 1)

tensor([[ 0.1504,  1.4688,  0.4382],
        [ 2.7580,  1.5142, -1.1294],
        [-0.4785,  1.4318,  0.8866],
        [-1.0432,  0.7075,  1.8772],
        [ 0.4359,  2.1113,  1.1025]], grad_fn=<AddBackward0>)