Skip to content

Commit

Permalink
slight change in API: give pairwise potentials separately for each edge.
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller committed Jan 7, 2013
1 parent f3a934c commit 010d74f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
10 changes: 5 additions & 5 deletions crf.cpp
Expand Up @@ -35,8 +35,8 @@ void validate_unaries_edges(PyArrayObject* unaries, PyArrayObject* edges)
PyObject * mrf(PyArrayObject* unaries, PyArrayObject* edges, PyArrayObject* edge_potentials, string alg, size_t verbose) {
// validate input
validate_unaries_edges(unaries, edges);
if (PyArray_NDIM(edge_potentials) != 2)
throw runtime_error("Edge potentials must be n_classes x n_classes");
if (PyArray_NDIM(edge_potentials) != 3)
throw runtime_error("Edge potentials must be n_edges x n_classes x n_classes");
if (PyArray_TYPE(edge_potentials) != PyArray_FLOAT64)
throw runtime_error("Edge potentials must be double.");

Expand All @@ -46,8 +46,8 @@ PyObject * mrf(PyArrayObject* unaries, PyArrayObject* edges, PyArrayObject* edge
int n_vertices = unaries_dims[0];
int n_states = unaries_dims[1];
int n_edges = edges_dims[0];
if ((edge_potential_dims[0] != n_states) or (edge_potential_dims[1] != n_states))
throw runtime_error("Edge potentials must be n_classes x n_classes");
if ((edge_potential_dims[0] != n_edges) || (edge_potential_dims[1] != n_states) || (edge_potential_dims[2] != n_states))
throw runtime_error("Edge potentials must be n_edges x n_classes x n_classes.");
if (verbose > 0)
cout << "n_vertices: " << n_vertices << " n_states: " << n_states << " n_edges: " << n_edges << endl;

Expand All @@ -73,7 +73,7 @@ PyObject * mrf(PyArrayObject* unaries, PyArrayObject* edges, PyArrayObject* edge
Factor pairwise_factor(VarSet(vars[e0], vars[e1]));
for (size_t i = 0; i < n_states; i++)
for(size_t j = 0; j < n_states; j++){
pairwise_factor.set(i + n_states * j, *((double*)PyArray_GETPTR2(edge_potentials, i, j)));
pairwise_factor.set(i + n_states * j, *((double*)PyArray_GETPTR3(edge_potentials, e, i, j)));
}
factors.push_back(pairwise_factor);
}
Expand Down
24 changes: 15 additions & 9 deletions example.py
Expand Up @@ -17,6 +17,8 @@ def compare_algorithms():
unaries = x_noisy.ravel()
unaries = np.c_[np.exp(-unaries), np.exp(unaries)]
pairwise = np.exp(np.eye(2) * 4.1)
# repeat pairwise for each edge
pairwise = np.repeat(pairwise[np.newaxis, :, :], len(edges), axis=0)
algorithms = ["maxprod", "gibbs", "jt", "trw", "treeep"]
fix, axes = plt.subplots(1, len(algorithms))
for ax, alg in zip(axes, algorithms):
Expand Down Expand Up @@ -56,8 +58,10 @@ def example_binary():
result = potts_mrf(unaries, edges, 1.1, verbose=1)
axes[3].matshow(result.reshape(x.shape))

result_mrf = mrf(unaries, edges, np.exp(np.eye(2) * 1.1), verbose=1,
alg="trw")
# repeat pairwise for each edge
pairwise = np.exp(np.eye(2) * 1.1)
pairwise = np.repeat(pairwise[np.newaxis, :, :], len(edges), axis=0)
result_mrf = mrf(unaries, edges, pairwise, verbose=1, alg="trw")
axes[4].set_title("MRF")
axes[4].matshow(result_mrf.reshape(x.shape))
for ax in axes:
Expand All @@ -82,11 +86,13 @@ def example_multinomial():
vert = np.c_[inds[:-1, :].ravel(), inds[1:, :].ravel()]
edges = np.vstack([horz, vert])
result = potts_mrf(unaries_noisy, edges, 1.1)
binaries = np.eye(3) + np.ones((1, 1))
binaries[-1, 0] = 0
binaries[0, -1] = 0
print(binaries)
result_mrf = mrf(unaries_noisy, edges, np.exp(binaries), alg="jt")
pairwise = np.eye(3) + np.ones((1, 1))
pairwise[-1, 0] = 0
pairwise[0, -1] = 0
print(pairwise)
# repeat pairwise for each edge
pairwise = np.repeat(pairwise[np.newaxis, :, :], len(edges), axis=0)
result_mrf = mrf(unaries_noisy, edges, np.exp(pairwise), alg="jt")
plot, axes = plt.subplots(1, 4)
axes[0].set_title("original")
axes[0].matshow(x)
Expand All @@ -102,6 +108,6 @@ def example_multinomial():
plt.show()

if __name__ == "__main__":
#example_binary()
example_binary()
#example_multinomial()
compare_algorithms()
#compare_algorithms()

0 comments on commit 010d74f

Please sign in to comment.