In [95]:
import torch
import pyro
from pyro.distributions import *
from collections import Counter
import pyro.infer
import pyro.optim

In [77]:
#nchords = 10
npcs = 4*7+1 # around C: Cbb to C## on LoF
def chord_model(npcs, nharmonies, nchords, pobserve=0.5):
    # parameters priors:
    # distribution of the harmonies
    p_harmony = pyro.sample('p_harmony', Dirichlet(0.5 * torch.ones(nharmonies)))
    # distribution of notes in the harmonies
    with pyro.plate('harmonies', nharmonies):
        p_chordtones = pyro.sample('p_chordtones', Dirichlet(0.5 * torch.ones(npcs)))
        p_ornaments = pyro.sample('p_ornaments', Dirichlet(0.5 * torch.ones(npcs)))
    # distribution of note rate in chords
    rate_notes = pyro.sample('rate_notes', Gamma(3,1))
    # distribution of ornament probability
    p_is_chordtone = pyro.sample('p_is_chordtone', Beta(1, 1))
    
    # sampling the data:
    chords = list()
    for c in pyro.plate('data', nchords):
        # pick a harmony
        h = pyro.sample('h_{}'.format(c), Categorical(p_harmony))
        # pick a number of notes
        nnotes = 1 + pyro.sample('n_{}'.format(c), Poisson(rate_notes))
        # collect the notes
        notes = torch.zeros((npcs, 3))
        for i in pyro.plate('notes_{}'.format(c), nnotes):
            # pick a note type
            t = pyro.sample('t_{}_{}'.format(c,i), Bernoulli(p_is_chordtone))
            # pick a pitch from the distribution corresponding to the note type
            ps = p_chordtones if t.item() == 1. else p_ornaments
            pitch = pyro.sample('pitch_{}_{}'.format(c,i), Categorical(ps[h]))
            # decide if the type is observed or not
            observe_type = pyro.sample('observe_type_{}_{}'.format(c,i), Bernoulli(pobserve))
            ot = int(t.item()) if observe_type.item() == 1. else 2
            # add pitch and type to chord
            notes[pitch.item(), ot] += 1
        # observe the chord
        counts = pyro.deterministic('chord_{}'.format(c), notes)
        chords.append({'h': h, 'counts': counts})
    return chords

In [103]:
def chord_guide(npcs, nharmonies, nchords, pobserve=0.5):
    # posterior of p_harmony
    params_p_harmony = pyro.param('params_p_harmony', 0.5 * torch.ones(nharmonies),
                                  constraint=constraints.positive)
    pyro.sample('p_harmony', Dirichlet(params_p_harmony))
    # posteriors of notes dists in harmonies
    with pyro.plate('harmonies', nharmonies):
        params_p_chordtones = pyro.param('params_p_chordtones', 0.5 * torch.ones(npcs),
                                         constraint=constraints.positive)
        pyro.sample('p_chordtones', Dirichlet(params_p_chordtones))
        params_p_ornaments = pyro.param('params_p_ornaments', 0.5 * torch.ones(npcs),
                                         constraint=constraints.positive)
        pyro.sample('p_ornaments', Dirichlet(params_p_ornaments))
    # posterior of note rate
    alpha_rate_notes = pyro.param('alpha_rate_notes', torch.tensor(3),
                                  constraint=constraints.positive)
    beta_rate_notes = pyro.param('beta_rate_notes', torch.tensor(1),
                                 constraint=constraints.positive)
    rate_notes = pyro.sample('rate_notes', Gamma(alpha_rate_notes, beta_rate_notes))
    # posterior of ornament probability
    alpha_p_ict = pyro.param('alpha_p_ict', torch.tensor(1), constraint=constraints.positive)
    beta_p_ict = pyro.param('beta_p_ict', torch.tensor(1), constraint=constraints.positive)
    pyro.sample('p_is_chordtone', Beta(alpha_p_ict, beta_p_ict))
    # data points
    for c in pyro.plate('data', nchords):
        # posterior of chosen harmony
        params_h = pyro.param('params_h_{}'.format(c), torch.ones(nharmonies)/nharmonies,
                              constraint=constraints.simplex)
        pyro.sample('h_{}'.format(c), Categorical(params_h))
        # the number of notes needs to be conditioned on
        # (otherwise the number of variables in the next plate is incorrect)
        nnotes = pyro.sample('n_{}'.format(c), Poisson(rate_notes)) + 1
        for i in pyro.plate('notes_{}'.format(c), nnotes):
            # posterior of each note type
            p_t = pyro.param('p_t_{}_{}'.format(c,i), torch.tensor(0.5), constraint=constraints.simplex)
            pyro.sample('t_{}_{}'.format(c,i), Bernoulli(p_t))
            # posterior of each pitch
            params_pitch = pyro.param('params_pitch_{}_{}', torch.ones(npcs), constraint=constraints.simplex)
            pitch = pyro.sample('pitch_{}_{}'.format(c,i), Categorical(params_pitch))
            # posterior of each observation coin
            p_obs = pyro.param('p_obs_{}_{}'.format(c,i), torch.tensor(0.5), constraint=constraints.simplex)
            pyro.sample('observe_type_{}_{}'.format(c,i), Bernoulli(p_obs))

In [87]:
pyro.clear_param_store()
chord_guide(npcs, 3, 5)

tensor([0.5000, 0.5000, 0.5000], grad_fn=<AddBackward0>)


In [28]:
chord_model(npcs, 3, 5)

[{'h': tensor(0),
  'counts': tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 1.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 1., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 1.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]])},
 {'h': tensor(2),
  'counts': tensor([[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
     

In [20]:
def chord_tensor(notes):
    notetype = {'chordtone': 0, 'ornament': 1, 'unknown': 2}
    chord = torch.zeros((npcs, 3))
    for (pitch, t) in notes:
        chord[pitch%npcs, notetype[t]] += 1
    return chord

def annot_data_obs(chords):
    obs = {}
    for (i, chord) in enumerate(chords):
        obs["h_{}".format(i)] = chord["label"]
        obs["n_{}".format(i)] = len(chord["notes"]) - 1
        obs["chord_{}".format(i)] = chord_tensor(chord["notes"])
    return obs

In [21]:
example_chords = [
    {'label': 0, 'notes': [(0,'chordtone'), (4,'chordtone'), (1,'chordtone'), (4,'ornament')]},
    {'label': 1, 'notes': [(0,'chordtone'), (-3,'chordtone'), (1,'chordtone'), (-2,'unknown')]}
]
example_obs = annot_data_obs(example_chords)
example_obs

{'h_0': 0,
 'n_0': 3,
 'chord_0': tensor([[1., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]),
 'h_1': 1,
 'n_1': 3,
 'chord_1': tensor([[1., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0.

In [90]:
conditioned_model = pyro.condition(chord_model, data=example_obs)

In [91]:
conditioned_guide = pyro.condition(chord_guide, data=example_obs)

In [102]:
# inference
pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_model,
                     guide=conditioned_guide,
                     optim=pyro.optim.SGD({"lr": 0.001, "momentum": 0.1}),
                     loss=pyro.infer.Trace_ELBO())

losses = []
nsteps = 100
for i in range(nsteps):
    losses.append(svi.step(npcs, 3, 5))

plt.plot(losses)
plt.xlabel("step")
plt.ylabel("loss")

RuntimeError: Multiple sample sites named 't_{}_{}'
        Trace Shapes:        
         Param Sites:        
     params_p_harmony    3   
  params_p_chordtones   29   
   params_p_ornaments   29   
     alpha_rate_notes        
      beta_rate_notes        
          alpha_p_ict        
           beta_p_ict        
           params_h_0    3   
              p_t_0_0        
   params_pitch_{}_{}   29   
            p_obs_0_0        
              p_t_0_1        
        Sample Sites:        
       p_harmony dist    |  3
                value    |  3
       harmonies dist    |   
                value 3  |   
    p_chordtones dist 3  | 29
                value 3  | 29
     p_ornaments dist 3  | 29
                value 3  | 29
      rate_notes dist    |   
                value    |   
  p_is_chordtone dist    |   
                value    |   
            data dist    |   
                value 5  |   
             h_0 dist    |   
                value    |   
             n_0 dist    |   
                value    |   
         notes_0 dist    |   
                value 4  |   
         t_{}_{} dist    |   
                value    |   
       pitch_0_0 dist    |   
                value    |   
observe_type_0_0 dist    |   
                value    |   