Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
752 lines (529 sloc) 20.9 KB
"""Provides some utilities widely used by other modules"""
import bisect
import collections
import heapq
import operator
import os.path
import random
import math
import functools
from itertools import chain, combinations
# ______________________________________________________________________________
# Functions on Sequences and Iterables
def sequence(iterable):
"""Coerce iterable to sequence, if it is not already one."""
return (iterable if isinstance(iterable,
else tuple(iterable))
def removeall(item, seq):
"""Return a copy of seq (or string) with all occurrences of item removed."""
if isinstance(seq, str):
return seq.replace(item, '')
return [x for x in seq if x != item]
def unique(seq): # TODO: replace with set
"""Remove duplicate elements from seq. Assumes hashable elements."""
return list(set(seq))
def count(seq):
"""Count the number of items in sequence that are interpreted as true."""
return sum(bool(x) for x in seq)
def product(numbers):
"""Return the product of the numbers, e.g. product([2, 3, 10]) == 60"""
result = 1
for x in numbers:
result *= x
return result
def first(iterable, default=None):
"""Return the first element of an iterable or the next element of a generator; or default."""
return iterable[0]
except IndexError:
return default
except TypeError:
return next(iterable, default)
def is_in(elt, seq):
"""Similar to (elt in seq), but compares with 'is', not '=='."""
return any(x is elt for x in seq)
def mode(data):
"""Return the most common data item. If there are ties, return any one of them."""
[(item, count)] = collections.Counter(data).most_common(1)
return item
def powerset(iterable):
"""powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
s = list(iterable)
return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))[1:]
# ______________________________________________________________________________
# argmin and argmax
identity = lambda x: x
argmin = min
argmax = max
def argmin_random_tie(seq, key=identity):
"""Return a minimum element of seq; break ties at random."""
return argmin(shuffled(seq), key=key)
def argmax_random_tie(seq, key=identity):
"""Return an element with highest fn(seq[i]) score; break ties at random."""
return argmax(shuffled(seq), key=key)
def shuffled(iterable):
"""Randomly shuffle a copy of iterable."""
items = list(iterable)
return items
# ______________________________________________________________________________
# Statistical and mathematical functions
def histogram(values, mode=0, bin_function=None):
"""Return a list of (value, count) pairs, summarizing the input values.
Sorted by increasing value, or if mode=1, by decreasing count.
If bin_function is given, map it over values first."""
if bin_function:
values = map(bin_function, values)
bins = {}
for val in values:
bins[val] = bins.get(val, 0) + 1
if mode:
return sorted(list(bins.items()), key=lambda x: (x[1], x[0]),
return sorted(bins.items())
def dotproduct(X, Y):
"""Return the sum of the element-wise product of vectors X and Y."""
return sum(x * y for x, y in zip(X, Y))
def element_wise_product(X, Y):
"""Return vector as an element-wise product of vectors X and Y"""
assert len(X) == len(Y)
return [x * y for x, y in zip(X, Y)]
def matrix_multiplication(X_M, *Y_M):
"""Return a matrix as a matrix-multiplication of X_M and arbitrary number of matrices *Y_M"""
def _mat_mult(X_M, Y_M):
"""Return a matrix as a matrix-multiplication of two matrices X_M and Y_M
>>> matrix_multiplication([[1, 2, 3],
[2, 3, 4]],
[[3, 4],
[1, 2],
[1, 0]])
[[8, 8],[13, 14]]
assert len(X_M[0]) == len(Y_M)
result = [[0 for i in range(len(Y_M[0]))] for j in range(len(X_M))]
for i in range(len(X_M)):
for j in range(len(Y_M[0])):
for k in range(len(Y_M)):
result[i][j] += X_M[i][k] * Y_M[k][j]
return result
result = X_M
for Y in Y_M:
result = _mat_mult(result, Y)
return result
def vector_to_diagonal(v):
"""Converts a vector to a diagonal matrix with vector elements
as the diagonal elements of the matrix"""
diag_matrix = [[0 for i in range(len(v))] for j in range(len(v))]
for i in range(len(v)):
diag_matrix[i][i] = v[i]
return diag_matrix
def vector_add(a, b):
"""Component-wise addition of two vectors."""
return tuple(map(operator.add, a, b))
def scalar_vector_product(X, Y):
"""Return vector as a product of a scalar and a vector"""
return [X * y for y in Y]
def scalar_matrix_product(X, Y):
"""Return matrix as a product of a scalar and a matrix"""
return [scalar_vector_product(X, y) for y in Y]
def inverse_matrix(X):
"""Inverse a given square matrix of size 2x2"""
assert len(X) == 2
assert len(X[0]) == 2
det = X[0][0] * X[1][1] - X[0][1] * X[1][0]
assert det != 0
inv_mat = scalar_matrix_product(1.0 / det, [[X[1][1], -X[0][1]], [-X[1][0], X[0][0]]])
return inv_mat
def probability(p):
"""Return true with probability p."""
return p > random.uniform(0.0, 1.0)
def weighted_sample_with_replacement(n, seq, weights):
"""Pick n samples from seq at random, with replacement, with the
probability of each element in proportion to its corresponding
sample = weighted_sampler(seq, weights)
return [sample() for _ in range(n)]
def weighted_sampler(seq, weights):
"""Return a random-sample function that picks from seq weighted by weights."""
totals = []
for w in weights:
totals.append(w + totals[-1] if totals else w)
return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
def rounder(numbers, d=4):
"""Round a single number, or sequence of numbers, to d decimal places."""
if isinstance(numbers, (int, float)):
return round(numbers, d)
constructor = type(numbers) # Can be list, set, tuple, etc.
return constructor(rounder(n, d) for n in numbers)
def num_or_str(x):
"""The argument is a string; convert to a number if
possible, or strip it."""
return int(x)
except ValueError:
return float(x)
except ValueError:
return str(x).strip()
def normalize(dist):
"""Multiply each number by a constant such that the sum is 1.0"""
if isinstance(dist, dict):
total = sum(dist.values())
for key in dist:
dist[key] = dist[key] / total
assert 0 <= dist[key] <= 1, "Probabilities must be between 0 and 1."
return dist
total = sum(dist)
return [(n / total) for n in dist]
def norm(X, n=2):
"""Return the n-norm of vector X"""
return sum([x ** n for x in X]) ** (1 / n)
def clip(x, lowest, highest):
"""Return x clipped to the range [lowest..highest]."""
return max(lowest, min(x, highest))
def sigmoid_derivative(value):
return value * (1 - value)
def sigmoid(x):
"""Return activation value of x with sigmoid function"""
return 1 / (1 + math.exp(-x))
def step(x):
"""Return activation value of x with sign function"""
return 1 if x >= 0 else 0
def gaussian(mean, st_dev, x):
"""Given the mean and standard deviation of a distribution, it returns the probability of x."""
return 1 / (math.sqrt(2 * math.pi) * st_dev) * math.e ** (-0.5 * (float(x - mean) / st_dev) ** 2)
try: # math.isclose was added in Python 3.5; but we might be in 3.4
from math import isclose
except ImportError:
def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):
"""Return true if numbers a and b are close to each other."""
return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)
def weighted_choice(choices):
"""A weighted version of random.choice"""
# NOTE: Shoule be replaced by random.choices if we port to Python 3.6
total = sum(w for _, w in choices)
r = random.uniform(0, total)
upto = 0
for c, w in choices:
if upto + w >= r:
return c, w
upto += w
# ______________________________________________________________________________
# Grid Functions
orientations = EAST, NORTH, WEST, SOUTH = [(1, 0), (0, 1), (-1, 0), (0, -1)]
turns = LEFT, RIGHT = (+1, -1)
def turn_heading(heading, inc, headings=orientations):
return headings[(headings.index(heading) + inc) % len(headings)]
def turn_right(heading):
return turn_heading(heading, RIGHT)
def turn_left(heading):
return turn_heading(heading, LEFT)
def distance(a, b):
"""The distance between two (x, y) points."""
xA, yA = a
xB, yB = b
return math.hypot((xA - xB), (yA - yB))
def distance_squared(a, b):
"""The square of the distance between two (x, y) points."""
xA, yA = a
xB, yB = b
return (xA - xB) ** 2 + (yA - yB) ** 2
def vector_clip(vector, lowest, highest):
"""Return vector, except if any element is less than the corresponding
value of lowest or more than the corresponding value of highest, clip to
those values."""
return type(vector)(map(clip, vector, lowest, highest))
# ______________________________________________________________________________
# Misc Functions
class injection():
"""Dependency injection of temporary values for global functions/classes/etc.
E.g., `with injection(DataBase=MockDataBase): ...`"""
def __init__(self, **kwds): = kwds
def __enter__(self):
self.old = {v: globals()[v] for v in}
def __exit__(self, type, value, traceback):
def memoize(fn, slot=None, maxsize=32):
"""Memoize fn: make it remember the computed value for any argument list.
If slot is specified, store result in that slot of first argument.
If slot is false, use lru_cache for caching the values."""
if slot:
def memoized_fn(obj, *args):
if hasattr(obj, slot):
return getattr(obj, slot)
val = fn(obj, *args)
setattr(obj, slot, val)
return val
def memoized_fn(*args):
return fn(*args)
return memoized_fn
def name(obj):
"""Try to find some reasonable name for the object."""
return (getattr(obj, 'name', 0) or getattr(obj, '__name__', 0) or
getattr(getattr(obj, '__class__', 0), '__name__', 0) or
def isnumber(x):
"""Is x a number?"""
return hasattr(x, '__int__')
def issequence(x):
"""Is x a sequence?"""
return isinstance(x,
def print_table(table, header=None, sep=' ', numfmt='{}'):
"""Print a list of lists as a table, so that columns line up nicely.
header, if specified, will be printed as the first row.
numfmt is the format for all numbers; you might want e.g. '{:.2f}'.
(If you want different formats in different columns,
don't use print_table.) sep is the separator between columns."""
justs = ['rjust' if isnumber(x) else 'ljust' for x in table[0]]
if header:
table.insert(0, header)
table = [[numfmt.format(x) if isnumber(x) else x for x in row]
for row in table]
sizes = list(
map(lambda seq: max(map(len, seq)),
list(zip(*[map(str, row) for row in table]))))
for row in table:
str(x), j)(size) for (j, size, x) in zip(justs, sizes, row)))
def open_data(name, mode='r'):
aima_root = os.path.dirname(__file__)
aima_file = os.path.join(aima_root, *['aima-data', name])
return open(aima_file, mode=mode)
def failure_test(algorithm, tests):
"""Grades the given algorithm based on how many tests it passes.
Most algorithms have arbitrary output on correct execution, which is difficult
to check for correctness. On the other hand, a lot of algorithms output something
particular on fail (for example, False, or None).
tests is a list with each element in the form: (values, failure_output)."""
from statistics import mean
return mean(int(algorithm(x) != y) for x, y in tests)
# ______________________________________________________________________________
# Expressions
# See
# See
class Expr(object):
"""A mathematical expression with an operator and 0 or more arguments.
op is a str like '+' or 'sin'; args are Expressions.
Expr('x') or Symbol('x') creates a symbol (a nullary Expr).
Expr('-', x) creates a unary; Expr('+', x, 1) creates a binary."""
def __init__(self, op, *args):
self.op = str(op)
self.args = args
# Operator overloads
def __neg__(self):
return Expr('-', self)
def __pos__(self):
return Expr('+', self)
def __invert__(self):
return Expr('~', self)
def __add__(self, rhs):
return Expr('+', self, rhs)
def __sub__(self, rhs):
return Expr('-', self, rhs)
def __mul__(self, rhs):
return Expr('*', self, rhs)
def __pow__(self, rhs):
return Expr('**', self, rhs)
def __mod__(self, rhs):
return Expr('%', self, rhs)
def __and__(self, rhs):
return Expr('&', self, rhs)
def __xor__(self, rhs):
return Expr('^', self, rhs)
def __rshift__(self, rhs):
return Expr('>>', self, rhs)
def __lshift__(self, rhs):
return Expr('<<', self, rhs)
def __truediv__(self, rhs):
return Expr('/', self, rhs)
def __floordiv__(self, rhs):
return Expr('//', self, rhs)
def __matmul__(self, rhs):
return Expr('@', self, rhs)
def __or__(self, rhs):
"""Allow both P | Q, and P |'==>'| Q."""
if isinstance(rhs, Expression):
return Expr('|', self, rhs)
return PartialExpr(rhs, self)
# Reverse operator overloads
def __radd__(self, lhs):
return Expr('+', lhs, self)
def __rsub__(self, lhs):
return Expr('-', lhs, self)
def __rmul__(self, lhs):
return Expr('*', lhs, self)
def __rdiv__(self, lhs):
return Expr('/', lhs, self)
def __rpow__(self, lhs):
return Expr('**', lhs, self)
def __rmod__(self, lhs):
return Expr('%', lhs, self)
def __rand__(self, lhs):
return Expr('&', lhs, self)
def __rxor__(self, lhs):
return Expr('^', lhs, self)
def __ror__(self, lhs):
return Expr('|', lhs, self)
def __rrshift__(self, lhs):
return Expr('>>', lhs, self)
def __rlshift__(self, lhs):
return Expr('<<', lhs, self)
def __rtruediv__(self, lhs):
return Expr('/', lhs, self)
def __rfloordiv__(self, lhs):
return Expr('//', lhs, self)
def __rmatmul__(self, lhs):
return Expr('@', lhs, self)
def __call__(self, *args):
"Call: if 'f' is a Symbol, then f(0) == Expr('f', 0)."
if self.args:
raise ValueError('can only do a call for a Symbol, not an Expr')
return Expr(self.op, *args)
# Equality and repr
def __eq__(self, other):
"'x == y' evaluates to True or False; does not build an Expr."
return (isinstance(other, Expr)
and self.op == other.op
and self.args == other.args)
def __hash__(self):
return hash(self.op) ^ hash(self.args)
def __repr__(self):
op = self.op
args = [str(arg) for arg in self.args]
if op.isidentifier(): # f(x) or f(x, y)
return '{}({})'.format(op, ', '.join(args)) if args else op
elif len(args) == 1: # -x or -(x + 1)
return op + args[0]
else: # (x - y)
opp = (' ' + op + ' ')
return '(' + opp.join(args) + ')'
# An 'Expression' is either an Expr or a Number.
# Symbol is not an explicit type; it is any Expr with 0 args.
Number = (int, float, complex)
Expression = (Expr, Number)
def Symbol(name):
"""A Symbol is just an Expr with no args."""
return Expr(name)
def symbols(names):
"""Return a tuple of Symbols; names is a comma/whitespace delimited str."""
return tuple(Symbol(name) for name in names.replace(',', ' ').split())
def subexpressions(x):
"""Yield the subexpressions of an Expression (including x itself)."""
yield x
if isinstance(x, Expr):
for arg in x.args:
yield from subexpressions(arg)
def arity(expression):
"""The number of sub-expressions in this expression."""
if isinstance(expression, Expr):
return len(expression.args)
else: # expression is a number
return 0
# For operators that are not defined in Python, we allow new InfixOps:
class PartialExpr:
"""Given 'P |'==>'| Q, first form PartialExpr('==>', P), then combine with Q."""
def __init__(self, op, lhs):
self.op, self.lhs = op, lhs
def __or__(self, rhs):
return Expr(self.op, self.lhs, rhs)
def __repr__(self):
return "PartialExpr('{}', {})".format(self.op, self.lhs)
def expr(x):
"""Shortcut to create an Expression. x is a str in which:
- identifiers are automatically defined as Symbols.
- ==> is treated as an infix |'==>'|, as are <== and <=>.
If x is already an Expression, it is returned unchanged. Example:
>>> expr('P & Q ==> Q')
((P & Q) ==> Q)
if isinstance(x, str):
return eval(expr_handle_infix_ops(x), defaultkeydict(Symbol))
return x
infix_ops = '==> <== <=>'.split()
def expr_handle_infix_ops(x):
"""Given a str, return a new str with ==> replaced by |'==>'|, etc.
>>> expr_handle_infix_ops('P ==> Q')
"P |'==>'| Q"
for op in infix_ops:
x = x.replace(op, '|' + repr(op) + '|')
return x
class defaultkeydict(collections.defaultdict):
"""Like defaultdict, but the default_factory is a function of the key.
>>> d = defaultkeydict(len); d['four']
def __missing__(self, key):
self[key] = result = self.default_factory(key)
return result
class hashabledict(dict):
"""Allows hashing by representing a dictionary as tuple of key:value pairs
May cause problems as the hash value may change during runtime
def __hash__(self):
return 1
# ______________________________________________________________________________
# Queues: Stack, FIFOQueue, PriorityQueue
# Stack and FIFOQueue are implemented as list and collection.deque
# PriorityQueue is implemented here
class PriorityQueue:
"""A Queue in which the minimum (or maximum) element (as determined by f and
order) is returned first.
If order is 'min', the item with minimum f(x) is
returned first; if order is 'max', then it is the item with maximum f(x).
Also supports dict-like lookup."""
def __init__(self, order='min', f=lambda x: x):
self.heap = []
if order == 'min':
self.f = f
elif order == 'max': # now item with max f(x)
self.f = lambda x: -f(x) # will be popped first
raise ValueError("order must be either 'min' or max'.")
def append(self, item):
"""Insert item at its correct position."""
heapq.heappush(self.heap, (self.f(item), item))
def extend(self, items):
"""Insert each item in items at its correct position."""
for item in items:
def pop(self):
"""Pop and return the item (with min or max f(x) value
depending on the order."""
if self.heap:
return heapq.heappop(self.heap)[1]
raise Exception('Trying to pop from empty PriorityQueue.')
def __len__(self):
"""Return current capacity of PriorityQueue."""
return len(self.heap)
def __contains__(self, item):
"""Return True if item in PriorityQueue."""
return (self.f(item), item) in self.heap
def __getitem__(self, key):
for _, item in self.heap:
if item == key:
return item
def __delitem__(self, key):
"""Delete the first occurrence of key."""
self.heap.remove((self.f(key), key))
# ______________________________________________________________________________
# Useful Shorthands
class Bool(int):
"""Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'"""
__str__ = __repr__ = lambda self: 'T' if self else 'F'
T = Bool(True)
F = Bool(False)