In [20]:
import random
from copy import copy

import torch
from pyro.distributions import constraints
from pyro.infer import config_enumerate
from pyroapi import pyro
import pyro.distributions as dist
from tqdm import trange

pyro.clear_param_store()

In [21]:
cards = [(suite, face) for suite in ("♥", "♠", "♦", "♣") for face in
         (list(map(str, range(2, 11))) + ["J", "Q", "K", "A"])]
cards

[('♥', '2'),
 ('♥', '3'),
 ('♥', '4'),
 ('♥', '5'),
 ('♥', '6'),
 ('♥', '7'),
 ('♥', '8'),
 ('♥', '9'),
 ('♥', '10'),
 ('♥', 'J'),
 ('♥', 'Q'),
 ('♥', 'K'),
 ('♥', 'A'),
 ('♠', '2'),
 ('♠', '3'),
 ('♠', '4'),
 ('♠', '5'),
 ('♠', '6'),
 ('♠', '7'),
 ('♠', '8'),
 ('♠', '9'),
 ('♠', '10'),
 ('♠', 'J'),
 ('♠', 'Q'),
 ('♠', 'K'),
 ('♠', 'A'),
 ('♦', '2'),
 ('♦', '3'),
 ('♦', '4'),
 ('♦', '5'),
 ('♦', '6'),
 ('♦', '7'),
 ('♦', '8'),
 ('♦', '9'),
 ('♦', '10'),
 ('♦', 'J'),
 ('♦', 'Q'),
 ('♦', 'K'),
 ('♦', 'A'),
 ('♣', '2'),
 ('♣', '3'),
 ('♣', '4'),
 ('♣', '5'),
 ('♣', '6'),
 ('♣', '7'),
 ('♣', '8'),
 ('♣', '9'),
 ('♣', '10'),
 ('♣', 'J'),
 ('♣', 'Q'),
 ('♣', 'K'),
 ('♣', 'A')]

In [22]:
def model():
    deck = copy(cards)
    with pyro.plate("draw"):
        card1_index = pyro.sample("card1_index", dist.Categorical(logits=torch.zeros(len(deck))))
        card2_index = pyro.sample("card2_index", dist.Categorical(logits=torch.zeros(len(deck) - 1)))
        card3_index = pyro.sample("card3_index", dist.Categorical(logits=torch.zeros(len(deck) - 2)))
        card4_index = pyro.sample("card4_index", dist.Categorical(logits=torch.zeros(len(deck) - 3)))
        card5_index = pyro.sample("card5_index", dist.Categorical(logits=torch.zeros(len(deck) - 4)))

    card1 = deck[card1_index]
    del deck[card1_index]
    card2 = deck[card2_index]
    del deck[card2_index]
    card3 = deck[card3_index]
    del deck[card3_index]
    card4 = deck[card4_index]
    del deck[card4_index]
    card5 = deck[card5_index]
    del deck[card5_index]

    hand = card1, card2, card3, card4, card5

    value = {}
    for card in hand:
        suite, face = card
        if face not in value:
            value[face] = set()
        value[face].add(suite)

    full_house = len(value) == 2 and 2 <= len(list(value.values())[0]) <= 3
    assert isinstance(full_house, bool), "full_house is not a bool"

    if full_house:
        print(hand)

    with pyro.plate("conditional_full_house"):
        p = pyro.sample("p", dist.Uniform(0, 1))
        pyro.sample("obs", dist.Bernoulli(p), obs=torch.tensor(full_house, dtype=torch.float))
    return hand, full_house

In [23]:
def guide():
    deck = copy(cards)
    with pyro.plate("draw"):
        card1_index = pyro.sample("card1_index", dist.Categorical(logits=torch.zeros(len(deck))))
        card2_index = pyro.sample("card2_index", dist.Categorical(logits=torch.zeros(len(deck) - 1)))
        card3_index = pyro.sample("card3_index", dist.Categorical(logits=torch.zeros(len(deck) - 2)))
        card4_index = pyro.sample("card4_index", dist.Categorical(logits=torch.zeros(len(deck) - 3)))
        card5_index = pyro.sample("card5_index", dist.Categorical(logits=torch.zeros(len(deck) - 4)))

    # card1 = deck[card1_index]
    # del deck[card1_index]
    # card2 = deck[card2_index]
    # del deck[card2_index]
    # card3 = deck[card3_index]
    # del deck[card3_index]
    # card4 = deck[card4_index]
    # del deck[card4_index]
    # card5 = deck[card5_index]
    # del deck[card5_index]
    #
    # hand = card1, card2, card3, card4, card5
    #
    # value = {}
    # for card in hand:
    #     suite, face = card
    #     if face not in value:
    #         value[face] = set()
    #     value[face].add(suite)
    #
    #
    # full_house = len(value) == 2 and 2 <= len(list(value.values())[0]) <= 3
    # if full_house:
    #     print(hand)

    with pyro.plate("conditional_full_house"):
        p_latent = pyro.param("p_latent", torch.tensor(0.5), constraint=constraints.interval(0, 1))
        assert 0 <= p_latent <= 1, f"p_latent: {p_latent}"
        pyro.sample("p", dist.Delta(p_latent))
    #return hand, full_house

In [24]:
auto_guide = pyro.infer.autoguide.AutoDelta(model)

In [25]:
pyro.clear_param_store()
optim = pyro.optim.ClippedAdam({"lr": 0.25, "lrd": 0.25})
loss = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(model, guide, optim, loss)

In [26]:
for _ in trange(10_000):
    svi.step()

  0%|          | 0/10000 [00:00<?, ?it/s]


AssertionError: 

In [None]:
pyro.param("p_latent")

In [None]:
pyro.clear_param_store()

kernel = pyro.infer.NUTS(model)
mcmc = pyro.infer.MCMC(kernel, num_samples=1000, warmup_steps=100, num_chains=1)

In [None]:
mcmc.run()

In [None]:
marginals = pyro.infer.EmpericalMarginal(mcmc.get_samples())