Skip to content

Commit

Permalink
Simplify matchings
Browse files Browse the repository at this point in the history
  • Loading branch information
lidakanari committed Jun 1, 2018
1 parent 2dbe26b commit 544e6fd
Showing 1 changed file with 27 additions and 53 deletions.
80 changes: 27 additions & 53 deletions tmd/Topology/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,79 +326,53 @@ def _marriage_problem(women_preferences, men_preferences):
free_men.append(m)

if swapped:
return {couples[k]: k for k in couples}
return [(couples[k], k) for k in couples]

return couples
return [(k, couples[k]) for k in couples]


def symmetric(p):
'''Returns the symmetric point of a PD point on the diagonal
'''
return [(p[0] + p[1]) / 2., (p[0] + p[1]) / 2]

def match_diagrams_marriage_probl(p1, p2, plot=False):
'''Returns a list of matching components
'''
from scipy.spatial.distance import cdist
if plot:
import view
fig, ax = view.common.get_figure(new_fig=True, subplot=111)

p1_enhanced = p1 + [symmetric(i) for i in p2]
p2_enhanced = p2 + [symmetric(i) for i in p1]

D = cdist(p1_enhanced, p2_enhanced)

first_distances = cdist(p1_enhanced, p2_enhanced)
second_distances = cdist(p2_enhanced, p1_enhanced)

first_pref = [np.argsort(k).tolist() for k in first_distances]
second_pref = [np.argsort(k).tolist() for k in second_distances]

indices = _marriage_problem(first_pref, second_pref)

pairs = []

for c in church:
pair = [p1_enhanced[c], p2_enhanced[church[c]]]
if plot:
view.common.plt.scatter(pair[0][0], pair[0][1], color='r')
view.common.plt.scatter(pair[1][0], pair[1][1], color='b')
view.common.plt.plot(np.array(pair)[:,0], np.array(pair)[:,1], color='black')
pairs.append(pair)

total_length = np.sum([np.linalg.norm([i[0][0]-i[1][0], i[0][1]-i[1][1]]) for i in pairs])

return pairs, total_length


def matching_munkres(p1, p2, plot=False):
'''finds a matching based on munkress module.
def matching_diagrams(p1, p2, plot=False, method='munkres'):
'''Returns a list of matching components
Possible matching methods:
- munkress
- marriage problem
'''
from scipy.spatial.distance import cdist
import munkres
import pylab as plt

def plot_matching(p1, p2, indices):
'''Plots matching between p1, p2
for the corresponding indices
'''
import pylab as plt
fig = plt.figure()
for i,j in indices:
plt.plot((p1[i][0], p2[j][0]), (p1[i][1], p2[j][1]), color='black')
plt.scatter(p1[i][0], p1[i][1], c='r')
plt.scatter(p2[j][0], p2[j][1], c='b')

p1_enh = p1 + [symmetric(i) for i in p2]
p2_enh = p2 + [symmetric(i) for i in p1]

D = cdist(p1_enh, p2_enh)

m = munkres.Munkres()
indices = m.compute(D)
if method=='munkres':
m = munkres.Munkres()
indices = m.compute(np.copy(D))
elif method=='marriage':
first_pref = [np.argsort(k).tolist() for k in cdist(p1_enh, p2_enh)]
second_pref = [np.argsort(k).tolist() for k in cdist(p2_enh, p1_enh)]
indices = _marriage_problem(first_pref, second_pref)

if plot:
fig = plt.figure()
for i,j in indices:
plt.plot((p1_enh[i][0], p2_enh[j][0]), (p1_enh[i][1], p2_enh[j][1]), color='black')
plt.scatter(p1_enh[i][0], p1_enh[i][1], c='r')
plt.scatter(p2_enh[j][0], p2_enh[j][1], c='b')
plot_matching(p1_enh, p2_enh, indices)

D = cdist(p1_enh, p2_enh)
ssum = 0
for i,j in indices:
ssum = ssum + D[i][j]
ssum = np.sum([D[i][j] for (i,j) in indices])

return indices, ssum


0 comments on commit 544e6fd

Please sign in to comment.