In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np

from cryoet.data.functional import normalize_volume_to_unit_range
from cryoet.data.parsers import get_volume_and_objects, read_annotated_volume
from cryoet.data.parsers import visualize_slices_grid


In [None]:
root = "./data/czii-cryo-et-object-identification"

mode = "denoised"

In [None]:
sample1 = read_annotated_volume(root, "TS_5_4", mode)
sample2 = read_annotated_volume(root, "TS_6_4", mode)


In [None]:
# fig

In [None]:
from cryoet.data.augmentations.functional import copy_paste_augmentation

In [None]:
def compute_weighted_matrix(volume, sigma=5.0):
    """
    Compute dx, dy, dz derivatives for volume2 and apply them to volume1.
    """
    # Merge two volumes using weighted sum where weight computed as 3d gaussian with a peak in the center

    # Compute the distance from the center of the volume
    center = np.array(volume.shape) / 2

    i = np.arange(volume.shape[0])
    j = np.arange(volume.shape[1])
    k = np.arange(volume.shape[2])

    I, J, K = np.meshgrid(i, j, k, indexing="ij")
    distances = np.sqrt((I - center[0]) ** 2 + (J - center[1]) ** 2 + (K - center[2]) ** 2)

    # Compute the weight
    weight = np.exp(-distances / (sigma**2))
    mask = distances < sigma * 0.8
    weight[mask] = 1.0
    weight[~mask] -= weight[~mask].min()
    weight[~mask] /= weight[~mask].max()
    return weight

In [None]:
weight = compute_weighted_matrix(np.zeros((31,31,31)), sigma=15)

f, ax = plt.subplots(1,4, figsize=(20,5))
ax[0].imshow(weight[0])
ax[1].imshow(weight[5])
ax[2].imshow(weight[10])
ax[3].imshow(weight[15])
f

In [None]:
weight[0].max()

In [None]:
weight[15]

In [None]:
from cryoet.data.augmentations.copy_paste_merge import merge_volume_using_weighted_sum

# data = dict(
#     volume=np.zeros((64,128,128)),
#     centers=np.empty((0,3)),
#     radius=np.empty((0,)),
#     labels=np.empty((0,)),
# )

data = dict(
    volume=sample1.volume,
    centers=sample1.centers_px,
    radius=sample1.radius_px,
    labels=sample1.labels,
)

for _ in range(4):
    data = copy_paste_augmentation(
        **data,
        samples=[sample2],
        scale=1.0,
        z_rotation_limit=5,
        x_rotation_limit=0,
        y_rotation_limit=0,
        merge_method=merge_volume_using_weighted_sum,
    )


In [None]:
fig = visualize_slices_grid(
    **data,
    slices_to_show=None,
    only_slices_with_objects=True,
    voxel_size=1.0,
    ncols=4,
)
fig.show()