In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
import argparse
import pickle
import datetime
import os
import re

from scipy.cluster.hierarchy import linkage
from itertools import combinations

from copy import deepcopy
from matplotlib import pyplot as plt
from collections import Counter
from io import StringIO
from Bio import Phylo
from ete3 import Tree

from VIPR import VIPR

import torch
import torch.nn as nn
from torch.distributions.log_normal import LogNormal
from torch.distributions.gamma import Gamma
from torch.distributions.exponential import Exponential
from torch.distributions.categorical import Categorical
torch.set_default_dtype(torch.float32)

In [None]:
dataset = "DS1" #["DS1","DS2","DS3","DS4","DS5","DS6","DS7","DS8","DS9","DS10","DS11","DS14"]
method = "reinforce" #["reparam","reinforce","VIMCO"]
alpha = 0.01 #[0.03,0.01,0.003,0.001]
rand_seed = 0

np.random.seed(rand_seed)
torch.manual_seed(rand_seed)

time = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
data_file = '../dat/'+dataset+'/'+dataset+'.pickle'
out_file = '../results/'+dataset+'/'+dataset+'_'+method+'_'+str(alpha)+'_'+str(rand_seed)+'_'+time+'.pickle'

In [None]:
# keep fixed values
decay = "exp"
batch_size = 10
max_iters = 1000#200000
record_every = 100
test_batch_size = 100
if decay == "linear":
    linear_decay = True
else:
    linear_decay = False
anneal_freq = 1
anneal_rate = 0.01**(1.0/max_iters)
pop_size = 5.0
max_time = 0.1 # HOURS

In [None]:
with open(data_file, 'rb') as f:
    ds = pickle.load(f)
    
genomes = []
species = []
for key in ds:
    genomes.append(ds[key])
    species.append(key)

ntaxa = len(species)

# Plot pair-wise hamming distances  

In [None]:
datasets = ["DS1","DS2","DS3","DS4","DS5","DS6","DS7","DS8","DS9","DS10","DS11","DS14"]
chars = ["C","A","G","T"]

for ds in datasets:
    
    print(ds)
    
    # load in genomes
    data_file = '../dat/'+ds+'/'+ds+'.pickle'
    with open(data_file, 'rb') as f:
        data0 = pickle.load(f)
    
    genomes0 = []
    species0 = []
    for key in data0:
        genomes0.append(data0[key])
        species0.append(key)

    ntaxa0 = len(species0)

    # plot hamming diatances
    hdists = np.zeros((ntaxa0,ntaxa0))

    for j in range(ntaxa0):
        for k in range(j):
            hdists[j,k] = np.mean([(x != y) for x,y in zip(genomes0[j],genomes0[k]) if (x.upper() in ["C","A","G","T"]) and (y.upper() in ["C","A","G","T"])])

    hdists[np.triu_indices(hdists.shape[0])] = np.nan
    plt.imshow(hdists)
    plt.colorbar()
    plt.show()
    print(np.nanmean(hdists),np.nanstd(hdists))

# Type of Trees

In [None]:
treedata = ""
ntrees = 0

for i in range(10):
    tree_file = "../dat/"+dataset+"/"+dataset+"_fixed_pop_support_short_run_rep_%d.trees"%(i+1)
    with open(tree_file, "r") as file:
        for j,line in enumerate(file):
            if j%10 == 0 and line.startswith("tree STATE"):
                line = line[line.find('('):]
                line = line.replace("[&rate=1.0]","")
                treedata = treedata + line + "\n"
                ntrees += 1
trees = Phylo.parse(StringIO(treedata), "newick")

def convert_tree_to_newick_without_lengths(tree):
    """Convert a tree to Newick format, ignoring branch lengths."""
    for clade in tree.find_clades():
        clade.branch_length = None  # Remove the branch lengths
    
    newick = StringIO()
    Phylo.write(tree, newick, "newick")
    return newick.getvalue().strip()

def remove_branch_lengths(newick_str):
    """
    Remove branch lengths from a Newick string.
    This uses a regex to remove any colon followed by a number (with optional decimal or exponent).
    """
    # The regex matches a colon followed by an optional sign, digits, optional decimal part,
    # and an optional exponent.
    return re.sub(r":[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?", "", newick_str)

structure_counter = Counter()

# Iterate over the trees and convert each one to Newick format
for i,tree in enumerate(trees):
        
    newick_str = convert_tree_to_newick_without_lengths(tree)
    structure_str = remove_branch_lengths(newick_str)
    structure_counter[structure_str] += 1

# Extract labels (tree structure strings) and counts for plotting
tree_labels = list(structure_counter.keys())
tree_counts = list(structure_counter.values())

plt.figure(figsize=(10, 6))
plt.plot(sorted(tree_counts,reverse=True))
plt.xticks(rotation=90)  # Rotate x-axis labels for better visibility
plt.xlabel('Tree Structure Index')
plt.ylabel('Count')
plt.title('Counts of Phylogenetic Tree Structures (Without Branch Lengths)')
plt.tight_layout()  # Adjust layout to fit labels
plt.show()

In [None]:
trees = []

for i in range(10):
    tree_file = "../dat/"+dataset+"/"+dataset+"_fixed_pop_support_short_run_rep_%d.trees"%(i+1)
    with open(tree_file, "r") as file:
        for j,line in enumerate(file):
            if j%100 == 0 and line.startswith("tree STATE"):
                line = line[line.find('('):]
                line = line.replace("[&rate=1.0]","")
                trees.append(Tree(line))
                
# Load trees from Newick files
ntrees = len(trees)
rf_dists = np.zeros((ntrees,ntrees))

for i in range(ntrees):
    for j in range(ntrees):
        rf_dists[i,j] = trees[i].robinson_foulds(trees[j])[0]
        
plt.imshow(rf_dists)
plt.colorbar()

# import $\theta$

In [None]:
treedata = ""
ntrees = 0
burnin = 100 + 2*ntaxa

for i in range(10):
    tree_file = "../dat/"+dataset+"/"+dataset+"_fixed_pop_support_short_run_rep_%d.trees"%(i+1)
    #tree_file = "../dat/DS1_tests/DS1_tips.trees"
    with open(tree_file, "r") as file:
        for j,line in enumerate(file):
            if j > burnin and j%10 == 0 and line.startswith("tree STATE"):
                line = line[line.find('('):]
                line = line.replace("[&rate=1.0]","")
                line = line.replace("[&rate=0.001]","")
                treedata = treedata + line + "\n"
                ntrees += 1

theta0 = torch.zeros((2,ntaxa,ntaxa))
trees = Phylo.parse(StringIO(treedata), "newick")

dists = np.zeros((ntrees,ntaxa,ntaxa))

print(ntrees)

for i,tree in enumerate(trees):
    
    if i % int(ntrees/100) == 0:
        print(i)
        
    for j in range(ntaxa):
        for k in range(j):
            mrca = tree.common_ancestor(str(j+1),str(k+1))
            dists[i,j,k] = min(tree.distance(mrca,str(j+1)),tree.distance(mrca,str(k+1)))

for j in range(ntaxa):
    for k in range(j):
        theta0[0,j,k] = np.mean(dists[:,j,k])
        theta0[1,j,k] = np.var(dists[:,j,k])

# add random noise
if rand_seed > 0:
    theta0 = theta0 + torch.normal(mean=0.0,std=rand_seed*0.1,size=(2,ntaxa,ntaxa))

In [None]:
from scipy.stats import gamma
from scipy.stats import lognorm
from scipy.stats import expon

for j in range(ntaxa):
    for k in range(j-1,j):
        print(j,k)
        print(np.mean([(x != y) and (x in ["c","a","g","t"]) and (y in ["c","a","g","t"]) for x,y in zip(genomes[j],genomes[k])]))
        print(np.mean([(x == y) and (x in ["c","a","g","t"]) and (y in ["c","a","g","t"]) for x,y in zip(genomes[j],genomes[k])]))
        vals = dists[:,j,k]
        mean = np.mean(vals)
        std = np.std(vals)
        
        sigma = np.sqrt(np.log(1 + (std / mean) ** 2))
        mu = np.log(mean) - 0.5 * sigma**2
        
        alpha = (mean / std) ** 2
        beta = std ** 2 / mean
        
        xs = np.linspace(min(vals),max(vals),100)
        y1s = gamma.pdf(xs, a=alpha, scale=beta)
        y2s = lognorm.pdf(xs, s=sigma, scale=np.exp(mu))
        y3s = expon.pdf(xs, scale=mean)
        
        plt.hist(dists[:,j,k],bins=100,density=True)
        plt.plot(xs,y1s)
        plt.plot(xs,y2s)
        plt.plot(xs,y3s)
        plt.legend(["gamma","lognorm","exp","hist"])
        plt.show()

In [None]:
# plot theta0
m = deepcopy(theta0[0])
m[np.triu_indices(m.shape[0])] = np.nan
plt.imshow(m)#,vmin = -10, vmax = -4)
plt.title("means of coalscent times")
plt.colorbar()
plt.show()

var = deepcopy(theta0[1])
var[np.triu_indices(var.shape[0])] = np.nan
plt.imshow(np.log(var))#,vmin=-4,vmax=1)
plt.title("log of variance of coalscent times")
plt.colorbar()
plt.show()

plt.scatter(m.flatten(),var.flatten())

In [None]:
optim_nrm = VIPR(genomes,theta0[0],theta0[1],var_dist="LogNormal",
                 phi_pop_size=torch.tensor([5.0]),var_dist_pop_size="Fixed",
                 theta_pop_size=None,prior_pop_size="Fixed",
                 tip_dates=None,#torch.tensor([float(i) for i in range(11)] + [0.0 for _ in range(16)]),
                 phi_rate=torch.tensor([1.0]),var_dist_rate="Fixed",
                 theta_rate=None,prior_rate="Fixed")

# keep fixed values
decay = "exp"
batch_size = 10
max_iters = 2000#200000
record_every = 100
test_batch_size = 100
if decay == "linear":
    linear_decay = True
else:
    linear_decay = False
anneal_freq = 1
anneal_rate = 0.1**(1.0/max_iters)
pop_size = 5.0
max_time = 2.0 # HOURS

alpha = 0.03

optim_nrm.learn(batch_size=batch_size,
                iters=max_iters,
                alpha=alpha,
                method="reinforce",
                record_every=record_every,
                test_batch_size=test_batch_size,
                pop_size=pop_size,
                anneal_freq=anneal_freq,
                anneal_rate=anneal_rate,
                linear_decay=linear_decay,
                max_time=max_time)

In [None]:
m = deepcopy(optim_nrm.phi[0].detach().numpy())
m[np.triu_indices(m.shape[0])] = np.nan
plt.imshow(m)#,vmin = -10, vmax = -4)
plt.title("means of log coalscent times")
plt.colorbar()
plt.show()

var = deepcopy(optim_nrm.phi[1].detach().numpy())
var[np.triu_indices(var.shape[0])] = np.nan
plt.imshow(var)#,vmin=-4,vmax=1)
plt.title("log of variance of log coalscent times")
plt.colorbar()
plt.show()

In [None]:
optim_exp = VIPR(genomes,theta0[0],theta0[1],var_dist="Exponential",
                 phi_pop_size=torch.tensor([5.0]),var_dist_pop_size="Fixed",
                 theta_pop_size=None,prior_pop_size="Fixed",
                 tip_dates=None,#torch.tensor([float(i) for i in range(11)] + [0.0 for _ in range(16)]),
                 phi_rate=torch.tensor([1.0]),var_dist_rate="Fixed",
                 theta_rate=None,prior_rate="Fixed")

# keep fixed values
decay = "exp"
batch_size = 10
max_iters = 2000#200000
record_every = 100
test_batch_size = 100
if decay == "linear":
    linear_decay = True
else:
    linear_decay = False
anneal_freq = 1
anneal_rate = 0.1**(1.0/max_iters)
pop_size = 5.0
max_time = 2.0 # HOURS

alpha = 0.03

optim_exp.learn(batch_size=batch_size,
                iters=max_iters,
                alpha=alpha,
                method="reinforce",
                record_every=record_every,
                test_batch_size=test_batch_size,
                pop_size=pop_size,
                anneal_freq=anneal_freq,
                anneal_rate=anneal_rate,
                linear_decay=linear_decay,
                max_time=max_time)

In [None]:
log_lamb = deepcopy(optim_exp.phi[0].detach().numpy())
log_lamb[np.triu_indices(log_lamb.shape[0])] = np.nan
plt.imshow(-log_lamb)#,vmin=-4,vmax=1)
plt.title("log-scale parameter of coalscent times")
plt.colorbar()
plt.show()

In [None]:
optim_mix = VIPR(genomes,theta0[0],theta0[1],var_dist="Mixture",
                 phi_pop_size=torch.tensor([5.0]),var_dist_pop_size="Fixed",
                 theta_pop_size=None,prior_pop_size="Fixed",
                 tip_dates=None,#torch.tensor([float(i) for i in range(11)] + [0.0 for _ in range(16)]),
                 phi_rate=torch.tensor([1.0]),var_dist_rate="Fixed",
                 theta_rate=None,prior_rate="Fixed")

logit_pis = 10.0*torch.ones(ntaxa,ntaxa)
mus = optim_nrm.phi[0]
log_sigs = optim_nrm.phi[1]
log_lambs = optim_exp.phi[0]
optim_mix.phi = nn.Parameter(torch.stack([logit_pis,mus,log_sigs,log_lambs]))

In [None]:
# keep fixed values
decay = "exp"
batch_size = 10
max_iters = 10000#200000
record_every = 100
test_batch_size = 100
if decay == "linear":
    linear_decay = True
else:
    linear_decay = False
anneal_freq = 1
anneal_rate = 0.1**(1.0/max_iters)
pop_size = 5.0
max_time = 2.0 # HOURS

alpha = 0.03

optim_mix.learn(batch_size=batch_size,
                iters=max_iters,
                alpha=alpha,
                method="reinforce",
                record_every=record_every,
                test_batch_size=test_batch_size,
                pop_size=pop_size,
                anneal_freq=anneal_freq,
                anneal_rate=anneal_rate,
                linear_decay=linear_decay,
                max_time=max_time)

In [None]:
# plot theta
pi = deepcopy(optim_mix.phi[0].detach().numpy())
pi[np.triu_indices(pi.shape[0])] = np.nan
plt.imshow(pi,cmap='seismic')#,vmin = -10, vmax = 10)
plt.title("logit of mixture component for LogNormal")
plt.colorbar()
plt.show()

m = deepcopy(optim_mix.phi[1].detach().numpy())
m[np.triu_indices(m.shape[0])] = np.nan
plt.imshow(m)#,vmin = -10, vmax = -4)
plt.title("means of log coalscent times")
plt.colorbar()
plt.show()

var = deepcopy(optim_mix.phi[2].detach().numpy())
var[np.triu_indices(var.shape[0])] = np.nan
plt.imshow(var)#,vmin=-4,vmax=1)
plt.title("log of variance of log coalscent times")
plt.colorbar()
plt.show()

lamb = deepcopy(optim_mix.phi[3].detach().numpy())
lamb[np.triu_indices(lamb.shape[0])] = np.nan
plt.imshow(lamb)#,vmin=-4,vmax=1)
plt.title("rate parameter of log coalscent times")
plt.colorbar()
plt.show()

plt.scatter(m.flatten(),var.flatten())

In [None]:
out_file = '../results/'+dataset+'/'+dataset+'_'+method+'_'+str(alpha)+'_'+str(rand_seed)+'_'+time+'.pickle'
with open(out_file, 'wb') as file:
    pickle.dump(optim_mix, file)