Skip to content

Commit

Permalink
tests: small fixed to runners
Browse files Browse the repository at this point in the history
  • Loading branch information
Gattocrucco committed Mar 5, 2024
1 parent 0fa7d26 commit 96fc97d
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 22 deletions.
3 changes: 3 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ versions of the manual list.
Rename all the single letter example scripts and gift them a communicative
description

Search engines still find the outdated "stable" version doc on readthedocs, I
have to disable it and keep only "latest".

## Fixes and tests

Stabilize Matern kernel near r == 0, then Matern derivatives for real nu
Expand Down
33 changes: 14 additions & 19 deletions docs/runcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,22 @@
import contextlib
import os
import pathlib
import warnings
import gc

import numpy as np
from matplotlib import pyplot as plt
import gvar
import pygments
from pygments import lexers, formatters
import jax
import lsqfitgp as lgp

warnings.filterwarnings('ignore', r'Negative eigenvalue with ')

def pyprint(text):
print(pygments.highlight(text, lexers.PythonLexer(), formatters.TerminalFormatter()))

@contextlib.contextmanager
def switchgvar():
try:
yield gvar.switch_gvar()
finally:
gvar.restore_gvar()

@contextlib.contextmanager
def chdir(path):
try:
cwd = pathlib.Path.cwd()
os.chdir(path)
yield
finally:
os.chdir(cwd)

pattern = re.compile(r'(?m)(?!\.\..+?)^.*?::\n\s*?\n(( {4,}.*\n)+)\s*?\n')
# TODO ^^^ try to delete this
# ^ delete
Expand All @@ -74,21 +64,26 @@ def runcode(file):
np.random.seed(0)
gvar.ranseed(0)
globals_dict = {}
with switchgvar():
with lgp.switchgvar():

# run code
for match in pattern.finditer(text):
codeblock = match.group(1)
print(58 * '-' + '\n')
code = textwrap.dedent(codeblock).strip()
printcode = '\n'.join(f' {i + 1:2d} ' + l for i, l in enumerate(code.split('\n')))
printcode = '\n'.join(
f' {i + 1:2d} ' + l
for i, l in enumerate(code.split('\n'))
)
pyprint(printcode)

with chdir(file.parent):
with contextlib.chdir(file.parent):
exec(code, globals_dict)

for file in sys.argv[1:]:
s = f'* running {file} *'
line = '*' * len(s)
print('\n' + line + '\n' + s + '\n' + line)
runcode(file)
gc.collect()
jax.clear_caches()
2 changes: 1 addition & 1 deletion docs/userguide/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Installation
============

First, you have to get a working Python interpreter. There are three main options: the official package from `<https://www.python.org>`_, the Anaconda distribution `<https://www.anaconda.com>`_, and the `Spyder IDE <https://www.spyder-ide.org>`_. The latter is probably the easier one if it's your first time with Python.
First, you have to get a working Python interpreter. There are three main options: the `official package <https://www.python.org>`_, the `Anaconda distribution <https://www.anaconda.com>`_, and the `Spyder IDE <https://www.spyder-ide.org>`_. The latter is probably the easier one if it's your first time with Python.

Then, install :mod:`lsqfitgp` by running this command in a shell:

Expand Down
2 changes: 1 addition & 1 deletion examples/doubleint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'integ' : 1,
'integx': 1,
}, ['data', 'integ', 'integx'])
priorsample = next(gvar.raniter(prior))
priorsample = gvar.sample(prior)

datamean = priorsample['data']
dataerr = np.full_like(datamean, 1)
Expand Down
3 changes: 3 additions & 0 deletions examples/runexamples.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
from matplotlib import pyplot as plt
import gvar
import lsqfitgp as lgp
import jax

warnings.filterwarnings('ignore', r'Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure\.')
warnings.filterwarnings('ignore', r'FigureCanvasAgg is non-interactive, and thus cannot be shown')
warnings.filterwarnings('ignore', r'Negative eigenvalue with ')

for file in sys.argv[1:]:

Expand All @@ -46,6 +48,7 @@
gvar.ranseed(0)
runpy.run_path(str(file))
gc.collect()
jax.clear_caches()

# save figures
nums = plt.get_fignums()
Expand Down
9 changes: 9 additions & 0 deletions lsqfitgp.sublime-project
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"folders":
[
{
"path": ".",
"folder_exclude_patterns": ["pyenv"],
}
]
}
9 changes: 8 additions & 1 deletion src/lsqfitgp/bayestree/_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from .. import _GP
from .. import _fastraniter

# TODO I added a lot of functionality to bcf. The easiest way to port it over is
# adding the option in bcf to drop the second bart model and its associated
# hypers, and then write bart as a simple convenience wrapper-subclass over bcf.

class bart:

def __init__(self,
Expand Down Expand Up @@ -370,6 +374,10 @@ def pred(self, *, hp='map', error=False, format='matrices', x_test=None,
out : array of `GVar`
The same distribution represented as an array of `GVar` objects.
"""

# TODO it is a bit confusing that if x_test=None and error=True, the
# prediction returns y_train exactly, instead of hypothetical new
# observations at the same covariates.

hp = self._gethp(hp, rng)
if x_test is not None:
Expand Down Expand Up @@ -414,7 +422,6 @@ def _to_structured(cls, x):

# check
assert x.ndim == 1
assert x.size > len(x.dtype)
def check_numerical(path, dtype):
if not numpy.issubdtype(dtype, numpy.number):
raise TypeError(f'covariate `{path}` is not numerical')
Expand Down

0 comments on commit 96fc97d

Please sign in to comment.