# Imports

In [1]:
import numpy as np
import torch
from torch.autograd import Variable
torch.manual_seed(1)
np.random.seed(1)

# Create ground truth rate matrix 'Q_true'

In [2]:
Q_true = torch.tensor(np.array([[-10, 3, 7], [8, -12, 4], [5, 9, -14]]) * 0.1)
print(f"Ground truth rate matrix Q_true:\n{Q_true}")
num_states = Q_true.shape[0]

Ground truth rate matrix Q_true:
tensor([[-1.0000,  0.3000,  0.7000],
        [ 0.8000, -1.2000,  0.4000],
        [ 0.5000,  0.9000, -1.4000]], dtype=torch.float64)


# Generate synthetic training data

In [3]:
dataset = []
m = 3000
print(f"Generating {m} synthetic datapoints of the form (starting_state, ending_state, branch_length)")
for _ in range(m):
    branch_length = np.random.uniform(0, 1)
    starting_state = np.random.choice(list(range(num_states)))
    transition_probs_from_starting_state = torch.matrix_exp(branch_length * Q_true)[starting_state, :]
    ending_state = np.random.choice(range(num_states), p=transition_probs_from_starting_state)
    datapoint = (starting_state, ending_state, branch_length)
    dataset.append(datapoint)

Generating 3000 synthetic datapoints of the form (starting_state, ending_state, branch_length)


# Learn the transition rate matrix!

In [4]:
%%time
# Parameterize the rate matrix Q. We choose to parameterize the off-diagonal elements of Q as Q_ij = theta_ij ** 2 to ensure their positivity.
theta = Variable(torch.rand(size=(num_states, num_states)), requires_grad=True)
optimizer = torch.optim.SGD([theta], lr=1.0, momentum=0.0, weight_decay=0)

num_epochs = 10
print(f"Training for {num_epochs} epochs")
for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Forward pass
    # First compute Q based on theta
    positive_matrix = theta * theta  # positive parameterization of off-diagonal elements of Q
    mask_kill_diagonal = torch.tensor(data=np.ones(shape=(num_states, num_states)) - np.eye(num_states))
    diagonal_elements = - (positive_matrix * mask_kill_diagonal).sum(axis=1)
    Q = positive_matrix * mask_kill_diagonal + torch.diag(diagonal_elements)
    if epoch == 0:
        print(f"Initial transition rate matrix Q:\n{Q}")
    # Now compute the loss
    loss = 0.0
    for i in range(m):
        datapoint = dataset[i]
        starting_state, ending_state, branch_length = datapoint
        loss += -1.0 / m * torch.log(torch.matrix_exp(branch_length * Q)[starting_state, ending_state])
    print(f"Epoch {epoch}: loss = {loss}, Frob norm to ground truth = {torch.sqrt(torch.sum((Q - Q_true) * (Q - Q_true)))}")
    # Take a gradient step.
    loss.backward(retain_graph=True)
    optimizer.step()
print(f"Learnt transition rate matrix Q:\n{Q}")
print(f"Ground truth rate matrix Q_true:\n{Q_true}")

Training for 10 epochs
Initial transition rate matrix Q:
tensor([[-0.2405,  0.0780,  0.1625],
        [ 0.5398, -1.1795,  0.6398],
        [ 0.1577,  0.5691, -0.7268]], dtype=torch.float64,
       grad_fn=<AddBackward0>)
Epoch 0: loss = 0.9191385530229957, Frob norm to ground truth = 1.3117162221396912
Epoch 1: loss = 0.8470322784935, Frob norm to ground truth = 0.7621375912355856
Epoch 2: loss = 0.8380485839248635, Frob norm to ground truth = 0.5472146870614751
Epoch 3: loss = 0.8347216295447729, Frob norm to ground truth = 0.4176521335761389
Epoch 4: loss = 0.8331908145946803, Frob norm to ground truth = 0.3329168691626058
Epoch 5: loss = 0.8323976311833947, Frob norm to ground truth = 0.27559248371533374
Epoch 6: loss = 0.8319542850662115, Frob norm to ground truth = 0.23638811529555961
Epoch 7: loss = 0.8316925381629672, Frob norm to ground truth = 0.20962481754609724
Epoch 8: loss = 0.8315312283610748, Frob norm to ground truth = 0.1915232438834196
Epoch 9: loss = 0.83142825497589