Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions src/aspire/abinitio/commonline_cn.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,99 @@ def generate_cand_rots_third_rows(self, legacy=True):
third_rows[i] = x, y, z

return third_rows


class VeeOuterProductEstimator:
"""
Incrementally accumulate outer product entries of unknown conjugation.
"""

# These arrays are small enough to just use doubles.
# Then we can probably avoid numerical summing concerns without precomputing denom
dtype = np.float64

# conjugation
J = np.array([[0, 0, -1], [0, 0, -1], [-1, -1, 0]], dtype=np.float64)

# Create a mask selecting elements unchanged by J
mask = J == 0
mask_inverse = ~mask

def __init__(self):
# Create storage for non_negative (index 0) and negative_entries (index 1)
self.V_estimates = np.zeros((2, 3, 3), dtype=self.dtype)
self.counts = np.zeros((2, 3, 3), dtype=int)
# Might as well gather the second moment for var in case you need it later
self.V_estimates_moment2 = self.V_estimates.copy()

def push(self, V):
"""
Given V, accumulate entries into two running averages.
"""

self.V_estimates[:, self.mask] += V[self.mask]

# Parens are important here
non_negative_entries = (V >= 0) & self.mask_inverse
negative_entries = (V < 0) & self.mask_inverse

self.V_estimates[0][non_negative_entries] += V[non_negative_entries]
self.V_estimates[1][negative_entries] += V[negative_entries]

self.counts[:, self.mask] += 1
self.counts[0][non_negative_entries] += 1
self.counts[1][negative_entries] += 1

self.V_estimates_moment2[..., self.mask] += V[self.mask] ** 2
self.V_estimates_moment2[0][non_negative_entries] += (
V[non_negative_entries] ** 2
)
self.V_estimates_moment2[1][negative_entries] += V[negative_entries] ** 2

def mean(self):
"""
Running mean.
"""
# note double sum and double count for `mask` elements cancel out
return np.sum(self.V_estimates, axis=0) / np.sum(self.counts, axis=0)

def second_moment(self):
"""
Running second moment.
"""
# note double sum and double count for `mask` elements cancel out
return np.sum(self.V_estimates_moment2, axis=0) / np.sum(self.counts, axis=0)

def variance(self):
"""
Running variance.
"""
return self.second_moment() - self.mean() ** 2

def median_sign_mean_estimate(self):
"""
Return the mean for the group of entries in V containing
the median value.

Seperately computes running metrics (mean) for the group of
non_negative and negative entries. Keeps seperate counts
so we can compute an effective median sign estimate.

"""

# Find whether non negative or negative had the most entries
# This should effectively give the the group which has the same
# sign as median.
# Note on tie this code will return non_negative.
# Technically the effective median would be mean(group_means) in that case,
# but I don't think that logic is necessary yet. If needed we can add easily.
group_ind = np.argmax(self.counts, axis=0)
group_sum = np.take_along_axis(self.V_estimates, group_ind[np.newaxis], axis=0)
group_count = np.take_along_axis(self.counts, group_ind[np.newaxis], axis=0)
group_mean = group_sum / group_count
# group_moment2 = np.take_along_axis(
# self.V_estimates_moment2, group_ind[np.newaxis], axis=0
# )
# group_var = group_moment2 / group_count - group_mean**2 # might be interesting...

return group_mean
60 changes: 60 additions & 0 deletions tests/test_orient_symmetric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numpy.linalg import det, norm

from aspire.abinitio import CLSymmetryC3C4, CLSymmetryCn
from aspire.abinitio.commonline_cn import VeeOuterProductEstimator
from aspire.source import Simulation
from aspire.utils import Rotation
from aspire.utils.coor_trans import (
Expand Down Expand Up @@ -519,3 +520,62 @@ def buildOuterProducts(n_img, dtype):
viis[i] = np.outer(gt_vis[i], gt_vis[i])

return vijs, viis, gt_vis


def test_vee_estimator_simple():
"""
Manully run VeeOuterProductEstimator for prebaked inputs.
"""

est = VeeOuterProductEstimator()

est.push(np.full((3, 3), -2, dtype=np.float64))
est.push(np.full((3, 3), 2, dtype=np.float64))

assert np.allclose(est.mean(), np.full((3, 3), 0, dtype=np.float64))
assert np.allclose(est.variance(), np.full((3, 3), 4, dtype=np.float64))
assert np.allclose(
est.median_sign_mean_estimate(), np.array([[0, 0, 2], [0, 0, 2], [2, 2, 0]])
)

est.push(np.full((3, 3), -2, dtype=np.float64))
est.push(np.full((3, 3), -2, dtype=np.float64))
assert np.allclose(
est.median_sign_mean_estimate(),
np.array([[-1, -1, -2], [-1, -1, -2], [-2, -2, -1]]),
)


def test_vee_estimator_stat():
"""
Tests incremental VeeOuterProductEstimator using random data,
comparing to global numpy arithmetic.
"""

est = VeeOuterProductEstimator()

n = 1000
# Mix of pos and negative centers
centers = np.array([(i % 2) * 2 - 1 for i in range(1, 10)])
V = np.array([np.random.normal(loc=c, scale=4, size=n) for c in centers])
V = V.reshape(3, 3, n)

for v in np.transpose(V, (2, 0, 1)):
est.push(v)

assert np.allclose(est.mean(), np.mean(V, axis=2))
assert np.allclose(est.variance(), np.var(V, axis=2))

res = np.empty((3, 3))
# Find the mean of the entries matching sign of median
for i, j in [(0, 2), (1, 2), (2, 0), (2, 1)]:
entries = V[i, j]
group_selection = np.sign(np.median(entries)) == np.sign(entries)
res[i, j] = np.mean(entries[group_selection])

# These entries should match a global mean (unaffected by J)
for i, j in [(0, 0), (0, 1), (1, 0), (1, 1), (2, 2)]:
entries = V[i, j]
res[i, j] = np.mean(entries)

assert np.allclose(est.median_sign_mean_estimate(), res)