-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcl.py
283 lines (217 loc) · 8.95 KB
/
mcl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import numpy as np
from scipy.sparse import isspmatrix, dok_matrix, csc_matrix
import sklearn.preprocessing
from .utils import MessagePrinter
def sparse_allclose(a, b, rtol=1e-5, atol=1e-8):
"""
Version of np.allclose for use with sparse matrices
"""
c = np.abs(a - b) - rtol * np.abs(b)
# noinspection PyUnresolvedReferences
return c.max() <= atol
def normalize(matrix):
"""
Normalize the columns of the given matrix
:param matrix: The matrix to be normalized
:returns: The normalized matrix
"""
return sklearn.preprocessing.normalize(matrix, norm="l1", axis=0)
def inflate(matrix, power):
"""
Apply cluster inflation to the given matrix by raising
each element to the given power.
:param matrix: The matrix to be inflated
:param power: Cluster inflation parameter
:returns: The inflated matrix
"""
if isspmatrix(matrix):
return normalize(matrix.power(power))
return normalize(np.power(matrix, power))
def expand(matrix, power):
"""
Apply cluster expansion to the given matrix by raising
the matrix to the given power.
:param matrix: The matrix to be expanded
:param power: Cluster expansion parameter
:returns: The expanded matrix
"""
if isspmatrix(matrix):
return matrix ** power
return np.linalg.matrix_power(matrix, power)
def add_self_loops(matrix, loop_value):
"""
Add self-loops to the matrix by setting the diagonal
to loop_value
:param matrix: The matrix to add loops to
:param loop_value: Value to use for self-loops
:returns: The matrix with self-loops
"""
shape = matrix.shape
assert shape[0] == shape[1], "Error, matrix is not square"
if isspmatrix(matrix):
new_matrix = matrix.todok()
else:
new_matrix = matrix.copy()
for i in range(shape[0]):
new_matrix[i, i] = loop_value
if isspmatrix(matrix):
return new_matrix.tocsc()
return new_matrix
def prune(matrix, threshold):
"""
Prune the matrix so that very small edges are removed.
The maximum value in each column is never pruned.
:param matrix: The matrix to be pruned
:param threshold: The value below which edges will be removed
:returns: The pruned matrix
"""
if isspmatrix(matrix):
pruned = dok_matrix(matrix.shape)
pruned[matrix >= threshold] = matrix[matrix >= threshold]
pruned = pruned.tocsc()
else:
pruned = matrix.copy()
pruned[pruned < threshold] = 0
# keep max value in each column. same behaviour for dense/sparse
num_cols = matrix.shape[1]
row_indices = matrix.argmax(axis=0).reshape((num_cols,))
col_indices = np.arange(num_cols)
pruned[row_indices, col_indices] = matrix[row_indices, col_indices]
return pruned
def converged(matrix1, matrix2):
"""
Check for convergence by determining if
matrix1 and matrix2 are approximately equal.
:param matrix1: The matrix to compare with matrix2
:param matrix2: The matrix to compare with matrix1
:returns: True if matrix1 and matrix2 approximately equal
"""
if isspmatrix(matrix1) or isspmatrix(matrix2):
return sparse_allclose(matrix1, matrix2)
return np.allclose(matrix1, matrix2)
def iterate(matrix, expansion, inflation):
"""
Run a single iteration (expansion + inflation) of the mcl algorithm
:param matrix: The matrix to perform the iteration on
:param expansion: Cluster expansion factor
:param inflation: Cluster inflation factor
"""
# Expansion
matrix = expand(matrix, expansion)
# Inflation
matrix = inflate(matrix, inflation)
return matrix
def delete_overlap(matrix, clusters):
"""
Deletes duplicates of ovelapping nodes in the clusters
:param matrix: The matrix produced by the MCL algorithm
:param clusters: A list of clusters produced by get_clusters
:returns: A list clusters without duplicates of the overlapping nodes
"""
clusters_total_size = sum(len(c) for c in clusters)
if matrix.shape[0] < clusters_total_size:
# checks for overlaping
printer = MessagePrinter(True)
printer.print("Clustering contains overlapping, to enable soft clustering set keep_overlap to True")
# set of all nodes
nodes = set(range(matrix.shape[0]))
# remove the overlapping nodes
for n, cluster in enumerate(clusters):
cluster = set(cluster)
if not cluster.issubset(nodes):
cluster = nodes.intersection(cluster)
clusters[n] = tuple(cluster)
nodes -= cluster
# getting ride of empty clusters
clusters = [c for c in clusters if len(c) > 0]
return clusters
def get_clusters(matrix, keep_overlap=False):
"""
Retrieve the clusters from the matrix
:param matrix: The matrix produced by the MCL algorithm
:param keep_overlap: If true, enables soft clustering
:returns: A list of tuples where each tuple represents a cluster and
contains the indices of the nodes belonging to the cluster
"""
if not isspmatrix(matrix):
# cast to sparse so that we don't need to handle different
# matrix types
matrix = csc_matrix(matrix)
# get the attractors - non-zero elements of the matrix diagonal
attractors = matrix.diagonal().nonzero()[0]
# somewhere to put the clusters
clusters = set()
# the nodes in the same row as each attractor form a cluster
for attractor in attractors:
cluster = tuple(matrix.getrow(attractor).nonzero()[1].tolist())
clusters.add(cluster)
# converting it to a list
clusters = sorted(list(clusters))
clusters = delete_overlap(matrix, clusters) if keep_overlap is False else clusters
return clusters
def run_mcl(matrix, expansion=2, inflation=2, loop_value=1,
iterations=100, pruning_threshold=0.001, pruning_frequency=1,
convergence_check_frequency=1, verbose=False):
"""
Perform MCL on the given similarity matrix
:param matrix: The similarity matrix to cluster
:param expansion: The cluster expansion factor
:param inflation: The cluster inflation factor
:param loop_value: Initialization value for self-loops
:param iterations: Maximum number of iterations
(actual number of iterations will be less if convergence is reached)
:param pruning_threshold: Threshold below which matrix elements will be set
set to 0
:param pruning_frequency: Perform pruning every 'pruning_frequency'
iterations.
:param convergence_check_frequency: Perform the check for convergence
every convergence_check_frequency iterations
:param verbose: Print extra information to the console
:returns: The final matrix
"""
assert expansion > 1, "Invalid expansion parameter"
assert inflation > 1, "Invalid inflation parameter"
assert loop_value >= 0, "Invalid loop_value"
assert iterations > 0, "Invalid number of iterations"
assert pruning_threshold >= 0, "Invalid pruning_threshold"
assert pruning_frequency > 0, "Invalid pruning_frequency"
assert convergence_check_frequency > 0, "Invalid convergence_check_frequency"
printer = MessagePrinter(verbose)
printer.print("-" * 50)
printer.print("MCL Parameters")
printer.print("Expansion: {}".format(expansion))
printer.print("Inflation: {}".format(inflation))
if pruning_threshold > 0:
printer.print("Pruning threshold: {}, frequency: {} iteration{}".format(
pruning_threshold, pruning_frequency, "s" if pruning_frequency > 1 else ""))
else:
printer.print("No pruning")
printer.print("Convergence check: {} iteration{}".format(
convergence_check_frequency, "s" if convergence_check_frequency > 1 else ""))
printer.print("Maximum iterations: {}".format(iterations))
printer.print("{} matrix mode".format("Sparse" if isspmatrix(matrix) else "Dense"))
printer.print("-" * 50)
# Initialize self-loops
if loop_value > 0:
matrix = add_self_loops(matrix, loop_value)
# Normalize
matrix = normalize(matrix)
# iterations
for i in range(iterations):
printer.print("Iteration {}".format(i + 1))
# store current matrix for convergence checking
last_mat = matrix.copy()
# perform MCL expansion and inflation
matrix = iterate(matrix, expansion, inflation)
# prune
if pruning_threshold > 0 and i % pruning_frequency == pruning_frequency - 1:
printer.print("Pruning")
matrix = prune(matrix, pruning_threshold)
# Check for convergence
if i % convergence_check_frequency == convergence_check_frequency - 1:
printer.print("Checking for convergence")
if converged(matrix, last_mat):
printer.print("Converged after {} iteration{}".format(i + 1, "s" if i > 0 else ""))
break
printer.print("-" * 50)
return matrix