forked from awillats/hmm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
124 lines (102 loc) · 3.35 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/// \brief basic usage of Hidden Markov Model code
///
/// \file
/// "hmm" is a simple set of hidden Markov model (HMM) code.
/// It decodes latent state-switches
/// from a categorical signal (usually a binary spike-train)
/// \author Adam Willats
/// \date 2/21/19
// /TODO(awillats): allow the main method to accept state size for better testing
// /TODO(awillats):output metrics of success, namely % accuracy, maybe also compute time
#include <iostream>
#include <vector>
#include <hmm_h/hmm_vec.hpp>
#include <hmm_h/printFuns.hpp>
#include <hmm_h/shuttleFuns.hpp>
//#include "legacy/dataFuns.h"
/**
* This is the main() for testing basic HMM functionality.
* Builds an HMM, generates a sequence, estimates states, then prints them to cout
* @param argv optional args:{nStates, nSamples}
* @return 0 if successful
*/
int main(int argc, const char *argv[]) {
/*
std::vector<double> trs = {0.1,0.1}; //transition rates
std::vector<double> frs = {.1,.9}; //firing rates
std::vector<double> pis = {.1,.9}; //initial state probabilities
*/
int nt = 175;//1e3;
int ntMaxPrint = 1e3;
int nrep = 0; // 1e3;
// HMMv myHMM = HMMv(2,2, trs, frs, pis);
std::cout << "argc" << argc << '\n';
//default HMM params
std::vector<std::vector<double>> trs = {{0.9,0.1},{.1,.9}};
std::vector<std::vector<double>> frs = {{0.9,0.1},{.2,.8}};
//std::vector<std::vector<double>> frs = {{0.9,0.05,0.05},{.1,0.45,.45}};
std::vector<double> pis = {.1, .9};
int nStates = 2;
int nEmission = 2;
int printMode = 1;
//Override HMM params if we have input from the console
//For a simplified debugging nStates == nEmissions here, but that doesn't have to be the case generally
//Also, for easy in-console visualization, I've chosen parameters such that the most common output
//For each state matches the identity of that state
if (argc>1)
{
std::cout << "argv0: " << argv[1] << '\n';
int nInput = std::stoi(argv[1]);
switch (nInput)
{
case 2:
break;
case 3:
trs = {
{0.8, 0.1, 0.1},
{.1, .8, .1},
{.1, .1, .8}};
frs = {
{0.9, 0.05, 0.05},
{.15, 0.7, .15},
{.1, 0.1, .8}};
pis = {.1, .9, .1};
nStates = 3;
nEmission = 3;
printMode = 3;
break;
//default:
}
}
if (argc>2)
{
nt = std::stoi(argv[2]);
}
std::cout << "nStates: "<<nStates << '\n';
/*
*/
HMMv myHMM = HMMv(nStates, nEmission, trs, frs, pis);
// myHMM.printMyParams();
myHMM.genSeq(nt);
myHMM.printSeqs(printMode);
int* vguess = viterbi(myHMM, myHMM.spikes, nt);
for (int i=0;i<nrep;i++)
{
myHMM.genSeq(nt);
viterbi(myHMM, myHMM.spikes, nt);
}
int spkAry[nt];
int stateAry[nt];
int vguess2[nt];
myHMM.exportSeqsGuess(nt,spkAry,stateAry,vguess2);
//print percent of states which are 0. Just as a sanity check that transition
//probabilities are reasonably implemented
int stateSum = std::accumulate(myHMM.states.begin(), myHMM.states.end(), 0);
double stateProb = double(stateSum)/double(nt);
//std::cout<< "\n avg output: "<<stateProb<<"\n";
printVecAsBlock(&vguess[0], myHMM.ntPrint,printMode);
//std::vector<int> v = array2vec(&vguess[0], nt);
/**/
std::cout << " < guessed states \n";
return 0;
}