diff --git a/.coveragerc b/.coveragerc index c7eee00..306da92 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,9 +1,9 @@ [run] -omit = +omit = */tests/* */version.py [report] -omit = +omit = */tests/* */version.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..cd6d83f --- /dev/null +++ b/.flake8 @@ -0,0 +1,41 @@ +[flake8] +ignore = + # No space before colon + E203 + W503 + # Variable name should be lower case. We have some single-letter variables that make more sense to be caps. + N806 + # Class attribute shadows a python builtin -- not much chance of that causing a problem + A003 + # First line should be in imperative mood -- cached_properties don't fit this bill. + D401 + # Missing docstring in public class -- my docstrings are in the __init__ which seems to fail this ? + D101 + # Otherwise it flags down **kwargs in docstrings. + RST210 + N815 + # Allow method names to start and end with __ + N807 + # allow method names to include upper-case characters + N802 + # allow variable names to be upper-case + N803 + # Missing docstring in public method: TODO: this should be removed and docs added! + D102 +max-line-length = 88 +max-complexity = 18 +inline-quotes = double +docstring-convention=numpy +rst-roles = + class + method + func + attr + mod + +# Ignoring F841 allows assigning local variables without using them. This happens +# because we "eval" strings which is not caught by flake8. This should be fixed in +# the tests. +per-file-ignores = + src/linsolve/__init__.py: F401 + tests/test_linsolve.py: F841 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..592b5b2 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,21 @@ +version: 2 +updates: + - package-ecosystem: github-actions + directory: "/" + schedule: + interval: monthly + - package-ecosystem: pip + directory: "/.github/workflows" + schedule: + interval: monthly + - package-ecosystem: pip + directory: "/docs" + schedule: + interval: monthly + - package-ecosystem: pip + directory: "/" + schedule: + interval: monthly + versioning-strategy: lockfile-only + allow: + - dependency-type: "all" diff --git a/.github/labels.yml b/.github/labels.yml new file mode 100644 index 0000000..f502701 --- /dev/null +++ b/.github/labels.yml @@ -0,0 +1,81 @@ +--- +# Labels names are important as they are used by Release Drafter to decide +# regarding where to record them in changelog or if to skip them. +# +# The repository labels will be automatically configured using this file and +# the GitHub Action https://github.com/marketplace/actions/github-labeler. +- name: API breaking + description: Breaking Changes + color: e305fc +- name: good first issue + description: Good for newcomers + color: 7057ff +- name: "status: duplicate" + description: This issue or pull request already exists + color: d93f0b +- name: "status: invalid" + description: Invalid issue or pull request + color: d93f0b +- name: "status: requires discussion" + description: Issue requires further discussion before implementation + color: f71bd2 +- name: "status: requires research" + description: Issue requires further research before implementation + color: "d93f0b" +- name: "status: wontfix" + description: Valid issue, but out of scope for future fixes + color: "d93f0b" +- name: "type: maint: dependencies" + description: Pull requests that update a dependency file + color: "0366d6" +- name: "type: maint: documentation" + description: Improvements or additions to documentation + color: "0075ca" +- name: "type: maint: refactoring" + description: Refactoring + color: ef67c4 +- name: "type: maint: removal" + description: Removals and Deprecations + color: 9ae7ea +- name: "type: maint: style" + description: Style + color: c120e5 +- name: "type: maint: build" + description: Build System and Dependencies + color: bfdadc +- name: "type: accuracy" + description: "Enhancement that improves accuracy" + color: "bfd4f2" +- name: "type: bug" + description: Something isn't working + color: d73a4a +- name: "type: feature: ui" + description: New feature that adds functionality for the user + color: "0e8a16" +- name: "type: feature: physical" + description: New feature that adds new analysis/physical model + color: "0e8a16" +- name: "type: testing" + description: Testing improvements + color: b1fc6f +- name: "type: performance: memory" + description: Performance improvements that reduce memory usage + color: "016175" +- name: "type: performance: cpu" + description: Performance improvements that reduce walltime + color: "016175" +- name: "type: ci" + description: Updates to CI (GH actions, RTD, etc.) + color: "000000" +- name: "type: question" + description: A question about the code or documentation + color: d876e3 +- name: "priority: low" + description: Low priority + color: "BFD4F2" +- name: "priority: medium" + description: Medium priority + color: "E99695" +- name: "priority: high" + description: High priority + color: "B60205" diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml new file mode 100644 index 0000000..45b84c9 --- /dev/null +++ b/.github/release-drafter.yml @@ -0,0 +1,62 @@ +# See https://github.com/marketplace/actions/release-drafter for configuration +categories: + - title: ":boom: Breaking Changes" + label: "API breaking" + - title: ":rocket: Features" + labels: + - "type: feature: ui" + - "type: feature: physical" + - title: ":fire: Removals and Deprecations" + label: "type: maint: removal" + - title: ":beetle: Fixes" + label: "type: bug" + - title: ":racehorse: Performance" + labels: + - "type: performance: memory" + - "type: performance: cpu" + - title: ":rotating_light: Testing" + label: "type: testing" + - title: ":construction_worker: Continuous Integration" + label: "type: ci" + - title: ":books: Documentation" + label: "type: maint: documentation" + - title: ":hammer: Refactoring" + label: "type: maint: refactoring" + - title: ":lipstick: Style" + label: "type: maint: style" + - title: ":package: Dependencies" + labels: + - "type: maint: dependencies" + - "type: maint: build" +name-template: 'v$RESOLVED_VERSION' +tag-template: 'v$RESOLVED_VERSION' +autolabeler: + - label: 'type: maint: documentation' + files: + - '*.md' + branch: + - '/.*docs{0,1}.*/' + - label: 'type: bug' + branch: + - '/fix.*/' + title: + - '/fix/i' + - label: "type: maint: removal" + title: + - "/remove .*/i" + - label: "type: ci" + files: + - '.github/*' + - '.pre-commit-config.yaml' + - '.coveragrc' + - label: "type: maint: style" + files: + - ".flake8" + - label: "type: maint: refactoring" + title: + - "/.* refactor.*/i" + +template: | + ## Changes + + $CHANGES diff --git a/.github/workflows/auto-merge-deps.yml b/.github/workflows/auto-merge-deps.yml new file mode 100644 index 0000000..7dd998e --- /dev/null +++ b/.github/workflows/auto-merge-deps.yml @@ -0,0 +1,14 @@ +name: auto-merge + +on: + pull_request: + +jobs: + auto-merge: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: ahmadnassri/action-dependabot-auto-merge@v2 + with: + target: minor + github-token: ${{ secrets.AUTO_MERGE }} diff --git a/.github/workflows/check-build.yml b/.github/workflows/check-build.yml new file mode 100644 index 0000000..b83e789 --- /dev/null +++ b/.github/workflows/check-build.yml @@ -0,0 +1,27 @@ +name: Check Distribution Build + +on: push + +jobs: + check-build: + name: Twine Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@master + with: + fetch-depth: 0 + + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install Build Tools + run: pip install build twine + + - name: Build a binary wheel + run: | + python -m build . + + - name: Check Distribution + run: | + twine check dist/* diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 0000000..dde342c --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,17 @@ +name: Labeler + +on: push + +jobs: + labeler: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - name: Check out the repository + uses: actions/checkout@v3 + + - name: Run Labeler + uses: crazy-max/ghaction-github-labeler@v4.1.0 + with: + yaml-file: .github/labels.yml diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 8406dd3..3fbeba1 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -1,37 +1,30 @@ name: Publish Python distributions to PyPI on: - push: - tags: - - '*' + release: + types: [published] jobs: build-n-publish: - name: Build and publish to PyPI + name: Make Release on PyPI and Github runs-on: ubuntu-latest - env: - ENV_NAME: publish - PYTHON: "3.10" steps: - - uses: actions/checkout@main + - uses: actions/checkout@master with: fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v3 + - uses: actions/setup-python@v4 with: - python-version: ${{ env.PYTHON }} + python-version: "3.10" - - name: Install build + - name: Install Build Tools run: pip install build - - name: Build a binary wheel and a source tarball + - name: Build a binary wheel run: | - python -m build + python -m build . - name: Publish to PyPI - if: startsWith(github.event.ref, 'refs/tags') uses: pypa/gh-action-pypi-publish@release/v1 with: - user: __token__ password: ${{ secrets.pypi_token }} diff --git a/.gitignore b/.gitignore index bdab5bb..ce747a8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,4 @@ dist/* *.ipynb_checkpoints* coverage.xml *.DS_Store -*/_version.py \ No newline at end of file +*/_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..bb656d2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,57 @@ +exclude: '^docs/conf.py|^devel/|^tests/data/' + +ci: + autoupdate_schedule: monthly + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ['--fix=no'] + +- repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + +- repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 # pick a git hash / tag to point to + hooks: + - id: flake8 + additional_dependencies: + - flake8-quotes + - flake8-comprehensions + - flake8-builtins + - flake8-eradicate + - pep8-naming + - flake8-docstrings + - flake8-rst-docstrings + - flake8-rst + - flake8-copyright + +- repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.10.0 + hooks: + - id: rst-backticks + +- repo: https://github.com/asottile/pyupgrade + rev: v3.9.0 + hooks: + - id: pyupgrade + args: [--py38-plus] diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 6f06b28..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -include *.md -include linsolve/VERSION -include linsolve/GIT_INFO \ No newline at end of file diff --git a/README.md b/README.md index 5184360..6f24f6f 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ See [linsolve_example.ipynb](linsolve_example.ipynb) for a tutorial on how to us Below we give a brief example on the general usage of `LinearSolver`. Assume we have a linear system of equations, with a data vector `y` containing measurements -and a model vector `b` containing parameters we would like to solve for. Let's simplify to +and a model vector `b` containing parameters we would like to solve for. Let's simplify to the problem of fitting a line to three data points, which amounts to solving for a slope and an offset. In this case, our linear system of equations can be written as @@ -82,7 +82,7 @@ issue log after verifying that the issue does not already exist. Comments on existing issues are also welcome. # Installation -Preferred method of installation is `pip install .` +Preferred method of installation is `pip install .` (or `pip install git+https://github.com/HERA-Team/linsolve`). This will install all dependencies. See below for manual management of dependencies. @@ -101,10 +101,9 @@ If you are developing `linsolve`, it is recommended to create a fresh environmen $ conda create -n linsolve python=3 $ conda activate linsolve $ conda env update -n linsolve -f environment.yml - $ pip install -e . - -This will install extra dependencies required for testing/development as well as the + $ pip install -e . + +This will install extra dependencies required for testing/development as well as the standard ones. To run tests, just run `nosetests` in the top-level directory. - diff --git a/pyproject.toml b/pyproject.toml index f0b92c4..4a0331a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,4 +60,4 @@ exclude = ''' | build | dist )/ -''' \ No newline at end of file +''' diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index d63fcf1..0000000 --- a/setup.cfg +++ /dev/null @@ -1,33 +0,0 @@ -[metadata] -name = linsolve -author = HERA Team -url=https://github.com/HERA-Team/linsolve -license=BSD -license_file=LICENSE -long_description = file: README.md -long_description_content_type = text/x-md -description= high-level tools for linearizing and solving systems of equations -classifiers= - Development Status :: 5 - Production/Stable - Intended Audience :: Science/Research - License :: OSI Approved :: BSD License - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.6 - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Topic :: Scientific/Engineering :: Mathematics -keywords= - linear equations - optimal estimation - -[options] -packages = find: -include_package_data= True -package_dir = - =src -install_requires= - numpy>=1.23 - scipy \ No newline at end of file diff --git a/src/linsolve/__init__.py b/src/linsolve/__init__.py index b124f50..a6a4d41 100644 --- a/src/linsolve/__init__.py +++ b/src/linsolve/__init__.py @@ -1,5 +1,4 @@ -from pathlib import Path - +"""Solve linear equations.""" try: from importlib.metadata import PackageNotFoundError, version except ImportError: @@ -12,6 +11,6 @@ __version__ = version(__name__) except PackageNotFoundError: # pragma: no cover # package is not installed - __version__ = 'unknown' + __version__ = "unknown" -from .linsolve import * +from .linsolve import * # noqa diff --git a/src/linsolve/_version.py b/src/linsolve/_version.py deleted file mode 100644 index 45c0656..0000000 --- a/src/linsolve/_version.py +++ /dev/null @@ -1,4 +0,0 @@ -# file generated by setuptools_scm -# don't change, don't track in version control -__version__ = version = '1.0.1.dev8+gc98f991.d20230713' -__version_tuple__ = version_tuple = (1, 0, 1, 'dev8', 'gc98f991.d20230713') diff --git a/src/linsolve/linsolve.py b/src/linsolve/linsolve.py index 05e30ee..30bfc66 100644 --- a/src/linsolve/linsolve.py +++ b/src/linsolve/linsolve.py @@ -1,12 +1,12 @@ -'''Module providing high-level tools for linearizing and finding chi^2 minimizing -solutions to systems of equations. +"""High-level tools for linearizing and finding chi^2 minimizing systems of equations. Solvers: LinearSolver, LogProductSolver, and LinProductSolver. These generally follow the form: -> data = {'a1*x+b1*y': np.array([5.,7]), 'a2*x+b2*y': np.array([4.,6])} -> ls = LinearSolver(data, a1=1., b1=np.array([2.,3]), a2=2., b2=np.array([1.,2])) -> sol = ls.solve() + + >>> data = {'a1*x+b1*y': np.array([5.,7]), 'a2*x+b2*y': np.array([4.,6])} + >>> ls = LinearSolver(data, a1=1., b1=np.array([2.,3]), a2=2., b2=np.array([1.,2])) + >>> sol = ls.solve() where equations are passed in as a dictionary where each key is a string describing the equation (which is parsed according to python syntax) and each @@ -25,24 +25,29 @@ form 'x*y + y*z'. For more detail on usage, see linsolve_example.ipynb -''' +""" -import numpy as np import ast -from scipy.sparse import csc_matrix -import scipy.sparse.linalg -import scipy.linalg from copy import deepcopy from functools import reduce +import numpy as np +import scipy.linalg +import scipy.sparse.linalg +from scipy.sparse import csc_matrix + # Monkey patch for backward compatibility: # ast.Num deprecated in Python 3.8. Make it an alias for ast.Constant # if it gets removed. -if not hasattr(ast, 'Num'): +if not hasattr(ast, "Num"): ast.Num = ast.Constant + def ast_getterms(n): - '''Convert an AST parse tree into a list of terms. E.g. 'a*x1+b*x2' -> [[a,x1],[b,x2]]''' + """Convert an AST parse tree into a list of terms. + + E.g. 'a*x1+b*x2' -> [[a,x1],[b,x2]] + """ if type(n) is ast.Name: return [[n.id]] elif type(n) is ast.Constant or type(n) is ast.Num: @@ -50,8 +55,8 @@ def ast_getterms(n): elif type(n) is ast.Expression: return ast_getterms(n.body) elif type(n) is ast.UnaryOp: - assert(type(n.op) is ast.USub) - return [[-1]+ast_getterms(n.operand)[0]] + assert type(n.op) is ast.USub + return [[-1] + ast_getterms(n.operand)[0]] elif type(n) is ast.BinOp: if type(n.op) is ast.Mult: return [ast_getterms(n.left)[0] + ast_getterms(n.right)[0]] @@ -60,213 +65,264 @@ def ast_getterms(n): elif type(n.op) is ast.Sub: return ast_getterms(n.left) + [[-1] + ast_getterms(n.right)[0]] else: - raise ValueError('Unsupported operation: %s' % str(n.op)) + raise ValueError("Unsupported operation: %s" % str(n.op)) else: - raise ValueError('Unsupported: %s' % str(n)) + raise ValueError("Unsupported: %s" % str(n)) + def get_name(s, isconj=False): - '''Parse variable names of form 'var_' as 'var' + conjugation.''' + """Parse variable names of form ``var_`` as ``var + conjugation``.""" if not type(s) is str: - if isconj: return str(s), False - else: return str(s) - if isconj: return s.rstrip('_'), s.endswith('_') # tag names ending in '_' for conj - else: return s.rstrip('_') # parse 'name_' as 'name' + conj + if isconj: + return str(s), False + else: + return str(s) + if isconj: + return s.rstrip("_"), s.endswith("_") # tag names ending in '_' for conj + else: + return s.rstrip("_") # parse 'name_' as 'name' + conj class Constant: - '''Container for constants (which can be arrays) in linear equations.''' + """Container for constants (which can be arrays) in linear equations.""" + def __init__(self, name, constants): self.name = get_name(name) - if type(name) is str: + if type(name) is str: self.val = constants[self.name] - else: + else: self.val = name - try: + try: self.dtype = self.val.dtype - except(AttributeError): + except AttributeError: self.dtype = type(self.val) + def shape(self): + """Return the shape of the constants.""" try: return self.val.shape - except(AttributeError): + except AttributeError: return () + def get_val(self, name=None): - '''Return value of constant. Handles conj if name='varname_' is requested - instead of name='varname'.''' + """Return value of constant. + + Handles conj if ``name='varname_'`` is requested instead of ``name='varname'``. + """ if name is not None and type(name) is str: name, conj = get_name(name, isconj=True) - assert(self.name == name) - if conj: + assert self.name == name + if conj: return self.val.conjugate() - else: + else: return self.val - else: + else: return self.val class Parameter: - def __init__(self, name): - '''Container for parameters that are to be solved for.''' + """Container for parameters that are to be solved for.""" self.name = get_name(name) def sparse_form(self, name, eqnum, prm_order, prefactor, re_im_split=True): - xs,ys,vals = [], [], [] - # separated into real and imaginary parts iff one of the variables is conjugated with "_" - if re_im_split: - name,conj = get_name(name, True) - ordr,ordi = 2*prm_order[self.name], 2*prm_order[self.name]+1 - cr,ci = prefactor.real, prefactor.imag - i = 2*eqnum + xs, ys, vals = [], [], [] + # separated into real and imaginary parts iff one of the variables + # is conjugated with "_" + if re_im_split: + name, conj = get_name(name, True) + ordr, ordi = 2 * prm_order[self.name], 2 * prm_order[self.name] + 1 + cr, ci = prefactor.real, prefactor.imag + i = 2 * eqnum # (cr,ci) * (pr,pi) = (cr*pr-ci*pi, ci*pr+cr*pi) - xs.append(i); ys.append(ordr); vals.append(cr) # real component - xs.append(i+1); ys.append(ordr); vals.append(ci) # imag component + xs.append(i) + ys.append(ordr) + vals.append(cr) # real component + xs.append(i + 1) + ys.append(ordr) + vals.append(ci) # imag component if not conj: - xs.append(i); ys.append(ordi); vals.append(-ci) # imag component - xs.append(i+1); ys.append(ordi); vals.append(cr) # imag component + xs.append(i) + ys.append(ordi) + vals.append(-ci) # imag component + xs.append(i + 1) + ys.append(ordi) + vals.append(cr) # imag component else: - xs.append(i); ys.append(ordi); vals.append(ci) # imag component - xs.append(i+1); ys.append(ordi); vals.append(-cr) # imag component + xs.append(i) + ys.append(ordi) + vals.append(ci) # imag component + xs.append(i + 1) + ys.append(ordi) + vals.append(-cr) # imag component else: - xs.append(eqnum); ys.append(prm_order[self.name]); vals.append(prefactor) + xs.append(eqnum) + ys.append(prm_order[self.name]) + vals.append(prefactor) return xs, ys, vals - + def get_sol(self, x, prm_order): - '''Extract prm value from appropriate row of x solution.''' - if x.shape[0] > len(prm_order): # detect that we are splitting up real and imaginary parts - ordr,ordi = 2*prm_order[self.name], 2*prm_order[self.name]+1 - return {self.name: x[ordr] + np.complex64(1.0j)*x[ordi]} - else: return {self.name: x[prm_order[self.name]]} + """Extract prm value from appropriate row of x solution.""" + if x.shape[0] > len( + prm_order + ): # detect that we are splitting up real and imaginary parts + ordr, ordi = 2 * prm_order[self.name], 2 * prm_order[self.name] + 1 + return {self.name: x[ordr] + np.complex64(1.0j) * x[ordi]} + else: + return {self.name: x[prm_order[self.name]]} class LinearEquation: - '''Container for all prms and constants associated with a linear equation.''' + """Container for all prms and constants associated with a linear equation.""" + def __init__(self, val, **kwargs): self.val = val if type(val) is str: - n = ast.parse(val, mode='eval') + n = ast.parse(val, mode="eval") val = ast_getterms(n) - self.wgts = kwargs.pop('wgts',np.float32(1.)) + self.wgts = kwargs.pop("wgts", np.float32(1.0)) self.has_conj = False - constants = kwargs.pop('constants', kwargs) + constants = kwargs.pop("constants", kwargs) self.process_terms(val, constants) def process_terms(self, terms, constants): - '''Classify terms from parsed str as Constant or Parameter.''' + """Classify terms from parsed str as Constant or Parameter.""" self.consts, self.prms = {}, {} for term in terms: for t in term: try: self.add_const(t, constants) - except(KeyError): # must be a parameter then + except KeyError: # must be a parameter then p = Parameter(t) - self.has_conj |= get_name(t,isconj=True)[-1] # keep track if any prms are conj + self.has_conj |= get_name(t, isconj=True)[ + -1 + ] # keep track if any prms are conj self.prms[p.name] = p self.terms = self.order_terms(terms) def add_const(self, name, constants): - '''Manually add a constant of given name to internal list of constants. Value is drawn from constants.''' + """Manually add a constant of given name to internal list of constants. + + Value is drawn from constants. + """ n = get_name(name) - if n in constants and isinstance(constants[n], Constant): c = constants[n] - else: c = Constant(name, constants) # raises KeyError if not a constant + if n in constants and isinstance(constants[n], Constant): + c = constants[n] + else: + c = Constant(name, constants) # raises KeyError if not a constant self.consts[c.name] = c - + def order_terms(self, terms): - '''Reorder terms to obey (const1,const2,...,prm) ordering.''' - for L in terms: L.sort(key=lambda x: get_name(x) in self.prms) + """Reorder terms to obey (const1,const2,...,prm) ordering.""" + for L in terms: + L.sort(key=lambda x: get_name(x) in self.prms) # Validate that each term has exactly 1 unsolved parameter. for t in terms: - assert(get_name(t[-1]) in self.prms) + assert get_name(t[-1]) in self.prms for ti in t[:-1]: - assert(type(ti) is not str or get_name(ti) in self.consts) + assert type(ti) is not str or get_name(ti) in self.consts return terms - def eval_consts(self, const_list, wgts=np.float32(1.)): - '''Multiply out constants (and wgts) for placing in matrix.''' + def eval_consts(self, const_list, wgts=np.float32(1.0)): + """Multiply out constants (and wgts) for placing in matrix.""" const_list = [self.consts[get_name(c)].get_val(c) for c in const_list] - return wgts**.5 * reduce(lambda x,y: x*y, const_list, np.float32(1.)) - # this has the effect of putting the square root of the weights into each A matrix - #return 1. * reduce(lambda x,y: x*y, const_list, 1.) + return wgts**0.5 * reduce(lambda x, y: x * y, const_list, np.float32(1.0)) def sparse_form(self, eqnum, prm_order, re_im_split=True): - '''Returns the row and col information and the values of coefficients to build up - part of the sparse (CSR) reprentation of the A matrix corresponding to this equation.''' + """Return the row/col info and the values of coefficients. + + Intended to build up part of the sparse (CSR) reprentation of the A matrix + corresponding to this equation. + """ xs, ys, vals = [], [], [] for term in self.terms: p = self.prms[get_name(term[-1])] f = self.eval_consts(term[:-1], self.wgts) - x,y,val = p.sparse_form(term[-1], eqnum, prm_order, f.flatten(), re_im_split) - xs += x; ys += y; vals += val + x, y, val = p.sparse_form( + term[-1], eqnum, prm_order, f.flatten(), re_im_split + ) + xs += x + ys += y + vals += val return xs, ys, vals - + def eval(self, sol): - '''Given dict of parameter solutions, evaluate this equation.''' + """Given dict of parameter solutions, evaluate this equation.""" rv = 0 for term in self.terms: total = self.eval_consts(term[:-1]) - name,isconj = get_name(term[-1],isconj=True) - if isconj: total *= np.conj(sol[name]) - else: total *= sol[name] + name, isconj = get_name(term[-1], isconj=True) + if isconj: + total *= np.conj(sol[name]) + else: + total *= sol[name] rv += total return rv - + def verify_weights(wgts, keys): - '''Given wgts and keys, ensure wgts have all keys and are all real. - If wgts == {} or None, return all 1s.''' + """Given wgts and keys, ensure wgts have all keys and are all real. + + If wgts == {} or None, return all 1s. + """ if wgts is None or wgts == {}: - return {k: np.float32(1.) for k in keys} - else: - for k in keys: - assert(k in wgts) # must have weights for all keys - assert(np.iscomplexobj(wgts[k]) == False) # tricky errors happen if wgts are complex - return wgts + return {k: np.float32(1.0) for k in keys} + for k in keys: + assert k in wgts # must have weights for all keys + # tricky errors happen if wgts are complex + assert not np.iscomplexobj(wgts[k]) + return wgts + def infer_dtype(values): - '''Given a list of values, return the appropriate numpy data - type for matrices, solutions. + """Get the appropriate numpy data type for matrices, solutions. + Returns float32, float64, complex64, or complex128. Python scalars will be treated float 32 or complex64 as appropriate. - Likewise, all int types will be treated as single precision floats.''' - + Likewise, all int types will be treated as single precision floats. + """ # ensure we are at least a float32 if we were passed integers - types = [np.dtype('float32')] + types = [np.dtype("float32")] # determine the data type of all values - all_types = list(set([v.dtype if hasattr(v,'dtype') else type(v) - for v in values])) + all_types = list({v.dtype if hasattr(v, "dtype") else type(v) for v in values}) # split types into numpy vs. python dtypes py_types = [t for t in all_types if not isinstance(t, np.dtype)] np_types = [t for t in all_types if isinstance(t, np.dtype)] # only use numpy dtypes that are floating/complex - types += [t for t in np_types if np.issubdtype(t, np.floating) - or np.issubdtype(t, np.complexfloating)] + types += [ + t + for t in np_types + if np.issubdtype(t, np.floating) or np.issubdtype(t, np.complexfloating) + ] # if any python constants are complex, promote to complex, but otherwise # don't promote to double if we have floats/doubles/ints in python if complex in py_types: - types.append(np.dtype('complex64')) + types.append(np.dtype("complex64")) # Use promote_types to determine the final floating/complex dtype dtype = reduce(np.promote_types, types) return dtype -class LinearSolver: +class LinearSolver: def __init__(self, data, wgts={}, sparse=False, **kwargs): """Set up a linear system of equations of the form 1*a + 2*b + 3*c = 4. - Args: - data: Dictionary that maps linear equations, written as valid python-interpetable strings - that include the variables in question, to (complex) numbers or numpy arrarys. - Variables with trailing underscores '_' are interpreted as complex conjugates. - wgts: Dictionary that maps equation strings from data to real weights to apply to each - equation. Weights are treated as 1/sigma^2. All equations in the data must have a weight - if wgts is not the default, {}, which means all 1.0s. - sparse: Boolean (default False). If True, represents A matrix sparsely (though AtA, Aty end up dense) - May be faster for certain systems of equations. - **kwargs: keyword arguments of constants (python variables in keys of data that - are not to be solved for) - - Returns: - None + Parameters + ---------- + data : dict + maps linear equations, written as valid python-interpetable strings + that include the variables in question, to (complex) numbers or numpy + arrays. Variables with trailing underscores '_' are interpreted as complex + conjugates. + wgts : dict + maps equation strings from data to real weights to apply to each + equation. Weights are treated as 1/sigma^2. All equations in the data must + have a weight if wgts is not the default, {}, which means all 1.0s. + sparse : bool + If True, represents A matrix sparsely (though AtA, Aty end up dense) + May be faster for certain systems of equations. + **kwargs: keyword arguments of constants (python variables in keys of data that + are not to be solved for) """ # XXX add ability to override datatype inference # see https://github.com/HERA-Team/linsolve/issues/30 @@ -274,84 +330,107 @@ def __init__(self, data, wgts={}, sparse=False, **kwargs): self.keys = list(data.keys()) self.sparse = sparse self.wgts = verify_weights(wgts, self.keys) - constants = kwargs.pop('constants', kwargs) - self.eqs = [LinearEquation(k,wgts=self.wgts[k], constants=constants) for k in self.keys] + constants = kwargs.pop("constants", kwargs) + self.eqs = [ + LinearEquation(k, wgts=self.wgts[k], constants=constants) for k in self.keys + ] # XXX add ability to have more than one measurment for a key=equation # see https://github.com/HERA-Team/linsolve/issues/14 self.prms = {} - for eq in self.eqs: + for eq in self.eqs: self.prms.update(eq.prms) self.consts = {} - for eq in self.eqs: - self.consts.update(eq.consts) + for eq in self.eqs: + self.consts.update(eq.consts) self.prm_order = {} - for i,p in enumerate(self.prms): + for i, p in enumerate(self.prms): self.prm_order[p] = i # infer dtype for later arrays - self.re_im_split = kwargs.pop('re_im_split',False) - #go through and figure out if any variables are conjugated - for eq in self.eqs: + self.re_im_split = kwargs.pop("re_im_split", False) + # go through and figure out if any variables are conjugated + for eq in self.eqs: self.re_im_split |= eq.has_conj - self.dtype = infer_dtype(list(self.data.values()) + list(self.consts.values()) + list(self.wgts.values())) - if self.re_im_split: self.dtype = np.real(np.ones(1, dtype=self.dtype)).dtype + self.dtype = infer_dtype( + list(self.data.values()) + + list(self.consts.values()) + + list(self.wgts.values()) + ) + if self.re_im_split: + self.dtype = np.real(np.ones(1, dtype=self.dtype)).dtype self.shape = self._shape() def _shape(self): - '''Get broadcast shape of constants, weights for last dim of A''' + """Get broadcast shape of constants, weights for last dim of A.""" sh = [] for k in self.consts: shk = self.consts[k].shape() - if len(shk) > len(sh): sh += [0] * (len(shk)-len(sh)) - for i in range(min(len(sh),len(shk))): sh[i] = max(sh[i],shk[i]) + if len(shk) > len(sh): + sh += [0] * (len(shk) - len(sh)) + for i in range(min(len(sh), len(shk))): + sh[i] = max(sh[i], shk[i]) for k in self.wgts: - try: shk = self.wgts[k].shape - except(AttributeError): continue - if len(shk) > len(sh): sh += [0] * (len(shk)-len(sh)) - for i in range(min(len(sh),len(shk))): sh[i] = max(sh[i],shk[i]) + try: + shk = self.wgts[k].shape + except AttributeError: + continue + if len(shk) > len(sh): + sh += [0] * (len(shk) - len(sh)) + for i in range(min(len(sh), len(shk))): + sh[i] = max(sh[i], shk[i]) return tuple(sh) def _A_shape(self): - '''Get shape of A matrix (# eqs, # prms, data.size). Now always 3D.''' - try: - sh = (reduce(lambda x,y: x*y, self.shape),) # flatten data dimensions so A is always 3D - except(TypeError): + """Get shape of A matrix (# eqs, # prms, data.size). Now always 3D.""" + try: + sh = ( + reduce(lambda x, y: x * y, self.shape), + ) # flatten data dimensions so A is always 3D + except TypeError: sh = (1,) - if self.re_im_split: - return (2*len(self.eqs),2*len(self.prm_order))+sh - else: return (len(self.eqs),len(self.prm_order))+sh + if self.re_im_split: + return (2 * len(self.eqs), 2 * len(self.prm_order)) + sh + else: + return (len(self.eqs), len(self.prm_order)) + sh def get_A(self): - '''Return A matrix for A*x=y.''' + """Return A matrix for A*x=y.""" A = np.zeros(self._A_shape(), dtype=self.dtype) - xs,ys,vals = self.sparse_form() - ones = np.ones_like(A[0,0]) - #A[xs,ys] += [v * ones for v in vals] # This is broken when a single equation has the same param more than once - for x,y,v in zip(xs,ys,[v * ones for v in vals]): - A[x,y] += v # XXX ugly + xs, ys, vals = self.sparse_form() + ones = np.ones_like(A[0, 0]) + + # This is broken when a single equation has the same param more than once: + # --> A[xs,ys] += [v * ones for v in vals] + + for x, y, v in zip(xs, ys, [v * ones for v in vals]): + A[x, y] += v # XXX ugly return A def sparse_form(self): - '''Returns a lists of lists of row and col numbers and coefficients in order to - express the linear system as a CSR sparse matrix.''' + """Get a list of lists of row/col numbers and coefficients. + + Intended to express the linear system as a CSR sparse matrix. + """ xs, ys, vals = [], [], [] - for i,eq in enumerate(self.eqs): - x,y,val = eq.sparse_form(i, self.prm_order, self.re_im_split) - xs += x; ys += y; vals += val + for i, eq in enumerate(self.eqs): + x, y, val = eq.sparse_form(i, self.prm_order, self.re_im_split) + xs += x + ys += y + vals += val return xs, ys, vals def get_A_sparse(self): - '''Fixes dimension needed for CSR sparse matrix representation.''' - xs,ys,vals = self.sparse_form() - ones = np.ones(self._A_shape()[2:],dtype=self.dtype) - for n,val in enumerate(vals): + """Fixes dimension needed for CSR sparse matrix representation.""" + xs, ys, vals = self.sparse_form() + ones = np.ones(self._A_shape()[2:], dtype=self.dtype) + for n, val in enumerate(vals): if not isinstance(val, np.ndarray) or val.size == 1: - vals[n] = ones*val + vals[n] = ones * val return np.array(xs), np.array(ys), np.array(vals, dtype=self.dtype).T - + def get_weighted_data(self): - '''Return y = data * wgt**.5 as a 2D vector, regardless of original data/wgt shape.''' - dtype = self.dtype # default + """Return y = data * wgt**.5 as a 2D vector, regardless of original shape.""" + dtype = self.dtype # default if self.re_im_split: if dtype == np.float32: dtype = np.complex64 @@ -360,76 +439,91 @@ def get_weighted_data(self): d = np.array([self.data[k] for k in self.keys], dtype=dtype) if len(self.wgts) > 0: w = np.array([self.wgts[k] for k in self.keys]) - w.shape += (1,) * (d.ndim-w.ndim) - d.shape += (1,) * (w.ndim-d.ndim) - d = d*(w**.5) - # this is w**.5 because A already has a factor of w**.5 in it, so - # (At N^-1 A)^1 At N^1 y ==> (At A)^1 At d (where d is the result of this + w.shape += (1,) * (d.ndim - w.ndim) + d.shape += (1,) * (w.ndim - d.ndim) + d = d * (w**0.5) + # this is w**.5 because A already has a factor of w**.5 in it, so + # (At N^-1 A)^1 At N^1 y ==> (At A)^1 At d (where d is the result of this # function and A is redefined to include half of the weights) - self._data_shape = d.shape[1:] # store for reshaping sols to original + self._data_shape = d.shape[1:] # store for reshaping sols to original # It's possible to have a zero-size d (if there are no keys) d.shape = (d.shape[0], -1) if d.size > 0 else (0, 0) - + if self.re_im_split: - rv = np.empty((2*d.shape[0],)+d.shape[1:], dtype=self.dtype) - rv[::2],rv[1::2] = d.real, d.imag + rv = np.empty((2 * d.shape[0],) + d.shape[1:], dtype=self.dtype) + rv[::2], rv[1::2] = d.real, d.imag return rv - else: return d - + else: + return d + def _invert_lsqr(self, A, y, rcond): - '''Use np.linalg.lstsq to solve a system of equations. Usually the best - performer, but for a fully-constrained system, 'solve' can be faster. Also, - there are a couple corner cases where lstsq is unstable but pinv works - for the same rcond. It seems particularly the case for single precision matrices.''' + """Solve a system of equations. + + Uses :func`np.linalg.lstsq`. Usually the best performer, but for a + fully-constrained system, 'solve' can be faster. Also, there are a couple + corner cases where lstsq is unstable but pinv works for the same rcond. It + seems particularly the case for single precision matrices. + """ # add ability for lstsq to work on stacks of matrices # see https://github.com/HERA-Team/linsolve/issues/31 - #x = [np.linalg.lstsq(A[...,k], y[...,k], rcond=rcond)[0] for k in range(y.shape[-1])] # np.linalg.lstsq uses lapack gelsd and is slower: - # see https://stackoverflow.com/questions/55367024/fastest-way-of-solving-linear-least-squares - x = [scipy.linalg.lstsq(A[...,k], y[...,k], - cond=rcond, lapack_driver='gelsy')[0] - for k in range(y.shape[-1])] + # see https://stackoverflow.com/questions/55367024/ + # fastest-way-of-solving-linear-least-squares + x = [ + scipy.linalg.lstsq(A[..., k], y[..., k], cond=rcond, lapack_driver="gelsy")[ + 0 + ] + for k in range(y.shape[-1]) + ] return np.array(x).T def _invert_lsqr_sparse(self, xs_ys_vals, y, rcond): - '''Use the scipy.sparse lsqr solver.''' - # x = [scipy.sparse.linalg.lsqr(A[k], y[...,k], atol=rcond, btol=rcond)[0] for k in range(y.shape[-1])] # this is crazy slow for unknown reasons + """Use the scipy.sparse lsqr solver.""" + # Note that using scipy.sparse.linalg.lsqr is crazy slow for unknown reasons. AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y) - x = [scipy.linalg.lstsq(AtA[k], Aty[k], - cond=rcond, lapack_driver='gelsy')[0] - for k in range(y.shape[-1])] + x = [ + scipy.linalg.lstsq(AtA[k], Aty[k], cond=rcond, lapack_driver="gelsy")[0] + for k in range(y.shape[-1]) + ] return np.array(x).T def _invert_pinv_shared(self, A, y, rcond): - '''Helper function for forming (At A)^-1 At. Uses pinv to invert.''' + """Helper function for forming (At A)^-1 At. Uses pinv to invert.""" At = A.T.conj() AtA = np.dot(At, A) AtAi = np.linalg.pinv(AtA, rcond=rcond, hermitian=True) - # x = np.einsum('ij,jk,kn->in', AtAi, At, y, optimize=True) # slow for small matrices - x = np.dot(AtAi, np.dot(At, y)) - return x + # Following is slow for small matrices: + # --> x = np.einsum('ij,jk,kn->in', AtAi, At, y, optimize=True) + return np.dot(AtAi, np.dot(At, y)) def _invert_pinv_shared_sparse(self, xs_ys_vals, y, rcond): - '''Use pinv to invert AtA matrix. Tends to be ~10x slower than lsqr for sparse matrices''' + """Use pinv to invert AtA matrix. + + Tends to be ~10x slower than lsqr for sparse matrices. + """ xs, ys, vals = xs_ys_vals A = csc_matrix((vals[0], (xs, ys))) At = A.T.conj() - AtA = At.dot(A).toarray() # make dense after sparse dot product + AtA = At.dot(A).toarray() # make dense after sparse dot product AtAi = np.linalg.pinv(AtA, rcond=rcond, hermitian=True) - x = np.dot(AtAi, At.dot(y)) - return x + return np.dot(AtAi, At.dot(y)) def _invert_pinv(self, A, y, rcond): - '''Use np.linalg.pinv to invert AtA matrix. Tends to be about ~3x slower than solve.''' + """Use np.linalg.pinv to invert AtA matrix. + + Tends to be about ~3x slower than solve. + """ # As of numpy 1.14, pinv works on stacks of matrices - At = A.transpose([2,1,0]).conj() - AtA = [np.dot(At[k], A[...,k]) for k in range(y.shape[-1])] - # AtA = np.einsum('jin,jkn->nik', A.conj(), A, optimize=True) # slower + At = A.transpose([2, 1, 0]).conj() + AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])] + + # This is slower: + # --> AtA = np.einsum('jin,jkn->nik', A.conj(), A, optimize=True) + AtAi = np.linalg.pinv(AtA, rcond=rcond, hermitian=True) - x = np.einsum('nij,njk,kn->in', AtAi, At, y, optimize=True) - return x + return np.einsum("nij,njk,kn->in", AtAi, At, y, optimize=True) def _get_AtA_Aty_sparse(self, xs_ys_vals, y): xs, ys, vals = xs_ys_vals @@ -439,10 +533,10 @@ def _get_AtA_Aty_sparse(self, xs_ys_vals, y): A = {} # can below be coded as a comprehension? need to be sure # to sum over repeat xs... - for _y,_x,_v in zip(ys, xs, vals.T): + for _y, _x, _v in zip(ys, xs, vals.T): try: A[_y][_x] = A[_y].get(_x, 0) + _v - except(KeyError): + except KeyError: A[_y] = {_x: _v} nprms = self._A_shape()[1] AtA = np.empty((y.shape[-1], nprms, nprms), dtype=self.dtype) @@ -451,42 +545,53 @@ def _get_AtA_Aty_sparse(self, xs_ys_vals, y): # Speedup over scipy.sparse b/c y[x] and A[i][x] are arrays for i in range(AtA.shape[1]): # 'i' is the column index, 'x' is the row index of A - Aty[:,i] = sum([A[i][x].conj() * y[x] for x in A[i]]) + Aty[:, i] = sum(A[i][x].conj() * y[x] for x in A[i]) for j in range(i, AtA.shape[1]): - AtA[:,i,j] = sum([A[i][x].conj() * A[j][x] - for x in A[i] if x in A[j]]) - AtA[:,j,i] = AtA[:,i,j].conj() # explicitly hermitian + AtA[:, i, j] = sum(A[i][x].conj() * A[j][x] for x in A[i] if x in A[j]) + AtA[:, j, i] = AtA[:, i, j].conj() # explicitly hermitian return AtA, Aty def _invert_pinv_sparse(self, xs_ys_vals, y, rcond): - '''Use pinv to invert AtA matrix. Tends to be ~10x slower than lsqr for sparse matrices''' + """Use pinv to invert AtA matrix. + + Tends to be ~10x slower than lsqr for sparse matrices. + """ AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y) AtAi = np.linalg.pinv(AtA, rcond=rcond, hermitian=True) x = [np.dot(AtAi[k], Aty[k]) for k in range(y.shape[-1])] return np.array(x).T def _invert_solve(self, A, y, rcond): - '''Use np.linalg.solve to solve a system of equations. Requires a fully constrained - system of equations (i.e. doesn't deal with singular matrices). Can by ~1.5x faster that lstsq - for this case. 'rcond' is unused, but passed as an argument to match the interface of other - _invert methods.''' + """Use np.linalg.solve to solve a system of equations. + + Requires a fully constrained system of equations (i.e. doesn't deal with + singular matrices). Can by ~1.5x faster that lstsq for this case. 'rcond' + is unused, but passed as an argument to match the interface of other _invert + methods. + """ # As of numpy 1.8, solve works on stacks of matrices - At = A.transpose([2,1,0]).conj() - AtA = [np.dot(At[k], A[...,k]) for k in range(y.shape[-1])] - Aty = [np.dot(At[k], y[...,k]) for k in range(y.shape[-1])] - return np.linalg.solve(AtA, Aty).T # sometimes errors if singular - #return scipy.linalg.solve(AtA, Aty, 'her') # slower by about 50% + At = A.transpose([2, 1, 0]).conj() + AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])] + Aty = [np.dot(At[k], y[..., k]) for k in range(y.shape[-1])] + + # This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her') + + # But this sometimes errors if singular: + return np.linalg.solve(AtA, Aty).T def _invert_solve_sparse(self, xs_ys_vals, y, rcond): - '''Use linalg.solve to solve a fully constrained (non-degenerate) system of equations. - Tends to be ~3x slower than lsqr for sparse matrices. 'rcond' is unused, but passed - as an argument to match the interface of other _invert methods.''' + """Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs. + + Tends to be ~3x slower than lsqr for sparse matrices. 'rcond' is unused, + but passed as an argument to match the interface of other _invert methods. + """ AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y) - #x = scipy.sparse.linalg.spsolve(AtA, Aty) # AtA and Aty don't end up being that sparse, usually + # AtA and Aty don't end up being that sparse, usually, so don't use this: + # --> x = scipy.sparse.linalg.spsolve(AtA, Aty) return np.linalg.solve(AtA, Aty).T def _invert_default(self, A, y, rcond): - '''The default inverter, currently 'pinv'.''' + """The default inverter, currently 'pinv'.""" # XXX doesn't deal w/ fact that individual matrices might # fail for one inversion method. # see https://github.com/HERA-Team/linsolve/issues/32 @@ -497,391 +602,501 @@ def _invert_default(self, A, y, rcond): return self._invert_pinv(A, y, rcond) def _invert_default_sparse(self, xs_ys_vals, y, rcond): - '''The default sparse inverter, currently 'pinv'.''' + """The default sparse inverter, currently 'pinv'.""" return self._invert_pinv_sparse(xs_ys_vals, y, rcond) - def solve(self, rcond=None, mode='default'): + def solve(self, rcond=None, mode="default"): """Compute x' = (At A)^-1 At * y, returning x' as dict of prms:values. - Args: - rcond: cutoff ratio for singular values useed in numpy.linalg.lstsq, numpy.linalg.pinv, - or (if sparse) as atol and btol in scipy.sparse.linalg.lsqr - Default: None (resolves to machine precision for inferred dtype) - mode: 'default', 'lsqr', 'pinv', or 'solve', selects which inverter to use, unless all equations share the same A matrix, in which case pinv is always used`. - 'default': alias for 'pinv'. - 'lsqr': uses numpy.linalg.lstsq to do an inversion-less solve. Usually - the fastest solver. - 'solve': uses numpy.linalg.solve to do an inversion-less solve. Fastest, - but only works for fully constrained systems of equations. - 'pinv': uses numpy.linalg.pinv to perform a pseudo-inverse and then solves. Can - sometimes be more numerically stable (but slower) than 'lsqr'. - All of these modes are superceded if the same system of equations applies - to all datapoints in an array. In this case, a inverse-based method is used so - that the inverted matrix can be re-used to solve all array indices. - - Returns: - sol: a dictionary of solutions with variables as keys + Parameters + ---------- + rcond + cutoff ratio for singular values useed in :func:`numpy.linalg.lstsq`, + :func:`numpy.linalg.pinv`, or (if sparse) as atol and btol in + :func:`scipy.sparse.linalg.lsqr` + mode : {'default', 'lsqr', 'pinv', or 'solve'}, + selects which inverter to use, unless all equations share the same A matrix, + in which case pinv is always used: + + * 'default': alias for 'pinv'. + * 'lsqr': uses numpy.linalg.lstsq to do an inversion-less solve. Usually + the fastest solver. + * 'solve': uses numpy.linalg.solve to do an inversion-less solve. Fastest, + but only works for fully constrained systems of equations. + * 'pinv': uses numpy.linalg.pinv to perform a pseudo-inverse and then + solves. Can sometimes be more numerically stable (but slower) than 'lsqr'. + + All of these modes are superceded if the same system of equations applies + to all datapoints in an array. In this case, a inverse-based method is + used so that the inverted matrix can be re-used to solve all array indices. + + Returns + ------- + sol + a dictionary of solutions with variables as keys """ - assert(mode in ['default','lsqr','pinv','solve']) + assert mode in ["default", "lsqr", "pinv", "solve"] if rcond is None: rcond = np.finfo(self.dtype).resolution y = self.get_weighted_data() if self.sparse: xs, ys, vals = self.get_A_sparse() - if vals.shape[0] == 1 and y.shape[-1] > 1: # reuse inverse - x = self._invert_pinv_shared_sparse((xs,ys,vals), y, rcond) - else: # we can't reuse inverses - if mode == 'default': _invert = self._invert_default_sparse - elif mode == 'lsqr': _invert = self._invert_lsqr_sparse - elif mode == 'pinv': _invert = self._invert_pinv_sparse - elif mode == 'solve': _invert = self._invert_solve_sparse - x = _invert((xs,ys,vals), y, rcond) - else: + if vals.shape[0] == 1 and y.shape[-1] > 1: # reuse inverse + x = self._invert_pinv_shared_sparse((xs, ys, vals), y, rcond) + else: # we can't reuse inverses + if mode == "default": + _invert = self._invert_default_sparse + elif mode == "lsqr": + _invert = self._invert_lsqr_sparse + elif mode == "pinv": + _invert = self._invert_pinv_sparse + elif mode == "solve": + _invert = self._invert_solve_sparse + x = _invert((xs, ys, vals), y, rcond) + else: A = self.get_A() Ashape = self._A_shape() - assert(A.ndim == 3) - if Ashape[-1] == 1 and y.shape[-1] > 1: # can reuse inverse - x = self._invert_pinv_shared(A[...,0], y, rcond) - else: # we can't reuse inverses - if mode == 'default': _invert = self._invert_default - elif mode == 'lsqr': _invert = self._invert_lsqr - elif mode == 'pinv': _invert = self._invert_pinv - elif mode == 'solve': _invert = self._invert_solve + assert A.ndim == 3 + if Ashape[-1] == 1 and y.shape[-1] > 1: # can reuse inverse + x = self._invert_pinv_shared(A[..., 0], y, rcond) + else: # we can't reuse inverses + if mode == "default": + _invert = self._invert_default + elif mode == "lsqr": + _invert = self._invert_lsqr + elif mode == "pinv": + _invert = self._invert_pinv + elif mode == "solve": + _invert = self._invert_solve x = _invert(A, y, rcond) - x.shape = x.shape[:1] + self._data_shape # restore to shape of original data + x.shape = x.shape[:1] + self._data_shape # restore to shape of original data sol = {} - for p in list(self.prms.values()): sol.update(p.get_sol(x,self.prm_order)) + for p in list(self.prms.values()): + sol.update(p.get_sol(x, self.prm_order)) return sol def eval(self, sol, keys=None): - """Returns a dictionary evaluating data keys to the current values given sol and consts. - Uses the stored data object unless otherwise specified.""" - if keys is None: keys = self.keys - elif type(keys) is str: keys = [keys] - elif type(keys) is dict: keys = list(keys.keys()) + """Get a dict mapping data keys to the current values given sol and consts. + + Uses the stored data object unless otherwise specified. + """ + if keys is None: + keys = self.keys + elif type(keys) is str: + keys = [keys] + elif type(keys) is dict: + keys = list(keys.keys()) result = {} for k in keys: eq = LinearEquation(k, **self.consts) result[k] = eq.eval(sol) return result - + def _chisq(self, sol, data, wgts, evaluator): """Internal adaptable chisq calculator.""" - if len(wgts) == 0: sigma2 = {k: 1.0 for k in list(data.keys())} #equal weights - else: sigma2 = {k: wgts[k]**-1 for k in list(wgts.keys())} + if len(wgts) == 0: + sigma2 = {k: 1.0 for k in list(data.keys())} # equal weights + else: + sigma2 = {k: wgts[k] ** -1 for k in list(wgts.keys())} evaluated = evaluator(sol, keys=data) chisq = 0 - for k in list(data.keys()): chisq += np.abs(evaluated[k]-data[k])**2 / sigma2[k] + for k in list(data.keys()): + chisq += np.abs(evaluated[k] - data[k]) ** 2 / sigma2[k] return chisq - + def chisq(self, sol, data=None, wgts=None): - """Compute Chi^2 = |obs - mod|^2 / sigma^2 for the specified solution. Weights are treated as 1/sigma^2. - wgts = {} means sigma = 1. Default uses the stored data and weights unless otherwise overwritten.""" - if data is None: + """Compute ``Chi^2 = |obs - mod|^2 / sigma^2`` for the specified solution. + + Weights are treated as 1/sigma^2. wgts = {} means sigma = 1. Default uses the + stored data and weights unless otherwise overwritten. + """ + if data is None: data = self.data - if wgts is None: + if wgts is None: wgts = self.wgts wgts = verify_weights(wgts, list(data.keys())) return self._chisq(sol, data, wgts, self.eval) - -# XXX need to add support for conjugated constants...maybe this already works because we have conjugated constants inherited from taylor expansion + +# XXX need to add support for conjugated constants...maybe this already works because +# we have conjugated constants inherited from taylor expansion # see https://github.com/HERA-Team/linsolve/issues/12 -def conjterm(term, mode='amp'): - '''Modify prefactor for conjugated terms, according to mode='amp|phs|real|imag'.''' - f = {'amp':1,'phs':-1,'real':1,'imag':1j}[mode] # if KeyError, mode was invalid - terms = [[f,t[:-1]] if t.endswith('_') else [t] for t in term] - return reduce(lambda x,y: x+y, terms) +def conjterm(term, mode="amp"): + """Modify prefactor for conjugated terms, for ``mode='amp|phs|real|imag'``.""" + f = {"amp": 1, "phs": -1, "real": 1, "imag": 1j}[ + mode + ] # if KeyError, mode was invalid + terms = [[f, t[:-1]] if t.endswith("_") else [t] for t in term] + return reduce(lambda x, y: x + y, terms) -def jointerms(terms): - '''String that joins lists of lists of terms as the sum of products.''' - return '+'.join(['*'.join(map(str,t)) for t in terms]) +def jointerms(terms): + """String that joins lists of lists of terms as the sum of products.""" + return "+".join(["*".join(map(str, t)) for t in terms]) -class LogProductSolver: +class LogProductSolver: def __init__(self, data, wgts={}, sparse=False, **kwargs): - """Set up a nonlinear system of equations of the form a*b = 1.0 to linearze via logarithm. - - Args: - data: Dictionary that maps nonlinear product equations, written as valid python-interpetable - strings that include the variables in question, to (complex) numbers or numpy arrarys. - Variables with trailing underscores '_' are interpreted as complex conjugates (e.g. x*y_ - parses as x * y.conj()). - wgts: Dictionary that maps equation strings from data to real weights to apply to each - equation. Weights are treated as 1/sigma^2. All equations in the data must have a weight - if wgts is not the default, {}, which means all 1.0s. - sparse: Boolean (default False). If True, represents A matrix sparsely (though AtA, Aty end up dense) - May be faster for certain systems of equations. - **kwargs: keyword arguments of constants (python variables in keys of data that - are not to be solved for) - - Returns: - None + """A log-solver for systems of equations of the form a*b = 1.0. + + Parameters + ---------- + data + dict mapping nonlinear product equations, written as valid + python-interpetable strings that include the variables in question, to + (complex) numbers or numpy arrarys. Variables with trailing underscores '_' + are interpreted as complex conjugates (e.g. x*y_ parses as x * y.conj()). + wgts + dict that maps equation strings from data to real weights to apply to each + equation. Weights are treated as 1/sigma^2. All equations in the data must + have a weight if wgts is not the default, {}, which means all 1.0s. + sparse : bool. + If True, represents A matrix sparsely (though AtA, Aty end up dense). + May be faster for certain systems of equations. + **kwargs: keyword arguments of constants (python variables in keys of data that + are not to be solved for) """ keys = list(data.keys()) wgts = verify_weights(wgts, keys) - eqs = [ast_getterms(ast.parse(k, mode='eval')) for k in keys] + eqs = [ast_getterms(ast.parse(k, mode="eval")) for k in keys] logamp, logphs = {}, {} logampw, logphsw = {}, {} - for k,eq in zip(keys,eqs): - assert(len(eq) == 1) # equations have to be purely products---no adds - eqamp = jointerms([conjterm([t],mode='amp') for t in eq[0]]) - eqphs = jointerms([conjterm([t],mode='phs') for t in eq[0]]) + for k, eq in zip(keys, eqs): + assert len(eq) == 1 # equations have to be purely products---no adds + eqamp = jointerms([conjterm([t], mode="amp") for t in eq[0]]) + eqphs = jointerms([conjterm([t], mode="phs") for t in eq[0]]) dk = np.log(data[k]) - logamp[eqamp],logphs[eqphs] = dk.real, dk.imag - try: logampw[eqamp],logphsw[eqphs] = wgts[k], wgts[k] - except(KeyError): pass - constants = kwargs.pop('constants', kwargs) - self.dtype = infer_dtype(list(data.values()) + list(constants.values()) + list(wgts.values())) + logamp[eqamp], logphs[eqphs] = dk.real, dk.imag + try: + logampw[eqamp], logphsw[eqphs] = wgts[k], wgts[k] + except KeyError: + pass + constants = kwargs.pop("constants", kwargs) + self.dtype = infer_dtype( + list(data.values()) + list(constants.values()) + list(wgts.values()) + ) logamp_consts, logphs_consts = {}, {} for k in constants: - c = np.log(constants[k]) # log unwraps complex circle at -pi + c = np.log(constants[k]) # log unwraps complex circle at -pi logamp_consts[k], logphs_consts[k] = c.real, c.imag - self.ls_amp = LinearSolver(logamp, logampw, sparse=sparse, constants=logamp_consts) + self.ls_amp = LinearSolver( + logamp, logampw, sparse=sparse, constants=logamp_consts + ) if self.dtype in (np.complex64, np.complex128): # XXX worry about enumrating these here without # explicitly ensuring that these are the support complex # dtypes. # see https://github.com/HERA-Team/linsolve/issues/33 - self.ls_phs = LinearSolver(logphs, logphsw, sparse=sparse, constants=logphs_consts) + self.ls_phs = LinearSolver( + logphs, logphsw, sparse=sparse, constants=logphs_consts + ) else: - self.ls_phs = None # no phase term to solve for + self.ls_phs = None # no phase term to solve for - def solve(self, rcond=None, mode='default'): + def solve(self, rcond=None, mode="default"): """Solve both amplitude and phase by taking the log of both sides to linearize. - Args: - rcond: cutoff ratio for singular values useed in numpy.linalg.lstsq, numpy.linalg.pinv, - or (if sparse) as atol and btol in scipy.sparse.linalg.lsqr - Default: None (resolves to machine precision for inferred dtype) - mode: 'default', 'lsqr', 'pinv', or 'solve', selects which inverter to use, unless all equations share the same A matrix, in which case pinv is always used`. - 'default': alias for 'pinv'. - 'lsqr': uses numpy.linalg.lstsq to do an inversion-less solve. Usually - the fastest solver. - 'solve': uses numpy.linalg.solve to do an inversion-less solve. Fastest, - but only works for fully constrained systems of equations. - 'pinv': uses numpy.linalg.pinv to perform a pseudo-inverse and then solves. Can - sometimes be more numerically stable (but slower) than 'lsqr'. - All of these modes are superceded if the same system of equations applies - to all datapoints in an array. In this case, a inverse-based method is used so - that the inverted matrix can be re-used to solve all array indices. - - Returns: - sol: a dictionary of complex solutions with variables as keys + Parameters + ---------- + rcond + cutoff ratio for singular values used in :func:`numpy.linalg.lstsq`, + :func:`numpy.linalg.pinv`, or (if sparse) as atol and btol in + :func:`scipy.sparse.linalg.lsqr`. Default is to resolve to machine precision + for inferred dtype. + mode : {'default', 'lsqr', 'pinv', or 'solve'} + Selects which inverter to use, unless all equations share the same A matrix, + in which case pinv is always used. + + * 'default': alias for 'pinv'. + * 'lsqr': uses numpy.linalg.lstsq to do an inversion-less solve. Usually + the fastest solver. + * 'solve': uses numpy.linalg.solve to do an inversion-less solve. Fastest, + but only works for fully constrained systems of equations. + * 'pinv': uses numpy.linalg.pinv to perform a pseudo-inverse and then + solves. Can sometimes be more numerically stable (but slower) than 'lsqr'. + + All of these modes are superceded if the same system of equations applies + to all datapoints in an array. In this case, a inverse-based method is + used so that the inverted matrix can be re-used to solve all array indices. + + Returns + ------- + sol + a dictionary of complex solutions with variables as keys """ sol_amp = self.ls_amp.solve(rcond=rcond, mode=mode) if self.ls_phs is not None: sol_phs = self.ls_phs.solve(rcond=rcond, mode=mode) - sol = {k: np.exp(sol_amp[k] + - np.complex64(1j) * sol_phs[k]).astype(self.dtype) - for k in sol_amp.keys()} + return { + k: np.exp(sol_amp[k] + np.complex64(1j) * sol_phs[k]).astype(self.dtype) + for k in sol_amp.keys() + } else: - sol = {k: np.exp(sol_amp[k]).astype(self.dtype) - for k in sol_amp.keys()} - return sol + return {k: np.exp(sol_amp[k]).astype(self.dtype) for k in sol_amp.keys()} -def taylor_expand(terms, consts={}, prepend='d'): - '''First-order Taylor expand terms (product of variables or the sum of a - product of variables) wrt all parameters except those listed in consts.''' - taylors = [] - for term in terms: taylors.append(term) + +def taylor_expand(terms, consts=None, prepend="d"): + """First-order Taylor expand terms. + + The product of variables or the sum of a product of variables wrt all parameters + except those listed in consts. + """ + if consts is None: + consts = {} + taylors = list(terms) for term in terms: - for i,t in enumerate(term): - if type(t) is not str or get_name(t) in consts: continue - taylors.append(term[:i]+[prepend+t]+term[i+1:]) + for i, t in enumerate(term): + if type(t) is not str or get_name(t) in consts: + continue + taylors.append(term[:i] + [prepend + t] + term[i + 1 :]) return taylors # XXX make a version of linproductsolver that taylor expands in e^{a+bi} form # see https://github.com/HERA-Team/linsolve/issues/15 class LinProductSolver: - def __init__(self, data, sol0, wgts={}, sparse=False, **kwargs): - """Set up a nonlinear system of equations of the form a*b + c*d = 1.0 - to linearize via Taylor expansion and solve iteratively using the Gauss-Newton algorithm. - - Args: - data: Dictionary that maps nonlinear product equations, written as valid python-interpetable - strings that include the variables in question, to (complex) numbers or numpy arrarys. - Variables with trailing underscores '_' are interpreted as complex conjugates (e.g. x*y_ - parses as x * y.conj()). - sol0: Dictionary mapping all variables (as keyword strings) to their starting guess values. - This is the point that is Taylor expanded around, so it must be relatively close to the - true chi^2 minimizing solution. In the same format as that produced by - linsolve.LogProductSolver.solve() or linsolve.LinProductSolver.solve(). - wgts: Dictionary that maps equation strings from data to real weights to apply to each - equation. Weights are treated as 1/sigma^2. All equations in the data must have a weight - if wgts is not the default, {}, which means all 1.0s. - sparse: Boolean (default False). If True, represents A matrix sparsely (though AtA, Aty end up dense) - May be faster for certain systems of equations. - **kwargs: keyword arguments of constants (python variables in keys of data that - are not to be solved for) - - Returns: - None + """Set up a nonlinear system of equations of the form a*b + c*d = 1.0. + + Linearize via Taylor expansion and solve iteratively using the Gauss-Newton + algorithm. + + Parameters + ---------- + data + dict that maps nonlinear product equations, written as valid + python-interpetable strings that include the variables in question, to + (complex) numbers or numpy arrarys. Variables with trailing underscores '_' + are interpreted as complex conjugates (e.g. x*y_ parses as x * y.conj()). + sol0 + dict mapping all variables (as keyword strings) to their starting guess + values. This is the point that is Taylor expanded around, so it must be + relatively close to the true chi^2 minimizing solution. In the same format + as that produced by :func:`~LogProductSolver.solve()` or + :func:`~LinProductSolver.solve()`. + wgts + dict that maps equation strings from data to real weights to apply to each + equation. Weights are treated as 1/sigma^2. All equations in the data must + have a weight if wgts is not the default, {}, which means all 1.0s. + sparse : bool + If True, represents A matrix sparsely (though AtA, Aty end up dense) + May be faster for certain systems of equations. + **kwargs: keyword arguments of constants (python variables in keys of data that + are not to be solved for) """ # XXX make this something hard to collide with # see https://github.com/HERA-Team/linsolve/issues/17 - self.prepend = 'd' + self.prepend = "d" self.data, self.sparse, self.keys = data, sparse, list(data.keys()) self.wgts = verify_weights(wgts, self.keys) - constants = kwargs.pop('constants', kwargs) + constants = kwargs.pop("constants", kwargs) self.init_kwargs, self.sols_kwargs = constants, deepcopy(constants) self.sols_kwargs.update(sol0) self.all_terms, self.taylors, self.taylor_keys = self.gen_taylors() - self.build_solver(sol0) + self.build_solver(sol0) self.dtype = self.ls.dtype - + def gen_taylors(self, keys=None): - '''Parses all terms, performs a taylor expansion, and maps equation keys to taylor expansion keys.''' - if keys is None: keys = self.keys - all_terms = [ast_getterms(ast.parse(k, mode='eval')) for k in keys] + """Perform Taylor expansion, and map eq. keys to taylor expansion keys.""" + if keys is None: + keys = self.keys + all_terms = [ast_getterms(ast.parse(k, mode="eval")) for k in keys] taylors, taylor_keys = [], {} for terms, k in zip(all_terms, keys): taylor = taylor_expand(terms, self.init_kwargs, prepend=self.prepend) taylors.append(taylor) - taylor_keys[k] = jointerms(taylor[len(terms):]) + taylor_keys[k] = jointerms(taylor[len(terms) :]) return all_terms, taylors, taylor_keys def build_solver(self, sol0): - '''Builds a LinearSolver using the taylor expansions and all relevant constants. - Update it with the latest solutions.''' + """Builds a LinearSolver using the taylor expansions and all relevant constants. + + Update it with the latest solutions. + """ dlin, wlin = {}, {} for k in self.keys: tk = self.taylor_keys[k] - dlin[tk] = self.data[k] #in theory, this will always be replaced with data - ans0 before use - try: + # in theory, this will always be replaced with data - ans0 before use + dlin[tk] = self.data[k] + try: wlin[tk] = self.wgts[k] - except(KeyError): + except KeyError: pass - self.ls = LinearSolver(dlin, wgts=wlin, sparse=self.sparse, constants=self.sols_kwargs) - self.eq_dict = {eq.val: eq for eq in self.ls.eqs} #maps taylor string expressions to linear equations - #Now make sure every taylor equation has every relevant constant, even if they don't appear in the derivative terms. - for k,terms in zip(self.keys, self.all_terms): + self.ls = LinearSolver( + dlin, wgts=wlin, sparse=self.sparse, constants=self.sols_kwargs + ) + self.eq_dict = { + eq.val: eq for eq in self.ls.eqs + } # maps taylor string expressions to linear equations + # Now make sure every taylor equation has every relevant constant, even if + # they don't appear in the derivative terms. + for k, terms in zip(self.keys, self.all_terms): for term in terms: for t in term: t_name = get_name(t) if t_name in self.sols_kwargs: - self.eq_dict[self.taylor_keys[k]].add_const(t_name, self.sols_kwargs) + self.eq_dict[self.taylor_keys[k]].add_const( + t_name, self.sols_kwargs + ) self._update_solver(sol0) def _update_solver(self, sol): - '''Update all constants in the internal LinearSolver and its LinearEquations based on new solutions. - Also update the residuals (data - ans0) for next iteration.''' + """Update the solver. + + Updates all constants in the internal LinearSolver and its LinearEquations + based on new solutions. Also update the residuals (data - ans0) for next + iteration. + """ self.sol0 = sol self.sols_kwargs.update(sol) for eq in self.ls.eqs: - for c in list(eq.consts.values()): - if c.name in sol: eq.consts[c.name].val = self.sols_kwargs[c.name] + for c in list(eq.consts.values()): + if c.name in sol: + eq.consts[c.name].val = self.sols_kwargs[c.name] self.ls.consts.update(eq.consts) ans0 = self._get_ans0(sol) - for k in ans0: self.ls.data[self.taylor_keys[k]] = self.data[k]-ans0[k] + for k in ans0: + self.ls.data[self.taylor_keys[k]] = self.data[k] - ans0[k] def _get_ans0(self, sol, keys=None): - '''Evaluate the system of equations given input sol. - Specify keys to evaluate only a subset of the equations.''' - if keys is None: + """Evaluate the system of equations given input sol. + + Specify keys to evaluate only a subset of the equations. + """ + if keys is None: keys = self.keys all_terms = self.all_terms taylors = self.taylors else: all_terms, taylors, _ = self.gen_taylors(keys) ans0 = {} - for k,taylor,terms in zip(keys,taylors,all_terms): + for k, taylor, terms in zip(keys, taylors, all_terms): eq = self.eq_dict[self.taylor_keys[k]] - ans0[k] = np.sum([eq.eval_consts(t) for t in taylor[:len(terms)]], axis=0) + ans0[k] = np.sum([eq.eval_consts(t) for t in taylor[: len(terms)]], axis=0) return ans0 - def solve(self, rcond=None, mode='default'): - '''Executes one iteration of a LinearSolver on the taylor-expanded system of - equations, improving sol0 to get sol. - - Args: - rcond: cutoff ratio for singular values useed in numpy.linalg.lstsq, numpy.linalg.pinv, - or (if sparse) as atol and btol in scipy.sparse.linalg.lsqr - Default: None (resolves to machine precision for inferred dtype) - mode: 'default', 'lsqr', 'pinv', or 'solve', selects which inverter to use, unless all equations share the same A matrix, in which case pinv is always used`. - 'default': alias for 'pinv'. - 'lsqr': uses numpy.linalg.lstsq to do an inversion-less solve. Usually - the fastest solver. - 'solve': uses numpy.linalg.solve to do an inversion-less solve. Fastest, - but only works for fully constrained systems of equations. - 'pinv': uses numpy.linalg.pinv to perform a pseudo-inverse and then solves. Can - sometimes be more numerically stable (but slower) than 'lsqr'. - All of these modes are superceded if the same system of equations applies - to all datapoints in an array. In this case, a inverse-based method is used so - that the inverted matrix can be re-used to solve all array indices. - - Returns: - sol: a dictionary of complex solutions with variables as keys - ''' + def solve(self, rcond=None, mode="default"): + """Execute one iteration of a LinearSolver. + + Executes on the taylor-expanded system of equations, improving sol0 to get sol. + + Parameters + ---------- + rcond + cutoff ratio for singular values useed in :func:`numpy.linalg.lstsq`, + :func:`numpy.linalg.pinv`, or (if sparse) as atol and btol in + :func:`scipy.sparse.linalg.lsqr`. Default is to resolve to machine precision + for inferred dtype. + mode : {}'default', 'lsqr', 'pinv', or 'solve'} + Selects which inverter to use, unless all equations share the same A matrix, + in which case pinv is always used`. + + * 'default': alias for 'pinv'. + * 'lsqr': uses numpy.linalg.lstsq to do an inversion-less solve. Usually + the fastest solver. + * 'solve': uses numpy.linalg.solve to do an inversion-less solve. Fastest, + but only works for fully constrained systems of equations. + * 'pinv': uses numpy.linalg.pinv to perform a pseudo-inverse and then + solves. Can sometimes be more numerically stable (but slower) than 'lsqr'. + + All of these modes are superceded if the same system of equations applies + to all datapoints in an array. In this case, a inverse-based method is + used so that the inverted matrix can be re-used to solve all array indices. + + Returns + ------- + sol + a dictionary of complex solutions with variables as keys + """ dsol = self.ls.solve(rcond=rcond, mode=mode) sol = {} for dk in dsol: - k = dk[len(self.prepend):] + k = dk[len(self.prepend) :] sol[k] = self.sol0[k] + dsol[dk] return sol - + def eval(self, sol, keys=None): - '''Returns a dictionary evaluating data keys to the current values given sol and consts. - Uses the stored data object unless otherwise specified.''' - if type(keys) is str: keys = [keys] - elif type(keys) is dict: keys = list(keys.keys()) + """Get a dict mapping data keys to the current values given sol and consts. + + Uses the stored data object unless otherwise specified. + """ + if type(keys) is str: + keys = [keys] + elif type(keys) is dict: + keys = list(keys.keys()) return self._get_ans0(sol, keys=keys) - + def chisq(self, sol, data=None, wgts=None): - '''Compute Chi^2 = |obs - mod|^2 / sigma^2 for the specified solution. Weights are treated as 1/sigma^2. - wgts = {} means sigma = 1. Uses the stored data and weights unless otherwise overwritten.''' - if data is None: + """Compute ``Chi^2 = |obs - mod|^2 / sigma^2`` for the specified solution. + + Weights are treated as 1/sigma^2. wgts = {} means sigma = 1. Uses the stored + data and weights unless otherwise overwritten. + """ + if data is None: data = self.data - if wgts is None: + if wgts is None: wgts = self.wgts wgts = verify_weights(wgts, list(data.keys())) return self.ls._chisq(sol, data, wgts, self.eval) - def solve_iteratively(self, conv_crit=None, maxiter=50, mode='default', verbose=False): - """Repeatedly solves and updates linsolve until convergence or maxiter is reached. - Returns a meta object containing the number of iterations, chisq, and convergence criterion. - - Args: - conv_crit: A convergence criterion below which to stop iterating. - Converegence is measured L2-norm of the change in the solution of all the variables - divided by the L2-norm of the solution itself. - Default: None (resolves to machine precision for inferred dtype) - maxiter: An integer maximum number of iterations to perform before quitting. Default 50. - mode: 'default', 'lsqr', 'pinv', or 'solve', selects which inverter to use, unless all equations share the same A matrix, in which case pinv is always used`. - 'default': alias for 'pinv'. - 'lsqr': uses numpy.linalg.lstsq to do an inversion-less solve. Usually - the fastest solver. - 'solve': uses numpy.linalg.solve to do an inversion-less solve. Fastest, - but only works for fully constrained systems of equations. - 'pinv': uses numpy.linalg.pinv to perform a pseudo-inverse and then solves. Can - sometimes be more numerically stable (but slower) than 'lsqr'. - All of these modes are superceded if the same system of equations applies - to all datapoints in an array. In this case, a inverse-based method is used so - that the inverted matrix can be re-used to solve all array indices. - verbose: print information about iterations - - Returns: meta, sol - meta: a dictionary with metadata about the solution, including - iter: the number of iterations taken to reach convergence (or maxiter) - chisq: the chi^2 of the solution produced by the final iteration - conv_crit: the convergence criterion evaluated at the final iteration - sol: a dictionary of complex solutions with variables as keys + def solve_iteratively( + self, conv_crit=None, maxiter=50, mode="default", verbose=False + ): + """Repeatedly solve and update linsolve until convergence or maxiter is reached. + + Parameters + ---------- + conv_crit + A convergence criterion below which to stop iterating. + Converegence is measured L2-norm of the change in the solution of all the + variables divided by the L2-norm of the solution itself. + Default: None (resolves to machine precision for inferred dtype) + maxiter : int, optional (default 50) + An integer maximum number of iterations to perform before quitting. + mode : {}'default', 'lsqr', 'pinv', or 'solve'} + Selects which inverter to use, unless all equations share the same A matrix, + in which case pinv is always used`. + + * 'default': alias for 'pinv'. + * 'lsqr': uses numpy.linalg.lstsq to do an inversion-less solve. Usually + the fastest solver. + * 'solve': uses numpy.linalg.solve to do an inversion-less solve. Fastest, + but only works for fully constrained systems of equations. + * 'pinv': uses numpy.linalg.pinv to perform a pseudo-inverse and then + solves. Can sometimes be more numerically stable (but slower) than 'lsqr'. + + All of these modes are superceded if the same system of equations applies + to all datapoints in an array. In this case, a inverse-based method is + used so that the inverted matrix can be re-used to solve all array indices. + + verbose : bool + print information about iterations + + Returns + ------- + meta : dict + A dictionary with metadata about the solution, including + * iter: the number of iterations taken to reach convergence (or maxiter) + * chisq: the chi^2 of the solution produced by the final iteration + * conv_crit: the convergence criterion evaluated at the final iteration + sol : dict + A dictionary of complex solutions with variables as keys """ if conv_crit is None: conv_crit = np.finfo(self.dtype).resolution - for i in range(1,maxiter+1): + for i in range(1, maxiter + 1): if verbose: - print('Beginning iteration %d/%d' % (i,maxiter)) - # rcond=conv_crit works because you can't get better precision than the accuracy of your inversion - # and vice versa, there's no real point in inverting with greater precision than you are shooting for + print("Beginning iteration %d/%d" % (i, maxiter)) + # rcond=conv_crit works because you can't get better precision than the + # accuracy of your inversion and vice versa, there's no real point in + # inverting with greater precision than you are shooting for new_sol = self.solve(rcond=conv_crit, mode=mode) - deltas = [new_sol[k]-self.sol0[k] for k in new_sol.keys()] - conv = np.linalg.norm(deltas, axis=0) / np.linalg.norm(list(new_sol.values()),axis=0) + deltas = [new_sol[k] - self.sol0[k] for k in new_sol.keys()] + conv = np.linalg.norm(deltas, axis=0) / np.linalg.norm( + list(new_sol.values()), axis=0 + ) if np.all(conv < conv_crit) or i == maxiter: - meta = {'iter': i, 'chisq': self.chisq(new_sol), 'conv_crit': conv} + meta = {"iter": i, "chisq": self.chisq(new_sol), "conv_crit": conv} return meta, new_sol self._update_solver(new_sol) diff --git a/src/linsolve/version.py b/src/linsolve/version.py index 7a7812a..2f13b19 100644 --- a/src/linsolve/version.py +++ b/src/linsolve/version.py @@ -1,6 +1,7 @@ """Version module, kept only for backwards-compatibility.""" import warnings + from . import __version__ version_info = __version__ @@ -13,5 +14,5 @@ warnings.warn("You should not rely on this module any more. Just use __version__.") -if __name__ == '__main__': +if __name__ == "__main__": print(__version__) diff --git a/tests/benchmark_A_large_shared.py b/tests/benchmark_A_large_shared.py index 0b4c8d8..ab8bf79 100644 --- a/tests/benchmark_A_large_shared.py +++ b/tests/benchmark_A_large_shared.py @@ -1,22 +1,26 @@ -'''Benchmark a system of equations with a large number of independent -parameters and a modest number of parallel instances that allow the -inverted A matrix to be reused.''' -import linsolve +"""Benchmark a system of equations with a large number of independent parameters. + +Use a modest number of parallel instances that allow the inverted A matrix to be reused. +""" + +import random +import time + import numpy as np -import time, random + +import linsolve np.random.seed(0) NPRMS = 2000 NEQS = 5000 SIZE = 100 -sparse = False # sparse: 1.30 s, dense: 1.50 s +sparse = False # sparse: 1.30 s, dense: 1.50 s -prms = {'g%d' % i: np.arange(SIZE) for i in range(NPRMS)} +prms = {"g%d" % i: np.arange(SIZE) for i in range(NPRMS)} prm_list = list(prms.keys()) -eqs = [('+'.join(['%s'] * 5)) % tuple(random.sample(prm_list, 5)) - for i in range(NEQS)] +eqs = [("+".join(["%s"] * 5)) % tuple(random.sample(prm_list, 5)) for _ in range(NEQS)] data = {eq: eval(eq, prms) for eq in eqs} @@ -25,7 +29,7 @@ sol = ls.solve() t1 = time.time() -print('Solved in {}'.format(t1-t0)) +print(f"Solved in {t1 - t0}") for k in prm_list: - assert np.mean(np.abs(sol[k] - prms[k])**2) < 1e-3 + assert np.mean(np.abs(sol[k] - prms[k]) ** 2) < 1e-3 diff --git a/tests/benchmark_A_large_shared_sparse.py b/tests/benchmark_A_large_shared_sparse.py index 647d481..2402af9 100644 --- a/tests/benchmark_A_large_shared_sparse.py +++ b/tests/benchmark_A_large_shared_sparse.py @@ -1,23 +1,27 @@ -'''Benchmark a system of equations with a large number of independent -parameters and a modest number of parallel instances that allow the -inverted A matrix to be reused. In this case, we test the speedup -for using a sparse representation.''' -import linsolve +"""Benchmark a system of equations with a large number of independent parameters. + +Use a modest number of parallel instances that allow the inverted A matrix to be reused. +In this case, we test the speedup for using a sparse representation. +""" + +import random +import time + import numpy as np -import time, random + +import linsolve np.random.seed(0) NPRMS = 2000 NEQS = 5000 SIZE = 100 -sparse = True # sparse: 1.30 s, dense: 1.50 s +sparse = True # sparse: 1.30 s, dense: 1.50 s -prms = {'g%d' % i: np.arange(SIZE) for i in range(NPRMS)} +prms = {"g%d" % i: np.arange(SIZE) for i in range(NPRMS)} prm_list = list(prms.keys()) -eqs = [('+'.join(['%s'] * 5)) % tuple(random.sample(prm_list, 5)) - for i in range(NEQS)] +eqs = [("+".join(["%s"] * 5)) % tuple(random.sample(prm_list, 5)) for _ in range(NEQS)] data = {eq: eval(eq, prms) for eq in eqs} @@ -26,7 +30,7 @@ sol = ls.solve() t1 = time.time() -print('Solved in {}'.format(t1-t0)) +print(f"Solved in {t1 - t0}") for k in prm_list: - assert np.mean(np.abs(sol[k] - prms[k])**2) < 1e-3 + assert np.mean(np.abs(sol[k] - prms[k]) ** 2) < 1e-3 diff --git a/tests/benchmark_A_small_shared.py b/tests/benchmark_A_small_shared.py index bf34580..ace0f28 100644 --- a/tests/benchmark_A_small_shared.py +++ b/tests/benchmark_A_small_shared.py @@ -1,9 +1,14 @@ -'''Benchmark a system of equations with a small number of independent -parameters and a large number of parallel instances that allow the -inverted A matrix to be reused.''' -import linsolve +"""Benchmark a system of equations with a small number of independent parameters. + +Use a large number of parallel instances that allow the inverted A matrix to be reused. +""" + +import random +import time + import numpy as np -import time, random + +import linsolve np.random.seed(0) @@ -12,11 +17,10 @@ SIZE = 1000000 # dense: 0.48 s -prms = {'g%d' % i: np.arange(SIZE) for i in range(NPRMS)} +prms = {"g%d" % i: np.arange(SIZE) for i in range(NPRMS)} prm_list = list(prms.keys()) -eqs = [('+'.join(['%s'] * 5)) % tuple(random.sample(prm_list, 5)) - for i in range(NEQS)] +eqs = [("+".join(["%s"] * 5)) % tuple(random.sample(prm_list, 5)) for _ in range(NEQS)] data = {eq: eval(eq, prms) for eq in eqs} @@ -25,7 +29,7 @@ sol = ls.solve() t1 = time.time() -print('Solved in {}'.format(t1-t0)) +print(f"Solved in {t1 - t0}") for k in prm_list: - assert np.mean(np.abs(sol[k] - prms[k])**2) < 1e-3 + assert np.mean(np.abs(sol[k] - prms[k]) ** 2) < 1e-3 diff --git a/tests/benchmark_A_small_shared_sparse.py b/tests/benchmark_A_small_shared_sparse.py index 56552d5..948ac1b 100644 --- a/tests/benchmark_A_small_shared_sparse.py +++ b/tests/benchmark_A_small_shared_sparse.py @@ -1,10 +1,15 @@ -'''Benchmark a system of equations with a small number of independent -parameters and a large number of parallel instances that allow the -inverted A matrix to be reused. In this case, we benchmark the -case when a sparse representation of A is used.''' -import linsolve +"""Benchmark a system of equations with a small number of independent parameters. + +Use a large number of parallel instances that allow the inverted A matrix to be reused. +In this case, we benchmark the case when a sparse representation of A is used. +""" + +import random +import time + import numpy as np -import time, random + +import linsolve np.random.seed(0) @@ -13,11 +18,10 @@ SIZE = 1000000 # sparse: 0.82 s -prms = {'g%d' % i: np.arange(SIZE) for i in range(NPRMS)} +prms = {"g%d" % i: np.arange(SIZE) for i in range(NPRMS)} prm_list = list(prms.keys()) -eqs = [('+'.join(['%s'] * 3)) % tuple(random.sample(prm_list, 3)) - for i in range(NEQS)] +eqs = [("+".join(["%s"] * 3)) % tuple(random.sample(prm_list, 3)) for _ in range(NEQS)] data = {eq: eval(eq, prms) for eq in eqs} @@ -26,7 +30,7 @@ sol = ls.solve() t1 = time.time() -print('Solved in {}'.format(t1-t0)) +print(f"Solved in {t1 - t0}") for k in prm_list: - assert np.mean(np.abs(sol[k] - prms[k])**2) < 1e-3 + assert np.mean(np.abs(sol[k] - prms[k]) ** 2) < 1e-3 diff --git a/tests/benchmark_A_small_unique.py b/tests/benchmark_A_small_unique.py index cfc7248..949b15d 100644 --- a/tests/benchmark_A_small_unique.py +++ b/tests/benchmark_A_small_unique.py @@ -1,34 +1,40 @@ -'''Benchmark a system of equations with a small number of independent -parameters and a large number of instances with different coefficients -so that inversions of the A matrix cannot be reused.''' -import linsolve +"""Benchmark a system of equations with a small number of independent parameters. + +Use a large number of instances with different coefficients so that inversions of the +A matrix cannot be reused. +""" +import random +import time + import numpy as np -import time, random + +import linsolve np.random.seed(0) NPRMS = 10 NEQS = 100 SIZE = 100000 -#MODE = 'solve' # dense:1.56 s -MODE = 'lsqr' # dense:5.4 s -#MODE = 'pinv' # dense:3.4 s +# --> MODE = 'solve' # dense:1.56 s +MODE = "lsqr" # dense:5.4 s +# --> MODE = 'pinv' # dense:3.4 s -prms = {'g%d' % i: np.arange(SIZE) for i in range(NPRMS)} +prms = {"g%d" % i: np.arange(SIZE) for i in range(NPRMS)} prm_list = list(prms.keys()) -prms['c0'] = np.arange(SIZE) +prms["c0"] = np.arange(SIZE) -eqs = [('+c0*'.join(['%s'] * 5)) % tuple(random.sample(prm_list, 5)) - for i in range(NEQS)] +eqs = [ + ("+c0*".join(["%s"] * 5)) % tuple(random.sample(prm_list, 5)) for i in range(NEQS) +] data = {eq: eval(eq, prms) for eq in eqs} -ls = linsolve.LinearSolver(data, c0=prms['c0'], sparse=False) +ls = linsolve.LinearSolver(data, c0=prms["c0"], sparse=False) t0 = time.time() sol = ls.solve(mode=MODE) t1 = time.time() -print('Solved in {}'.format(t1-t0)) +print(f"Solved in {t1 - t0}") for k in prm_list: - assert np.mean(np.abs(sol[k] - prms[k])**2) < 1e-3 + assert np.mean(np.abs(sol[k] - prms[k]) ** 2) < 1e-3 diff --git a/tests/benchmark_A_small_unique_sparse.py b/tests/benchmark_A_small_unique_sparse.py index 3a4df55..a61c7aa 100644 --- a/tests/benchmark_A_small_unique_sparse.py +++ b/tests/benchmark_A_small_unique_sparse.py @@ -1,35 +1,42 @@ -'''Benchmark a system of equations with a small number of independent -parameters and a large number of instances with different coefficients -so that inversions of the A matrix cannot be reused. In this case, we -benchmark the case when a sparse representation of A is used.''' -import linsolve +"""Benchmark a system of equations with a small number of independent parameters. + +Use a large number of instances with different coefficients so that inversions of the +A matrix cannot be reused. In this case, we benchmark the case when a sparse +representation of A is used. +""" +import random +import time + import numpy as np -import time, random + +import linsolve np.random.seed(0) NPRMS = 10 NEQS = 100 SIZE = 100000 -MODE = 'solve' # sparse: 0.43 s -#MODE = 'lsqr' # sparse: 3.8 s -#MODE = 'pinv' # sparse: 2.72 s +MODE = "solve" # sparse: 0.43 s + +# --> MODE = 'lsqr' # sparse: 3.8 +# --> MODE = 'pinv' # sparse: 2.72 s -prms = {'g%d' % i: np.arange(SIZE) for i in range(NPRMS)} +prms = {"g%d" % i: np.arange(SIZE) for i in range(NPRMS)} prm_list = list(prms.keys()) -prms['c0'] = np.arange(SIZE) +prms["c0"] = np.arange(SIZE) -eqs = [('+c0*'.join(['%s'] * 2)) % tuple(random.sample(prm_list, 2)) - for i in range(NEQS)] +eqs = [ + ("+c0*".join(["%s"] * 2)) % tuple(random.sample(prm_list, 2)) for i in range(NEQS) +] data = {eq: eval(eq, prms) for eq in eqs} -ls = linsolve.LinearSolver(data, c0=prms['c0'], sparse=True) +ls = linsolve.LinearSolver(data, c0=prms["c0"], sparse=True) t0 = time.time() sol = ls.solve(mode=MODE) t1 = time.time() -print('Solved in {}'.format(t1-t0)) +print(f"Solved in {t1 - t0}") for k in prm_list: - assert np.mean(np.abs(sol[k] - prms[k])**2) < 1e-3 + assert np.mean(np.abs(sol[k] - prms[k]) ** 2) < 1e-3 diff --git a/tests/test_linsolve.py b/tests/test_linsolve.py index e445a2c..09928fa 100644 --- a/tests/test_linsolve.py +++ b/tests/test_linsolve.py @@ -1,432 +1,456 @@ +"""Test the linsolve module.""" +import ast +import io +import sys + +import numpy as np import pytest + import linsolve -import numpy as np -import ast, io, sys -class TestLinSolve(): +class TestLinSolve: def test_ast_getterms(self): - n = ast.parse('x+y',mode='eval') + n = ast.parse("x+y", mode="eval") terms = linsolve.ast_getterms(n) - assert terms == [['x'],['y']] - n = ast.parse('x-y',mode='eval') + assert terms == [["x"], ["y"]] + n = ast.parse("x-y", mode="eval") terms = linsolve.ast_getterms(n) - assert terms == [['x'],[-1,'y']] - n = ast.parse('3*x-y',mode='eval') + assert terms == [["x"], [-1, "y"]] + n = ast.parse("3*x-y", mode="eval") terms = linsolve.ast_getterms(n) - assert terms, [[3,'x'],[-1,'y']] + assert terms, [[3, "x"], [-1, "y"]] def test_unary(self): - n = ast.parse('-x+y',mode='eval') + n = ast.parse("-x+y", mode="eval") terms = linsolve.ast_getterms(n) - assert terms == [[-1,'x'],['y']] + assert terms == [[-1, "x"], ["y"]] def test_multiproducts(self): - n = ast.parse('a*x+a*b*c*y',mode='eval') + n = ast.parse("a*x+a*b*c*y", mode="eval") terms = linsolve.ast_getterms(n) - assert terms == [['a','x'],['a','b','c','y']] - n = ast.parse('-a*x+a*b*c*y',mode='eval') + assert terms == [["a", "x"], ["a", "b", "c", "y"]] + n = ast.parse("-a*x+a*b*c*y", mode="eval") terms = linsolve.ast_getterms(n) - assert terms == [[-1,'a','x'],['a','b','c','y']] - n = ast.parse('a*x-a*b*c*y',mode='eval') + assert terms == [[-1, "a", "x"], ["a", "b", "c", "y"]] + n = ast.parse("a*x-a*b*c*y", mode="eval") terms = linsolve.ast_getterms(n) - assert terms == [['a','x'],[-1,'a','b','c','y']] + assert terms == [["a", "x"], [-1, "a", "b", "c", "y"]] def test_taylorexpand(self): - terms = linsolve.taylor_expand([['x','y','z']],prepend='d') - assert terms == [['x','y','z'],['dx','y','z'],['x','dy','z'],['x','y','dz']] - terms = linsolve.taylor_expand([[1,'y','z']],prepend='d') - assert terms == [[1,'y','z'],[1,'dy','z'],[1,'y','dz']] - terms = linsolve.taylor_expand([[1,'y','z']],consts={'y':3}, prepend='d') - assert terms == [[1,'y','z'],[1,'y','dz']] + terms = linsolve.taylor_expand([["x", "y", "z"]], prepend="d") + assert terms == [ + ["x", "y", "z"], + ["dx", "y", "z"], + ["x", "dy", "z"], + ["x", "y", "dz"], + ] + terms = linsolve.taylor_expand([[1, "y", "z"]], prepend="d") + assert terms == [[1, "y", "z"], [1, "dy", "z"], [1, "y", "dz"]] + terms = linsolve.taylor_expand([[1, "y", "z"]], consts={"y": 3}, prepend="d") + assert terms == [[1, "y", "z"], [1, "y", "dz"]] def test_verify_weights(self): - assert linsolve.verify_weights({},['a']) == {'a':1} - assert linsolve.verify_weights(None,['a']) == {'a':1} - assert linsolve.verify_weights({'a':10.0},['a']) == {'a': 10.0} + assert linsolve.verify_weights({}, ["a"]) == {"a": 1} + assert linsolve.verify_weights(None, ["a"]) == {"a": 1} + assert linsolve.verify_weights({"a": 10.0}, ["a"]) == {"a": 10.0} with pytest.raises(AssertionError): - linsolve.verify_weights({'a':1.0+1.0j}, ['a']) + linsolve.verify_weights({"a": 1.0 + 1.0j}, ["a"]) with pytest.raises(AssertionError): - linsolve.verify_weights({'a':1.0}, ['a', 'b']) + linsolve.verify_weights({"a": 1.0}, ["a", "b"]) def test_infer_dtype(self): - assert linsolve.infer_dtype([1.,2.]) == np.float32 - assert linsolve.infer_dtype([3,4]) == np.float32 - assert linsolve.infer_dtype([np.float32(1),4]) == np.float32 - assert linsolve.infer_dtype([np.float64(1),4]) == np.float64 - assert linsolve.infer_dtype([np.float32(1),4j]) == np.complex64 - assert linsolve.infer_dtype([np.float64(1),4j]) == np.complex128 - assert linsolve.infer_dtype([np.complex64(1),4j]) == np.complex64 - assert linsolve.infer_dtype([np.complex64(1),4.]) == np.complex64 - assert linsolve.infer_dtype([np.complex128(1),np.float64(4.)]) == np.complex128 - assert linsolve.infer_dtype([np.complex64(1),np.float64(4.)]) == np.complex128 - assert linsolve.infer_dtype([np.complex64(1),np.int32(4.)]) == np.complex64 - assert linsolve.infer_dtype([np.complex64(1),np.int64(4.)]) == np.complex64 - -class TestLinearEquation(): - + assert linsolve.infer_dtype([1.0, 2.0]) == np.float32 + assert linsolve.infer_dtype([3, 4]) == np.float32 + assert linsolve.infer_dtype([np.float32(1), 4]) == np.float32 + assert linsolve.infer_dtype([np.float64(1), 4]) == np.float64 + assert linsolve.infer_dtype([np.float32(1), 4j]) == np.complex64 + assert linsolve.infer_dtype([np.float64(1), 4j]) == np.complex128 + assert linsolve.infer_dtype([np.complex64(1), 4j]) == np.complex64 + assert linsolve.infer_dtype([np.complex64(1), 4.0]) == np.complex64 + assert ( + linsolve.infer_dtype([np.complex128(1), np.float64(4.0)]) == np.complex128 + ) + assert linsolve.infer_dtype([np.complex64(1), np.float64(4.0)]) == np.complex128 + assert linsolve.infer_dtype([np.complex64(1), np.int32(4.0)]) == np.complex64 + assert linsolve.infer_dtype([np.complex64(1), np.int64(4.0)]) == np.complex64 + + +class TestLinearEquation: def test_basics(self): - le = linsolve.LinearEquation('x+y') - assert le.terms == [['x'],['y']] + le = linsolve.LinearEquation("x+y") + assert le.terms == [["x"], ["y"]] assert le.consts == {} assert len(le.prms) == 2 - le = linsolve.LinearEquation('x-y') - assert le.terms == [['x'],[-1,'y']] - le = linsolve.LinearEquation('a*x+b*y',a=1,b=2) - assert le.terms == [['a','x'],['b','y']] - assert 'a' in le.consts - assert 'b' in le.consts + le = linsolve.LinearEquation("x-y") + assert le.terms == [["x"], [-1, "y"]] + le = linsolve.LinearEquation("a*x+b*y", a=1, b=2) + assert le.terms == [["a", "x"], ["b", "y"]] + assert "a" in le.consts + assert "b" in le.consts assert len(le.prms) == 2 - le = linsolve.LinearEquation('a*x-b*y',a=1,b=2) - assert le.terms == [['a','x'],[-1,'b','y']] + le = linsolve.LinearEquation("a*x-b*y", a=1, b=2) + assert le.terms == [["a", "x"], [-1, "b", "y"]] def test_more(self): - consts = {'g5':1,'g1':1} - for k in ['g5*bl95', 'g1*bl111', 'g1*bl103']: - le = linsolve.LinearEquation(k,**consts) - le.terms[0][0][0] == 'g' + consts = {"g5": 1, "g1": 1} + for k in ["g5*bl95", "g1*bl111", "g1*bl103"]: + le = linsolve.LinearEquation(k, **consts) + le.terms[0][0][0] == "g" def test_unary(self): - le = linsolve.LinearEquation('-a*x-b*y',a=1,b=2) - assert le.terms, [[-1,'a','x'],[-1,'b','y']] + le = linsolve.LinearEquation("-a*x-b*y", a=1, b=2) + assert le.terms, [[-1, "a", "x"], [-1, "b", "y"]] def test_order_terms(self): - le = linsolve.LinearEquation('x+y') - terms = [[1,1,'x'],[1,1,'y']] - assert terms == le.order_terms([[1,1,'x'],[1,1,'y']]) - terms2 = [[1,1,'x'],[1,'y',1]] - assert terms == le.order_terms([[1,1,'x'],[1,'y',1]]) - le = linsolve.LinearEquation('a*x-b*y',a=2,b=4) - terms = [[1,'a','x'],[1,'b','y']] - assert terms == le.order_terms([[1,'a','x'],[1,'b','y']]) - terms2 = [[1,'x','a'],[1,'b','y']] - assert terms == le.order_terms([[1,'x','a'],[1,'b','y']]) - le = linsolve.LinearEquation('g5*bl95+g1*bl111',g5=1,g1=1) - terms = [['g5','bl95'],['g1','bl111']] - assert terms == le.order_terms([['g5','bl95'],['g1','bl111']]) + le = linsolve.LinearEquation("x+y") + terms = [[1, 1, "x"], [1, 1, "y"]] + assert terms == le.order_terms([[1, 1, "x"], [1, 1, "y"]]) + terms2 = [[1, 1, "x"], [1, "y", 1]] + assert terms == le.order_terms([[1, 1, "x"], [1, "y", 1]]) + le = linsolve.LinearEquation("a*x-b*y", a=2, b=4) + terms = [[1, "a", "x"], [1, "b", "y"]] + assert terms == le.order_terms([[1, "a", "x"], [1, "b", "y"]]) + terms2 = [[1, "x", "a"], [1, "b", "y"]] + assert terms == le.order_terms([[1, "x", "a"], [1, "b", "y"]]) + le = linsolve.LinearEquation("g5*bl95+g1*bl111", g5=1, g1=1) + terms = [["g5", "bl95"], ["g1", "bl111"]] + assert terms == le.order_terms([["g5", "bl95"], ["g1", "bl111"]]) def test_term_check(self): - le = linsolve.LinearEquation('a*x-b*y',a=2,b=4) - terms = [[1,'a','x'],[1,'b','y']] - assert terms == le.order_terms([[1,'a','x'],[1,'b','y']]) - terms4 = [['c','x','a'],[1,'b','y']] + le = linsolve.LinearEquation("a*x-b*y", a=2, b=4) + terms = [[1, "a", "x"], [1, "b", "y"]] + assert terms == le.order_terms([[1, "a", "x"], [1, "b", "y"]]) + terms4 = [["c", "x", "a"], [1, "b", "y"]] with pytest.raises(AssertionError): le.order_terms(terms4) - terms5 = [[1,'a','b'],[1,'b','y']] + terms5 = [[1, "a", "b"], [1, "b", "y"]] with pytest.raises(AssertionError): le.order_terms(terms5) def test_eval(self): - le = linsolve.LinearEquation('a*x-b*y',a=2,b=4) - sol = {'x':3, 'y':7} - assert 2*3-4*7 == le.eval(sol) - sol = {'x':3*np.ones(4), 'y':7*np.ones(4)} - np.testing.assert_equal(2*3-4*7, le.eval(sol)) - le = linsolve.LinearEquation('x_-y') - sol = {'x':3+3j*np.ones(10), 'y':7+2j*np.ones(10)} - ans = np.conj(sol['x']) - sol['y'] + le = linsolve.LinearEquation("a*x-b*y", a=2, b=4) + sol = {"x": 3, "y": 7} + assert 2 * 3 - 4 * 7 == le.eval(sol) + sol = {"x": 3 * np.ones(4), "y": 7 * np.ones(4)} + np.testing.assert_equal(2 * 3 - 4 * 7, le.eval(sol)) + le = linsolve.LinearEquation("x_-y") + sol = {"x": 3 + 3j * np.ones(10), "y": 7 + 2j * np.ones(10)} + ans = np.conj(sol["x"]) - sol["y"] np.testing.assert_equal(ans, le.eval(sol)) - -class TestLinearSolver(): +class TestLinearSolver: def setup(self): self.sparse = False - eqs = ['x+y','x-y'] - x,y = 1,2 - d,w = {}, {} + eqs = ["x+y", "x-y"] + x, y = 1, 2 + d, w = {}, {} for eq in eqs: - d[eq],w[eq] = eval(eq), 1. - self.ls = linsolve.LinearSolver(d,w,sparse=self.sparse) + d[eq], w[eq] = eval(eq), 1.0 + self.ls = linsolve.LinearSolver(d, w, sparse=self.sparse) def test_basics(self): assert len(self.ls.prms) == 2 assert len(self.ls.eqs) == 2 - assert self.ls.eqs[0].terms == [['x'],['y']] - assert self.ls.eqs[1].terms == [['x'],[-1,'y']] + assert self.ls.eqs[0].terms == [["x"], ["y"]] + assert self.ls.eqs[1].terms == [["x"], [-1, "y"]] def test_get_A(self): - self.ls.prm_order = {'x':0,'y':1} # override random default ordering + self.ls.prm_order = {"x": 0, "y": 1} # override random default ordering A = self.ls.get_A() - assert A.shape == (2,2,1) - #np.testing.assert_equal(A.todense(), np.array([[1.,1],[1.,-1]])) - np.testing.assert_equal(A, np.array([[[1.], [1]],[[1.],[-1]]])) - - #def test_get_AtAiAt(self): - # self.ls.prm_order = {'x':0,'y':1} # override random default ordering - # AtAiAt = self.ls.get_AtAiAt().squeeze() - # #np.testing.assert_equal(AtAiAt.todense(), np.array([[.5,.5],[.5,-.5]])) - # #np.testing.assert_equal(AtAiAt, np.array([[.5,.5],[.5,-.5]])) - # measured = np.array([[3.],[-1]]) - # x,y = AtAiAt.dot(measured).flatten() - # self.assertAlmostEqual(x, 1.) - # self.assertAlmostEqual(y, 2.) + assert A.shape == (2, 2, 1) + np.testing.assert_equal(A, np.array([[[1.0], [1]], [[1.0], [-1]]])) def test_solve(self): sol = self.ls.solve() - np.testing.assert_almost_equal(sol['x'], 1.) - np.testing.assert_almost_equal(sol['y'], 2.) + np.testing.assert_almost_equal(sol["x"], 1.0) + np.testing.assert_almost_equal(sol["y"], 2.0) def test_solve_modes(self): - for mode in ['default','lsqr','pinv','solve']: + for mode in ["default", "lsqr", "pinv", "solve"]: sol = self.ls.solve(mode=mode) - np.testing.assert_almost_equal(sol['x'], 1.) - np.testing.assert_almost_equal(sol['y'], 2.) + np.testing.assert_almost_equal(sol["x"], 1.0) + np.testing.assert_almost_equal(sol["y"], 2.0) def test_solve_arrays(self): # range of 1 to 101 prevents "The exact solution is x = 0" printouts - x = np.arange(1,101,dtype=np.float64); x.shape = (10,10) - y = np.arange(1,101,dtype=np.float64); y.shape = (10,10) - eqs = ['2*x+y','-x+3*y'] - d,w = {}, {} + x = np.arange(1, 101, dtype=np.float64) + x.shape = (10, 10) + y = np.arange(1, 101, dtype=np.float64) + y.shape = (10, 10) + eqs = ["2*x+y", "-x+3*y"] + d, w = {}, {} for eq in eqs: - d[eq],w[eq] = eval(eq), 1. - ls = linsolve.LinearSolver(d,w, sparse=self.sparse) + d[eq], w[eq] = eval(eq), 1.0 + ls = linsolve.LinearSolver(d, w, sparse=self.sparse) sol = ls.solve() - np.testing.assert_almost_equal(sol['x'], x) - np.testing.assert_almost_equal(sol['y'], y) + np.testing.assert_almost_equal(sol["x"], x) + np.testing.assert_almost_equal(sol["y"], y) def test_solve_arrays_modes(self): # range of 1 to 101 prevents "The exact solution is x = 0" printouts - x = np.arange(1,101,dtype=np.float64); x.shape = (10,10) - y = np.arange(1,101,dtype=np.float64); y.shape = (10,10) - eqs = ['2*x+y','-x+3*y'] - d,w = {}, {} + x = np.arange(1, 101, dtype=np.float64) + x.shape = (10, 10) + y = np.arange(1, 101, dtype=np.float64) + y.shape = (10, 10) + eqs = ["2*x+y", "-x+3*y"] + d, w = {}, {} for eq in eqs: - d[eq],w[eq] = eval(eq), 1. - ls = linsolve.LinearSolver(d,w, sparse=self.sparse) - for mode in ['default','lsqr','pinv','solve']: + d[eq], w[eq] = eval(eq), 1.0 + ls = linsolve.LinearSolver(d, w, sparse=self.sparse) + for mode in ["default", "lsqr", "pinv", "solve"]: sol = ls.solve(mode=mode) - np.testing.assert_almost_equal(sol['x'], x) - np.testing.assert_almost_equal(sol['y'], y) + np.testing.assert_almost_equal(sol["x"], x) + np.testing.assert_almost_equal(sol["y"], y) def test_A_shape(self): # range of 1 to 11 prevents "The exact solution is x = 0" printouts - consts = {'a':np.arange(1,11), 'b':np.zeros((1,10))} - ls = linsolve.LinearSolver({'a*x+b*y':0.},{'a*x+b*y':1},**consts) - assert ls._A_shape() == (1,2,10*10) + consts = {"a": np.arange(1, 11), "b": np.zeros((1, 10))} + ls = linsolve.LinearSolver({"a*x+b*y": 0.0}, {"a*x+b*y": 1}, **consts) + assert ls._A_shape() == (1, 2, 10 * 10) def test_const_arrays(self): - x,y = 1.,2. - a = np.array([3.,4,5]) - b = np.array([1.,2,3]) - eqs = ['a*x+y','x+b*y'] - d,w = {}, {} - for eq in eqs: d[eq],w[eq] = eval(eq), 1. - ls = linsolve.LinearSolver(d,w,a=a,b=b, sparse=self.sparse) + x, y = 1.0, 2.0 + a = np.array([3.0, 4, 5]) + b = np.array([1.0, 2, 3]) + eqs = ["a*x+y", "x+b*y"] + d, w = {}, {} + for eq in eqs: + d[eq], w[eq] = eval(eq), 1.0 + ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse) sol = ls.solve() - np.testing.assert_almost_equal(sol['x'], x*np.ones(3,dtype=np.float64)) - np.testing.assert_almost_equal(sol['y'], y*np.ones(3,dtype=np.float64)) + np.testing.assert_almost_equal(sol["x"], x * np.ones(3, dtype=np.float64)) + np.testing.assert_almost_equal(sol["y"], y * np.ones(3, dtype=np.float64)) def test_wgt_arrays(self): - x,y = 1.,2. - a,b = 3.,1. - eqs = ['a*x+y','x+b*y'] - d,w = {}, {} + x, y = 1.0, 2.0 + a, b = 3.0, 1.0 + eqs = ["a*x+y", "x+b*y"] + d, w = {}, {} for eq in eqs: - d[eq],w[eq] = eval(eq), np.ones(4) - ls = linsolve.LinearSolver(d,w,a=a,b=b, sparse=self.sparse) + d[eq], w[eq] = eval(eq), np.ones(4) + ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse) sol = ls.solve() - np.testing.assert_almost_equal(sol['x'], x*np.ones(4,dtype=np.float64)) - np.testing.assert_almost_equal(sol['y'], y*np.ones(4,dtype=np.float64)) + np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64)) + np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64)) def test_wgt_const_arrays(self): - x,y = 1.,2. - a,b = 3.*np.ones(4),1. - eqs = ['a*x+y','x+b*y'] - d,w = {}, {} + x, y = 1.0, 2.0 + a, b = 3.0 * np.ones(4), 1.0 + eqs = ["a*x+y", "x+b*y"] + d, w = {}, {} for eq in eqs: - d[eq],w[eq] = eval(eq)*np.ones(4), np.ones(4) - ls = linsolve.LinearSolver(d,w,a=a,b=b, sparse=self.sparse) + d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4) + ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse) sol = ls.solve() - np.testing.assert_almost_equal(sol['x'], x*np.ones(4,dtype=np.float64)) - np.testing.assert_almost_equal(sol['y'], y*np.ones(4,dtype=np.float64)) + np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64)) + np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64)) def test_nonunity_wgts(self): - x,y = 1.,2. - a,b = 3.*np.ones(4),1. - eqs = ['a*x+y','x+b*y'] - d,w = {}, {} - for eq in eqs: d[eq],w[eq] = eval(eq)*np.ones(4), 2*np.ones(4) - ls = linsolve.LinearSolver(d,w,a=a,b=b, sparse=self.sparse) + x, y = 1.0, 2.0 + a, b = 3.0 * np.ones(4), 1.0 + eqs = ["a*x+y", "x+b*y"] + d, w = {}, {} + for eq in eqs: + d[eq], w[eq] = eval(eq) * np.ones(4), 2 * np.ones(4) + ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse) sol = ls.solve() - np.testing.assert_almost_equal(sol['x'], x*np.ones(4,dtype=np.float64)) - np.testing.assert_almost_equal(sol['y'], y*np.ones(4,dtype=np.float64)) + np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64)) + np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64)) def test_eval(self): - x,y = 1.,2. - a,b = 3.*np.ones(4),1. - eqs = ['a*x+y','x+b*y'] - d,w = {}, {} + x, y = 1.0, 2.0 + a, b = 3.0 * np.ones(4), 1.0 + eqs = ["a*x+y", "x+b*y"] + d, w = {}, {} for eq in eqs: - d[eq],w[eq] = eval(eq)*np.ones(4), np.ones(4) - ls = linsolve.LinearSolver(d,w,a=a,b=b, sparse=self.sparse) + d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4) + ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse) sol = ls.solve() - np.testing.assert_almost_equal(sol['x'], x*np.ones(4,dtype=np.float64)) - np.testing.assert_almost_equal(sol['y'], y*np.ones(4,dtype=np.float64)) + np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64)) + np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64)) result = ls.eval(sol) for eq in d: np.testing.assert_almost_equal(d[eq], result[eq]) - result = ls.eval(sol, 'a*x+b*y') - np.testing.assert_almost_equal(3*1+1*2, list(result.values())[0]) + result = ls.eval(sol, "a*x+b*y") + np.testing.assert_almost_equal(3 * 1 + 1 * 2, list(result.values())[0]) def test_chisq(self): - x = 1. - d = {'x':1, 'a*x':2} - ls = linsolve.LinearSolver(d,a=1.0, sparse=self.sparse) + x = 1.0 + d = {"x": 1, "a*x": 2} + ls = linsolve.LinearSolver(d, a=1.0, sparse=self.sparse) sol = ls.solve() chisq = ls.chisq(sol) - np.testing.assert_almost_equal(chisq, .5) - x = 1. - d = {'x':1, '1.0*x':2} + np.testing.assert_almost_equal(chisq, 0.5) + x = 1.0 + d = {"x": 1, "1.0*x": 2} ls = linsolve.LinearSolver(d, sparse=self.sparse) sol = ls.solve() chisq = ls.chisq(sol) - np.testing.assert_almost_equal(chisq, .5) - x = 1. - d = {'1*x': 2.0, 'x': 1.0} - w = {'1*x': 1.0, 'x': .5} + np.testing.assert_almost_equal(chisq, 0.5) + x = 1.0 + d = {"1*x": 2.0, "x": 1.0} + w = {"1*x": 1.0, "x": 0.5} ls = linsolve.LinearSolver(d, wgts=w, sparse=self.sparse) sol = ls.solve() chisq = ls.chisq(sol) - np.testing.assert_almost_equal(sol['x'], 5.0/3.0, 6) - np.testing.assert_almost_equal(ls.chisq(sol), 1.0/3.0) + np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6) + np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0) def test_dtypes(self): - ls = linsolve.LinearSolver({'x_': 1.0+1.0j}, sparse=self.sparse) + ls = linsolve.LinearSolver({"x_": 1.0 + 1.0j}, sparse=self.sparse) # conjugation should trigger re_im_split, splitting the # complex64 type into two float32 types assert ls.dtype == np.float32 - assert type(ls.solve()['x']) == np.complex64 + assert type(ls.solve()["x"]) == np.complex64 - ls = linsolve.LinearSolver({'x': 1.0+1.0j}, sparse=self.sparse) + ls = linsolve.LinearSolver({"x": 1.0 + 1.0j}, sparse=self.sparse) assert ls.dtype == np.complex64 - assert type(ls.solve()['x']) == np.complex64 + assert type(ls.solve()["x"]) == np.complex64 - ls = linsolve.LinearSolver({'x_': np.ones(1,dtype=np.complex64)[0]}, sparse=self.sparse) + ls = linsolve.LinearSolver( + {"x_": np.ones(1, dtype=np.complex64)[0]}, sparse=self.sparse + ) # conjugation should trigger re_im_split, splitting the # complex64 type into two float32 types assert ls.dtype == np.float32 - assert type(ls.solve()['x']) == np.complex64 + assert type(ls.solve()["x"]) == np.complex64 - ls = linsolve.LinearSolver({'x': np.ones(1,dtype=np.complex64)[0]}, sparse=self.sparse) - assert ls.dtype,np.complex64 - assert type(ls.solve()['x']) == np.complex64 + ls = linsolve.LinearSolver( + {"x": np.ones(1, dtype=np.complex64)[0]}, sparse=self.sparse + ) + assert ls.dtype, np.complex64 + assert type(ls.solve()["x"]) == np.complex64 - ls = linsolve.LinearSolver({'c*x': np.array(1.0, dtype=np.float32)}, c=1.0+1.0j, sparse=self.sparse) + ls = linsolve.LinearSolver( + {"c*x": np.array(1.0, dtype=np.float32)}, c=1.0 + 1.0j, sparse=self.sparse + ) assert ls.dtype == np.complex64 - assert type(ls.solve()['x']) == np.complex64 + assert type(ls.solve()["x"]) == np.complex64 - d = {'c*x': np.ones(1,dtype=np.float32)[0]} - wgts = {'c*x': np.ones(1,dtype=np.float64)[0]} - c = np.ones(1,dtype=np.float32)[0] + d = {"c*x": np.ones(1, dtype=np.float32)[0]} + wgts = {"c*x": np.ones(1, dtype=np.float64)[0]} + c = np.ones(1, dtype=np.float32)[0] ls = linsolve.LinearSolver(d, wgts=wgts, c=c, sparse=self.sparse) assert ls.dtype == np.float64 - assert type(ls.solve()['x']) == np.float64 + assert type(ls.solve()["x"]) == np.float64 - d = {'c*x': np.ones(1,dtype=np.float32)[0]} - wgts = {'c*x': np.ones(1,dtype=np.float32)[0]} - c = np.ones(1,dtype=np.float32)[0] + d = {"c*x": np.ones(1, dtype=np.float32)[0]} + wgts = {"c*x": np.ones(1, dtype=np.float32)[0]} + c = np.ones(1, dtype=np.float32)[0] ls = linsolve.LinearSolver(d, wgts=wgts, c=c, sparse=self.sparse) assert ls.dtype == np.float32 - assert type(ls.solve()['x']) == np.float32 + assert type(ls.solve()["x"]) == np.float32 def test_degen_sol(self): # test how various solvers deal with degenerate solutions - d = {'x+y': 1., '2*x+2*y': 2.} + d = {"x+y": 1.0, "2*x+2*y": 2.0} ls = linsolve.LinearSolver(d, sparse=self.sparse) - for mode in ('pinv', 'lsqr'): + for mode in ("pinv", "lsqr"): sol = ls.solve(mode=mode) - np.testing.assert_almost_equal(sol['x'] + sol['y'], 1., 6) + np.testing.assert_almost_equal(sol["x"] + sol["y"], 1.0, 6) with pytest.raises(np.linalg.LinAlgError): - ls.solve(mode='solve') + ls.solve(mode="solve") -class TestLinearSolverSparse(TestLinearSolver): +class TestLinearSolverSparse(TestLinearSolver): def setup(self): self.sparse = True - eqs = ['x+y','x-y'] - x,y = 1,2 - d,w = {}, {} - for eq in eqs: d[eq],w[eq] = eval(eq), 1. - self.ls = linsolve.LinearSolver(d,w,sparse=self.sparse) - + eqs = ["x+y", "x-y"] + x, y = 1, 2 + d, w = {}, {} + for eq in eqs: + d[eq], w[eq] = eval(eq), 1.0 + self.ls = linsolve.LinearSolver(d, w, sparse=self.sparse) -class TestLogProductSolver(): +class TestLogProductSolver: def setup(self): - self.sparse=False + self.sparse = False def test_init(self): - x,y,z = np.exp(1.+0j), np.exp(2.), np.exp(3.) - keys = ['x*y*z', 'x*y', 'y*z'] - d,w = {}, {} - for k in keys: d[k],w[k] = eval(k), 1. - ls = linsolve.LogProductSolver(d,w,sparse=self.sparse) + x, y, z = np.exp(1.0 + 0j), np.exp(2.0), np.exp(3.0) + keys = ["x*y*z", "x*y", "y*z"] + d, w = {}, {} + for k in keys: + d[k], w[k] = eval(k), 1.0 + ls = linsolve.LogProductSolver(d, w, sparse=self.sparse) for k in ls.ls_phs.data: np.testing.assert_equal(ls.ls_phs.data[k], 0) - x,y,z = 1.,2.,3. + x, y, z = 1.0, 2.0, 3.0 for k in ls.ls_amp.data: np.testing.assert_equal(eval(k), ls.ls_amp.data[k]) def test_conj(self): - x,y = 1+1j, 2+2j - d,w = {}, {} - d['x*y_'] = x * y.conjugate() - d['x_*y'] = x.conjugate() * y - d['x*y'] = x * y - d['x_*y_'] = x.conjugate() * y.conjugate() - for k in d: w[k] = 1. - ls = linsolve.LogProductSolver(d,w,sparse=self.sparse) + x, y = 1 + 1j, 2 + 2j + d, w = {}, {} + d["x*y_"] = x * y.conjugate() + d["x_*y"] = x.conjugate() * y + d["x*y"] = x * y + d["x_*y_"] = x.conjugate() * y.conjugate() + for k in d: + w[k] = 1.0 + ls = linsolve.LogProductSolver(d, w, sparse=self.sparse) assert len(ls.ls_amp.data) == 4 for k in ls.ls_amp.data: - assert eval(k) == 3+3j # make sure they are all x+y - assert k.replace('1','-1') in ls.ls_phs.data + assert eval(k) == 3 + 3j # make sure they are all x+y + assert k.replace("1", "-1") in ls.ls_phs.data def test_solve(self): - x,y,z = np.exp(1.+1j), np.exp(2.+2j), np.exp(3.+3j) - keys = ['x*y*z', 'x*y', 'y*z'] - d,w = {}, {} - for k in keys: d[k],w[k] = eval(k), 1. - ls = linsolve.LogProductSolver(d,w,sparse=self.sparse) + x, y, z = np.exp(1.0 + 1j), np.exp(2.0 + 2j), np.exp(3.0 + 3j) + keys = ["x*y*z", "x*y", "y*z"] + d, w = {}, {} + for k in keys: + d[k], w[k] = eval(k), 1.0 + ls = linsolve.LogProductSolver(d, w, sparse=self.sparse) sol = ls.solve() for k in sol: np.testing.assert_almost_equal(sol[k], eval(k)) + def test_conj_solve(self): - x,y = np.exp(1.+2j), np.exp(2.+1j) - d,w = {'x*y_':x*y.conjugate(), 'x':x}, {} - for k in d: w[k] = 1. - ls = linsolve.LogProductSolver(d,w,sparse=self.sparse) + x, y = np.exp(1.0 + 2j), np.exp(2.0 + 1j) + d, w = {"x*y_": x * y.conjugate(), "x": x}, {} + for k in d: + w[k] = 1.0 + ls = linsolve.LogProductSolver(d, w, sparse=self.sparse) sol = ls.solve() for k in sol: np.testing.assert_almost_equal(sol[k], eval(k)) + def test_no_abs_phs_solve(self): - x,y,z = 1.+1j, 2.+2j, 3.+3j - d,w = {'x*y_':x*y.conjugate(), 'x*z_':x*z.conjugate(), 'y*z_':y*z.conjugate()}, {} - for k in list(d.keys()): w[k] = 1. - ls = linsolve.LogProductSolver(d,w,sparse=self.sparse) + x, y, z = 1.0 + 1j, 2.0 + 2j, 3.0 + 3j + d, w = { + "x*y_": x * y.conjugate(), + "x*z_": x * z.conjugate(), + "y*z_": y * z.conjugate(), + }, {} + for k in list(d.keys()): + w[k] = 1.0 + ls = linsolve.LogProductSolver(d, w, sparse=self.sparse) # some ridiculousness to avoid "The exact solution is x = 0" prints save_stdout = sys.stdout sys.stdout = io.StringIO() sol = ls.solve() sys.stdout = save_stdout - x,y,z = sol['x'], sol['y'], sol['z'] - np.testing.assert_almost_equal(np.angle(x*y.conjugate()), 0.) - np.testing.assert_almost_equal(np.angle(x*z.conjugate()), 0.) - np.testing.assert_almost_equal(np.angle(y*z.conjugate()), 0.) + x, y, z = sol["x"], sol["y"], sol["z"] + np.testing.assert_almost_equal(np.angle(x * y.conjugate()), 0.0) + np.testing.assert_almost_equal(np.angle(x * z.conjugate()), 0.0) + np.testing.assert_almost_equal(np.angle(y * z.conjugate()), 0.0) # check projection of degenerate mode - np.testing.assert_almost_equal(np.angle(x), 0.) - np.testing.assert_almost_equal(np.angle(y), 0.) - np.testing.assert_almost_equal(np.angle(z), 0.) + np.testing.assert_almost_equal(np.angle(x), 0.0) + np.testing.assert_almost_equal(np.angle(y), 0.0) + np.testing.assert_almost_equal(np.angle(z), 0.0) + def test_dtype(self): for dtype in (np.float32, np.float64, np.complex64, np.complex128): - x,y,z = np.exp(1.), np.exp(2.), np.exp(3.) - keys = ['x*y*z', 'x*y', 'y*z'] - d,w = {}, {} + x, y, z = np.exp(1.0), np.exp(2.0), np.exp(3.0) + keys = ["x*y*z", "x*y", "y*z"] + d, w = {}, {} for k in keys: d[k] = eval(k).astype(dtype) - w[k] = np.float32(1.) - ls = linsolve.LogProductSolver(d,w,sparse=self.sparse) + w[k] = np.float32(1.0) + ls = linsolve.LogProductSolver(d, w, sparse=self.sparse) # some ridiculousness to avoid "The exact solution is x = 0" prints save_stdout = sys.stdout sys.stdout = io.StringIO() @@ -435,210 +459,302 @@ def test_dtype(self): for k in sol: assert sol[k].dtype == dtype + class TestLogProductSolverSparse(TestLogProductSolver): - def setup(self): - self.sparse=True + self.sparse = True -class TestLinProductSolver(): +class TestLinProductSolver: def setup(self): - self.sparse=False + self.sparse = False + def test_init(self): - x,y,z = 1.+1j, 2.+2j, 3.+3j - d,w = {'x*y_':x*y.conjugate(), 'x*z_':x*z.conjugate(), 'y*z_':y*z.conjugate()}, {} - for k in list(d.keys()): w[k] = 1. + x, y, z = 1.0 + 1j, 2.0 + 2j, 3.0 + 3j + d, w = { + "x*y_": x * y.conjugate(), + "x*z_": x * z.conjugate(), + "y*z_": y * z.conjugate(), + }, {} + for k in list(d.keys()): + w[k] = 1.0 sol0 = {} - for k in 'xyz': sol0[k] = eval(k)+.01 - ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) - x,y,z = 1.,1.,1. - x_,y_,z_ = 1.,1.,1. - dx = dy = dz = .001 - dx_ = dy_ = dz_ = .001 + for k in "xyz": + sol0[k] = eval(k) + 0.01 + ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) + x, y, z = 1.0, 1.0, 1.0 + x_, y_, z_ = 1.0, 1.0, 1.0 + dx = dy = dz = 0.001 + dx_ = dy_ = dz_ = 0.001 for k in ls.ls.keys: np.testing.assert_almost_equal(eval(k), 0.002) assert len(ls.ls.prms) == 3 def test_real_solve(self): - x,y,z = 1., 2., 3. - keys = ['x*y', 'x*z', 'y*z'] - d,w = {}, {} - for k in keys: d[k],w[k] = eval(k), 1. + x, y, z = 1.0, 2.0, 3.0 + keys = ["x*y", "x*z", "y*z"] + d, w = {}, {} + for k in keys: + d[k], w[k] = eval(k), 1.0 sol0 = {} - for k in 'xyz': sol0[k] = eval(k)+.01 - ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) + for k in "xyz": + sol0[k] = eval(k) + 0.01 + ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) sol = ls.solve() for k in sol: np.testing.assert_almost_equal(sol[k], eval(k), 4) def test_single_term(self): - x,y,z = 1., 2., 3. - keys = ['x*y', 'x*z', '2*z'] - d,w = {}, {} - for k in keys: d[k],w[k] = eval(k), 1. + x, y, z = 1.0, 2.0, 3.0 + keys = ["x*y", "x*z", "2*z"] + d, w = {}, {} + for k in keys: + d[k], w[k] = eval(k), 1.0 sol0 = {} - for k in 'xyz': sol0[k] = eval(k)+.01 - ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) + for k in "xyz": + sol0[k] = eval(k) + 0.01 + ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) sol = ls.solve() for k in sol: np.testing.assert_almost_equal(sol[k], eval(k), 4) def test_complex_solve(self): - x,y,z = 1+1j, 2+2j, 3+2j - keys = ['x*y', 'x*z', 'y*z'] - d,w = {}, {} - for k in keys: d[k],w[k] = eval(k), 1. + x, y, z = 1 + 1j, 2 + 2j, 3 + 2j + keys = ["x*y", "x*z", "y*z"] + d, w = {}, {} + for k in keys: + d[k], w[k] = eval(k), 1.0 sol0 = {} - for k in 'xyz': sol0[k] = eval(k)+.01 - ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) + for k in "xyz": + sol0[k] = eval(k) + 0.01 + ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) sol = ls.solve() for k in sol: np.testing.assert_almost_equal(sol[k], eval(k), 4) def test_complex_conj_solve(self): - x,y,z = 1.+1j, 2.+2j, 3.+3j - d,w = {'x*y_':x*y.conjugate(), 'x*z_':x*z.conjugate(), 'y*z_':y*z.conjugate()}, {} - for k in list(d.keys()): w[k] = 1. + x, y, z = 1.0 + 1j, 2.0 + 2j, 3.0 + 3j + d, w = { + "x*y_": x * y.conjugate(), + "x*z_": x * z.conjugate(), + "y*z_": y * z.conjugate(), + }, {} + for k in list(d.keys()): + w[k] = 1.0 sol0 = {} - for k in 'xyz': sol0[k] = eval(k) + .01 - ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) - ls.prm_order = {'x':0,'y':1,'z':2} - _, sol = ls.solve_iteratively(mode='lsqr') # XXX fails for pinv - x,y,z = sol['x'], sol['y'], sol['z'] - np.testing.assert_almost_equal(x*y.conjugate(), d['x*y_'], 3) - np.testing.assert_almost_equal(x*z.conjugate(), d['x*z_'], 3) - np.testing.assert_almost_equal(y*z.conjugate(), d['y*z_'], 3) + for k in "xyz": + sol0[k] = eval(k) + 0.01 + ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) + ls.prm_order = {"x": 0, "y": 1, "z": 2} + _, sol = ls.solve_iteratively(mode="lsqr") # XXX fails for pinv + x, y, z = sol["x"], sol["y"], sol["z"] + np.testing.assert_almost_equal(x * y.conjugate(), d["x*y_"], 3) + np.testing.assert_almost_equal(x * z.conjugate(), d["x*z_"], 3) + np.testing.assert_almost_equal(y * z.conjugate(), d["y*z_"], 3) def test_complex_array_solve(self): - x = np.arange(30, dtype=np.complex128); x.shape = (3,10) - y = np.arange(30, dtype=np.complex128); y.shape = (3,10) - z = np.arange(30, dtype=np.complex128); z.shape = (3,10) - d,w = {'x*y':x*y, 'x*z':x*z, 'y*z':y*z}, {} - for k in list(d.keys()): w[k] = np.ones(d[k].shape) + x = np.arange(30, dtype=np.complex128) + x.shape = (3, 10) + y = np.arange(30, dtype=np.complex128) + y.shape = (3, 10) + z = np.arange(30, dtype=np.complex128) + z.shape = (3, 10) + d, w = {"x*y": x * y, "x*z": x * z, "y*z": y * z}, {} + for k in list(d.keys()): + w[k] = np.ones(d[k].shape) sol0 = {} - for k in 'xyz': sol0[k] = eval(k) + .01 - ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) - ls.prm_order = {'x':0,'y':1,'z':2} + for k in "xyz": + sol0[k] = eval(k) + 0.01 + ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) + ls.prm_order = {"x": 0, "y": 1, "z": 2} sol = ls.solve() - np.testing.assert_almost_equal(sol['x'], x, 2) - np.testing.assert_almost_equal(sol['y'], y, 2) - np.testing.assert_almost_equal(sol['z'], z, 2) + np.testing.assert_almost_equal(sol["x"], x, 2) + np.testing.assert_almost_equal(sol["y"], y, 2) + np.testing.assert_almost_equal(sol["z"], z, 2) def test_complex_array_NtimesNfreqs1_solve(self): - x = np.arange(1, dtype=np.complex128); x.shape = (1,1) - y = np.arange(1, dtype=np.complex128); y.shape = (1,1) - z = np.arange(1, dtype=np.complex128); z.shape = (1,1) - d,w = {'x*y':x*y, 'x*z':x*z, 'y*z':y*z}, {} - for k in list(d.keys()): w[k] = np.ones(d[k].shape) + x = np.arange(1, dtype=np.complex128) + x.shape = (1, 1) + y = np.arange(1, dtype=np.complex128) + y.shape = (1, 1) + z = np.arange(1, dtype=np.complex128) + z.shape = (1, 1) + d, w = {"x*y": x * y, "x*z": x * z, "y*z": y * z}, {} + for k in list(d.keys()): + w[k] = np.ones(d[k].shape) sol0 = {} - for k in 'xyz': sol0[k] = eval(k) + .01 - ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) - ls.prm_order = {'x':0,'y':1,'z':2} + for k in "xyz": + sol0[k] = eval(k) + 0.01 + ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) + ls.prm_order = {"x": 0, "y": 1, "z": 2} sol = ls.solve() - np.testing.assert_almost_equal(sol['x'], x, 2) - np.testing.assert_almost_equal(sol['y'], y, 2) - np.testing.assert_almost_equal(sol['z'], z, 2) + np.testing.assert_almost_equal(sol["x"], x, 2) + np.testing.assert_almost_equal(sol["y"], y, 2) + np.testing.assert_almost_equal(sol["z"], z, 2) def test_sums_of_products(self): - x = np.arange(1,31)*(1.0+1.0j); x.shape=(10,3) - y = np.arange(1,31)*(2.0-3.0j); y.shape=(10,3) - z = np.arange(1,31)*(3.0-9.0j); z.shape=(10,3) - w = np.arange(1,31)*(4.0+2.0j); w.shape=(10,3) - x_,y_,z_,w_ = list(map(np.conjugate,(x,y,z,w))) - expressions = ['x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z', '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z'] + x = np.arange(1, 31) * (1.0 + 1.0j) + x.shape = (10, 3) + y = np.arange(1, 31) * (2.0 - 3.0j) + y.shape = (10, 3) + z = np.arange(1, 31) * (3.0 - 9.0j) + z.shape = (10, 3) + w = np.arange(1, 31) * (4.0 + 2.0j) + w.shape = (10, 3) + x_, y_, z_, w_ = list(map(np.conjugate, (x, y, z, w))) + expressions = [ + "x*y+z*w", + "2*x_*y_+z*w-1.0j*z*w", + "2*x*w", + "1.0j*x + y*z", + "-1*x*z+3*y*w*x+y", + "2*w_", + "2*x_ + 3*y - 4*z", + ] data = {} - for ex in expressions: data[ex] = eval(ex) - currentSol = {'x':1.1*x, 'y': .9*y, 'z': 1.1*z, 'w':1.2*w} - for i in range(5): # reducing iters prevents printing a bunch of "The exact solution is x = 0" - testSolve = linsolve.LinProductSolver(data, currentSol,sparse=self.sparse) + for ex in expressions: + data[ex] = eval(ex) + currentSol = {"x": 1.1 * x, "y": 0.9 * y, "z": 1.1 * z, "w": 1.2 * w} + for i in range( + 5 + ): # reducing iters prevents printing a bunch of "The exact solution is x = 0" + testSolve = linsolve.LinProductSolver(data, currentSol, sparse=self.sparse) currentSol = testSolve.solve() - for var in 'wxyz': - np.testing.assert_almost_equal(currentSol[var], eval(var), 4) + for var in "wxyz": + np.testing.assert_almost_equal(currentSol[var], eval(var), 4) def test_eval(self): - x = np.arange(1,31)*(1.0+1.0j); x.shape=(10,3) - y = np.arange(1,31)*(2.0-3.0j); y.shape=(10,3) - z = np.arange(1,31)*(3.0-9.0j); z.shape=(10,3) - w = np.arange(1,31)*(4.0+2.0j); w.shape=(10,3) - x_,y_,z_,w_ = list(map(np.conjugate,(x,y,z,w))) - expressions = ['x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z', '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z'] + x = np.arange(1, 31) * (1.0 + 1.0j) + x.shape = (10, 3) + y = np.arange(1, 31) * (2.0 - 3.0j) + y.shape = (10, 3) + z = np.arange(1, 31) * (3.0 - 9.0j) + z.shape = (10, 3) + w = np.arange(1, 31) * (4.0 + 2.0j) + w.shape = (10, 3) + x_, y_, z_, w_ = list(map(np.conjugate, (x, y, z, w))) + expressions = [ + "x*y+z*w", + "2*x_*y_+z*w-1.0j*z*w", + "2*x*w", + "1.0j*x + y*z", + "-1*x*z+3*y*w*x+y", + "2*w_", + "2*x_ + 3*y - 4*z", + ] data = {} - for ex in expressions: data[ex] = eval(ex) - currentSol = {'x':1.1*x, 'y': .9*y, 'z': 1.1*z, 'w':1.2*w} - for i in range(5): # reducing iters prevents printing a bunch of "The exact solution is x = 0" - testSolve = linsolve.LinProductSolver(data, currentSol,sparse=self.sparse) + for ex in expressions: + data[ex] = eval(ex) + currentSol = {"x": 1.1 * x, "y": 0.9 * y, "z": 1.1 * z, "w": 1.2 * w} + for i in range( + 5 + ): # reducing iters prevents printing a bunch of "The exact solution is x = 0" + testSolve = linsolve.LinProductSolver(data, currentSol, sparse=self.sparse) currentSol = testSolve.solve() - for var in 'wxyz': + for var in "wxyz": np.testing.assert_almost_equal(currentSol[var], eval(var), 4) result = testSolve.eval(currentSol) for eq in data: np.testing.assert_almost_equal(data[eq], result[eq], 4) def test_chisq(self): - x = 1. - d = {'x*y':1, '.5*x*y+.5*x*y':2, 'y':1} - currentSol = {'x':2.3,'y':.9} - for i in range(5): # reducing iters prevents printing a bunch of "The exact solution is x = 0" - testSolve = linsolve.LinProductSolver(d, currentSol,sparse=self.sparse) + x = 1.0 + d = {"x*y": 1, ".5*x*y+.5*x*y": 2, "y": 1} + currentSol = {"x": 2.3, "y": 0.9} + for i in range( + 5 + ): # reducing iters prevents printing a bunch of "The exact solution is x = 0" + testSolve = linsolve.LinProductSolver(d, currentSol, sparse=self.sparse) currentSol = testSolve.solve() chisq = testSolve.chisq(currentSol) - np.testing.assert_almost_equal(chisq, .5) + np.testing.assert_almost_equal(chisq, 0.5) def test_solve_iteratively(self): - x = np.arange(1,31)*(1.0+1.0j); x.shape=(10,3) - y = np.arange(1,31)*(2.0-3.0j); y.shape=(10,3) - z = np.arange(1,31)*(3.0-9.0j); z.shape=(10,3) - w = np.arange(1,31)*(4.0+2.0j); w.shape=(10,3) - x_,y_,z_,w_ = list(map(np.conjugate,(x,y,z,w))) - expressions = ['x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z', '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z'] + x = np.arange(1, 31) * (1.0 + 1.0j) + x.shape = (10, 3) + y = np.arange(1, 31) * (2.0 - 3.0j) + y.shape = (10, 3) + z = np.arange(1, 31) * (3.0 - 9.0j) + z.shape = (10, 3) + w = np.arange(1, 31) * (4.0 + 2.0j) + w.shape = (10, 3) + x_, y_, z_, w_ = list(map(np.conjugate, (x, y, z, w))) + expressions = [ + "x*y+z*w", + "2*x_*y_+z*w-1.0j*z*w", + "2*x*w", + "1.0j*x + y*z", + "-1*x*z+3*y*w*x+y", + "2*w_", + "2*x_ + 3*y - 4*z", + ] data = {} - for ex in expressions: data[ex] = eval(ex) - currentSol = {'x':1.1*x, 'y': .9*y, 'z': 1.1*z, 'w':1.2*w} - testSolve = linsolve.LinProductSolver(data, currentSol,sparse=self.sparse) + for ex in expressions: + data[ex] = eval(ex) + currentSol = {"x": 1.1 * x, "y": 0.9 * y, "z": 1.1 * z, "w": 1.2 * w} + testSolve = linsolve.LinProductSolver(data, currentSol, sparse=self.sparse) meta, new_sol = testSolve.solve_iteratively() - for var in 'wxyz': + for var in "wxyz": np.testing.assert_almost_equal(new_sol[var], eval(var), 4) def test_solve_iteratively_dtype(self): - x = np.arange(1,31)*(1.0+1.0j); x.shape=(10,3) - y = np.arange(1,31)*(2.0-3.0j); y.shape=(10,3) - z = np.arange(1,31)*(3.0-9.0j); z.shape=(10,3) - w = np.arange(1,31)*(4.0+2.0j); w.shape=(10,3) - x_,y_,z_,w_ = list(map(np.conjugate,(x,y,z,w))) - expressions = ['x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z', '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z'] + x = np.arange(1, 31) * (1.0 + 1.0j) + x.shape = (10, 3) + y = np.arange(1, 31) * (2.0 - 3.0j) + y.shape = (10, 3) + z = np.arange(1, 31) * (3.0 - 9.0j) + z.shape = (10, 3) + w = np.arange(1, 31) * (4.0 + 2.0j) + w.shape = (10, 3) + x_, y_, z_, w_ = list(map(np.conjugate, (x, y, z, w))) + expressions = [ + "x*y+z*w", + "2*x_*y_+z*w-1.0j*z*w", + "2*x*w", + "1.0j*x + y*z", + "-1*x*z+3*y*w*x+y", + "2*w_", + "2*x_ + 3*y - 4*z", + ] data = {} for dtype in (np.complex128, np.complex64): - for ex in expressions: + for ex in expressions: data[ex] = eval(ex).astype(dtype) - currentSol = {'x':1.1*x, 'y': .9*y, 'z': 1.1*z, 'w':1.2*w} - currentSol = {k:v.astype(dtype) for k,v in currentSol.items()} - testSolve = linsolve.LinProductSolver(data, currentSol,sparse=self.sparse) + currentSol = {"x": 1.1 * x, "y": 0.9 * y, "z": 1.1 * z, "w": 1.2 * w} + currentSol = {k: v.astype(dtype) for k, v in currentSol.items()} + testSolve = linsolve.LinProductSolver(data, currentSol, sparse=self.sparse) # some ridiculousness to avoid "The exact solution is x = 0" prints save_stdout = sys.stdout sys.stdout = io.StringIO() meta, new_sol = testSolve.solve_iteratively(conv_crit=1e-7) sys.stdout = save_stdout - for var in 'wxyz': + for var in "wxyz": assert new_sol[var].dtype == dtype np.testing.assert_almost_equal(new_sol[var], eval(var), 4) def test_degen_sol(self): # test how various solvers deal with degenerate solutions - x,y,z = 1.+1j, 2.+2j, 3.+3j - d,w = {'x*y_':x*y.conjugate(), 'x*z_':x*z.conjugate(), 'y*z_':y*z.conjugate()}, {} - for k in list(d.keys()): w[k] = 1. + x, y, z = 1.0 + 1j, 2.0 + 2j, 3.0 + 3j + d, w = { + "x*y_": x * y.conjugate(), + "x*z_": x * z.conjugate(), + "y*z_": y * z.conjugate(), + }, {} + for k in list(d.keys()): + w[k] = 1.0 sol0 = {} - for k in 'xyz': sol0[k] = eval(k) + .01 - ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) - ls.prm_order = {'x':0,'y':1,'z':2} - for mode in ('pinv', 'lsqr'): + for k in "xyz": + sol0[k] = eval(k) + 0.01 + ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) + ls.prm_order = {"x": 0, "y": 1, "z": 2} + for mode in ("pinv", "lsqr"): _, sol = ls.solve_iteratively(mode=mode) - x,y,z = sol['x'], sol['y'], sol['z'] - np.testing.assert_almost_equal(x*y.conjugate(), d['x*y_'], 3) - np.testing.assert_almost_equal(x*z.conjugate(), d['x*z_'], 3) - np.testing.assert_almost_equal(y*z.conjugate(), d['y*z_'], 3) - #self.assertRaises(np.linalg.LinAlgError, ls.solve_iteratively, mode='solve') # this fails for matrices where machine precision breaks degeneracies in system of equations + x, y, z = sol["x"], sol["y"], sol["z"] + np.testing.assert_almost_equal(x * y.conjugate(), d["x*y_"], 3) + np.testing.assert_almost_equal(x * z.conjugate(), d["x*z_"], 3) + np.testing.assert_almost_equal(y * z.conjugate(), d["y*z_"], 3) + class TestLinProductSolverSparse(TestLinProductSolver): def setup(self): - self.sparse=True + self.sparse = True