Skip to content

Commit

Permalink
Merge branch 'consolidate' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
dalonsoa committed Oct 8, 2020
2 parents 73788ce + 11ecf90 commit 1ff5e8c
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 51 deletions.
7 changes: 7 additions & 0 deletions strainmap/gui/strain_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, root, controller):
self.effective_disp = tk.BooleanVar(value=True)
self.resample = tk.BooleanVar(value=True)
self.gls = (tk.StringVar(), tk.StringVar(), tk.StringVar())
self.timeshift_var = tk.DoubleVar(value=0.0)

# Figure-related variables
self.fig = None
Expand Down Expand Up @@ -105,6 +106,8 @@ def create_controls(self):
resample = ttk.Checkbutton(
master=strain_frame, text="Resample RR", variable=self.resample
)
timeshift_lbl = ttk.Label(strain_frame, text="Time shift (s):")
timeshift = ttk.Entry(strain_frame, textvariable=self.timeshift_var)
recalc = ttk.Button(
master=strain_frame, text="Recalculate strain", command=self.recalculate
)
Expand Down Expand Up @@ -146,6 +149,8 @@ def create_controls(self):
effective.grid(row=0, column=1, sticky=tk.NSEW, padx=5)
resample.grid(row=1, column=1, sticky=tk.NSEW, padx=5)
recalc.grid(row=2, column=0, columnspan=2, sticky=tk.NSEW, padx=5)
timeshift_lbl.grid(row=0, column=2, sticky=tk.NSEW, padx=5)
timeshift.grid(row=1, column=2, sticky=tk.NSEW, padx=5)
self.output_frame.grid(row=0, column=2, rowspan=3, sticky=tk.NSEW, padx=5)
for i, l in enumerate(self.gls_lbl):
l.grid(row=i, column=99, sticky=tk.NSEW, padx=5)
Expand All @@ -158,6 +163,7 @@ def dataset_changed(self, *args):
"""Updates the view when the selected dataset is changed."""
current = self.datasets_var.get()
self.images = self.data.data_files.mag(current)
self.timeshift_var.set(self.data.timeshift)
if self.data.strain.get(current):
self.update_strain_list(current)
else:
Expand Down Expand Up @@ -345,6 +351,7 @@ def calculate_strain(self, recalculate=False):
effective_displacement=self.effective_disp.get(),
resample=self.resample.get(),
recalculate=recalculate,
timeshift=self.timeshift_var.get(),
)
lbl = ("psGLS", "esGLS", "pGLS")
for i, v in enumerate(self.gls):
Expand Down
7 changes: 7 additions & 0 deletions strainmap/models/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ def read_h5_file(stored: Tuple, filename: Union[Path, Text]) -> dict:
for s in stored:
if s == "sign_reversal":
attributes[s] = tuple(sm_file[s][...])
elif s == "timeshift":
# TODO Simplify in the final version. Current design "heals" existing files
if s in sm_file.attrs:
attributes[s] = sm_file.attrs[s]
elif s in sm_file:
del sm_file[s]
continue
elif "files" in s and s in sm_file:
base_dir = paths_from_hdf5(defaultdict(dict), filename, sm_file[s])
if base_dir is None:
Expand Down
190 changes: 148 additions & 42 deletions strainmap/models/strain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Text, Tuple, Callable
from typing import Dict, Text, Tuple, Callable, Optional
from itertools import product
from functools import partial
from collections import defaultdict
Expand Down Expand Up @@ -209,6 +209,7 @@ def coordinates(
nang: int = 24,
background: str = "Estimated",
resample=True,
use_frame_zero=True,
) -> np.ndarray:

rkey = f"radial x{nrad} - {background}"
Expand All @@ -220,6 +221,7 @@ def coordinates(
origin_iter = (data.zero_angle[d][..., 1] for d in datasets)
px_size = data.data_files.pixel_size(datasets[0])
t_iter = tuple((data.data_files.time_interval(d) for d in datasets))
ts = data.timeshift

# z_loc should be increasing with the dataset
z_loc = -z_loc if all(np.diff(z_loc) < 0) else z_loc
Expand All @@ -243,6 +245,17 @@ def to_cylindrical(mask, theta0, origin):
* z_loc[(...,) + (None,) * (len(in_plane[:, :1].shape) - 1)]
)
result = np.concatenate((out_plane, in_plane), axis=1).transpose((1, 2, 0, 3, 4))

# We shift the coordinates
for i, t in enumerate(t_iter):
result[:, :, i, ...] = shift_data(
result[:, :, i, ...], time_interval=t, timeshift=ts, axis=1
)
# We pick just the first element, representing the pixel locations at time zero.
# Yeah, this function is an massive overkill and will be changed asap.
if use_frame_zero:
result[...] = result[:, :1, ...]

return resample_interval(result, t_iter) if resample else result


Expand All @@ -265,6 +278,7 @@ def displacement(
cyl_iter = (data.masks[d][vkey] for d in datasets)
m_iter = (data.masks[d][rkey] + 100 * data.masks[d][akey] for d in datasets)
t_iter = tuple((data.data_files.time_interval(d) for d in datasets))
ts = data.timeshift
reduced_vel_map = map(partial(masked_reduction, axis=img_axis), cyl_iter, m_iter)

# Create a mask to define the regions over which to calculate the background
Expand All @@ -290,6 +304,9 @@ def displacement(
# The signs of the in-plane displacement are reversed to be consistent with ECHO
disp[-1] = np.concatenate((vlong[None, ...], -disp[-1]), axis=0)

# We shift the data to the correct time
disp[-1] = shift_data(disp[-1], time_interval=t, timeshift=ts, axis=1)

disp = np.asarray(disp)

result = np.cumsum(disp, axis=2).transpose((1, 2, 0, 3, 4))
Expand Down Expand Up @@ -351,38 +368,17 @@ def unresample_interval(
return np.moveaxis(np.array([f(tt) for tt, f in zip(t, fdisp)]), 0, 2)


def reconstruct_strain(
strain,
masks,
datasets,
resample: bool,
interval: tuple,
nrad: int = 3,
nang: int = 24,
background: str = "Estimated",
):
rkey = f"radial x{nrad} - {background}"
akey = f"angular x{nang} - {background}"
m_iter = (masks[d][rkey] + 100 * masks[d][akey] for d in datasets)
frames = masks[datasets[0]][rkey].shape[0]

strain = unresample_interval(strain, interval, frames) if resample else strain
return (
masked_expansion(s, m, axis=(2, 3))
for s, m in zip(strain.transpose((2, 0, 1, 3, 4)), m_iter)
)


def calculate_strain(
data: StrainMapData,
datasets: Tuple[str, ...],
callback: Callable = terminal,
effective_displacement=True,
resample=True,
recalculate=False,
timeshift: Optional[float] = None,
):
"""Calculates the strain and updates the Data object with the result."""
steps = 6.0
steps = 7.0
# Do we need to calculate the strain?
if all([d in data.strain.keys() for d in datasets]) and not recalculate:
return
Expand All @@ -401,34 +397,38 @@ def calculate_strain(
callback("Insufficient datasets to calculate strain. At least 2 are needed.")
return 1

callback("Preparing dependent variables", 1 / steps)
if timeshift is not None:
data.timeshift = timeshift
data.save(["timeshift"])

callback("Calculating displacement", 1 / steps)
disp = displacement(
data,
sorted_datasets,
effective_displacement=effective_displacement,
resample=resample,
)

callback("Preparing independent variables", 2 / steps)
callback("Calculating coordinates", 2 / steps)
space = coordinates(data, sorted_datasets, resample=resample)

callback("Calculating twist", 2 / steps)
callback("Calculating twist", 3 / steps)
data.twist = twist(data, sorted_datasets)

callback("Calculating derivatives", 3 / steps)
callback("Calculating strain", 4 / steps)
reduced_strain = differentiate(disp, space)

callback("Calculating the regional strains", 4 / steps)
strain = reconstruct_strain(
callback("Calculating the regional strains", 5 / steps)
data.strain = calculate_regional_strain(
reduced_strain,
data.masks,
sorted_datasets,
resample=resample,
interval=tuple((data.data_files.time_interval(d) for d in datasets)),
timeshift=data.timeshift,
)
data.strain = calculate_regional_strain(strain, data.masks, sorted_datasets)

callback("Calculating markers", 5 / steps)
callback("Calculating markers", 6 / steps)
for d in datasets:
labels = [
s
Expand Down Expand Up @@ -520,21 +520,70 @@ def finite_differences(f, x, axis=0, period=None):
return np.moveaxis(result, 0, axis)


def calculate_regional_strain(strain, masks, datasets) -> Dict:
def calculate_regional_strain(
reduced_strain: np.ndarray,
masks: dict,
datasets: tuple,
resample: bool,
interval: tuple,
timeshift: float,
nrad: int = 3,
nang: int = 24,
lreg: int = 6,
):
"""Calculate the regional strains (1D curves)."""
from strainmap.models.contour_mask import masked_means

vkey = "cylindrical - Estimated"
gkey = "global - Estimated"
akey = "angular x6 - Estimated"
a24key = "angular x24 - Estimated"
rkey = "radial x3 - Estimated"

data_shape = masks[datasets[0]][vkey].shape
m_iter = (masks[d][rkey] + 100 * masks[d][a24key] for d in datasets)

strain = (
unresample_interval(reduced_strain, interval, data_shape[1])
if resample
else reduced_strain
)

# Mask to define the 6x angular and 3x radial regional masks for the reduced strain
treg = nrad * nang
lmask = (
np.ceil(np.arange(1, treg + 1) / treg * lreg)
.reshape((nang, nrad))
.T[None, None, ...]
)
rmask = (
np.arange(1, nrad + 1)[None, None, :, None] * np.ones(nang)[None, None, None, :]
)

result: Dict[Text, Dict[str, np.ndarray]] = defaultdict(dict)
for d, s in zip(datasets, strain):
result[d][gkey] = masked_means(s, masks[d][gkey], axes=(2, 3)) * 100
result[d][akey] = masked_means(s, masks[d][akey], axes=(2, 3)) * 100
result[d][rkey] = masked_means(s, masks[d][rkey], axes=(2, 3)) * 100
result[d][vkey] = s
vars = zip(datasets, strain.transpose((2, 0, 1, 3, 4)), m_iter, interval)
for d, s, m, t in vars:

# When calculating the regional strains from the reduced strain, we need the
# superpixel area. This has to be shifted to match the times of the strain.
rm = shift_data(
superpixel_area(m, data_shape, axis=(2, 3)), t, timeshift, axis=1
)

# Global and regional strains are calculated by modifying the relevant weights
result[d][gkey] = np.average(s, weights=rm, axis=(2, 3))[None, ...] * 100
result[d][akey] = np.stack(
np.average(s, weights=rm * (lmask == i), axis=(2, 3)) * 100
for i in range(1, lreg + 1)
)
result[d][rkey] = np.stack(
np.average(s, weights=rm * (rmask == i), axis=(2, 3)) * 100
for i in range(1, nrad + 1)
)

# To match the strain with the masks, we shift the strain in the opposite
# direction
result[d][vkey] = masked_expansion(
shift_data(s, t, -timeshift, axis=1), m, axis=(2, 3)
)

return result

Expand Down Expand Up @@ -592,7 +641,11 @@ def initialise_markers(data: StrainMapData, dataset: str, str_labels: list):
In a healthy patient, the three markers should be roughly at the same position.
"""
pos_es = int(data.markers[dataset]["global - Estimated"][0, 1, 3, 0])
# The location of the ES marker is shifted by an approximate number of frames
pos_es = int(
data.markers[dataset]["global - Estimated"][0, 1, 3, 0]
- round(data.timeshift / data.data_files.time_interval(dataset))
)

# Loop over the region types (global, angular, etc)
for r in data.strain[dataset].keys():
Expand Down Expand Up @@ -666,7 +719,9 @@ def twist(
cyl_iter = (data.masks[d][vkey] for d in datasets)
m_iter = (data.masks[d][rkey] + 100 * data.masks[d][akey] for d in datasets)
reduced_vel_map = map(partial(masked_reduction, axis=img_axis), cyl_iter, m_iter)
radius = coordinates(data, datasets, resample=False)[1].mean(axis=(2, 3))
radius = coordinates(data, datasets, resample=False, use_frame_zero=False)[1].mean(
axis=(2, 3)
)

vels = (
np.array([v[2].mean(axis=(1, 2)) - v[2].mean() for v in reduced_vel_map])
Expand All @@ -678,3 +733,54 @@ def twist(
coords={"dataset": datasets, "item": ["angular_velocity", "radius"]},
values=np.stack((vels, radius.T), axis=-1),
)


def shift_data(
data: np.ndarray, time_interval: float, timeshift: float, axis: int = 0
) -> np.ndarray:
"""Interpolates the data to account for a timeshift correction."""
time = np.arange(-1, data.shape[axis] + 1)
d = np.moveaxis(data, axis, 0)
d = np.concatenate([d[-1:], d, d[:1]], axis=0)

shift_frames = int(round(timeshift / time_interval))
remainder = timeshift - time_interval * shift_frames
new_time = np.arange(data.shape[axis]) + remainder
new_data = np.roll(
interpolate.interp1d(time, d, axis=0)(new_time), -shift_frames, axis=0,
)
return np.moveaxis(new_data, 0, axis)


def superpixel_area(masks: np.ndarray, data_shape: tuple, axis: tuple) -> np.ndarray:
from functools import reduce

assert data_shape[-len(masks.shape) :] == masks.shape

mask_max = masks.max()
nrad, nang = mask_max % 100, mask_max // 100
nz = np.nonzero(masks)
xmin, xmax, ymin, ymax = (
nz[-2].min(),
nz[-2].max() + 1,
nz[-1].min(),
nz[-1].max() + 1,
)
smasks = masks[..., xmin : xmax + 1, ymin : ymax + 1]

shape = [s for i, s in enumerate(data_shape) if i not in axis] + [nrad, nang]
reduced = np.zeros(shape, dtype=int)

tile_shape = (
(data_shape[0],) + (1,) * len(masks.shape)
if data_shape != masks.shape
else (1,) * len(masks.shape)
)

def reduction(red, idx):
elements = tuple([...] + [k - 1 for k in idx])
i = idx[0] + 100 * idx[1]
red[elements] = np.tile(smasks == i, tile_shape).sum(axis=axis).data
return red

return reduce(reduction, product(range(1, nrad + 1), range(1, nang + 1)), reduced)
Loading

0 comments on commit 1ff5e8c

Please sign in to comment.