# Phylogenetic estimation using torchtree API

> [Amine Remita](https://github.com/maremita), Department of Computer Science, Université du Québec à Montréal

This notebook shows an example of using torchtree API to infer the branch lengths and GTR+I parameters with fixed topology.

I really enjoyed implementing the model using torchtree. Making this notebook helped me to have a better understanding of it. So, I wanted to share my experience with you.

In [1]:
# Uncomment if not included in ipython config
# %load_ext autoreload
# %autoreload 2

In [2]:
from torchtree.evolution.io import read_tree_and_alignment
from torchtree.evolution.alignment import Alignment, Sequence
from torchtree.evolution.taxa import Taxa, Taxon
from torchtree.evolution.datatype import NucleotideDataType
from torchtree.evolution.site_pattern import SitePattern
from torchtree.evolution.site_model import InvariantSiteModel
from torchtree.evolution.substitution_model import GTR
from torchtree.evolution.tree_likelihood import TreeLikelihoodModel

from torchtree import Parameter, TransformedParameter
from torchtree.evolution.tree_model import UnRootedTreeModel
from torchtree.distributions import Distribution
from torchtree.distributions.joint_distribution import JointDistributionModel

from torchtree.variational import ELBO

from collections import OrderedDict
from pprint import pprint
from json import dump

import torch
from torch.distributions import ExpTransform, StickBreakingTransform, SigmoidTransform

import numpy as np

In [3]:
# The results of this notebook should comparable with those generated using torchtree-cli and torchtree
# 
# torchtree-cli advi -m GTR -C 1 -I -i ../data/fluA.fasta -t ../data/fluA.tree > fluA_gtr_i.json
# torchtree fluA_gtr_i.json

In [4]:
# seed = 42
# torch.manual_seed(seed)
# np.random.seed(seed)

In [5]:
# FluA dataset

# Input files
newick_file = "../data/fluA.tree"
fasta_file = "../data/fluA.fasta"

## I. Building alignment and taxa

In [6]:
with open(newick_file, 'r') as fp:
    newick = fp.read()
    newick = newick.strip()

In [7]:
## Parse fasta and tree file

tree, dna = read_tree_and_alignment(newick_file, fasta_file, dated=True, heterochornous=True)

sequence_list = []
taxa_list = []

for taxon, seq in dna.items():
    sequence_list.append(Sequence(taxon.label, str(seq)))
    taxa_list.append(Taxon(taxon.label, None))

In [8]:
taxa_id = "taxa"

taxa = Taxa(taxa_id, taxa_list)

nb_blens = len(taxa) * 2 - 3
nb_blens

135

In [9]:
alignment_id = "alignment"

alignment = Alignment(
    alignment_id,              
    sequence_list, 
    taxa, NucleotideDataType('nuc')
)

## II. Evolutionary parameters and samples holders

In [10]:
tree_id = "tree"

#### Initialize branch length parameters

In [11]:
blens_id = f'{tree_id}.blens'
blens_unres_id = blens_id + '.unres'

blens_init = 0.1

# Parameter to hold samples of branch lengths
x_blens_unres = Parameter(blens_unres_id, torch.full((1, nb_blens,), np.log(blens_init)).detach())

branch_lengths = TransformedParameter(
    blens_id,
    x_blens_unres,
    ExpTransform() # lower = 0
)

In [12]:
branch_lengths.id

'tree.blens'

In [13]:
x_blens_unres.shape

torch.Size([1, 135])

In [14]:
branch_lengths().shape

torch.Size([1, 135])

#### Initialize site model parameter (pinv)

In [15]:
site_model_id = "sitemodel"

In [16]:
pinv_id = f'{site_model_id}.pinv'
pinv_unres_id = pinv_id + '.unres'

pinv_init = 0.1

# Parameter to hold samples of invariant probabilities
x_pinv_unres = Parameter(pinv_unres_id, torch.tensor((pinv_init,)).detach())

pinv_param = TransformedParameter(
    pinv_id,
    x_pinv_unres,
    SigmoidTransform() # lower = 0, upper = 1
)

In [17]:
pinv_param.id

'sitemodel.pinv'

In [18]:
site_model = InvariantSiteModel(
    site_model_id,
    pinv_param
)

#### Initialize GTR substitution model parameters

In [19]:
subst_model_id = "substmodel"

In [20]:
## Rates
rates_id = f'{subst_model_id}.rates'
rates_unres_id = rates_id + '.unres'

rates_init = 0.0

# Parameter to hold samples of GTR substitution rates
x_rates_unres = Parameter(rates_unres_id, torch.full((1, 5,), rates_init).detach())

rates_param = TransformedParameter(
    rates_id,
    x_rates_unres,
    StickBreakingTransform(),
)

In [21]:
print(rates_param.x.shape)
rates_param.x

torch.Size([1, 5])


Parameter(id_='substmodel.rates.unres', tensor=torch.tensor([[0., 0., 0., 0., 0.]]))

In [22]:
rates_param()

tensor([-10.7506])

In [23]:
rates_param.tensor

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])

In [24]:
## Relative frequencies
freqs_id = f'{subst_model_id}.frequencies'
freqs_unres_id = freqs_id + '.unres'

freqs_init = 0.0

# Parameter to hold samples of GTR relative frequencies
x_freqs_unres = Parameter(rates_unres_id, torch.full((1, 3,), freqs_init).detach())

freqs_param = TransformedParameter(
    freqs_id,
    x_freqs_unres,
    StickBreakingTransform(),
)

In [25]:
freqs_param()

tensor([-5.5452])

In [26]:
freqs_param.tensor

tensor([[0.2500, 0.2500, 0.2500, 0.2500]])

In [27]:
subst_model = GTR(
    subst_model_id,
    rates_param,
    freqs_param
)

In [28]:
subst_model.rates

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])

## III. Evolutionary joint distribution (tree likelihood + priors)

### 1. Tree Likelihood

#### Initialize tree model

In [29]:
tree_model = UnRootedTreeModel(
    tree_id,
    tree, 
    taxa, 
    branch_lengths
)

#### Initialize site pattern model

In [30]:
site_pattern_id = "patterns"

site_pattern = SitePattern(site_pattern_id, alignment)

#### Initialize tree likelihood model

In [31]:
like_id = 'like'

tree_likelihood = TreeLikelihoodModel(
    id_=like_id, 
    site_pattern=site_pattern, 
    tree_model=tree_model, 
    subst_model=subst_model, 
    site_model=site_model,
    clock_model=None,
)

In [32]:
tree_model

<torchtree.evolution.tree_model.UnRootedTreeModel at 0x7fb6401fcdc0>

In [33]:
# tree_likelihood.parameters()

In [34]:
tree_likelihood.sample_shape

torch.Size([1])

In [35]:
with torch.no_grad():
    tree_likelihood()

In [36]:
tree_likelihood.lp

tensor([[nan]])

### 2. Priors

#### Branch lengths *p*()

In [37]:
blens_p_id = blens_id + '.prior'

blens_prior = Distribution(
    blens_p_id,
    torch.distributions.Exponential,
    branch_lengths,
    OrderedDict({'rate': Parameter(None, torch.tensor([10.0]).detach())}),
)

print(blens_prior.id)
print(blens_prior.x.id)

tree.blens.prior
tree.blens


#### Invariant probability *p*()

In [38]:
## Remark: torchtree-cli does not add a prior on pinv when using -I

pinv_p_id = pinv_id + '.prior'

pinv_prior = Distribution(
    pinv_p_id,
    torch.distributions.Exponential,
#     x_pinv_unres,
    pinv_param,
    OrderedDict({'rate': Parameter(None, torch.tensor([2.0]).detach())})
)

print(pinv_prior.id)
print(pinv_prior.x.id)

sitemodel.pinv.prior
sitemodel.pinv


#### GTR rates *p*()

In [39]:
rates_p_id = rates_id + '.rates.prior'

rates_prior = Distribution(
    rates_p_id,
    torch.distributions.Dirichlet,
#     x_rates_unres,
    rates_param,
    OrderedDict({'concentration': Parameter(None, torch.ones(6).detach())})
)

print(rates_prior.id)
print(rates_prior.x.id)

substmodel.rates.rates.prior
substmodel.rates


#### GTR Frequencies *p*()

In [40]:
freqs_p_id = freqs_id + '.prior'

freqs_prior = Distribution(
    freqs_p_id,
    torch.distributions.Dirichlet,
#     x_freqs_unres,
    freqs_param,
    OrderedDict({'concentration': Parameter(None, torch.ones(4).detach())})
)

print(freqs_prior.id)
print(freqs_prior.x.id)

substmodel.frequencies.prior
substmodel.frequencies


In [41]:
prior_joint_list = [blens_prior, pinv_prior, rates_prior, freqs_prior]

for dist in prior_joint_list:
    print(dist.id)

tree.blens.prior
sitemodel.pinv.prior
substmodel.rates.rates.prior
substmodel.frequencies.prior


### 3. Evolutionary joint distribution *P*()

In [42]:
joint_list = [tree_likelihood] + prior_joint_list

joint_dist = JointDistributionModel('joint', joint_list)

In [43]:
for model in joint_dist._distributions.models():
    # _distributions attribute is a core.container.Container
    print(model.id)

like
tree.blens.prior
sitemodel.pinv.prior
substmodel.rates.rates.prior
substmodel.frequencies.prior


In [44]:
with torch.no_grad():
    joint_dist()

## IV. Variational meanfield distribution

### 1. Variational independent distributions

In [45]:
var_dist = "Normal"
var_id = "var" + "." + var_dist

In [46]:
var_parameters = []

In [47]:
def get_loc_logscale_normal(v):
    
    if isinstance(v, (int, float)):
        v = np.array(v)
    elif isinstance(v, torch.Tensor):
        v  = v.numpy()

    loc = np.log(v / np.sqrt(1 + 0.001 / v**2)).tolist()
    scale_log = np.log(np.sqrt(np.log(1 + 0.001 / v**2))).tolist()
    
    return loc, scale_log

#### Branch lengths *q*()

In [48]:
## q() Branch lengths
b = np.array(0.1)
b_loc, b_scale_log = get_loc_logscale_normal(b)


# Parameters of variational distribution

b_unres_loc = Parameter(blens_unres_id+".loc", 
                torch.full((nb_blens,), b_loc))

b_unres_scale = TransformedParameter(
    blens_unres_id+".scale",
    Parameter(blens_unres_id+".scale.unres", 
              torch.full((nb_blens,), b_scale_log)),
    ExpTransform()
)

var_parameters.append(b_unres_loc)
var_parameters.append(b_unres_scale)

blens_q = Distribution(
    var_id + "." + blens_unres_id,
    torch.distributions.Normal,
    x_blens_unres,
    OrderedDict({'loc': b_unres_loc, 'scale': b_unres_scale})
)


print(b_loc)
print(b_scale_log)
print()
print(blens_q.id)
print(blens_q.x.id)


-2.3502401828962083
-1.1753093277566469

var.Normal.tree.blens.unres
tree.blens.unres


In [49]:
# blens_q.parameters()

In [50]:
with torch.no_grad():
    blens_q.rsample([100])

In [51]:
blens_q.x.shape

torch.Size([100, 135])

In [52]:
x_blens_unres.tensor[0]

tensor([-2.0237, -2.4660, -2.4146, -1.7869, -2.3419, -2.2407, -2.6596, -2.8978,
        -2.4016, -2.5741, -2.2607, -3.0991, -2.5679, -2.3381, -2.2194, -2.4720,
        -2.3597, -2.3949, -2.6676, -2.4295, -2.6130, -2.9177, -2.5750, -2.3864,
        -2.5243, -2.2456, -2.2272, -2.4884, -2.1980, -2.5346, -2.8250, -2.5847,
        -2.4022, -3.1242, -2.4724, -2.4796, -2.3231, -1.8745, -2.2846, -2.1611,
        -2.5171, -1.9773, -2.1952, -2.6862, -2.2720, -1.8243, -2.3279, -2.0839,
        -2.6234, -2.5094, -2.6803, -2.4260, -2.0064, -2.4540, -2.4428, -2.3874,
        -2.1392, -2.2806, -1.9871, -2.3693, -2.1365, -2.2626, -2.5355, -2.3941,
        -2.3234, -2.4946, -2.6250, -2.5848, -2.2591, -2.4657, -2.4017, -2.4425,
        -2.7320, -2.6632, -2.2307, -2.2290, -2.0762, -3.2016, -2.6090, -2.1238,
        -2.3364, -2.4631, -2.7328, -2.4546, -2.4228, -2.0226, -2.1185, -2.3535,
        -2.1883, -2.6048, -2.2926, -1.9646, -2.0351, -2.6276, -2.7922, -2.8000,
        -2.5625, -2.4581, -2.3153, -3.05

In [53]:
branch_lengths.tensor[0]

tensor([0.1322, 0.0849, 0.0894, 0.1675, 0.0961, 0.1064, 0.0700, 0.0551, 0.0906,
        0.0762, 0.1043, 0.0451, 0.0767, 0.0965, 0.1087, 0.0844, 0.0944, 0.0912,
        0.0694, 0.0881, 0.0733, 0.0541, 0.0762, 0.0920, 0.0801, 0.1059, 0.1078,
        0.0830, 0.1110, 0.0793, 0.0593, 0.0754, 0.0905, 0.0440, 0.0844, 0.0838,
        0.0980, 0.1534, 0.1018, 0.1152, 0.0807, 0.1384, 0.1113, 0.0681, 0.1031,
        0.1613, 0.0975, 0.1244, 0.0726, 0.0813, 0.0685, 0.0884, 0.1345, 0.0859,
        0.0869, 0.0919, 0.1177, 0.1022, 0.1371, 0.0935, 0.1181, 0.1041, 0.0792,
        0.0913, 0.0979, 0.0825, 0.0724, 0.0754, 0.1044, 0.0849, 0.0906, 0.0869,
        0.0651, 0.0697, 0.1075, 0.1076, 0.1254, 0.0407, 0.0736, 0.1196, 0.0967,
        0.0852, 0.0650, 0.0859, 0.0887, 0.1323, 0.1202, 0.0950, 0.1121, 0.0739,
        0.1010, 0.1402, 0.1307, 0.0723, 0.0613, 0.0608, 0.0771, 0.0856, 0.0987,
        0.0470, 0.0628, 0.0966, 0.0784, 0.0654, 0.0674, 0.0806, 0.0783, 0.1124,
        0.0737, 0.0828, 0.1077, 0.1074, 

#### Invariant probability *q*()

In [54]:
p_loc = torch.distributions.SigmoidTransform().inv(torch.tensor(pinv_init)).tolist()
p_scale_log = -1.89712

pinv_unres_loc =  Parameter(pinv_unres_id+".loc", torch.tensor([p_loc]))

pinv_unres_scale = TransformedParameter(
    pinv_unres_id+".scale",
    Parameter(pinv_unres_id+".scale.unres", 
              torch.tensor([p_scale_log])),
    ExpTransform()
)


var_parameters.append(pinv_unres_loc)
var_parameters.append(pinv_unres_scale)


pinv_q = Distribution(
    var_id + "." + pinv_unres_id,
    torch.distributions.Normal,
    x_pinv_unres,
    OrderedDict({'loc': pinv_unres_loc, 'scale': pinv_unres_scale})
)


print(p_loc)
print(p_scale_log)
print()
print(pinv_q.id)
print(pinv_q.x.id)

-2.1972246170043945
-1.89712

var.Normal.sitemodel.pinv.unres
sitemodel.pinv.unres


In [55]:
with torch.no_grad():
    pinv_q.rsample([100])

In [56]:
x_pinv_unres.tensor[0]

tensor([-2.2672])

In [57]:
pinv_param.tensor[0]

tensor([0.0939])

#### GTR rates *q*()

In [58]:
r_loc, r_scale_log = 0.5, -1.89712


rates_unres_loc = Parameter(rates_unres_id+".loc", torch.full((5,), r_loc))

rates_unres_scale = TransformedParameter(
    rates_unres_id+".scale",
    Parameter(rates_unres_id+".scale.unres", 
              torch.full((5,), r_scale_log)),
    ExpTransform()
)

var_parameters.append(rates_unres_loc)
var_parameters.append(rates_unres_scale)

rates_q = Distribution(
    var_id + "." + rates_unres_id,
    torch.distributions.Normal,
    x_rates_unres,
    OrderedDict({'loc': rates_unres_loc, 'scale': rates_unres_scale})
)

print(rates_q.id)
print(rates_q.x.id)

var.Normal.substmodel.rates.unres
substmodel.rates.unres


In [59]:
with torch.no_grad():
    rates_q.rsample([100])
    
rates_q.x.shape

torch.Size([100, 5])

#### GTR Frequencies *q*()

In [60]:
f_loc, f_scale_log = 0.5, -1.89712


freqs_unres_loc = Parameter(freqs_unres_id+".loc", torch.full((3,), f_loc))

freqs_unres_scale = TransformedParameter(
    freqs_unres_id+".scale",
    Parameter(freqs_unres_id+".scale.unres", 
              torch.full((3,), f_scale_log)),
    ExpTransform()
)

var_parameters.append(freqs_unres_loc)
var_parameters.append(freqs_unres_scale)

freqs_q = Distribution(
    var_id + "." + freqs_unres_id,
    torch.distributions.Normal,
    x_freqs_unres,
    OrderedDict({'loc': freqs_unres_loc, 'scale': freqs_unres_scale})
)

print(rates_q.id)
print(rates_q.x.id)

var.Normal.substmodel.rates.unres
substmodel.rates.unres


In [61]:
with torch.no_grad():
    freqs_q.rsample([100])

freqs_q.x.shape

torch.Size([100, 3])

### 2. Variational joint distribution *Q*()

In [62]:
## Joint variational

var_joint_list = [blens_q, pinv_q, rates_q, freqs_q]

var_joint_dist = JointDistributionModel('variational', var_joint_list)

In [63]:
# rates_q.parameters()

In [64]:
for model in var_joint_dist._distributions.models():
    # _distributions attribute is a core.container.Container
    print(model.id)

var.Normal.tree.blens.unres
var.Normal.sitemodel.pinv.unres
var.Normal.substmodel.rates.unres
var.Normal.substmodel.frequencies.unres


In [65]:
## Get tensors of variational parameters
var_param_tensors = [p.tensor for p in var_parameters]

print(len(var_param_tensors))

8


### 3. ELBO

In [66]:
elbo_samples = 1

elbo = ELBO(
        id_='elbo',
        q=var_joint_dist,
        p=joint_dist,
        samples=torch.Size([elbo_samples]),
        entropy=False,
)

In [67]:
# elbo()

## V. Fitting the variational model

In [68]:
maximize=True
lr = 0.05 
# Fitting crashes when lr=0.1 and even sometimes with smaller values (invalid values in sampling)
weight_decay=0.001

optimizer = torch.optim.Adam(var_param_tensors,
                             maximize=maximize,
                             lr=lr,
                             weight_decay=weight_decay)

In [69]:
lambda1r = lambda epoch: 1.0 / (epoch + 1)**0.5

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1r)

In [70]:
iterations = 5000
trials = 10
var_samples = 100
print_every = 100


for p in var_param_tensors:
    p.requires_grad = True


for epoch in range(1, iterations + 1):

    for trial in range(trials):
        optimizer.zero_grad()
        loss = elbo()

        optimizer.zero_grad()
        loss.backward()

        for p in var_param_tensors:
            retry = torch.any(torch.isinf(p.grad)) or torch.any(
                torch.isnan(p.grad)
            )
            if retry:
                break
        if not retry:
            break
        else:
            for p in var_parameters:
                p.fire_parameter_changed()

    optimizer.step()

    scheduler.step()

    for p in var_parameters:
        p.fire_parameter_changed()

    with torch.no_grad():
        logl = tree_likelihood.lp.mean(0).item()
        logq = var_joint_dist.lp.mean(0).item()

        # Sample from variational distributions q()
        for var_distr in var_joint_list:
            var_distr.sample([var_samples])

        mean_tree_length = branch_lengths.tensor.sum(1).mean()
        mean_pinv = pinv_param.tensor.mean()
        mean_rates = rates_param.tensor.mean(0)
        mean_freqs = freqs_param.tensor.mean(0)

        # printing
        if epoch % print_every == 0:
            print(f'{epoch}\tELBO: {loss:.3f}\tLogL: {logl:.3f}\tLogQ: {logq:.3f}')
            print(f'tree length: {mean_tree_length:.3f}\npinv: {mean_pinv:.3f}')
            print('GTR rates:', [f"{r:.3f}" for r in mean_rates.tolist()])
            print(f'GTR freqs:',[f"{f:.3f}" for f in mean_freqs.tolist()],'\n') 


100	ELBO: -5663.874	LogL: -5946.705	LogQ: -25.098
tree length: 6.021
pinv: 0.207
GTR rates: ['0.306', '0.294', '0.090', '0.093', '0.168', '0.048']
GTR freqs: ['0.355', '0.253', '0.222', '0.171'] 

200	ELBO: -5184.143	LogL: -5476.102	LogQ: -19.176
tree length: 4.517
pinv: 0.265
GTR rates: ['0.265', '0.340', '0.075', '0.078', '0.199', '0.043']
GTR freqs: ['0.340', '0.234', '0.236', '0.190'] 

300	ELBO: -4902.405	LogL: -5207.971	LogQ: -25.322
tree length: 3.695
pinv: 0.305
GTR rates: ['0.226', '0.373', '0.068', '0.066', '0.221', '0.045']
GTR freqs: ['0.340', '0.222', '0.244', '0.195'] 

400	ELBO: -4718.143	LogL: -5023.073	LogQ: -19.024
tree length: 3.155
pinv: 0.338
GTR rates: ['0.191', '0.398', '0.066', '0.057', '0.245', '0.043']
GTR freqs: ['0.343', '0.211', '0.250', '0.197'] 

500	ELBO: -4631.680	LogL: -4952.049	LogQ: -31.433
tree length: 2.773
pinv: 0.365
GTR rates: ['0.175', '0.414', '0.062', '0.055', '0.250', '0.044']
GTR freqs: ['0.341', '0.210', '0.248', '0.201'] 

600	ELBO: -4497

4300	ELBO: -3778.479	LogL: -4197.314	LogQ: -107.888
tree length: 0.601
pinv: 0.560
GTR rates: ['0.119', '0.344', '0.067', '0.016', '0.392', '0.062']
GTR freqs: ['0.335', '0.222', '0.215', '0.228'] 

4400	ELBO: -3793.816	LogL: -4210.344	LogQ: -105.724
tree length: 0.593
pinv: 0.561
GTR rates: ['0.120', '0.341', '0.067', '0.015', '0.394', '0.062']
GTR freqs: ['0.334', '0.221', '0.219', '0.227'] 

4500	ELBO: -3788.309	LogL: -4202.057	LogQ: -102.100
tree length: 0.585
pinv: 0.560
GTR rates: ['0.122', '0.347', '0.066', '0.015', '0.388', '0.062']
GTR freqs: ['0.336', '0.221', '0.214', '0.229'] 

4600	ELBO: -3780.934	LogL: -4190.281	LogQ: -97.945
tree length: 0.578
pinv: 0.564
GTR rates: ['0.120', '0.348', '0.066', '0.016', '0.390', '0.060']
GTR freqs: ['0.336', '0.221', '0.214', '0.228'] 

4700	ELBO: -3769.205	LogL: -4190.500	LogQ: -110.423
tree length: 0.570
pinv: 0.566
GTR rates: ['0.121', '0.346', '0.066', '0.016', '0.389', '0.061']
GTR freqs: ['0.336', '0.222', '0.214', '0.228'] 

4800	E