Skip to content

Commit

Permalink
partitions: simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
smichr committed Aug 17, 2012
1 parent 35b61cb commit 23cc51e
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 171 deletions.
216 changes: 91 additions & 125 deletions sympy/combinatorics/partitions.py
@@ -1,4 +1,4 @@
from sympy.core import Basic, C, Dict
from sympy.core import Basic, C, Dict, sympify
from sympy.matrices import zeros
from sympy.functions import floor
from sympy.utilities.misc import default_sort_key
Expand Down Expand Up @@ -45,7 +45,7 @@ def __new__(cls, *args):
partition = args[0]

if not all(isinstance(part, list) for part in partition):
raise ValueError("Partition should be a list of lists.")
raise ValueError("Partition should be a list of lists.")

# sort so we have a canonical reference for RGS
partition = sorted(sum(partition, []), key=default_sort_key)
Expand All @@ -54,45 +54,43 @@ def __new__(cls, *args):

obj = C.FiniteSet.__new__(cls, map(lambda x: C.FiniteSet(x), args[0]))
obj.members = tuple(partition)
obj.set_size = len(partition) # should this just be size?
obj.size = len(partition)
return obj

def as_list(self):
"""Return partition as a sorted list of lists.
Examples
========
>>> from sympy.combinatorics.partitions import Partition
>>> Partition([[1], [2, 3]]).as_list()
[[1], [2, 3]]
"""
return sorted(sorted(p) for p in self.args)
def sort_key(self, order=None):
"""Return a canonical key that can be used for sorting.
def next(self):
"""
Generates the next partition.
Ordering is based on the size and sorted elements of the partition
and ties are broken with the rank.
Examples
========
>>> from sympy.utilities.iterables import default_sort_key
>>> from sympy.combinatorics.partitions import Partition
>>> a = Partition([[1, 2], [3, 4, 5]])
>>> a.next()
{{1, 2}, {3, 4}, {5}}
>>> from sympy.abc import x
>>> a = Partition([[1, 2]])
>>> b = Partition([[3, 4]])
>>> c = Partition([[1, x]])
>>> d = Partition([range(4)])
>>> l = [d, b, a + 1, a, c]
>>> l.sort(key=default_sort_key); l
[{{1, 2}}, {{1}, {2}}, {{1, x}}, {{3, 4}}, {{0, 1, 2, 3}}]
"""
return self + 1
if order is None:
members = self.members
else:
members = tuple(sorted(self.members,
key=lambda w: default_sort_key(w, order)))
return self.size, members, self.rank

def previous(self):
"""
Generates the previous partition.
def as_list(self):
"""Return partition as a sorted list of lists.
Examples
========
>>> from sympy.combinatorics.partitions import Partition
>>> a = Partition([[1, 2], [3, 4], [5]])
>>> a.previous()
{{1, 2}, {3, 4, 5}}
>>> Partition([[1], [2, 3]]).as_list()
[[1], [2, 3]]
"""
return self - 1
return sorted(sorted(p) for p in self.args)

def _partition_op(self, other, op=0):
"""
Expand All @@ -104,27 +102,18 @@ def _partition_op(self, other, op=0):
else:
offset = self.rank - other.rank
result = RGS_unrank((offset) %
RGS_enum(self.set_size),
self.set_size)
RGS_enum(self.size),
self.size)
elif isinstance(other, int):
if op == 0:
offset = self.rank + other
else:
offset = self.rank - other
result = RGS_unrank((offset) %
RGS_enum(self.set_size),
self.set_size)
RGS_enum(self.size),
self.size)
return Partition.partition_from_rgs(result, self.members)

def sort_key(self, order=None):
"""Return a canonical key that can be used for sorting."""
if order is None:
members = self.members
else:
members = tuple(sorted(self.members,
key=lambda w: default_sort_key(w, order)))
return self.set_size, self.rank, members

def __add__(self, other):
"""
Routine to increment the rank of self by other's rank or value
Expand Down Expand Up @@ -161,41 +150,6 @@ def __sub__(self, other):
"""
return self._partition_op(other, 1)

def compare(self, other):
"""
Compares two partitions based on rank if they have the same
superset else based on their elements.
Examples
========
>>> from sympy.combinatorics.partitions import Partition
>>> a = Partition([[i] for i in range(3)])
>>> b = Partition([[4]])
>>> a.rank, b.rank
(4, 0)
>>> a < b
True
>>> a = Partition([[1, 2], [3, 4, 5]])
>>> b = Partition([[1], [2, 3], [4], [5]])
>>> a.compare(b)
-1
>>> a.compare(a)
0
>>> b.compare(a)
1
"""
if not isinstance(other, Partition):
raise ValueError('XXX how should the comparison be made?')
if self.members != other.members:
s, o = (w.members for w in (self, other))
else:
s, o = (w.rank for w in (self, other))
if s < o:
return -1
elif s > o:
return 1
return 0

def __le__(self, other):
"""
Checks if a partition is less than or equal to
Expand All @@ -213,10 +167,7 @@ def __le__(self, other):
>>> a <= b
True
"""
try:
return self.compare(other) <= 0
except AssertionError:
return super(Partition, self).__le__(other)
return self.sort_key() <= sympify(other).sort_key()

def __lt__(self, other):
"""
Expand All @@ -232,10 +183,7 @@ def __lt__(self, other):
>>> a < b
True
"""
try:
return self.compare(other) < 0
except AssertionError:
return super(Partition, self).__lt__(other)
return self.sort_key() < sympify(other).sort_key()

@property
def rank(self):
Expand Down Expand Up @@ -272,7 +220,7 @@ def RGS(self):
(1, 2, 3, 4, 5)
>>> a.RGS
[0, 0, 1, 2, 2]
>>> a.next()
>>> a + 1
{{1, 2}, {3}, {4}, {5}}
>>> _.RGS
[0, 0, 1, 2, 3]
Expand Down Expand Up @@ -390,17 +338,20 @@ def __new__(cls, partition, integer=None):
obj.integer = integer
return obj

def next(self):
"""Return the next partition of the integer in rev-lex order.
def prev_lex(self):
"""Return the previous partition of the integer in lexical order.
Examples
========
>>> from sympy.combinatorics.partitions import IntegerPartition
>>> p = IntegerPartition([4])
>>> print p.next()
>>> print p.prev_lex()
[3, 1]
>>> p.as_list() > p.prev_lex().as_list()
True
"""
d = self.as_dict()
d = defaultdict(int)
d.update(self.as_dict())
keys = self._keys
if keys == [1]:
return IntegerPartition({self.integer: 1})
Expand All @@ -412,47 +363,62 @@ def next(self):
d[keys[-1] - 1] = d[1] = 1
else:
d[keys[-2]] -= 1
if keys[-2] == 2:
d[1] += 2
else:
new = keys[-2] - 1
need = d[1] + 1
d[new] = 1
q, r = divmod(need, new)
d[new] += q
d[1] = r
return IntegerPartition(d)
left = d[1] + keys[-2]
new = keys[-2]
d[1] = 0
while left:
new -= 1
if left - new >= 0:
d[new] += left//new
left -= d[new]*new
return IntegerPartition(self.integer, d)

def prev(self):
"""Return the previous partition of the integer in rev-lex order.
def next_lex(self):
"""Return the next partition of the integer in lexical order.
Examples
========
>>> from sympy.combinatorics.partitions import IntegerPartition
>>> p = IntegerPartition([3, 1])
>>> print p.prev()
>>> print p.next_lex()
[4]
>>> p.as_list() < p.next_lex().as_list()
True
"""
d = defaultdict(int)
d.update(self.as_dict())
keys = self._keys
if self._keys == [self.integer]:
return IntegerPartition({1: self.integer})
if len(keys) == 1:
d = {keys[-1] + 1: 1, 1: self.integer - keys[-1] - 1}
elif d[keys[-1]] == 1:
tot = keys[-1] + keys[-2]
new = keys[-2] + 1
d[new] += 1
d[keys[-1]] -= 1
d[keys[-2]] -= 1
r = tot - new
if r:
d[r] += 1
key = self._keys
a = key[-1]
if a == self.integer:
d.clear()
d[1] = self.integer
elif a == 1:
if d[a] > 1:
d[a + 1] += 1
d[a] -= 2
else:
b = key[-2]
d[b + 1] += 1
d[1] = (d[b] - 1)*b
d[b] = 0
else:
new = 2*keys[-1]
d[new] += 1
d[keys[-1]] -= 2
if d[a] > 1:
if len(key) == 1:
d.clear()
d[a + 1] = 1
d[1] = self.integer - a - 1
else:
a1 = a + 1
d[a1] += 1
d[1] = d[a]*a - a1
d[a] = 0
else:
b = key[-2]
b1 = b + 1
d[b1] += 1
need = d[b]*b + d[a]*a - b1
d[a] = d[b] = 0
d[1] = need
return IntegerPartition(self.integer, d)

def as_list(self):
Expand Down Expand Up @@ -513,11 +479,11 @@ def __lt__(self, other):
is listed from smallest to biggest.
>>> from sympy.combinatorics.partitions import IntegerPartition
>>> a = IntegerPartition([4])
>>> a = IntegerPartition([3, 1])
>>> a < a
False
>>> b = a.next()
>>> b < a
>>> b = a.next_lex()
>>> a < b
True
>>> a == b
False
Expand All @@ -535,7 +501,7 @@ def __le__(self, other):
"""
return list(reversed(self.partition)) <= list(reversed(other.partition))

def ferrers_representation(self):
def ferrers_representation(self, char='#'):
"""
Prints the ferrer diagram of a partition.
Expand All @@ -547,7 +513,7 @@ def ferrers_representation(self):
#
#
"""
return "\n".join(['#'*i for i in self.partition])
return "\n".join([char*i for i in self.partition])

def __str__(self):
return str(list(self.partition))
Expand Down

0 comments on commit 23cc51e

Please sign in to comment.