# STATS 607
## Week 6: Profiling and debugging

### Example 1: STRUCTURE

In [16]:
pip install line_profiler

Collecting line_profiler
  Downloading line_profiler-4.1.3-cp39-cp39-macosx_10_9_x86_64.whl (140 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.3/140.3 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: line_profiler
Successfully installed line_profiler-4.1.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [17]:
import numpy as np

def gibbs_sampler(G, K, rng, num_iterations=1000, alpha=1.0, F=.1, p=.1):
    """
    Run Gibbs sampler for the STRUCTURE model.
    
    data: N x L x 2 array of genotype data
    K: Number of populations
    rng: Random number generator.
    num_iterations: Number of Gibbs sampling iterations
    alpha: Dirichlet prior parameter for q
    F: estimate of Fst between populations.
    p: mean allele frequency
    """
    N, L, _ = G.shape
    assert G.shape == (N, L, 2)
    # Initialize parameters
    Q = rng.dirichlet([alpha] * K, size=N)
    f = (1 - F) / F
    f, p = [np.broadcast_to(x, [K, L]) for x in [f, p]] 
    a_prior = f * p
    b_prior = f * (1. - p)
    P = rng.beta(a_prior, b_prior)   # balding-nichols model
    P = np.stack([P, 1 - P], axis=2)
    # Initialize z assignments randomly
    Z = rng.integers(0, K, size=(N, L, 2))
    for iteration in range(num_iterations):
        # Update z
        breakpoint()
        k, l, j = np.ogrid[:3, :2777, :2]
        p =  P[k, l, G]
        aoeu
        # N, K, L, 2
        probs = Q[:, :, None, None] * P[k, l, G]
        probs /= probs.sum()
        Z[i, l, j] = rng.choice(K, p=probs)
        # Update p
        for k in range(K):
            for l in range(L):
                counts = np.zeros([2])
                idx = np.where(Z[:,l,:] == k)
                alleles = G[:,l,:][idx]
                for a in range(2):
                    counts[a] = np.sum(alleles == a)
                P[k, l] = rng.beta(counts[0] + a_prior[k, l], counts[1] + b_prior[k, l])
        # Update q
        for i in range(N):
            counts = np.zeros(K)
            for k in range(K):
                counts[k] = np.sum(Z[i,:,:] == k)
            Q[i] = rng.dirichlet(counts + alpha)
        print(f"Iteration {iteration}")
    return Q, P

In [18]:
gibbs_sampler(G, 3, rng, p=p.clip(1e-6, 1-1e-6), F=F)

IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (3,1,1) (1,2777,1) (9,2777,2) 

In [15]:
%debug

> [0;32m/var/folders/b1/jjpl1p_53jxggrgr9nn841nm0000gn/T/ipykernel_76927/2338917766.py[0m(30)[0;36mgibbs_sampler[0;34m()[0m
[0;32m     28 [0;31m        [0;31m# Update z[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m        [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m        [0maoeu[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m        [0;31m# N, K, L, 2[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     32 [0;31m        [0mprobs[0m [0;34m=[0m [0mQ[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;34m:[0m[0;34m,[0m [0;32mNone[0m[0;34m,[0m [0;32mNone[0m[0;34m][0m [0;34m*[0m [0mP[0m[0;34m[[0m[0;32mNone[0m[0;34m,[0m [0;34m:[0m[0;34m,[0m [0;34m:[0m[0;34m,[0m [0;34m:[0m[0;34m][0m[0;34m[[0m[0mG[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;32mNone[0m[0;34m,[0m [0;34m:[0m[0;34m,[0m [0;34m:[0m[0;34m][0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> Q[

In [10]:
import stdpopsim
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("OutOfAfrica_3G09")
chrom = species.get_contig('1', length_multiplier=0.01)
samples = model.get_samples(3, 3, 3)
engine = stdpopsim.get_engine("msprime")
ts = engine.simulate(model, chrom, samples)

G = ts.genotype_matrix().T
# convert diploid to tensor
G = np.array([[0, 0], [0, 1], [1, 1]])[G]

rng = np.random.default_rng()
p = G.mean((0, 2))
F = ts.Fst([np.arange(3), np.arange(3, 6)])

In [11]:
gibbs_sampler(G, 3, rng, p=p, F=F)

NameError: name 'aoeu' is not defined

In [7]:
%load_ext line_profiler

In [11]:
%lprun -T prof0 -f gibbs_sampler gibbs_sampler(G, 3, rng, p=p, F=F, num_iterations=1)

Iteration 0

*** Profile printout saved to text file 'prof0'. 


Timer unit: 1e-09 s

Total time: 57.593 s
File: /var/folders/j8/n524q16s54l6y_vxh70gb9lr0000gs/T/ipykernel_39911/1030734579.py
Function: gibbs_sampler at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
     3                                           def gibbs_sampler(G, K, rng, num_iterations=1000, alpha=1.0, F=.1, p=.1):
     4                                               """
     5                                               Run Gibbs sampler for the STRUCTURE model.
     6                                               
     7                                               data: N x L x 2 array of genotype data
     8                                               K: Number of populations
     9                                               rng: Random number generator.
    10                                               num_iterations: Number of Gibbs sampling iterations
    11                                               alpha: Dirichlet prior parameter f

### Example 2: Debugging NaNs

- `NaN` (not a number) can crop up from time to time in numerical computations, especially if doing scientific computing.
- `NaN`s in primal are usually pretty easy to debug.
- `NaN`s in gradients are harder, but can usually still be figured out.

### Debugging NaNs in Jax

- `jax.config.update("jax_debug_nans", True)` or `with jax.debug_nans(True):` ...
- Limited use for backward pass NaNs.
- Don`t rule out printf()-style debugging.

In [17]:
from jax import grad

def f(x):
    return x ** (1/3)

x = 0.0
primal_result = f(x)

In [35]:
import jax

with jax.debug_infs(True):
    jax.grad(f)(x)

FloatingPointError: invalid value (inf) encountered in jit(pow)

In [36]:
%debug

> [0;32m/Users/terhorst/opt/py39/lib/python3.9/site-packages/jax/_src/dispatch.py[0m(879)[0;36m_check_special[0;34m()[0m
[0;32m    877 [0;31m      [0;32mraise[0m [0mFloatingPointError[0m[0;34m([0m[0;34mf"invalid value (nan) encountered in {name}"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    878 [0;31m    [0;32mif[0m [0mconfig[0m[0;34m.[0m[0mjax_debug_infs[0m [0;32mand[0m [0mnp[0m[0;34m.[0m[0many[0m[0;34m([0m[0mnp[0m[0;34m.[0m[0misinf[0m[0;34m([0m[0mnp[0m[0;34m.[0m[0masarray[0m[0;34m([0m[0mbuf[0m[0;34m)[0m[0;34m)[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 879 [0;31m      [0;32mraise[0m [0mFloatingPointError[0m[0;34m([0m[0;34mf"invalid value (inf) encountered in {name}"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    880 [0;31m[0;34m[0m[0m
[0m[0;32m    881 [0;31mdef _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
[0m
ipdb> up
> [0;32m/Users/te

In [41]:
import jax.numpy as jnp

def f(x):
    return x * jnp.log(x)


In [42]:
with jax.debug_nans(True):
    f(0.)

Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.


FloatingPointError: invalid value (nan) encountered in jit(mul)

In [43]:
import jax.scipy

jax.grad(f)(0.)

Array(nan, dtype=float32, weak_type=True)

Exploding/vanisishing gradients can occur in neural networks with many layers, or RNNs. For example:

NOTE: could also define a custom backwards derivative in torch/jax and have it print out

In [60]:
import jax
import jax.numpy as jnp

def compute_rnn_loss(inputs, targets, W):
    W_h, W_x = W
    def f(accum, xy):
        h, loss = accum
        x, y = xy
        h = jax.nn.relu(W_h @ h + W_x @ x)
        loss += jnp.sum((h - y) ** 2)
        return (h, loss), None
    
    h0 = jnp.zeros(W_h.shape[0])
    (_, loss), _ = jax.lax.scan(f, (h0, 0.), (inputs, targets))
    return loss

# Initialize weights with large values
n = 100
inputs = jnp.ones([1000, n])
targets = jnp.zeros_like(inputs)

def loss(W):
    return compute_rnn_loss(inputs, targets, W)

W0 = np.random.normal(size=(2, n, n)) / np.sqrt(1000 * n)

loss(W0)

Array(41.20715, dtype=float32)

In [61]:
import jaxopt

jax.config.update('jax_enable_x64', False)

res = jaxopt.GradientDescent(loss).run(W0)
Wstar = res.params
Wstar

Array([[[-7.1936997e-04,  2.5338116e-03,  1.7132886e-03, ...,
          1.5326979e-03, -2.1151325e-03,  1.7426903e-03],
        [-4.6431349e-04, -1.1227080e-03, -5.7782046e-03, ...,
         -4.7274847e-05, -1.6010848e-04,  3.7446495e-03],
        [ 3.6421118e-03,  5.6044938e-04,  3.4370569e-03, ...,
         -2.7548911e-03,  4.1265483e-03,  3.3906966e-03],
        ...,
        [-1.8593601e-03, -3.8566182e-03,  4.5611490e-03, ...,
          7.6865316e-03, -4.1736653e-03,  3.7536228e-03],
        [ 5.4711392e-03,  4.0740641e-03, -2.1023150e-03, ...,
          2.0873668e-03,  1.1653560e-03, -1.8014332e-03],
        [-2.2329125e-03, -5.1090471e-03, -6.3343003e-05, ...,
         -9.0829906e-04,  1.1893467e-03, -8.1428597e-03]],

       [[ 2.8177295e-03,  6.5748650e-04,  2.9932056e-03, ...,
          4.4896286e-03, -2.2861490e-03,  1.5127858e-03],
        [ 1.6115031e-03, -2.6692047e-03,  3.5574629e-03, ...,
         -7.5662305e-04, -4.5344550e-03, -5.0266176e-03],
        [-1.3363195e-03, 

### Example 3: Test-first implementation of a numerical algorithm


Consider the following problem:

- Two particles execute continuous time random walk around a graph. 
- The rate of going from node $n$ to node $m$ is given by the edge weight $w_{nm}$.
- Whenever they are in the same node $n$, the are both absorbed into a state `crypt`, from which there is no exit, with rate node-specific rate $c_n$. 
- Let $T_\text{crypt}$ be the first hitting time to the `crypt` state.
- Problem: compute $\mathbb{P}(T_\text{crypt} > t)$.

In [66]:
import networkx as nx

k = 2
G = nx.grid_graph((2, 2))

for n in G.nodes():
    G.nodes[n]['weight'] = np.random.exponential()

for e in G.edges():
    G.edges[e]['weight'] = np.random.exponential()

In [68]:
Q = nx.adjacency_matrix(G).todense()



  Q = nx.adjacency_matrix(G).todense()


matrix([[0.        , 2.53035439, 3.66323364, 0.        ],
        [2.53035439, 0.        , 0.        , 0.90976091],
        [3.66323364, 0.        , 0.        , 0.82560733],
        [0.        , 0.90976091, 0.82560733, 0.        ]])

## Computing $\mathbb{P}(T > t)$

- Let $S(t),c(t)) \in \mathbb{R}^{k\times k} \times \mathbb{R}$ be the overall state of the system.
- Then 

\begin{align}
\frac{dS_{ij}(t)}{dt} &= \sum_x w_{xi}S_{xj} + w_{xj}S_{ix} - S_{ij}(w_{ix} + w_{xj}) - \mathbf{1}(i=j) c_i \\
\frac{dc(t)}{dt} &= \sum_i S_{ii} c_i
\end{align}

In [None]:
def dS(t, S):
    ...

def test_dS():
    k = 10
    S = np.random.dro

In [None]:
def dS(t, S):
    ...
    
def test_dS():
    k = 10
    S = np.random.dirichlet([1.] * 10)
    t = np.random.rand()
    assert dS(t, S).sum() == 0.
    
def test_absorbing():
    t1, dt = np.random.rand(2)
    t2 = t1 + dt
    S1 = solve_ode(S0, t1)
    S2 = solve_ode(S0, t2)
    assert S1[-1] <= S2[-1]

def test_solve_ode():
    S0 = np.random.dirichlet([1.] * 10)
    t1 = np.random.rand()
    S1 = solve_ode(S0, t1)
    assert np.isclose(S1.sum(), 1.)


In [83]:
c = np.random.exponential(size=Q.shape[0])
n = 4
y0 = np.random.dirichlet(np.ones(n ** 2 + 1))
Q = np.array(Q)
Q

array([[0.        , 2.53035439, 3.66323364, 0.        ],
       [2.53035439, 0.        , 0.        , 0.90976091],
       [3.66323364, 0.        , 0.        , 0.82560733],
       [0.        , 0.90976091, 0.82560733, 0.        ]])

In [88]:
def f(Q, y):
    # yshape == (n ** 2 + 1,)
    # assert Q.shape == (n, n)
    # assert c.shape == (n,)

    Q1 = Q.sum(0)[:, None] + Q.sum(1)[None, :]
    yn = y[:-1].reshape(n, n)
    yd = jnp.diag(yn)
    yc = y[-1]
    ret = Q @ yn + yn @ Q.T  # upper left
    
    ret -= Q1 * yn
    ret -= jnp.diag(c * yd)
    return jnp.append(ret.reshape(n ** 2), c.dot(yd))

f(Q, y0)

Array([-1.6211667e+00,  2.5633267e-01,  7.1935260e-01, -6.7024809e-01,
       -1.6417135e-01, -3.9275709e-01, -1.3409510e-01,  4.7105369e-01,
        4.4172296e-01,  4.0247972e-04,  1.1420314e-01,  2.8007287e-01,
        3.7421206e-01, -3.2054934e-01,  8.3529130e-02, -1.7456809e-01,
        7.3667407e-01], dtype=float32)

In [None]:
y = 

In [94]:
n = 4

def test_dS():
    k = 10
    y = np.random.dirichlet([1.] * (n ** 2 + 1))
    np.testing.assert_allclose(f(Q, y).sum(), 0., atol=1e-6)

In [106]:
from scipy.integrate import solve_ivp

def S(y0, t):
    return solve_ivp(lambda _, y: f(Q, y), [0., 1.], y0).y[:, 0]

def test_absorbing():
    y0 = np.random.dirichlet([1.] * (n ** 2 + 1))
    t1, dt = np.random.rand(2)
    t2 = t1 + dt
    S1 = S(y0, t1)
    S2 = S(y0, t2)
    assert S1[-1] <= S2[-1]

test_absorbing()


def test_solve_ode():
    y0 = np.random.dirichlet([1.] * (n ** 2 + 1))
    t1, dt = np.random.rand(2)
    t2 = t1 + dt
    S1 = S(y0, t1)
    S2 = S(y0, t2)
    np.testing.assert_allclose(S1.sum(), 1.)
    np.testing.assert_allclose(S2.sum(), 1.)
    
test_solve_ode()

