# STATS 607
## Week 6: Profiling and debugging

### Example 1: STRUCTURE

In [1]:
import numpy as np

def gibbs_sampler(G, K, rng, num_iterations=1000, alpha=1.0, F=.1, p=.1):
    """
    Run Gibbs sampler for the STRUCTURE model.
    
    data: N x L x 2 array of genotype data
    K: Number of populations
    rng: Random number generator.
    num_iterations: Number of Gibbs sampling iterations
    alpha: Dirichlet prior parameter for q
    F: estimate of Fst between populations.
    p: mean allele frequency
    """
    N, L, _ = G.shape
    assert G.shape == (N, L, 2)
    # Initialize parameters
    Q = rng.dirichlet([alpha] * K, size=N)
    f = (1 - F) / F
    f, p = [np.broadcast_to(x, [K, L]) for x in [f, p]] 
    a_prior = f * p
    b_prior = f * (1. - p)
    P = rng.beta(a_prior, b_prior)   # balding-nichols model
    P = np.stack([P, 1 - P], axis=2)
    # Initialize z assignments randomly
    Z = rng.integers(0, K, size=(N, L, 2))
    for iteration in range(num_iterations):
        # Update z
        for i in range(N):
            for l in range(L):
                for j in range(2):
                    probs = Q[i] * P[:, l, G[i, l, j]]
                    probs /= probs.sum()
                    Z[i, l, j] = rng.choice(K, p=probs)
        # Update p
        for k in range(K):
            for l in range(L):
                counts = np.zeros([2])
                idx = np.where(Z[:,l,:] == k)
                alleles = G[:,l,:][idx]
                for a in range(2):
                    counts[a] = np.sum(alleles == a)
                P[k, l] = rng.beta(counts[0] + a_prior[k, l], counts[1] + b_prior[k, l])
        # Update q
        for i in range(N):
            counts = np.zeros(K)
            for k in range(K):
                counts[k] = np.sum(Z[i,:,:] == k)
            Q[i] = rng.dirichlet(counts + alpha)
        if iteration % 100 == 0:
            print(f"Iteration {iteration}")
    return Q, P

In [8]:
import stdpopsim
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("OutOfAfrica_3G09")
chrom = species.get_contig('1', length_multiplier=0.1)
samples = model.get_samples(10, 10, 10)
engine = stdpopsim.get_engine("msprime")
ts = engine.simulate(model, chrom, samples)

G = ts.genotype_matrix().T
# convert diploid to tensor
G = np.array([[0, 0], [0, 1], [1, 1]])[G]

rng = np.random.default_rng()
p = G.mean((0, 2))
F = ts.Fst([np.arange(10), np.arange(10, 20)])
# gibbs_sampler(G, 3, rng, p=p, F=F)

In [5]:
!pip install line_profiler

Defaulting to user installation because normal site-packages is not writeable
Collecting line_profiler
  Downloading line_profiler-4.1.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (34 kB)
Downloading line_profiler-4.1.3-cp312-cp312-macosx_11_0_arm64.whl (132 kB)
Installing collected packages: line_profiler
Successfully installed line_profiler-4.1.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.12 -m pip install --upgrade pip[0m


In [6]:
%load_ext line_profiler

In [9]:
%lprun -T prof0 -f gibbs_sampler gibbs_sampler(G, 3, rng, p=p, F=F, num_iterations=1)

Iteration 0

*** Profile printout saved to text file 'prof0'. 


Timer unit: 1e-09 s

Total time: 57.593 s
File: /var/folders/j8/n524q16s54l6y_vxh70gb9lr0000gs/T/ipykernel_39911/1030734579.py
Function: gibbs_sampler at line 3

Line #      Hits         Time  Per Hit   % Time  Line Contents
     3                                           def gibbs_sampler(G, K, rng, num_iterations=1000, alpha=1.0, F=.1, p=.1):
     4                                               """
     5                                               Run Gibbs sampler for the STRUCTURE model.
     6                                               
     7                                               data: N x L x 2 array of genotype data
     8                                               K: Number of populations
     9                                               rng: Random number generator.
    10                                               num_iterations: Number of Gibbs sampling iterations
    11                                               alpha: Dirichlet prior parameter f

### Example 2: Debugging NaNs

- `NaN` (not a number) can crop up from time to time in numerical computations, especially if doing scientific computing.
- `NaN`s in primal are usually pretty easy to debug.
- `NaN`s in gradients are harder, but can usually still be figured out.

### Debugging NaNs in Jax

- `jax.config.update("jax_debug_nans", True)` or `with jax.debug_nans(True):` ...
- Limited use for backward pass NaNs.
- Don`t rule out printf()-style debugging.

In [10]:
from jax import grad

def f(x):
    return x ** (1/3)

x = 0.0
primal_result = f(x)
gradient_result = grad(f)(x)

In [20]:
import jax

with jax.debug_infs(True):
    grad(f)(x)

FloatingPointError: invalid value (inf) encountered in jit(pow). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.

In [23]:
import jax.numpy as jnp

def f(x):
    return x * jnp.log(x)

Array(nan, dtype=float32, weak_type=True)

In [30]:
import jax.scipy

def f(x):
    return jax.scipy.special.xlogy(x, x)

with jax.debug_nans(True):
    jax.grad(f)(0.)

Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.


FloatingPointError: invalid value (nan) encountered in jit(mul). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.

Exploding/vanisishing gradients can occur in neural networks with many layers, or RNNs. For example:

NOTE: could also define a custom backwards derivative in torch/jax and have it print out

In [141]:
import jax
import jax.numpy as jnp

def compute_rnn_loss(inputs, targets, W):
    W_h, W_x = W
    def f(accum, xy):
        h, loss = accum
        x, y = xy
        h = jnp.tanh(W_h @ h + W_x @ x)
        loss += jnp.sum((h - y) ** 2)
        return (h, loss), None
    
    h0 = jnp.zeros(W_h.shape[0])
    (_, loss), _ = jax.lax.scan(f, (h0, 0.), (inputs, targets))
    return loss

# Initialize weights with large values
n = 10
inputs = jnp.ones([1000, n])
targets = jnp.zeros_like(inputs)

def loss(W):
    return compute_rnn_loss(inputs, targets, W)


In [107]:
import jaxopt

jax.config.update('jax_enable_x64', False)
W0 = np.random.normal(size=(2, n, n))
res = jaxopt.GradientDescent(loss).run(W0)
Wstar = res.params
Wstar

Array([[[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        ...,
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan]],

       [[nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        ...,
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan],
        [nan, nan, nan, ..., nan, nan, nan]]], dtype=float32)

### Example 3: Test-first implementation of a numerical algorithm


Consider the following problem:

- Two particles execute continuous time random walk around a graph. 
- The rate of going from node $n$ to node $m$ is given by the edge weight $w_{nm}$.
- Whenever they are in the same node $n$, the are both absorbed into a state `crypt`, from which there is no exit, with rate node-specific rate $c_n$. 
- Let $T_\text{crypt}$ be the first hitting time to the `crypt` state.
- Problem: compute $\mathbb{P}(T_\text{crypt} > t)$.

In [151]:
import networkx as nx

k = 3
G = nx.grid_graph((3, 3))

for n in G.nodes():
    G.nodes[n]['weight'] = np.random.exponential()

for e in G.edges():
    G.edges[e]['weight'] = np.random.exponential()

In [153]:
Q = nx.adjacency_matrix(G).todense()
Q


(9, 9)

## Computing $\mathbb{P}(T > t)$

- Let $S(t),c(t)) \in \mathbb{R}^{k\times k} \times \mathbb{R}$ be the overall state of the system.
- Then 

\begin{align}
\frac{dS_{ij}(t)}{dt} &= \sum_x w_{xi}S_{xj} + w_{xj}S_{ix} - S_{ij}(w_{ix} + w_{xj}) - \mathbf{1}(i=j) c_i \\
\frac{dc(t)}{dt} &= \sum_i S_{ii} c_i
\end{align}

In [None]:
def dS(t, S):
    ...

def test_dS():
    k = 10
    S = np.random.dro

In [192]:
def kronsum(A, B):
    I_A, I_B = [np.eye(len(x)) for x in (A, B)]
    return np.kron(A, I_B) + np.kron(I_A, B)
    
QQ = kronsum(Q, Q)
QQ.shape

(81, 81)

In [193]:
QQ = np.pad(QQ, [0, 1])
QQ.shape

(82, 82)

In [197]:
QQ[::k ** 2, -1][:-1] = [G.nodes[n]['weight'] for n in G.nodes()]

In [198]:
QQ -= np.diag(QQ.sum(1))

In [204]:
import scipy
np.sum(np.eye(82)[-2] @ scipy.linalg.expm(QQ))

np.float64(0.9999999999999998)

In [205]:
def f(Q, y):
    # yshape == (n ** 2 + 1,)
    # assert Q.shape == (n, n)
    # assert c.shape == (n,)
    Q1 = Q.sum(0)[:, None] + Q.sum(1)[None, :]

    yn = y[:-1].reshape(n, n)
    yd = jnp.diag(yn)
    yc = y[-1]
    ret = Q @ yn + yn @ Q.T
    ret -= Q1 * yn
    ret -= jnp.diag(q * yd)
    return jnp.append(ret.reshape(n ** 2), q.dot(yd))

In [207]:
pip install ipytest

  pid, fd = os.forkpty()


Collecting ipytest
  Downloading ipytest-0.14.2-py3-none-any.whl (18 kB)
Installing collected packages: ipytest
Successfully installed ipytest-0.14.2
Note: you may need to restart the kernel to use updated packages.
