Skip to content

Commit

Permalink
fix issue #9
Browse files Browse the repository at this point in the history
  • Loading branch information
scottgigante committed Jun 12, 2018
1 parent 62acd69 commit f20b014
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
23 changes: 13 additions & 10 deletions graphtools/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,8 @@ def build_kernel(self):
if i == j:
# downweight within-batch affinities by beta
Kij = Kij * self.beta
K = set_submatrix(K, self.sample_idx == i,
self.sample_idx == j, Kij)
K = set_submatrix(K, self.sample_idx == self.samples[i],
self.sample_idx == self.samples[j], Kij)
log_complete(
"kernel from sample {} to {}".format(self.samples[i],
self.samples[j]))
Expand All @@ -996,19 +996,21 @@ def symmetrize_kernel(self, K):
# experimental samples to be corrected simultaneously
log_debug("Using gamma symmetrization. "
"Gamma:\n{}".format(self.gamma))
for i in range(len(self.samples)):
for j in range(i, len(self.samples)):
Kij = K[self.sample_idx == i, :][:, self.sample_idx == j]
Kji = K[self.sample_idx == j, :][:, self.sample_idx == i]
for i, sample_i in enumerate(self.samples):
for j, sample_j in enumerate(self.samples):
Kij = K[self.sample_idx == sample_i, :][
:, self.sample_idx == sample_j]
Kji = K[self.sample_idx == sample_j, :][
:, self.sample_idx == sample_i]
Kij_symm = self.gamma[i, j] * \
elementwise_minimum(Kij, Kji.T) + \
(1 - self.gamma[i, j]) * \
elementwise_maximum(Kij, Kji.T)
K = set_submatrix(K, self.sample_idx == i,
self.sample_idx == j, Kij_symm)
K = set_submatrix(K, self.sample_idx == sample_i,
self.sample_idx == sample_j, Kij_symm)
if not i == j:
K = set_submatrix(K, self.sample_idx == j,
self.sample_idx == i, Kij_symm.T)
K = set_submatrix(K, self.sample_idx == sample_j,
self.sample_idx == sample_i, Kij_symm.T)
else:
K = super().symmetrize_kernel(K)
return K
Expand Down Expand Up @@ -1043,6 +1045,7 @@ def build_kernel_to_data(self, Y, gamma=None):
transitions : array-like, [n_samples_y, self.data.shape[0]]
Transition matrix from `Y` to `self.data`
"""
raise NotImplementedError
log_warning("building MNN kernel to gamma is experimental")
if not isinstance(self.gamma, str) and \
not isinstance(self.gamma, numbers.Number):
Expand Down
32 changes: 32 additions & 0 deletions test/test_mnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,38 @@ def test_mnn_with_vector_gamma():
gamma=np.linspace(0, 1, n_sample - 1))


def test_mnn_with_non_zero_indexed_sample_idx():
X, sample_idx = generate_swiss_roll()
G = build_graph(X, sample_idx=sample_idx,
kernel_symm='gamma', gamma=0.5,
n_pca=None, use_pygsp=True)
sample_idx += 1
G2 = build_graph(X, sample_idx=sample_idx,
kernel_symm='gamma', gamma=0.5,
n_pca=None, use_pygsp=True)
assert G.N == G2.N
assert np.all(G.d == G2.d)
assert (G.W != G2.W).nnz == 0
assert (G2.W != G.W).sum() == 0
assert isinstance(G2, graphtools.graphs.MNNGraph)


def test_mnn_with_string_sample_idx():
X, sample_idx = generate_swiss_roll()
G = build_graph(X, sample_idx=sample_idx,
kernel_symm='gamma', gamma=0.5,
n_pca=None, use_pygsp=True)
sample_idx = np.where(sample_idx == 0, 'a', 'b')
G2 = build_graph(X, sample_idx=sample_idx,
kernel_symm='gamma', gamma=0.5,
n_pca=None, use_pygsp=True)
assert G.N == G2.N
assert np.all(G.d == G2.d)
assert (G.W != G2.W).nnz == 0
assert (G2.W != G.W).sum() == 0
assert isinstance(G2, graphtools.graphs.MNNGraph)


#####################################################
# Check kernel
#####################################################
Expand Down

0 comments on commit f20b014

Please sign in to comment.