# Phylogenetic estimation using torchtree CLI

> [Amine Remita](https://github.com/maremita), Department of Computer Science, Université du Québec à Montréal

This notebook shows an example of using torchtree to generate json file to infer the branch lengths and GTR+I parameters with fixed topology.

In [1]:
# Uncomment if not included in ipython config
# %load_ext autoreload
# %autoreload 2

In [2]:
from torchtree import Parameter
from torchtree.evolution.alignment import read_fasta_sequences
from torchtree.evolution.tree_model import UnRootedTreeModel
from torchtree.distributions import Distribution

from torchtree.cli.advi import create_meanfield
from torchtree.cli.jacobians import create_jacobians

from pprint import pprint
from json import dump

import torch

In [3]:
# The results of this notebook should comparable with those generated using torchtree-cli and torchtree
# 
# torchtree-cli advi -m GTR -C 1 -I -i ../data/fluA.fasta -t ../data/fluA.tree > fluA_gtr_i.json
# torchtree fluA_gtr_i.json

In [4]:
# seed = 42
# torch.manual_seed(seed)
# np.random.seed(seed)

In [5]:
# FluA dataset

# Input files
newick_file = "../data/fluA.tree"
fasta_file = "../data/fluA.fasta"

In [6]:
# Output files
# TODO it's better to put these files in a separate folder (results/ for example)

json_outfile= "fluA_gtr_i.json"
checkpoint_file = "fluA_checkpoint.json"
file_sample_name = 'fluA_samples.csv'
tree_file_name = 'fluA_samples.trees'

## I. Building alignment and taxa dataset

In [7]:
with open(newick_file, 'r') as fp:
    newick = fp.read()
    newick = newick.strip()

In [8]:
## Parse fasta file

sequences = read_fasta_sequences(fasta_file)
sequence_list = []
taxa_list = []

for sequence in sequences:
    sequence_list.append(
        {'taxon': sequence.taxon, 'sequence': sequence.sequence}
    )
    taxa_list.append({'id': sequence.taxon, 'type': 'Taxon'})

In [9]:
## Create taxa dictionary
taxa = {'id': "taxa", 
        'type': 'Taxa', 
        'taxa': taxa_list}

nb_blens = len(taxa['taxa']) * 2 - 3

In [10]:
## Create alignment dictionary
alignment = {
    'id': "alignment",
    'type': 'Alignment',
    'datatype':{
        "id": "data_type",
        "type": "NucleotideDataType"
    },
    'taxa': "taxa",
    "sequences":sequence_list
}

## II. Evolutionary parameters

In [11]:
tree_id = "tree"

#### Initialize branch length parameters

In [12]:
branch_lengths = Parameter.json_factory(
    f'{tree_id}.blens', 
    **{'tensor': 0.1, 'full': [nb_blens]}
)

branch_lengths['lower'] = 0.0

#### Initialize site model parameter (pinv)

In [13]:
site_model_id = "sitemodel"

In [14]:
prop = Parameter.json_factory(f'{site_model_id}.pinv', **{'tensor': [0.1]})
prop['lower'] = 0
prop['upper'] = 1

site_model = {
    'id': site_model_id, 
    'type': 'InvariantSiteModel', 
    'invariant': prop}

site_model

{'id': 'sitemodel',
 'type': 'InvariantSiteModel',
 'invariant': {'id': 'sitemodel.pinv',
  'type': 'Parameter',
  'tensor': [0.1],
  'lower': 0,
  'upper': 1}}

#### Initialize GTR substitution model parameters

In [15]:
subst_model_id = "substmodel"

In [16]:
## Rates
rates = Parameter.json_factory(
    "{}.rates".format(subst_model_id), 
    **{'tensor': 1 / 6, 'full': [6]}
)
rates['simplex'] = True
# rates['transform'] = 'torch.distributions.StickBreakingTransform'

In [17]:
## Relative frequencies

frequencies = Parameter.json_factory(
    "{}.frequencies".format(subst_model_id),
    **{'tensor': [0.25] * 4}
)

frequencies['simplex'] = True

In [18]:
## The GTR model

subst_model = {
    'id': subst_model_id, 
    'type': 'GTR',
    "rates": rates,
    "frequencies": frequencies
}

subst_model

{'id': 'substmodel',
 'type': 'GTR',
 'rates': {'id': 'substmodel.rates',
  'type': 'Parameter',
  'full': [6],
  'tensor': 0.16666666666666666,
  'simplex': True},
 'frequencies': {'id': 'substmodel.frequencies',
  'type': 'Parameter',
  'tensor': [0.25, 0.25, 0.25, 0.25],
  'simplex': True}}

## III. Evolutionary joint distribution (tree likelihood + priors)

### 1. Tree Likelihood

#### Initialize tree model

In [19]:
tree_model = UnRootedTreeModel.json_factory(
    tree_id, newick, branch_lengths, 'taxa')

# pprint(tree_model)

#### Initialize site pattern model

In [20]:
site_pattern_id = "patterns"

site_pattern = {
    'id': site_pattern_id, 
    'type': 'SitePattern', 
    'alignment': 'alignment'}

site_pattern

{'id': 'patterns', 'type': 'SitePattern', 'alignment': 'alignment'}

#### Initialize tree likelihood model

In [21]:
tree_likelihood = {
        'id': 'like',
        'type': 'TreeLikelihoodModel',
        'tree_model': tree_model,
        'site_model': site_model,
        'site_pattern': site_pattern,
        'substitution_model': subst_model,
    }

# pprint(tree_likelihood)

### 2. Priors

#### Branch lengths *p*()

In [22]:
blens_prior = Distribution.json_factory(
    f'{tree_id}.blens.prior',
    'torch.distributions.Exponential',
    f'{tree_id}.blens',
    {'rate': 10.0}
)

blens_prior

{'id': 'tree.blens.prior',
 'type': 'Distribution',
 'distribution': 'torch.distributions.Exponential',
 'x': 'tree.blens',
 'parameters': {'rate': 10.0}}

#### Invariant probability *p*()

In [23]:
## Remark: torchtree-cli does not add a prior on pinv when using -I

site_model_prior = Distribution.json_factory(
    f'{site_model_id}.pinv.prior',
    'torch.distributions.Exponential',
    f'{site_model_id}.pinv',
    {'rate': 2.0}
)

site_model_prior

{'id': 'sitemodel.pinv.prior',
 'type': 'Distribution',
 'distribution': 'torch.distributions.Exponential',
 'x': 'sitemodel.pinv',
 'parameters': {'rate': 2.0}}

#### GTR rates *p*()

In [24]:
rates_prior = Distribution.json_factory(
    f'{subst_model_id}.rates.prior',
    'torch.distributions.Dirichlet',
    f'{subst_model_id}.rates',
    {'concentration': [1.0] * 6}
)

rates_prior

{'id': 'substmodel.rates.prior',
 'type': 'Distribution',
 'distribution': 'torch.distributions.Dirichlet',
 'x': 'substmodel.rates',
 'parameters': {'concentration': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}}

#### GTR Frequencies *p*()

In [25]:
freqs_prior = Distribution.json_factory(
    f'{subst_model_id}.frequencies.prior',
    'torch.distributions.Dirichlet',
    f'{subst_model_id}.frequencies',
    {'concentration': [1.0] * 4}
)

freqs_prior

{'id': 'substmodel.frequencies.prior',
 'type': 'Distribution',
 'distribution': 'torch.distributions.Dirichlet',
 'x': 'substmodel.frequencies',
 'parameters': {'concentration': [1.0, 1.0, 1.0, 1.0]}}

In [26]:
prior_joint_list = [blens_prior, site_model_prior, freqs_prior, rates_prior]
prior_joint_list

[{'id': 'tree.blens.prior',
  'type': 'Distribution',
  'distribution': 'torch.distributions.Exponential',
  'x': 'tree.blens',
  'parameters': {'rate': 10.0}},
 {'id': 'sitemodel.pinv.prior',
  'type': 'Distribution',
  'distribution': 'torch.distributions.Exponential',
  'x': 'sitemodel.pinv',
  'parameters': {'rate': 2.0}},
 {'id': 'substmodel.frequencies.prior',
  'type': 'Distribution',
  'distribution': 'torch.distributions.Dirichlet',
  'x': 'substmodel.frequencies',
  'parameters': {'concentration': [1.0, 1.0, 1.0, 1.0]}},
 {'id': 'substmodel.rates.prior',
  'type': 'Distribution',
  'distribution': 'torch.distributions.Dirichlet',
  'x': 'substmodel.rates',
  'parameters': {'concentration': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}}]

### 3. Evolutionary joint distribution *P*()

In [27]:
joint_list = [tree_likelihood] + prior_joint_list

joint_dic = {
    'id': 'joint',
    'type': 'JointDistributionModel',
    'distributions': joint_list
}

In [28]:
json_list = []
json_list.append(taxa)
json_list.append(alignment)
json_list.append(joint_dic)
json_list

[{'id': 'taxa',
  'type': 'Taxa',
  'taxa': [{'id': 'A_Belgium_2_1981', 'type': 'Taxon'},
   {'id': 'A_ChristHospital_231_1982', 'type': 'Taxon'},
   {'id': 'A_Philippines_2_1982', 'type': 'Taxon'},
   {'id': 'A_Baylor1B_1983', 'type': 'Taxon'},
   {'id': 'A_Oita_3_1983', 'type': 'Taxon'},
   {'id': 'A_Texas_12764_1983', 'type': 'Taxon'},
   {'id': 'A_Alaska_8_1984', 'type': 'Taxon'},
   {'id': 'A_Caen_1_1984', 'type': 'Taxon'},
   {'id': 'A_Texas_17988_1984', 'type': 'Taxon'},
   {'id': 'A_Colorado_2_1987', 'type': 'Taxon'},
   {'id': 'A_Guangdong_9_1987', 'type': 'Taxon'},
   {'id': 'A_Guizhou_1_1987', 'type': 'Taxon'},
   {'id': 'A_LosAngeles_1987', 'type': 'Taxon'},
   {'id': 'A_Qingdao_10_1987', 'type': 'Taxon'},
   {'id': 'A_Shanghai_11_1987', 'type': 'Taxon'},
   {'id': 'A_Sichuan_2_1987', 'type': 'Taxon'},
   {'id': 'A_Sydney_1_1987', 'type': 'Taxon'},
   {'id': 'A_Tokyo_1275_1987', 'type': 'Taxon'},
   {'id': 'A_Victoria_7_1987', 'type': 'Taxon'},
   {'id': 'A_Alaska_9_1992', 

## IV. Variational meanfield distribution

### 1. Variational independent distributions

In [29]:
var_dist = "Normal"
var_id = 'var.' + var_dist

In [30]:
# convert Parameters with constraints to TransformedParameters
# and create variational distributions

# TO_DO not use directlly create_meanfield()
var_distributions, var_parameters = create_meanfield(var_id, json_list, var_dist)

# Jacobian
# I don't think this is useful here
jacobians_list = create_jacobians(json_list)
joint_dic['distributions'].extend(jacobians_list)

### 2. Variational joint distribution *Q*()

In [31]:
# Variational joint distribution
variational = {'id': "variational", 'type': 'JointDistributionModel'}

variational['distributions'] = var_distributions
json_list.append(variational)

## V. Fitting the variational model using ADVI

### 1. Automatic Differentiation Variational Inference (ADVI)

In [32]:
## Hyperparams
lr = 0.1
iterations = 100000
grad_samples = 1
convergence_every = 100
elbo_samples = 100
tol_rel_obj = 0.01

In [33]:
# Initialize optimizer
advi_dic = {
    'id': 'advi',
    'type': 'Optimizer',
    'algorithm': 'torch.optim.Adam',
    'options': {'lr': lr},
    'maximize': True,
    'checkpoint': checkpoint_file,
    'iterations': iterations,
    'loss': {
        'id': 'elbo',
        'type': 'ELBO',
        'samples': grad_samples,
        'joint': 'joint', 
        'variational': 'variational',
        },
    'parameters': var_parameters,
}

In [34]:
# Initialize convergence
advi_dic['convergence'] = {
    'type': 'StanVariationalConvergence',
    'max_iterations': iterations,
    'loss': 'elbo',
    'every': convergence_every,
    'samples': elbo_samples,
    'tol_rel_obj': tol_rel_obj,
    }


In [35]:
# Initialize scheduler
advi_dic['scheduler'] = {
    'type': 'torchtree.optim.Scheduler',
    'scheduler': 'torch.optim.lr_scheduler.LambdaLR',
    'lr_lambda': 'lambda epoch: 1.0 / (epoch + 1)**0.5',
}

json_list.append(advi_dic)

In [36]:
json_list

[{'id': 'taxa',
  'type': 'Taxa',
  'taxa': [{'id': 'A_Belgium_2_1981', 'type': 'Taxon'},
   {'id': 'A_ChristHospital_231_1982', 'type': 'Taxon'},
   {'id': 'A_Philippines_2_1982', 'type': 'Taxon'},
   {'id': 'A_Baylor1B_1983', 'type': 'Taxon'},
   {'id': 'A_Oita_3_1983', 'type': 'Taxon'},
   {'id': 'A_Texas_12764_1983', 'type': 'Taxon'},
   {'id': 'A_Alaska_8_1984', 'type': 'Taxon'},
   {'id': 'A_Caen_1_1984', 'type': 'Taxon'},
   {'id': 'A_Texas_17988_1984', 'type': 'Taxon'},
   {'id': 'A_Colorado_2_1987', 'type': 'Taxon'},
   {'id': 'A_Guangdong_9_1987', 'type': 'Taxon'},
   {'id': 'A_Guizhou_1_1987', 'type': 'Taxon'},
   {'id': 'A_LosAngeles_1987', 'type': 'Taxon'},
   {'id': 'A_Qingdao_10_1987', 'type': 'Taxon'},
   {'id': 'A_Shanghai_11_1987', 'type': 'Taxon'},
   {'id': 'A_Sichuan_2_1987', 'type': 'Taxon'},
   {'id': 'A_Sydney_1_1987', 'type': 'Taxon'},
   {'id': 'A_Tokyo_1275_1987', 'type': 'Taxon'},
   {'id': 'A_Victoria_7_1987', 'type': 'Taxon'},
   {'id': 'A_Alaska_9_1992', 

### 2. Sampler

In [37]:
parameters = [
    'tree.blens', 
    "substmodel.rates", 
    "substmodel.frequencies", 
    "sitemodel.pinv"
]

samples = 1000

sampler = {
    "id": 'sampler',
    "type": "Sampler",
    "model": 'variational',
    "samples": samples,
    "loggers": [
        {
            "id": "logger",
            "type": "Logger",
            "file_name": file_sample_name,
            "parameters": ['joint', 'like', 'variational'] + parameters,
            "delimiter": "\t",
        },
        {
            "id": "tree.logger",
            "type": "TreeLogger",
            "file_name": tree_file_name,
            "tree_model": "tree",
        },
    ],
}

sampler

{'id': 'sampler',
 'type': 'Sampler',
 'model': 'variational',
 'samples': 1000,
 'loggers': [{'id': 'logger',
   'type': 'Logger',
   'file_name': 'fluA_samples.csv',
   'parameters': ['joint',
    'like',
    'variational',
    'tree.blens',
    'substmodel.rates',
    'substmodel.frequencies',
    'sitemodel.pinv'],
   'delimiter': '\t'},
  {'id': 'tree.logger',
   'type': 'TreeLogger',
   'file_name': 'fluA_samples.trees',
   'tree_model': 'tree'}]}

In [38]:
json_list.append(sampler)

### 3. Create json file

In [39]:
# create json file

with open(json_outfile, "w") as fh:
    dump(json_list, fh, indent=2)

### 4. Run torchtree

In [40]:
# torchtree should  be installed in the environment where jupyter is running

!torchtree  $json_outfile

SEED: 2582957386847507338
dtype: torch.float64

  iter             ELBO   delta_ELBO_mean   delta_ELBO_med   notes 
     0        -8189.801             1.000            1.000
   100        -5312.984             0.771            0.771
   200        -4998.764             0.535            0.541
   300        -4852.789             0.409            0.302
   400        -4766.674             0.330            0.063
   500        -4718.430             0.277            0.046
   600        -4683.129             0.239            0.030
   700        -4659.274             0.209            0.024
   800        -4642.069             0.187            0.018
   900        -4629.355             0.168            0.014
  1000        -4619.499             0.153            0.010
  1100        -4610.929             0.140            0.009   MEDIAN ELBO CONVERGED
