In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
import pickle
import datetime

from scipy.cluster.hierarchy import linkage
from itertools import combinations

from copy import deepcopy
from matplotlib import pyplot as plt
from collections import Counter
from io import StringIO
from Bio import Phylo

from VIPR import VIPR

import torch
import torch.nn as nn
torch.set_default_dtype(torch.float32)

# Set Parameters

In [3]:
dataset = "DS1" #["DS1","DS2","DS3","DS4","DS5","DS6","DS7","DS8","DS9","DS10","DS11","COV"]
method = "reinforce" #["reparam","reinforce","VIMCO"]
alpha = 0.01 #[0.03,0.01,0.003,0.001]

rand_seed = 0

np.random.seed(rand_seed)
torch.manual_seed(rand_seed)

time = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
data_file = '../dat/'+dataset+'/'+dataset+'.pickle'
out_file = '../results/'+dataset+'/'+dataset+'_'+var_dist+'_'+method+'_'+str(alpha)+'_'+str(rand_seed)+'_'+time+'.pickle'

In [4]:
# model parameters
var_dist = "LogNormal" #["LogNormal","Exponential","Mixture"]
pop_size = 5.0 # effective popuation size
rate = 1.0 # rate of evolution

# optimization parameters
batch_size = 10
max_iters = 10000 
max_time = 1.0 # HOURS
record_every = 100
test_batch_size = 100

# decay rate parameters
decay = "exp" # how to decay the learning rate
if decay == "linear":
    linear_decay = True
else:
    linear_decay = False
lr_decay_freq = 1
lr_decay_rate = 0.01**(1.0/max_iters)

# Import Data

In [5]:
with open(data_file, 'rb') as f:
    ds = pickle.load(f)
    
genomes = []
species = []
for key in ds:
    genomes.append(ds[key])
    species.append(key)

ntaxa = len(species)

# Initialize $\theta$

In [6]:
treedata = ""
ntrees = 0
burnin = 100 + 2*ntaxa

for i in range(10):
    tree_file = "../dat/"+dataset+"/"+dataset+"_fixed_pop_support_short_run_rep_%d.trees"%(i+1)
    with open(tree_file, "r") as file:
        for j,line in enumerate(file):
            if j > burnin and j%10 == 0 and line.startswith("tree STATE"):
                line = line[line.find('('):]
                line = line.replace("[&rate=1.0]","")
                line = line.replace("[&rate=0.001]","")
                treedata = treedata + line + "\n"
                ntrees += 1

theta0 = torch.zeros((2,ntaxa,ntaxa))
trees = Phylo.parse(StringIO(treedata), "newick")

dists = np.zeros((ntrees,ntaxa,ntaxa))

for i,tree in enumerate(trees):
    for j in range(ntaxa):
        for k in range(j):
            mrca = tree.common_ancestor(str(j+1),str(k+1))
            dists[i,j,k] = min(tree.distance(mrca,str(j+1)),tree.distance(mrca,str(k+1)))

for j in range(ntaxa):
    for k in range(j):
        theta0[0,j,k] = np.mean(dists[:,j,k])
        theta0[1,j,k] = np.var(dists[:,j,k])

# add random noise
if rand_seed > 0:
    theta0 = theta0 + torch.normal(mean=0.0,std=rand_seed*0.1,size=(2,ntaxa,ntaxa))

In [None]:
# plot theta0
m = deepcopy(theta0[0])
m[np.triu_indices(m.shape[0])] = np.nan
plt.imshow(m)#,vmin = -10, vmax = -4)
plt.title("means of coalscent times")
plt.colorbar()
plt.show()

var = deepcopy(theta0[1])
var[np.triu_indices(var.shape[0])] = np.nan
plt.imshow(np.log(var))#,vmin=-4,vmax=1)
plt.title("log of variance of coalscent times")
plt.colorbar()
plt.show()

# Train Model

In [None]:
optim = VIPR(genomes,theta0[0],theta0[1],var_dist=var_dist,
             phi_pop_size=torch.tensor([pop_size]),var_dist_pop_size="Fixed",
             theta_pop_size=None,prior_pop_size="Fixed",
             tip_dates=None,
             phi_rate=torch.tensor([rate]),var_dist_rate="Fixed",
             theta_rate=None,prior_rate="Fixed")

optim.learn(batch_size=batch_size,
            iters=max_iters,
            alpha=alpha,
            method=method,
            record_every=record_every,
            test_batch_size=test_batch_size,
            pop_size=pop_size,
            lr_decay_freq=lr_decay_freq,
            lr_decay_rate=lr_decay_rate,
            linear_decay=linear_decay,
            max_time=max_time)

In [None]:
m = deepcopy(optim_nrm.phi[0].detach().numpy())
m[np.triu_indices(m.shape[0])] = np.nan
plt.imshow(m)
plt.title("means of log coalscent times")
plt.colorbar()
plt.show()

var = deepcopy(optim_nrm.phi[1].detach().numpy())
var[np.triu_indices(var.shape[0])] = np.nan
plt.imshow(var)
plt.title("log of variance of log coalscent times")
plt.colorbar()
plt.show()

# Save Model

In [20]:
with open(out_file, 'wb') as file:
    pickle.dump(optim_nrm, file)