Skip to content

Commit

Permalink
Cleaning up code
Browse files Browse the repository at this point in the history
  • Loading branch information
lidakanari committed Dec 13, 2018
1 parent bfccd1b commit 7705675
Show file tree
Hide file tree
Showing 16 changed files with 458 additions and 1,251 deletions.
101 changes: 101 additions & 0 deletions examples/matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
'''
tmd matching algorithms implementation
'''


def marriage_problem(women_preferences, men_preferences):
'''Matches N women to M men so that max(M, N)
are coupled to their preferred choice that is available
See https://en.wikipedia.org/wiki/Stable_marriage_problem
'''
N = len(women_preferences)
M = len(men_preferences)

swapped = False

if M > N:
swap = women_preferences
women_preferences = men_preferences
men_preferences = swap
N = len(women_preferences)
M = len(men_preferences)
swapped = True

free_women = range(N)
free_men = range(M)

couples = {x: None for x in range(N)} # woman first, then current husband

while len(free_men) > 0:
m = free_men.pop()
choice = men_preferences[m].pop(0)

if choice in free_women:
couples[choice] = m
free_women.remove(choice)
else:
current = np.where(np.array(women_preferences)[choice] == couples[choice])[0][0]
tobe = np.where(np.array(women_preferences)[choice] == m)[0][0]
if current < tobe:
free_men.append(couples[choice])
couples[choice] = m
else:
free_men.append(m)

if swapped:
return [(couples[k], k) for k in couples]

return [(k, couples[k]) for k in couples]


def symmetric(p):
'''Returns the symmetric point of a PD point on the diagonal
'''
return [(p[0] + p[1]) / 2., (p[0] + p[1]) / 2]


def matching_diagrams(p1, p2, plot=False, method='munkres', use_diag=True, new_fig=True, subplot=(111)):
'''Returns a list of matching components
Possible matching methods:
- munkress
- marriage problem
'''
from scipy.spatial.distance import cdist
import munkres
from tmd.view import common as _cm

def plot_matching(p1, p2, indices, new_fig=True, subplot=(111)):
'''Plots matching between p1, p2
for the corresponding indices
'''
import pylab as plt
fig, ax = _cm.get_figure(new_fig=new_fig, subplot=subplot)
for i, j in indices:
ax.plot((p1[i][0], p2[j][0]), (p1[i][1], p2[j][1]), color='black')
ax.scatter(p1[i][0], p1[i][1], c='r')
ax.scatter(p2[j][0], p2[j][1], c='b')

if use_diag:
p1_enh = p1 + [symmetric(i) for i in p2]
p2_enh = p2 + [symmetric(i) for i in p1]
else:
p1_enh = p1
p2_enh = p2

D = cdist(p1_enh, p2_enh)

if method == 'munkres':
m = munkres.Munkres()
indices = m.compute(np.copy(D))
elif method == 'marriage':
first_pref = [np.argsort(k).tolist() for k in cdist(p1_enh, p2_enh)]
second_pref = [np.argsort(k).tolist() for k in cdist(p2_enh, p1_enh)]
indices = marriage_problem(first_pref, second_pref)

if plot:
plot_matching(p1_enh, p2_enh, indices, new_fig=new_fig, subplot=subplot)

ssum = np.sum([D[i][j] for (i, j) in indices])

return indices, ssum

2 changes: 0 additions & 2 deletions tmd/Neuron/Neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
tmd class : Neuron
'''


class Neuron(object):
"""
A Neuron object is a container for Trees
(basal, apical and axon) and a Soma.
"""
from tmd.Neuron.methods import size
from tmd.Neuron.methods import get_section_lengths
from tmd.Neuron.methods import get_bounding_box
from tmd.Neuron.methods import simplify

Expand Down
31 changes: 1 addition & 30 deletions tmd/Neuron/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,18 @@
tmd Neuron's methods
'''


def size(self, neurite_type='all'):
"""
Neuron method to get size.
"""
if neurite_type == 'all':
neurite_list = ['basal', 'axon', 'apical']
else:
neurite_list = list([neurite_type])

s = 0

for neu in neurite_list:

s = s + len(getattr(self, neu))
s = np.sum([len(getattr(self, neu)) for neu in neurite_list])

return int(s)


def get_section_lengths(self, neurite_type='all'):
"""
Neuron method to get section lengths.
"""
import numpy as np

if neurite_type == 'all':
neurite_list = ['basal', 'axon', 'apical']
else:
neurite_list = list([neurite_type])

lengths = []

for neu in neurite_list:

for tree in getattr(self, neu):

lengths = lengths + list(tree.get_section_lengths())

return np.array(lengths)


def get_bounding_box(self):
"""
Input
Expand Down
92 changes: 0 additions & 92 deletions tmd/Population/Population.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,95 +54,3 @@ def append_neuron(self, new_neuron):

if isinstance(new_neuron, Neuron.Neuron):
self.neurons.append(new_neuron)

def extract_ph(self, neurite_type='all', output_folder='./',
feature='radial_distances'):
"""Extract the persistent homology of all
neurites in the population and saves
them in files according to the tree type.
"""
from tmd.Topology.methods import extract_ph as eph
import os

def try_except(tree, ntree, feature, output_folder, ttype='basal'):
'''Try except extract ph from tree.
'''
try:
eph(tree, feature=feature,
output_file=os.path.join(output_folder, ttype + '_' + str(ntree) + '.txt'))
except ValueError:
print(tree)

if neurite_type == 'all':
_ = [try_except(ap, enap, feature, output_folder, ttype='apical')
for enap, ap in enumerate(self.apicals)]

_ = [try_except(ax, enax, feature, output_folder, ttype='axon')
for enax, ax in enumerate(self.axons)]

_ = [try_except(bas, enbas, feature, output_folder, ttype='basal')
for enbas, bas in enumerate(self.basals)]

else:
_ = [try_except(ax, enax, feature, output_folder, ttype=neurite_type)
for enax, ax in enumerate(getattr(self, neurite_type + 's'))]

def extract_ph_names(self, neurite_type='all', output_folder='./',
feature='radial_distances'):
"""Extract the persistent homology of all
neurites in the population and saves
them in files according to the tree type.
"""
from tmd.Topology.methods import extract_ph as eph
import os

def try_except(tree, ntree, feature, output_folder, ttype='basal'):
'''Try except extract ph from tree.
'''
try:
eph(tree, feature=feature,
output_file=os.path.join(output_folder,
ttype + '_' + str(ntree) + '.txt'))
except ValueError:
print(tree)

if neurite_type == 'all':
_ = [[try_except(ap, enap, feature, output_folder,
ttype='apical_' + n.name.split('/')[-1])
for enap, ap in enumerate(n.apical)] for n in self.neurons]

_ = [[try_except(ax, enax, feature, output_folder,
ttype='axon_' + n.name.split('/')[-1])
for enax, ax in enumerate(n.axon)] for n in self.neurons]

_ = [[try_except(bas, enbas, feature, output_folder,
ttype='basal_' + n.name.split('/')[-1])
for enbas, bas in enumerate(n.basal)] for n in self.neurons]

else:
_ = [[try_except(ax, enax, feature, output_folder,
ttype=neurite_type + '_' + n.name.split('/')[-1])
for enax, ax in enumerate(getattr(n, neurite_type + 's'))] for n in self.neurons]

def extract_ph_neurons(self, neurite_type='all', output_folder='./',
feature='radial_distances'):
"""Extract the persistent homology of all
neurites in the population and saves
them in files according to the tree type.
"""
from tmd.Topology.methods import extract_ph as eph
import os

def try_except(neuron, feature, output_folder, ttype=neurite_type):
'''Try except extract ph from tree.
'''
try:
eph(neuron, feature=feature, function='get_ph_neuron',
neurite_type=ttype,
output_file=os.path.join(output_folder, neuron.name + '.txt'))

except ValueError:
print(neuron.name)

_ = [try_except(n, feature, output_folder, ttype=neurite_type)
for n in self.neurons]
15 changes: 1 addition & 14 deletions tmd/Soma/Soma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,13 @@ def __init__(self, x=_np.array([]), y=_np.array([]), z=_np.array([]),
----------
x : numpy array
The x-coordinates of surface trace of neuron soma.
y : numpy array
The y-coordinates of surface trace of neuron soma.
z : numpy array
The z-coordinate of surface trace of neuron soma.
d : numpy array
The diameters of surface trace of neuron soma.
Returns
-------
soma : Soma
tmd Soma object
Reference
---------
----------
"""
import numpy as np

Expand All @@ -49,9 +38,7 @@ def copy_soma(self):
"""
Returns a deep copy of the Soma.
"""

import copy

return copy.deepcopy(self)

def is_equal(self, soma):
Expand Down
7 changes: 1 addition & 6 deletions tmd/Soma/methods.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
'''
tmd Soma's methods
'''

import numpy as np

def get_center(self):
"""
Soma method to get the center of the soma.
"""
import numpy as np

x_center = np.mean(self.x)
y_center = np.mean(self.y)
z_center = np.mean(self.z)
Expand All @@ -20,14 +18,11 @@ def get_diameter(self):
"""
Soma method to get the diameter of the soma.
"""
import numpy as np

if len(self.x) == 1:
diameter = self.d[0]
else:
center = self.get_center()
diameter = np.mean(np.sqrt(np.power(self.x - center[0], 2) +
np.power(self.y - center[1], 2) +
np.power(self.z - center[2], 2)))

return diameter

0 comments on commit 7705675

Please sign in to comment.