Skip to content

Commit

Permalink
Merge pull request #250 from carterbox/restructure-tests
Browse files Browse the repository at this point in the history
TST: Refactor reconstruction tests into smaller modules
  • Loading branch information
carterbox committed Jan 17, 2023
2 parents a085c85 + fcfc1ff commit 462c5e0
Show file tree
Hide file tree
Showing 8 changed files with 736 additions and 606 deletions.
3 changes: 3 additions & 0 deletions tests/ptycho/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import matplotlib

matplotlib.use('Agg')
129 changes: 129 additions & 0 deletions tests/ptycho/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import warnings
import os

import numpy as np
import tike.view

test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))

result_dir = os.path.join(test_dir, 'result', 'ptycho')

data_dir = os.path.join(test_dir, 'data')


def _save_eigen_probe(output_folder, eigen_probe):
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
flattened = []
for i in range(eigen_probe.shape[-4]):
probe = eigen_probe[..., i, :, :, :]
flattened.append(
np.concatenate(
probe.reshape((-1, *probe.shape[-2:])),
axis=1,
))
flattened = np.concatenate(flattened, axis=0)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
plt.imsave(
f'{output_folder}/eigen-phase.png',
np.angle(flattened),
# The output of np.angle is locked to (-pi, pi]
cmap=plt.cm.twilight,
vmin=-np.pi,
vmax=np.pi,
)
plt.imsave(
f'{output_folder}/eigen-ampli.png',
np.abs(flattened),
)


def _save_probe(output_folder, probe, algorithm):
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
flattened = np.concatenate(
probe.reshape((-1, *probe.shape[-2:])),
axis=1,
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
plt.imsave(
f'{output_folder}/probe-phase.png',
np.angle(flattened),
# The output of np.angle is locked to (-pi, pi]
cmap=plt.cm.twilight,
vmin=-np.pi,
vmax=np.pi,
)
plt.imsave(
f'{output_folder}/probe-ampli.png',
np.abs(flattened),
)
f = plt.figure()
tike.view.plot_probe_power(probe)
plt.semilogy()
plt.title(algorithm)
plt.savefig(f'{output_folder}/probe-power.svg')
plt.close(f)


def _save_ptycho_result(result, algorithm):
if result is None:
return
try:
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import tike.view
fname = os.path.join(result_dir, f'{algorithm}')
os.makedirs(fname, exist_ok=True)

fig = plt.figure()
ax1, ax2 = tike.view.plot_cost_convergence(
result.algorithm_options.costs,
result.algorithm_options.times,
)
ax2.set_xlim(0, 20)
ax1.set_ylim(10**(-1), 10**2)
fig.suptitle(algorithm)
fig.tight_layout()
plt.savefig(os.path.join(fname, 'convergence.svg'))
plt.close(fig)
plt.imsave(
f'{fname}/{0}-phase.png',
np.angle(result.psi).astype('float32'),
# The output of np.angle is locked to (-pi, pi]
cmap=plt.cm.twilight,
vmin=-np.pi,
vmax=np.pi,
)
plt.imsave(
f'{fname}/{0}-ampli.png',
np.abs(result.psi).astype('float32'),
)
import tifffile
tifffile.imwrite(
f'{fname}/{0}-ampli.tiff',
np.abs(result.psi).astype('float32'),
)
_save_probe(fname, result.probe, algorithm)
if result.eigen_weights is not None:
_save_eigen_weights(fname, result.eigen_weights)
if result.eigen_weights.shape[-2] > 1:
_save_eigen_probe(fname, result.eigen_probe)
except ImportError:
pass


def _save_eigen_weights(fname, weights):
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
plt.figure()
tike.view.plot_eigen_weights(weights)
plt.suptitle('weights')
plt.tight_layout()
plt.savefig(f'{fname}/weights.svg')
86 changes: 86 additions & 0 deletions tests/ptycho/templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import bz2
import typing

import numpy as np
import cupy as cp

from .io import data_dir

import tike.ptycho
import tike.communicators


class SiemensStarSetup():
"""Implements a setUp function which loads the siemens start dataset."""

def setUp(self, filename='siemens-star-small.npz.bz2'):
"""Load a dataset for reconstruction."""
dataset_file = os.path.join(data_dir, filename)
with bz2.open(dataset_file, 'rb') as f:
archive = np.load(f)
self.scan = archive['scan'][0]
self.data = archive['data'][0]
self.probe = archive['probe'][0]
self.scan -= np.amin(self.scan, axis=-2) - 20
self.probe = tike.ptycho.probe.add_modes_cartesian_hermite(
self.probe, 5)
self.probe = tike.ptycho.probe.adjust_probe_power(self.probe)
self.probe = tike.ptycho.probe.orthogonalize_eig(self.probe)

with tike.communicators.Comm(1, mpi=tike.communicators.MPIComm) as comm:
mask = tike.cluster.by_scan_stripes(
self.scan,
n=comm.mpi.size,
fly=1,
axis=0,
)[comm.mpi.rank]
self.scan = self.scan[mask]
self.data = self.data[mask]

self.psi = np.full(
(600, 600),
dtype=np.complex64,
fill_value=np.complex64(0.5 + 0j),
)


try:
from mpi4py import MPI
_mpi_size = MPI.COMM_WORLD.Get_size()
_mpi_rank = MPI.COMM_WORLD.Get_rank()
except ModuleNotFoundError:
_mpi_size = 1
_mpi_rank = 0

_device_per_rank = cp.cuda.runtime.getDeviceCount() // _mpi_size
_base_device = _device_per_rank * _mpi_rank
_gpu_indices = tuple(i + _base_device for i in range(_device_per_rank))


class MPIAndGPUInfo():
"""Provides mpi rank and gpu index information."""

mpi_size: int = _mpi_size
mpi_rank: int = _mpi_rank
gpu_indices: typing.Tuple[int] = _gpu_indices


class ReconstructTwice(MPIAndGPUInfo):
"""Call tike.ptycho reconstruct twice in a loop."""

def template_consistent_algorithm(self, *, data, params):
"""Check ptycho.solver.algorithm for consistency."""
with cp.cuda.Device(self.gpu_indices[0]):
# Call twice to check that reconstruction continuation is correct
for _ in range(2):
params = tike.ptycho.reconstruct(
data=data,
parameters=params,
num_gpu=self.gpu_indices,
use_mpi=self.mpi_size > 1,
)

print()
print('\n'.join(f'{c[0]:1.3e}' for c in params.algorithm_options.costs))
return params
126 changes: 107 additions & 19 deletions tests/ptycho/test_multigrid.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@

import os.path
import bz2
import os

import cupy as cp
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import pytest
import unittest

from tike.ptycho.solvers.options import _resize_fft, _resize_spline, _resize_cubic, _resize_lanczos, _resize_linear
import tike.ptycho
from tike.ptycho.solvers.options import (
_resize_fft,
_resize_spline,
_resize_cubic,
_resize_lanczos,
_resize_linear,
)

testdir = os.path.dirname(os.path.dirname(__file__))
output_folder = os.path.join(testdir, 'result', 'ptycho', 'multigrid')
from .templates import _mpi_size
from .io import result_dir, data_dir
from .test_ptycho import PtychoRecon

@pytest.mark.parametrize(
"function",[
_resize_fft,
_resize_spline,
_resize_linear,
_resize_cubic,
_resize_lanczos,
]
)
def test_resample(function, filename='data/siemens-star-small.npz.bz2'):
output_folder = os.path.join(result_dir, 'multigrid')


@pytest.mark.parametrize("function", [
_resize_fft,
_resize_spline,
_resize_linear,
_resize_cubic,
_resize_lanczos,
])
def test_resample(function, filename='siemens-star-small.npz.bz2'):

os.makedirs(output_folder, exist_ok=True)

dataset_file = os.path.join(testdir, filename)
dataset_file = os.path.join(data_dir, filename)
with bz2.open(dataset_file, 'rb') as f:
archive = np.load(f)
probe = archive['probe'][0]
Expand All @@ -44,3 +53,82 @@ def test_resample(function, filename='data/siemens-star-small.npz.bz2'):
f'{output_folder}/{function.__name__}-probe-phase-{i}.png',
np.angle(flattened),
)


@unittest.skipIf(
_mpi_size > 1,
reason="MPI not implemented for multi-grid.",
)
class ReconMultiGrid():
"""Test ptychography multi-grid reconstruction method."""

def interp(self, x, f):
pass

def template_consistent_algorithm(self, *, data, params):
"""Check ptycho.solver.algorithm for consistency."""
if _mpi_size > 1:
raise NotImplementedError()

with cp.cuda.Device(self.gpu_indices[0]):
parameters = tike.ptycho.reconstruct_multigrid(
parameters=params,
data=self.data,
num_gpu=self.gpu_indices,
use_mpi=self.mpi_size > 1,
num_levels=2,
interp=self.interp,
)

print()
print('\n'.join(
f'{c[0]:1.3e}' for c in parameters.algorithm_options.costs))
return parameters


class TestPtychoReconMultiGridFFT(
ReconMultiGrid,
PtychoRecon,
unittest.TestCase,
):

post_name = '-multigrid-fft'

def interp(self, x, f):
return _resize_fft(x, f)


if False:
# Don't need to run these tests on CI every time.

class TestPtychoReconMultiGridLinear(PtychoReconMultiGrid, TestPtychoRecon,
unittest.TestCase):

post_name = '-multigrid-linear'

def interp(self, x, f):
return _resize_linear(x, f)

class TestPtychoReconMultiGridCubic(PtychoReconMultiGrid, TestPtychoRecon,
unittest.TestCase):

post_name = '-multigrid-cubic'

def interp(self, x, f):
return _resize_cubic(x, f)

class TestPtychoReconMultiGridLanczos(PtychoReconMultiGrid, TestPtychoRecon,
unittest.TestCase):

post_name = '-multigrid-lanczos'

def interp(self, x, f):
return _resize_lanczos(x, f)

class TestPtychoReconMultiGridSpline(PtychoReconMultiGrid, TestPtychoRecon,
unittest.TestCase):

post_name = '-multigrid-spline'

def interp(self, x, f):
return _resize_spline(x, f)

0 comments on commit 462c5e0

Please sign in to comment.