Skip to content

Commit

Permalink
Add matching with munkress
Browse files Browse the repository at this point in the history
  • Loading branch information
lidakanari committed Jun 1, 2018
1 parent 1252b86 commit 2dbe26b
Showing 1 changed file with 143 additions and 5 deletions.
148 changes: 143 additions & 5 deletions tmd/Topology/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,32 @@ def get_image_max_diff(Z1, Z2):
return diff


def transform_to_length(ph, direction=False):
'''Transforms a persistence diagram into
a (start_point, length) equivalent diagram.
def transform_to_length(ph, keep_side='end'):
'''Transforms a persistence diagram into a
(start_point, length) equivalent diagram or a
(end, length) diagram depending on keep_side option.
Note: the direction of the diagram will be lost!
'''
if not direction:
return [[i[0], np.abs(i[1] - i[0])] for i in ph]
if keep_side == 'start':
# keeps the start point and the length of the bar
return [[min(i), np.abs(i[1] - i[0])] for i in ph]
else:
# keeps the end point and the length of the bar
return [[max(i), np.abs(i[1] - i[0])] for i in ph]


def transform_from_length(ph, keep_side='end'):
'''Transforms a persistence diagram into a
(start_point, length) equivalent diagram or a
(end, length) diagram depending on keep_side option.
Note: the direction of the diagram will be lost!
'''
if keep_side == 'start':
# keeps the start point and the length of the bar
return [[i[0], i[1] - i[0]] for i in ph]
else:
# keeps the end point and the length of the bar
return [[i[0] - i[1], i[0]] for i in ph]


def average_image(ph_list, xlims=None, ylims=None, norm_factor=None, **kwargs):
Expand All @@ -264,3 +282,123 @@ def average_image(ph_list, xlims=None, ylims=None, norm_factor=None, **kwargs):
average_imgs = average_imgs / len(imgs_list)

return average_imgs


def _marriage_problem(women_preferences, men_preferences):
'''Matches N women to M men so that max(M, N)
are coupled to their preferred choice that is available
See https://en.wikipedia.org/wiki/Stable_marriage_problem
'''
N = len(women_preferences)
M = len(men_preferences)

swapped = False

if M > N:
swap = women_preferences
women_preferences = men_preferences
men_preferences = swap
N = len(women_preferences)
M = len(men_preferences)
swapped = True

free_women = range(N)
free_men = range(M)

couples = {x: None for x in xrange(N)} # woman first, then current husband

count = 0

while len(free_men) > 0:
m = free_men.pop()
choice = men_preferences[m].pop(0)

if choice in free_women:
couples[choice] = m
free_women.remove(choice)
else:
current = np.where(np.array(women_preferences)[choice] == couples[choice])[0][0]
tobe = np.where(np.array(women_preferences)[choice] == m)[0][0]
if current < tobe:
free_men.append(couples[choice])
couples[choice] = m
else:
free_men.append(m)

if swapped:
return {couples[k]: k for k in couples}

return couples


def symmetric(p):
'''Returns the symmetric point of a PD point on the diagonal
'''
return [(p[0] + p[1]) / 2., (p[0] + p[1]) / 2]

def match_diagrams_marriage_probl(p1, p2, plot=False):
'''Returns a list of matching components
'''
from scipy.spatial.distance import cdist
if plot:
import view
fig, ax = view.common.get_figure(new_fig=True, subplot=111)

p1_enhanced = p1 + [symmetric(i) for i in p2]
p2_enhanced = p2 + [symmetric(i) for i in p1]

D = cdist(p1_enhanced, p2_enhanced)

first_distances = cdist(p1_enhanced, p2_enhanced)
second_distances = cdist(p2_enhanced, p1_enhanced)

first_pref = [np.argsort(k).tolist() for k in first_distances]
second_pref = [np.argsort(k).tolist() for k in second_distances]

indices = _marriage_problem(first_pref, second_pref)

pairs = []

for c in church:
pair = [p1_enhanced[c], p2_enhanced[church[c]]]
if plot:
view.common.plt.scatter(pair[0][0], pair[0][1], color='r')
view.common.plt.scatter(pair[1][0], pair[1][1], color='b')
view.common.plt.plot(np.array(pair)[:,0], np.array(pair)[:,1], color='black')
pairs.append(pair)

total_length = np.sum([np.linalg.norm([i[0][0]-i[1][0], i[0][1]-i[1][1]]) for i in pairs])

return pairs, total_length


def matching_munkres(p1, p2, plot=False):
'''finds a matching based on munkress module.
'''
from scipy.spatial.distance import cdist
import munkres
import pylab as plt

p1_enh = p1 + [symmetric(i) for i in p2]
p2_enh = p2 + [symmetric(i) for i in p1]

D = cdist(p1_enh, p2_enh)

m = munkres.Munkres()
indices = m.compute(D)

if plot:
fig = plt.figure()
for i,j in indices:
plt.plot((p1_enh[i][0], p2_enh[j][0]), (p1_enh[i][1], p2_enh[j][1]), color='black')
plt.scatter(p1_enh[i][0], p1_enh[i][1], c='r')
plt.scatter(p2_enh[j][0], p2_enh[j][1], c='b')

D = cdist(p1_enh, p2_enh)
ssum = 0
for i,j in indices:
ssum = ssum + D[i][j]

return indices, ssum


0 comments on commit 2dbe26b

Please sign in to comment.