Skip to content
This repository has been archived by the owner on Feb 6, 2024. It is now read-only.

Split ModelUnrest from ModelNonRev #2

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ pll/Makefile
pll/cmake_install.cmake
pll/libpll.a
pll_test/

build/
/Default-clang
1 change: 1 addition & 0 deletions model/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ ratefreeinvar.cpp
modelcodon.cpp
modelmorphology.cpp
modelmixture.cpp
modelunrest.cpp
)
5 changes: 2 additions & 3 deletions model/modelmixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,10 @@ ModelSubst* createModel(string model_str, ModelsBlock *models_block, StateFreqTy
// ((ModelGTR*)model)->readRates(model_params);
// ((ModelGTR*)model)->init(freq_type);
// } else
if (model_str == "UNREST") {
if (ModelNonRev::validModelName(model_str)) {
freq_type = FREQ_ESTIMATE;
//params.optimize_by_newton = false;
tree->optimize_by_newton = false;
model = new ModelNonRev(tree, model_params, count_rates);
model = ModelNonRev::getModelByName(model_str, tree, model_params, count_rates);
((ModelNonRev*)model)->init(freq_type);
} else if (tree->aln->seq_type == SEQ_BINARY) {
model = new ModelBIN(model_str.c_str(), model_params, freq_type, freq_params, tree, count_rates);
Expand Down
72 changes: 40 additions & 32 deletions model/modelnonrev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,18 @@
* 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *
***************************************************************************/
#include "modelnonrev.h"
#include "modelunrest.h"
//#include "whtest/eigen.h"

ModelNonRev::ModelNonRev(PhyloTree *tree, string model_params, bool count_rates)
ModelNonRev::ModelNonRev(PhyloTree *tree)
: ModelGTR(tree, false)
{
num_params = getNumRateEntries() - 1;
// model_parameters must be initialized by subclass
int num_rates = getNumRateEntries();
delete [] rates;
rates = new double [num_params+1];
memset(rates, 0, sizeof(double) * (num_params+1));
if (count_rates)
phylo_tree->aln->computeEmpiricalRateNonRev(rates);
else
for (int i = 0; i <= num_params; i++)
rates[i] = 1.0;

if (model_params != "") {
readRates(model_params);
}
rates = new double [num_rates];
memset(rates, 0, sizeof(double) * (num_rates));

name = "UNREST";
full_name = "Unrestricted model (non-reversible)";
rate_matrix = new double[num_states*num_states];
temp_space = new double[num_states*num_states];
if (!tree->rooted) {
Expand All @@ -51,6 +42,20 @@ void ModelNonRev::freeMem() {
ModelGTR::freeMem();
delete [] temp_space;
delete [] rate_matrix;
delete [] model_parameters;
}

/* static */ ModelNonRev* ModelNonRev::getModelByName(string model_name, PhyloTree *tree, string model_params, bool count_rates) {
if (ModelUnrest::validModelName(model_name)) {
return((ModelNonRev*)new ModelUnrest(tree, model_params, count_rates));
} else {
cerr << "Unrecognized model name " << model_name << endl;
return((ModelNonRev*)NULL);
}
}

/* static */ bool ModelNonRev::validModelName(string model_name) {
return ModelUnrest::validModelName(model_name);
}

void ModelNonRev::getQMatrix(double *rate_mat) {
Expand Down Expand Up @@ -173,9 +178,13 @@ void ModelNonRev::decomposeRateMatrix() {


void ModelNonRev::writeInfo(ostream &out) {
int i, j, k;
out << "Model parameters: " << model_parameters[0];
for (i=0; i < num_params; i++) out << "," << model_parameters[i];
out << endl;

if (num_states != 4) return;
out << "Rate parameters:" << endl;
int i, j, k;
for (i = 0, k = 0; i < num_states; i++) {
switch (i) {
case 0:
Expand Down Expand Up @@ -273,18 +282,12 @@ double ModelNonRev::computeTrans(double time, int state1, int state2) {
}

int ModelNonRev::getNDim() {
int ndim = num_params;
return ndim;
return(num_params);
}

void ModelNonRev::setBounds(double *lower_bound, double *upper_bound, bool *bound_check) {
int i, ndim = getNDim();

for (i = 1; i <= ndim; i++) {
lower_bound[i] = 0.01;
upper_bound[i] = 100.0;
bound_check[i] = false;
}
// I don't know the proper C++ way to handle this: got error if I didn't define something here.
cerr << "setBounds should only be called on subclass of ModelNonRev\n";
}

void ModelNonRev::setVariables(double *variables) {
Expand All @@ -297,15 +300,19 @@ bool ModelNonRev::getVariables(double *variables) {
int nrate = getNDim();
int i;
bool changed = false;
if (nrate > 0) {
for (i = 0; i < nrate; i++)
changed |= (rates[i] != variables[i+1]);
memcpy(rates, variables+1, nrate * sizeof(double));
for (i = 0; i < nrate && !changed; i++) changed = (model_parameters[i] != variables[i+1]);
if (changed) {
memcpy(model_parameters, variables+1, nrate * sizeof(double));
this->setRates();
}

return changed;
}

void ModelNonRev::setRates() {
// I don't know the proper C++ way to handle this: got error if I didn't define something here.
cerr << "setRates should only be called on subclass of ModelNonRev\n";
}

double ModelNonRev::targetFunk(double x[]) {
bool changed = getVariables(x);
// if (state_freq[num_states-1] < 1e-4) return 1.0e+12;
Expand Down Expand Up @@ -358,14 +365,15 @@ double ModelNonRev::optimizeParameters(double gradient_epsilon) {

void ModelNonRev::saveCheckpoint() {
checkpoint->startStruct("ModelNonRev");
CKP_ARRAY_SAVE(num_params+1, rates);
CKP_ARRAY_SAVE(num_params, model_parameters);
checkpoint->endStruct();
ModelSubst::saveCheckpoint();
}

void ModelNonRev::restoreCheckpoint() {
ModelSubst::restoreCheckpoint();
checkpoint->startStruct("ModelNonRev");
CKP_ARRAY_RESTORE(num_params+1, rates);
CKP_ARRAY_RESTORE(num_params, model_parameters);
checkpoint->endStruct();
this->setRates();
}
32 changes: 31 additions & 1 deletion model/modelnonrev.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,17 @@ The general non-reversible model
class ModelNonRev : public ModelGTR
{
public:
ModelNonRev(PhyloTree *tree, string model_params, bool count_rates = true);
ModelNonRev(PhyloTree *tree);

/**
* Return a model of type given by model_name. (Will be some subclass of ModelNonRev.)
*/
static ModelNonRev* getModelByName(string model_name, PhyloTree *tree, string model_params, bool count_rates);

/**
* true if model_name is the name of some known non-reversible model
*/
static bool validModelName(string model_name);

/**
save object into the checkpoint
Expand Down Expand Up @@ -108,6 +118,13 @@ class ModelNonRev : public ModelGTR

protected:

/*
* MDW:
* I strongly object to the function naming here - it completely threw me.
* a 'get' method should copy data from the object, and a 'set' method should
* write data into the object. These are named completely the wrong way around!
*/

/**
this function is served for the multi-dimension optimization. It should pack the model parameters
into a vector that is index from 1 (NOTE: not from 0)
Expand All @@ -123,8 +140,21 @@ class ModelNonRev : public ModelGTR
*/
virtual bool getVariables(double *variables);

/**
* Called from getVariables to update the rate matrix for the new
* model parameters.
*/
virtual void setRates();

virtual void freeMem();

/**
Model parameters - cached so we know when they change, and thus when
recalculations are needed.

*/
double *model_parameters;

/**
unrestricted Q matrix. Note that Q is normalized to 1 and has row sums of 0.
no state frequencies are involved here since Q is a general matrix.
Expand Down
56 changes: 56 additions & 0 deletions model/modelunrest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* modelunrest.cpp
*
* Created on: 24/05/2016
* Author: mdw2
*/

#include "modelunrest.h"

ModelUnrest::ModelUnrest(PhyloTree *tree, string model_params, bool count_rates)
: ModelNonRev(tree)
{
num_params = getNumRateEntries() - 1;
model_parameters = new double [num_params];
for (int i=0; i<= num_params; i++) model_parameters[i] = 1;
this->setRates();
/*
* I'm not sure how to correctly handle count_rates, so for now I'm just
* avoiding the problem. Actual IQTree programmers can fix this.
* Whatever happens should leave model_parameters[] and rates[]
* consistent with each other.
*/
if (count_rates)
cerr << "WARNING: count_rates=TRUE not implemented in ModelUnrest constructor -- ignored" << endl;
/* phylo_tree->aln->computeEmpiricalRateNonRev(rates); */
if (model_params != "") {
cerr << "WARNING: Supplying model params to constructor not yet properly implemented -- ignored" << endl;
// TODO: parse model_params into model_parameters, then call setRates().
}
name = "UNREST";
full_name = "Unrestricted model (non-reversible)";
}

/* static */ bool ModelUnrest::validModelName(string model_name) {
return (model_name == "UNREST");
}

void ModelUnrest::setBounds(double *lower_bound, double *upper_bound, bool *bound_check) {
int i, ndim = getNDim();

for (i = 1; i <= ndim; i++) {
lower_bound[i] = 0.01;
upper_bound[i] = 100.0;
bound_check[i] = false;
}
}

/*
* Set rates from model_parameters
*/
void ModelUnrest::setRates() {
// For UNREST, parameters are simply the off-diagonal rate matrix entries
// (except [4,3] = rates[11], which is constrained to be 1)
memcpy(rates, model_parameters, num_params*sizeof(double));
rates[num_params]=1;
}
22 changes: 22 additions & 0 deletions model/modelunrest.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* modelunrest.h
*
* Created on: 24/05/2016
* Author: Michael Woodhams
*/

#ifndef MODELUNREST_H_
#define MODELUNREST_H_

#include "modelnonrev.h"

class ModelUnrest: public ModelNonRev {
public:
ModelUnrest(PhyloTree *tree, string model_params, bool count_rates);
static bool validModelName(string model_name);
void setBounds(double *lower_bound, double *upper_bound, bool *bound_check);
private:
void setRates();
};

#endif /* MODELUNREST_H_ */
8 changes: 7 additions & 1 deletion ngs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/

#include "ngs.h"
#include "model/modelnonrev.h"
//#include "modeltest_wrapper.h"

/****************************************************************************
Expand Down Expand Up @@ -832,7 +833,12 @@ void reportNGSAnalysis(const char *file_name, Params &params, NGSAlignment &aln,

tree.getModel()->getRateMatrix(rate_param);

if (tree.getModel()->name == "UNREST") {
/*
* This isn't a good way of doing it. Rather somewhere high up in the model heirarchy
* define a bool isSymmetric() method, which is true for time reversible models and
* not true for nonTR models. ! ModelGTR has 'half_matrix' member, which should do the job.
*/
if (ModelNonRev::validModelName(tree.getModel()->name)) {
for (i = 0, k=0; i < aln.num_states; i++)
for (j = 0; j < aln.num_states; j++)
if (i != j)
Expand Down