# 3 Tensor Shapes in Pyro

In [1]:
import os
import torch
import pyro
from torch.distributions import constraints as C
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam


assert pyro.__version__.startswith('1.8.2')

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

```js
x = d.sample()
x.shape == d.batch_saphe + d.event_shape
```

```js
d.log_prob(x).shape == d.batch_shape
```

```js
x = d.sample(sample_shape)
x.shape == sample_shape + batch_shape + event_shape
```

In [2]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()


In [3]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)           
assert d.log_prob(x).shape == ()

In [4]:
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)

In [5]:
x = pyro.sample('x', Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10, )

Declaring independent dims with `plate`

In [6]:
with pyro.plate('x', 200):
    with pyro.plate('y', 200):
        # dims -1 and -2 are independent
        pass

In [7]:
def model1():
    a = pyro.sample('a', Normal(0, 1))
    b = pyro.sample('b', Normal(torch.zeros(2), 1).to_event(1))
    
    with pyro.plate('c_plate', 2):
        c = pyro.sample('c', Normal(torch.zeros(2), 1))
        
    with pyro.plate('d_plate', 3):
        d = pyro.sample('d', Normal(torch.zeros(3, 4, 5), 1).to_event(2))
        
    assert a.shape == ()
    assert b.shape == (2, )
    assert c.shape == (2, )
    assert d.shape == (3, 4, 5)
    
    x_axis = pyro.plate('x_axis', 3, dim=-2)
    y_axis = pyro.plate('y_axis', 2, dim=-3)
    with x_axis:
        x = pyro.sample('x', Normal(0, 1))
    with y_axis:
        y = pyro.sample('y', Normal(0, 1))
    with x_axis, y_axis:
        xy = pyro.sample('xy', Normal(0, 1))
        z = pyro.sample('z', Normal(0, 1).expand([5]).to_event(1))
        
    assert x.shape == (3, 1)        
    assert y.shape == (2, 1, 1)     
    assert xy.shape == (2, 3, 1)    
    assert z.shape == (2, 3, 1, 5)
    

test_model(model1, model1, Trace_ELBO())

```js
batch dims | event dims 
-----------+----------- 
           |        a = sample("a", Normal(0, 1)) 
           |2       b = sample("b", Normal(zeros(2), 1) 
           |                        .to_event(1))
           |        with plate("c", 2): 
          2|            c = sample("c", Normal(zeros(2), 1)) 
           |        with plate("d", 3):
          3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
           |                       .to_event(2))
           |
           |        x_axis = plate("x", 3, dim=-2)
           |        y_axis = plate("y", 2, dim=-3)
           |        with x_axis:
        3 1|            x = sample("x", Normal(0, 1))
           |        with y_axis:
      2 1 1|            y = sample("y", Normal(0, 1))
           |        with x_axis, y_axis:
      2 3 1|            xy = sample("xy", Normal(0, 1))
      2 3 1|5           z = sample("z", Normal(0, 1).expand([5])
           |                       .to_event(1))
```

In [8]:
trace = poutine.trace(model1).get_trace()
print(trace.format_shapes())

Trace Shapes:            
 Param Sites:            
Sample Sites:            
       a dist       |    
        value       |    
       b dist       | 2  
        value       | 2  
 c_plate dist       |    
        value     2 |    
       c dist     2 |    
        value     2 |    
 d_plate dist       |    
        value     3 |    
       d dist     3 | 4 5
        value     3 | 4 5
  x_axis dist       |    
        value     3 |    
  y_axis dist       |    
        value     2 |    
       x dist   3 1 |    
        value   3 1 |    
       y dist 2 1 1 |    
        value 2 1 1 |    
      xy dist 2 3 1 |    
        value 2 3 1 |    
       z dist 2 3 1 | 5  
        value 2 3 1 | 5  


In [9]:
data = torch.arange(100)

def model2():
    mean = pyro.param('mean', torch.zeros(len(data)))
    with pyro.plate('data', len(data), subsample_size=10) as ind:
        assert len(ind) == 10
        batch = data[ind]
        x = pyro.sample('x', Normal(mean[ind], 1), obs=batch)
        assert len(x) == 10 
        
test_model(model2, lambda: None, Trace_ELBO())           


```js
      max_plate_nesting = 3
           |<--->|
enumeration|batch|event
-----------+-----+-----
           |. . .|      a = sample("a", Normal(0, 1))
           |. . .|2     b = sample("b", Normal(zeros(2), 1)
           |     |                      .to_event(1))
           |     |      with plate("c", 2):
           |. . 2|          c = sample("c", Normal(zeros(2), 1))
           |     |      with plate("d", 3):
           |. . 3|4 5       d = sample("d", Normal(zeros(3,4,5), 1)
           |     |                     .to_event(2))
           |     |
           |     |      x_axis = plate("x", 3, dim=-2)
           |     |      y_axis = plate("y", 2, dim=-3)
           |     |      with x_axis:
           |. 3 1|          x = sample("x", Normal(0, 1))
           |     |      with y_axis:
           |2 1 1|          y = sample("y", Normal(0, 1))
           |     |      with x_axis, y_axis:
           |2 3 1|          xy = sample("xy", Normal(0, 1))
           |2 3 1|5         z = sample("z", Normal(0, 1).expand([5])
           |     |                     .to_event(1))

```

In [11]:
@config_enumerate
def model3():
    p = pyro.param('p', torch.arange(6) / 6)
    locs = pyro.param('locs', torch.tensor([-1., 1.]))
    
    a = pyro.sample('a', Categorical(torch.ones(6) / 6))
    b = pyro.sample('b', Bernoulli(p[a])) # depends on a
    
    with pyro.plate('c_plate', 4):
        c = pyro.sample('c', Bernoulli(0.3))
        with pyro.plate('d_palte', 5):
            d = pyro.sample('d', Bernoulli(0.4))
            eloc = locs[d.long()].unsqueeze(-1)
            escale = torch.arange(1, 8)
            e = pyro.sample('e', Normal(eloc, escale).to_event(1)) # depends on d
            
    #                   enumerated|batch|event dims
    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
    assert b.shape == (      2, 1, 1, 1   )  # Two enumerated Bernoullis, unexpanded.
    assert c.shape == (   2, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert d.shape == (2, 1, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.
    
    assert eloc.shape   == (2, 1, 1, 1, 1, 1, 1,)
    assert escale.shape == (                  7,)

test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))

```js
     max_plate_nesting = 2
            |<->|
enumeration batch event
------------|---|-----
           6|1 1|     a = pyro.sample("a", Categorical(torch.ones(6) / 6))
         2 1|1 1|     b = pyro.sample("b", Bernoulli(p[a]))
            |   |     with pyro.plate("c_plate", 4):
       2 1 1|1 1|         c = pyro.sample("c", Bernoulli(0.3))
            |   |         with pyro.plate("d_plate", 5):
     2 1 1 1|1 1|             d = pyro.sample("d", Bernoulli(0.4))
     2 1 1 1|1 1|1            e_loc = locs[d.long()].unsqueeze(-1)
            |   |7            e_scale = torch.arange(1., 8.)
     2 1 1 1|5 4|7            e = pyro.sample("e", Normal(e_loc, e_scale)
            |   |                             .to_event(1))
```

In [12]:
W = 8
H = 10
sparse = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
ENUM = None 

def fun(observe):
    px = pyro.param('px', torch.tensor(.1), constraint=C.unit_interval)
    py = pyro.param('py', torch.tensor(.1), constraint=C.unit_interval)
    xaxis = pyro.plate('xaxis', W, dim=-2)
    yaxis = pyro.plate('yaxis', H, dim=-1)
    
    with xaxis:
        xactive = pyro.sample('xactive', Bernoulli(px))
    with yaxis:
        yactive = pyro.sample('yactive', Bernoulli(py))
        
    if ENUM:
        assert xactive.shape == (2, 1, 1)
        assert yactive.shape == (2, 1, 1, 1)
    else:
        assert xactive.shape == (W, 1)
        assert yactive.shape == (H, )
        
    p = 0.1 + 0.5 * xactive * yactive 
    
    if ENUM:
        assert p.shape == (2, 2, 1, 1)
    else:
        assert p.shape == (W, H)
        
    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (W, H)))
    
    for x, y in sparse:
        dense_pixels[..., x, y] = 1
    if ENUM:
        assert dense_pixels.shape == (2, 2, W, H)
    else:
        assert dense_pixels.shape == (W, H)
        
    with xaxis, yaxis:
        if observe:
            pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)
            

def model4():
    fun(True)
    
def guide():
    fun(False)
    

# Test without enumeration.
ENUM = False
test_model(model4, guide, Trace_ELBO())

# Test with enumeration.
ENUM = True
test_model(model4, config_enumerate(guide, "parallel"),
           TraceEnum_ELBO(max_plate_nesting=2))  
    