Skip to content

Commit

Permalink
yes
Browse files Browse the repository at this point in the history
  • Loading branch information
OverLordGoldDragon committed May 31, 2022
1 parent c26ac33 commit ca0140a
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 18 deletions.
2 changes: 2 additions & 0 deletions tests/scattering1d/test_jtfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,8 @@ def test_implementation():
jtfs = TimeFrequencyScattering1D(shape=N, J=4, Q=2,
implementation=implementation,
frontend=default_backend)
assert jtfs.implementation == implementation, (
jtfs.implementation, implementation)
_ = jtfs(x)


Expand Down
80 changes: 68 additions & 12 deletions tests/scattering1d/test_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
# -----------------------------------------------------------------------------
"""Test that wavespin/visuals.py methods run without error."""
import pytest, os, warnings
import numpy as np
from copy import deepcopy

from wavespin import Scattering1D, TimeFrequencyScattering1D
from wavespin.toolkit import echirp, pack_coeffs_jtfs
from wavespin import visuals as v
Expand Down Expand Up @@ -42,7 +45,7 @@ def make_reusables():
if skip_all:
return None if run_without_pytest else pytest.skip()
N = 512
kw0 = dict(shape=N, J=9, Q=8, frontend=default_backend)
kw0 = dict(shape=N, T=2**8, J=9, Q=8, frontend=default_backend)
sc_tms.extend([Scattering1D(**kw0, out_type='array')])

sfs = [('resample', 'resample'), ('exclude', 'resample'),
Expand Down Expand Up @@ -103,7 +106,18 @@ def test_viz_jtfs_2d(G):
return None if run_without_pytest else pytest.skip()
jtfss = G['jtfss']
out_jtfss = G['out_jtfss']
v.viz_jtfs_2d(jtfss[1], Scx=out_jtfss[1])

# without save
# _ = v.viz_jtfs_2d(jtfss[1], Scx=out_jtfss[1], show=0,
# plot_cfg={'filter_label': True, 'phi_t_loc': 'both'})

# with save
base = 'viz_jtfs2d'
fn = lambda savedir: v.viz_jtfs_2d(
jtfss[1], Scx=out_jtfss[1], show=1, savename=os.path.join(savedir, base),
plot_cfg={'filter_part': 'imag', 'filterbank_zoom': -1})
# name changes internally
_run_with_cleanup(fn, [base + '0.png', base + '1.png'])


def test_gif_jtfs_2d(G):
Expand All @@ -125,13 +139,17 @@ def test_gif_jtfs_3d(G):
warnings.warn("Skipped `test_gif_jtfs_3d` since `plotly` not installed.")
return

out_jtfss, metas = G['out_jtfss'], G['metas']
jtfss, out_jtfss, metas = G['jtfss'], G['out_jtfss'], G['metas']
packed = pack_coeffs_jtfs(out_jtfss[1], metas[2], structure=2,
sampling_psi_fr='exclude')

savename = 'jtfs3d.gif'
fn = lambda savedir: v.gif_jtfs_3d(packed, savedir=savedir, base_name=savename,
images_ext='.png', verbose=False)
kw = dict(base_name=savename, images_ext='.png', verbose=0)

fn = lambda savedir: v.gif_jtfs_3d(packed, savedir=savedir, **kw)
_run_with_cleanup_handle_exception(fn, savename)
fn = lambda savedir: v.gif_jtfs_3d(out_jtfss[1], jtfss[1], angles='rotate',
savedir=savedir, **kw)
_run_with_cleanup_handle_exception(fn, savename)


Expand Down Expand Up @@ -167,6 +185,35 @@ def test_coeff_distance_jtfs(G):
raise e


def test_compare_distances_jtfs(G):
if skip_all:
return None if run_without_pytest else pytest.skip()
out_jtfss = G['out_jtfss']
Scx0, Scx1 = out_jtfss[0], deepcopy(out_jtfss[0])
for pair in Scx1:
Scx1[pair] += 1

dists0 = v.coeff_distance_jtfs(Scx0, Scx1, metas[1])[1]
dists1 = v.coeff_distance_jtfs(Scx1, Scx0, metas[1], plots=1)[1]
_ = v.compare_distances_jtfs(dists0, dists1, plots=1)


def test_scalogram(G):
if skip_all:
return None if run_without_pytest else pytest.skip()
sc_tm = G['sc_tms'][0]
sc_tm.average = False
sc_tm.out_type = 'list'
_ = v.scalogram(np.random.randn(sc_tm.shape), sc_tm, show_x=1, fs=1)


def test_misc(G):
if skip_all:
return None if run_without_pytest else pytest.skip()
_ = v.plot([1, 2], xticks=[0, 1], yticks=[0, 1], show=0)
_ = v._colorize_complex(np.array([[1 + 1j]]))


def test_viz_spin_1d(G):
if skip_all:
return None if run_without_pytest else pytest.skip()
Expand All @@ -188,19 +235,24 @@ def test_viz_spin_2d(G):
def _run_with_cleanup(fn, savename):
if skip_all:
return None if run_without_pytest else pytest.skip()
if not isinstance(savename, list):
savename = [savename]

with tempdir() as savedir:
try:
fn(savedir)
path = os.path.join(savedir, savename)
# assert file was created
assert os.path.isfile(path), path
os.unlink(path)
for nm in savename:
path = os.path.join(savedir, nm)
assert os.path.isfile(path), path
os.unlink(path)
finally:
# clean up images, if any were made
paths = [os.path.join(savedir, n) for n in os.listdir(savedir)
if (n.startswith(savename) and n.endswith('.png'))]
for p in paths:
os.unlink(p)
for nm in savename:
paths = [os.path.join(savedir, n) for n in os.listdir(savedir)
if (n.startswith(nm) and n.endswith('.png'))]
for p in paths:
os.unlink(p)


def _run_with_cleanup_handle_exception(fn, savename):
Expand All @@ -210,6 +262,7 @@ def _run_with_cleanup_handle_exception(fn, savename):
if 'ffmpeg' not in str(e):
# automated testing has some issues with this
raise e
warnings.warn("Ignored error:\n%s" % str(e))


# create testing objects #####################################################
Expand Down Expand Up @@ -244,6 +297,9 @@ def G():
test_gif_jtfs_3d(G)
test_energy_profile_jtfs(G)
test_coeff_distance_jtfs(G)
test_compare_distances_jtfs(G)
test_scalogram(G)
test_misc(G)
test_viz_spin_1d(G)
test_viz_spin_2d(G)
else:
Expand Down
2 changes: 1 addition & 1 deletion wavespin/scattering1d/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def subsample_fourier(cls, x, k, axis=-1):
The input tensor periodized along the next to last axis to yield a
tensor of size x.shape[-2] // k along that dimension.
"""
if k == 0:
if k == 1:
return x
cls.complex_check(x)

Expand Down
2 changes: 2 additions & 0 deletions wavespin/scattering1d/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def subsample_fourier(cls, x, k, axis=-1):
The input tensor periodized along the next to last axis to yield a
tensor of size x.shape[-2] // k along that dimension.
"""
if k == 1:
return x
cls.complex_check(x)

axis = axis if axis >= 0 else x.ndim + axis # ensure positive
Expand Down
2 changes: 2 additions & 0 deletions wavespin/scattering1d/backend/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def subsample_fourier(cls, x, k, axis=-1):
The input tensor periodized along the next to last axis to yield a
tensor of size x.shape[-2] // k along that dimension.
"""
if k == 1:
return x
cls.complex_check(x)

axis = axis if axis >= 0 else x.ndim + axis # ensure positive
Expand Down
2 changes: 1 addition & 1 deletion wavespin/scattering1d/frontend/numpy_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, J, shape, Q, J_fr=None, Q_fr=2, T=None, F=None,

# Frequential scattering object
TimeFrequencyScatteringBase1D.__init__(
self, J_fr, Q_fr, F, average_fr, out_type, **kwargs)
self, J_fr, Q_fr, F, average_fr, out_type, implementation, **kwargs)
TimeFrequencyScatteringBase1D.build(self)

def scattering(self, x, Tx=None):
Expand Down
2 changes: 1 addition & 1 deletion wavespin/scattering1d/frontend/tensorflow_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, J, shape, Q, J_fr=None, Q_fr=2, T=None, F=None,

# Frequential scattering object
TimeFrequencyScatteringBase1D.__init__(
self, J_fr, Q_fr, F, average_fr, out_type, **kwargs)
self, J_fr, Q_fr, F, average_fr, out_type, implementation, **kwargs)
TimeFrequencyScatteringBase1D.build(self)

def scattering(self, x, Tx=None):
Expand Down
2 changes: 1 addition & 1 deletion wavespin/scattering1d/frontend/torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(self, J, shape, Q, J_fr=None, Q_fr=2, T=None, F=None,

# Frequential scattering object
TimeFrequencyScatteringBase1D.__init__(
self, J_fr, Q_fr, F, average_fr, out_type, **kwargs)
self, J_fr, Q_fr, F, average_fr, out_type, implementation, **kwargs)
TimeFrequencyScatteringBase1D.build(self)

self.register_filters()
Expand Down
6 changes: 4 additions & 2 deletions wavespin/visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ def viz_jtfs_2d(jtfs, Scx=None, viz_filterbank=True, viz_coeffs=None,
(see `plot_cfg_defaults` in source code). Will not warn if an argument
is unused (e.g. per `viz_coeffs=False`). Supported key-values:
'phi_t_blank' : bool:
'phi_t_blank' : bool
If True, draws `phi_t * psi_f` pairs only once (since up == down).
Can't be `True` with `phi_t_loc='both'`.
Expand Down Expand Up @@ -1498,7 +1498,7 @@ def viz_jtfs_2d(jtfs, Scx=None, viz_filterbank=True, viz_coeffs=None,

# `plot_cfg`, defaults
plot_cfg_defaults = {
'phi_t_blank': True,
'phi_t_blank': None,
'phi_t_loc': 'bottom',

'filter_part': 'real',
Expand Down Expand Up @@ -1530,6 +1530,8 @@ def viz_jtfs_2d(jtfs, Scx=None, viz_filterbank=True, viz_coeffs=None,
if C['phi_t_blank']:
warnings.warn("`phi_t_blank` does nothing if `phi_t_loc='both'`")
C['phi_t_blank'] = 0
elif C['phi_t_blank'] is None:
C['phi_t_blank'] = 1

# fs
if fs is not None:
Expand Down

0 comments on commit ca0140a

Please sign in to comment.