# Toy Model

Trying toy model inspired by [Griffiths & Steyvers, 2004](https://doi.org/10.1073/pnas.0307752101).

In [2]:
import tensorflow_probability as tfp
import tensorflow as tf
tfd = tfp.distributions
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)
tf.random.set_seed(42)

### Setting up Topics and Data

1. Specify global parameters

In [3]:
## Number of Topics
K = 10

## Square Root of the Number of "Vocabulary" (must be sqrt such that pictorial interpretation is possible)
sqrtV = 5

## Number of words per document
N = 100

## Number of documents
D = 1000

2. Define Topic-Word relations

In [4]:
## Word grid
V_grid = np.reshape(np.arange(0, sqrtV**2), newshape=(sqrtV, sqrtV))

## Topic-Word Distribution
#  Words belonging to a topic are rows and columns
Theta_idx = [row for row in V_grid] + [col for col in V_grid.T]
Theta = np.zeros((K, sqrtV**2))
for k, idx in enumerate(Theta_idx):
    Theta[k, idx] = 1. / sqrtV

## Document topic prior
Alpha = 1

In [5]:
Theta

array([[0.2, 0.2, 0.2, 0.2, 0.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0.2, 0.2, 0.2, 0.2, 0.2, 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0.2, 0.2,
        0.2, 0.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
        0. , 0. , 0.2, 0.2, 0.2, 0.2, 0.2, 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0.2, 0.2, 0.2, 0.2],
       [0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. ,
        0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. ],
       [0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. ,
        0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. ],
      

3. Generating Data

In [6]:
## Topic-Word Distibution is omitted because Theta is fixed

## Document-Topic Distribution
dist_Pi = tfd.Dirichlet(K*[Alpha])
Pi      = dist_Pi.sample(D)

## Topic Assignments of word c_{dik} of word w_{di}
dist_C    = tfd.Categorical(probs=Pi)
C         = tf.reshape(dist_C.sample(N), shape=(D, -1)) ## Its more efficient to reshape before converting to one_hot vectors
C_one_hot = tf.one_hot(C, depth=K, axis=-1)
assert tf.reduce_all(tf.reduce_sum(C_one_hot, axis=-1) == 1)

## Draw words w_{di}
dist_W = tfd.Categorical(probs=Theta)
W      = dist_W.sample()

In [7]:
W

<tf.Tensor: shape=(10,), dtype=int32, numpy=array([ 0,  8, 13, 16, 21,  0,  1, 17, 23,  4])>

In [8]:
C[0]

<tf.Tensor: shape=(100,), dtype=int32, numpy=
array([5, 5, 2, 6, 7, 2, 1, 0, 0, 5, 5, 6, 9, 5, 2, 3, 0, 7, 7, 9, 4, 9,
       8, 5, 7, 2, 3, 9, 0, 6, 0, 2, 2, 1, 3, 6, 0, 8, 3, 6, 6, 1, 4, 4,
       9, 6, 7, 3, 4, 6, 0, 1, 9, 1, 4, 1, 1, 3, 6, 4, 8, 2, 7, 0, 4, 8,
       7, 3, 8, 2, 6, 1, 5, 7, 5, 3, 1, 4, 9, 9, 0, 5, 3, 4, 5, 2, 7, 7,
       7, 9, 8, 8, 2, 6, 7, 9, 7, 0, 8, 5])>

In [9]:
C_one_hot[0, :3, :]

<tf.Tensor: shape=(3, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

In [10]:
tf.constant(Theta)

<tf.Tensor: shape=(10, 25), dtype=float64, numpy=
array([[0.2, 0.2, 0.2, 0.2, 0.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0.2, 0.2, 0.2, 0.2, 0.2, 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0.2, 0.2,
        0.2, 0.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
        0. , 0. , 0.2, 0.2, 0.2, 0.2, 0.2, 0. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
        0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.2, 0.2, 0.2, 0.2, 0.2],
       [0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. ,
        0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. ],
       [0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. , 0. , 0. , 0. , 0.2, 0. ,
        0. , 0. , 0. , 0.