Skip to content

Commit

Permalink
Basis object: abstract, monomial and bsplines added.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcarbajo committed Dec 27, 2017
1 parent efbc7c0 commit 1a7d860
Showing 1 changed file with 291 additions and 0 deletions.
291 changes: 291 additions & 0 deletions fda/basis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
"""This module defines a basis class for later represeting functional data
objects.
"""

from abc import ABC, abstractmethod

import numpy
import matplotlib.pyplot
import scipy.interpolate

__author__ = "Miguel Carbajo Berrocal"
__email__ = "miguel.carbajo@estudiante.uam.es"


class Basis(ABC):
"""Defines a functional basis.
Attributes:
def_range (tuple): a tuple of length 2 containing the initial and end
values of the interval over which the basis can be evaluated.
nbasis (int): number of functions in the basis.
"""

def __init__(self, def_range=(0, 1), nbasis=1):
"""Basis constructor.
Args:
def_range (tuple, optional): Definition of the interval where the
basis defines a space. Defaults to (0,1).
nbasis: Number of functions that form the basis. Defaults to 1.
"""
# Some checks
if def_range[0] >= def_range[1]:
raise ValueError("The interval {} is not well-defined.".format(
def_range))
if nbasis < 1:
raise ValueError("The number of basis has to be strictly "
"possitive.")
self.def_range = def_range
self.nbasis = nbasis
self._drop_index_lst = []

super().__init__()

@abstractmethod
def compute_matrix(self, eval_points):
"""Computes the basis at a list of values.
Args:
eval_points (array_like): List of points where the basis is
evaluated.
Returns:
(numpy.darray): Matrix whose rows are the values of the each
basis at the values specified in eval_points.
"""
pass

@abstractmethod
def evaluate(self, eval_points):
"""Evaluates the basis at a list of values.
Args:
eval_points (array_like): List of points where the basis is
evaluated.
Returns:
(numpy.darray): Matrix whose rows are the values of the each
basis at the values specified in eval_points.
"""
eval_points = numpy.asarray(eval_points)
if numpy.any(numpy.isnan(eval_points)):
raise ValueError("The list of points where the function is "
"evaluated can not contain nan values.")

# TODO include evaluation at derivatives
return self.compute_matrix(eval_points)

@abstractmethod
def plot(self, **kwargs):
"""Plots the basis object.
Args:
**kwargs: keyword arguments to be passed to the
matplotlib.pyplot.plot function.
Returns:
List of lines that were added to the plot.
"""

# Number of points where the basis are evaluated
npoints = max(501, 10 * self.nbasis)
# List of points where the basis are evaluated
eval_points = numpy.linspace(self.def_range[0], self.def_range[1],
npoints)
# Basis evaluated in the previous list of points
mat = self.evaluate(eval_points)
# Plot
return matplotlib.pyplot.plot(eval_points, mat.T, **kwargs)


class Monomial(Basis):
"""Monomial basis.
Basis formed by powers of the argument :math:`t`:
.. math::
1, t, t^2, t^3...
Examples:
Defines a monomial base over the interval :math:`[0, 5]` consisting
on the first 3 powers of :math:`t`: :math:`1, t, t^2`.
>>> bs_mon = Monomial((0,5), nbasis=3)
And evaluates all the functions in the basis in a list of descrete
values.
>>> bs_mon.evaluate([0, 1, 2])
array([[ 1., 1., 1.],
[ 0., 1., 2.],
[ 0., 1., 4.]])
"""
def __init__(self, def_range=(0, 1), nbasis=1):
super().__init__(def_range, nbasis)

def plot(self, **kwargs):
return super().plot(**kwargs)

def evaluate(self, eval_points):
return super().evaluate(eval_points)

def compute_matrix(self, eval_points):
"""Computes the basis at a list of values.
For each of the basis computes its value for each of the points in
the list passed as argument to the method.
Args:
eval_points (array_like): List of points where the basis is
evaluated.
Returns:
(numpy.darray): Matrix whose rows are the values of the each
basis at the values specified in eval_points.
"""
# Initialise empty matrix
mat = numpy.empty((self.nbasis, len(eval_points)))

# For each basis computes its value for each evaluation point
for i in range(self.nbasis):
mat[i] = eval_points ** i

return mat


class BSpline(Basis):
r"""BSpline basis.
BSpline basis elements are defined recursively as:
.. math::
B_{i, 0}(x) = 1, \textrm{if $t_i \le x < t_{i+1}$, otherwise $0$,}
B_{i, k}(x) = \frac{x - t_i}{t_{i+k} - t_i} B_{i, k-1}(x)
+ \frac{t_{i+k+1} - x}{t_{i+k+1} - t_{i+1}} B_{i+1, k-1}(x)
Where k indicates the order of the spline.
Implementation details: In order to work correctly on the ends of the
interval where the basis is defined is usually necessary to have these
values at the ends as knots several times. This class handles this so that
knots don't have to be duplicated.
Examples:
Constructs specifying number of basis and order.
>>> bss = BSpline(nbasis=8, order=4)
If no order is specified defaults to 4 because cubic splines are
the most used. So the previous example is the same as:
>>> bss = BSpline(nbasis=8)
It is also possible to create a BSpline basis specifying the knots.
>>> bss = BSpline(knots=[0, 0.2, 0.4, 0.6, 0.8, 1])
Once we create a basis we can evaluate each of its functions at a
set of points.
>>> bss = BSpline(nbasis=3, order=3)
>>> bss.evaluate([0, 0.5, 1])
array([[ 1. , 0.25, 0. ],
[ 0. , 0.5 , 0. ],
[ 0. , 0.25, 1. ]])
"""
def __init__(self, def_range=None, nbasis=None, order=4, knots=None):
"""BSpline basis constructor.
Args:
def_range (tuple, optional): Definition of the interval where the
basis defines a space. Defaults to (0,1) if knots are not
specified. If knots are specified defaults to the first and
last element of the knots.
nbasis (int, optional): Number of splines that form the basis.
order (int, optional): Order of the splines. One greater that
their degree. Defaults to 4 which mean cubic splines.
knots (list): List of knots of the splines. If def_range is
specified the first and last elements of the knots have to
match with it.
"""
# Knots default to equally space points in the def_range
if knots is None:
if nbasis is None:
raise ValueError("Must provide either a list of knots or the"
"number of basis.")
if def_range is None:
def_range = (0, 1)
knots = list(numpy.linspace(def_range[0], def_range[1],
nbasis - order + 2))
else:
knots = list(knots)
knots.sort()

# nbasis default to number of knots + order of the splines - 2
if nbasis is None:
nbasis = len(knots) + order - 2
if def_range is None:
def_range = (knots[0], knots[-1])

if def_range[0] != knots[0] or def_range[1] != knots[-1]:
raise ValueError("The ends of the knots must be the same as "
"the def_range.")

# Checks
if nbasis != order + len(knots) - 2:
raise ValueError("The number of basis has to equal the order "
"plus the number of knots minus 2.")

self.order = order
self.knots = knots
super().__init__(def_range, nbasis)

def plot(self, **kwargs):
return super().plot(**kwargs)

def evaluate(self, eval_points):
return super().evaluate(eval_points)

def compute_matrix(self, eval_points):
"""Computes the basis at a list of values.
It uses the scipy implementation of BSplines to compute the values
for each element of the basis.
Args:
eval_points (array_like): List of points where the basis is
evaluated.
Returns:
(numpy.darray): Matrix whose rows are the values of the each
basis at the values specified in eval_points.
"""
# Because of how bsplines are defined is necessary to duplicate the
# values at the ends of the knots as many times as the order.
knots = numpy.array([self.knots[0]] * (self.order - 1) + self.knots
+ [self.knots[-1]] * (self.order - 1))
# c is used the select which spline the function splev below computes
c = numpy.zeros(len(knots))

# Initialise empty matrix
mat = numpy.empty((self.nbasis, len(eval_points)))

# For each basis computes its value for each evaluation point
for i in range(self.nbasis):
# write a 1 in c in the position of the spline calculated in each
# iteration
c[i] = 1
# compute the spline
mat[i] = scipy.interpolate.splev(eval_points, (knots, c,
self.order - 1))
c[i] = 0

return mat

0 comments on commit 1a7d860

Please sign in to comment.