Oh, I guess have to generalize with Bayes Net.
Implement missing indicators as Bernoulli nodes, manually set P to 0 or 1.

In [2]:
import numpy as np
import torch
from pomegranate.bayesian_network import BayesianNetwork
from pomegranate.distributions import Normal, Bernoulli

### Set up the example

In [3]:
N_TRACKS = 3
# e.g. 0 = dead/quiescent, 1 = promoter, 2 = other actives
states = [0, 1, 2]

Generate some sample observation data assuming each state's emission ~ Normal.

Although in principle, a separate emission distribution for every state-track pair, just use a multivariate Normal, of <# of tracks> dimensions, per state.

In [4]:
MEANS = (0.3, 10.0, 6.0)
STDEVS = (0.05, 1.0, 0.5)
true_emitters = []
for mean, stdev in zip(MEANS, STDEVS):
    multivar_mean = torch.full([N_TRACKS], mean)
    multivar_stdev = torch.full([N_TRACKS], stdev)
    true_emitters.append(torch.distributions.Normal(multivar_mean, multivar_stdev))
true_emitters

[Normal(loc: torch.Size([3]), scale: torch.Size([3])),
 Normal(loc: torch.Size([3]), scale: torch.Size([3])),
 Normal(loc: torch.Size([3]), scale: torch.Size([3]))]

In [5]:
for i, emitter in enumerate(true_emitters):
    print(f"Emitter {i}:\n\t{emitter.loc}\n\t{emitter.scale}")

Emitter 0:
	tensor([0.3000, 0.3000, 0.3000])
	tensor([0.0500, 0.0500, 0.0500])
Emitter 1:
	tensor([10., 10., 10.])
	tensor([1., 1., 1.])
Emitter 2:
	tensor([6., 6., 6.])
	tensor([0.5000, 0.5000, 0.5000])


In [6]:
TRUE_START_PROBS = [0.8, 0.1, 0.1]
sum(TRUE_START_PROBS)

1.0

In [7]:
TRUE_TRANS_PROBS = torch.tensor([
    [0.8, 0.1, 0.1],
    [0.1, 0.6, 0.3],
    [0.3, 0.1, 0.6]
])

In [8]:
SEQ_LEN = 20
gen_seq = np.random.choice(states, size=SEQ_LEN, p=TRUE_START_PROBS)
gen_seq

array([1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 2, 0])

Transpose into column vector.

In [9]:
true_emitters[2].sample().reshape((-1, 1))

tensor([[5.5303],
        [5.7641],
        [6.8711]])

In [10]:
observations = torch.stack([true_emitters[st_id].sample() for st_id in gen_seq])
observations

tensor([[9.1227, 9.2222, 9.9002],
        [5.9096, 7.6977, 6.5035],
        [0.2859, 0.2903, 0.2718],
        [0.2971, 0.2673, 0.2488],
        [0.4261, 0.2477, 0.4275],
        [0.3762, 0.3033, 0.3200],
        [0.2984, 0.2729, 0.2782],
        [0.3618, 0.3337, 0.3212],
        [0.3731, 0.3245, 0.3022],
        [0.3222, 0.2915, 0.2834],
        [5.9165, 6.4204, 5.6979],
        [0.2526, 0.3527, 0.4575],
        [0.2845, 0.2894, 0.2687],
        [0.2558, 0.3178, 0.2374],
        [5.5680, 6.1009, 5.5575],
        [0.2595, 0.3433, 0.2973],
        [0.3392, 0.3478, 0.2443],
        [0.2927, 0.2308, 0.2953],
        [5.5657, 6.3600, 6.7604],
        [0.3359, 0.2903, 0.2822]])

# Initialize model