The original neighbor implementation does not deal with the negative part of the symmetry matrix.

I need to find a way to get the neighbors within the negative part of the spectrum so that we can perform the message passing mechanism where neighbors from a given point are aggregated via some differentiable function.

The roadmap would be:
 1. Find neihgbors on the negative part of the symmetry spectrum
 2. Make sure that the order of the neighbors is the same as the flattened array corresponding to the negative spectrum, i.e. the index j corresponds to the j-th flattened pixel and the set of neighbors of that same j-th frequency component

In [119]:
from tomoSegmentPipeline.utils.common import read_array, write_array
from tomoSegmentPipeline.utils import setup

from mwReconstruction.dataloader import destripeDataSet
from mwReconstruction.model import *

import numpy as np
from operator import itemgetter
import matplotlib.pyplot as plt
import random
import mrcfile
import pandas as pd
import torch
from torch.utils.data import Dataset
import os
from glob import glob
import scipy.io as io
# from skimage.metrics import structural_similarity as ssim
# from skimage.metrics import normalized_mutual_information as nmi
from scipy import ndimage
from joblib import Parallel, delayed
from tqdm import tqdm

PARENT_PATH = setup.PARENT_PATH
ISONET_PATH = os.path.join(PARENT_PATH, 'data/isoNet/')

%matplotlib inline
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def make_shell(inner_radius, delta_r, tomo_shape):
    """
    Creates a (2D) shell with given inner_radius and delta_r width centered at the middle of the array.
    
    """
    outer_radius = inner_radius + delta_r

    length = min(tomo_shape)
    mask_shape = len(tomo_shape) * [length]
    _shell_mask = np.zeros(mask_shape)

    # only do positive quadrant first
    for y in range(0, outer_radius + 1):
        for x in range(0, outer_radius + 1):

            r = np.linalg.norm([y, x])

            if r >= inner_radius and r < outer_radius:
                yidx = y + length // 2
                xidx = x + length // 2

                _shell_mask[yidx, xidx] = 1

    # first get shell for x>0
    aux = (
        np.rot90(_shell_mask, axes=(0, 1))
        + np.rot90(_shell_mask, 2, axes=(0, 1))
        + np.rot90(_shell_mask, 3, axes=(0, 1))
    )
    aux2 = _shell_mask + aux

    # finally, fill the actual shape of the tomogram with the mask
    shell_mask = np.zeros(tomo_shape)
    shell_mask[
        (tomo_shape[0] - length) // 2 : (tomo_shape[0] + length) // 2,
        (tomo_shape[1] - length) // 2 : (tomo_shape[1] + length) // 2,
    ] = aux2

    return shell_mask

In [3]:
def make_N_neg_matrix(dr, neg_mask, power_mask):
    """
    Returns a pd.DataFrame of shape (N negative components, 65) with the flattened negative frequency components as the index.
    
    The columns correspond to the index value (own neighbor) and 64 randomly choosen, "uncorrupted", neighbors (NaN whenever no neighbor is present) 
    within a ring of radius 3 of each index value. 
    
    All values are given according to the flattened arrays. We assume that we use ZX slices of YZX images (missing wedge in ZX).
    
    - neg_mask: boolean array with 1 wherever the symmatrix equals -1. Shape: (Z,X)
    - power_mask: boolean array indicating low power coefficients. Shape: (Z,X)
    """
    tomo_shape = neg_mask.shape

    neg_neighbors = []

    def get_neighbors(inner_radius):
        # get masks
        ring_mask = make_shell(inner_radius, dr, tomo_shape)
        ring_uncorrupted_mask = (1 - power_mask) * ring_mask
        ring_neg_mask = ring_mask * neg_mask

        # get neighbors
        ring_neg_nghbrs = np.nonzero(ring_neg_mask.flatten())[0]
        ring_uncorrupted_nghbrs = np.nonzero(ring_uncorrupted_mask.flatten())[0]

        k = min(len(ring_uncorrupted_nghbrs), 64)

        # for each negative neighbor, get a random sample of size k from the uncorrupted neighbors. The first neighbor of a point is itself.
        aux = pd.DataFrame(
            [
                np.append(
                    n,
                    np.random.choice(
                        ring_uncorrupted_nghbrs[ring_uncorrupted_nghbrs != n], k
                    ),
                )
                for n in ring_neg_nghbrs
            ],
            index=ring_neg_nghbrs,
        )

        return aux

    neg_neighbors = Parallel(n_jobs=8)(
        delayed(get_neighbors)(inner_radius)
        for inner_radius in np.arange(0, min(tomo_shape) // 2 - dr, dr)
    )
    N_neg = pd.concat(neg_neighbors)

    # make a dummy dataframe with all indices corresponding to negative entries from the symmatrix
    aux = np.nonzero(neg_mask.flatten())[0]
    aux = pd.DataFrame(aux, index=aux)

    # get the final data frame with all negative flattened indices
    all_N_neg = N_neg.join(aux, how="right", rsuffix="_r")
    all_N_neg["0"] = all_N_neg["0_r"]
    all_N_neg.drop("0_r", axis=1, inplace=True)
    all_N_neg.columns = range(65)
    all_N_neg = all_N_neg.sort_index()

    return all_N_neg


# Mapping negative frequency components

We want to find a mapping of the indices from the flattened array versions of the original matrix and the negative component part.

In [4]:
dummy_symmatrix = make_symmatrix(10, 10)
dummy_symmatrix

array([[-1, -1, -1, -1, -1, -1,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1, -1,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1, -1,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1, -1,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1, -1,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1,  0,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1,  1,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1,  1,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1,  1,  1,  1,  1,  1],
       [-1, -1, -1, -1, -1,  1,  1,  1,  1,  1]])

In [5]:
dummy_symmask = torch.from_numpy(dummy_symmatrix).float()

_, global_to_neg_mapping = torch.sort(
            torch.cat(
                [
                    torch.where(dummy_symmask.reshape(-1) == index)[0]
                    for index in [-1, 0, 1]
                ]
            )
        )

In [6]:
global_to_neg_mapping = torch.where(dummy_symmask.flatten()==-1)[0]

_mapping = zip(global_to_neg_mapping.numpy(), range(len(global_to_neg_mapping)))
global_to_neg_mapping = dict(_mapping)

In [7]:
global_to_neg_mapping

{0: 0,
 1: 1,
 2: 2,
 3: 3,
 4: 4,
 5: 5,
 10: 6,
 11: 7,
 12: 8,
 13: 9,
 14: 10,
 15: 11,
 20: 12,
 21: 13,
 22: 14,
 23: 15,
 24: 16,
 25: 17,
 30: 18,
 31: 19,
 32: 20,
 33: 21,
 34: 22,
 35: 23,
 40: 24,
 41: 25,
 42: 26,
 43: 27,
 44: 28,
 45: 29,
 50: 30,
 51: 31,
 52: 32,
 53: 33,
 54: 34,
 60: 35,
 61: 36,
 62: 37,
 63: 38,
 64: 39,
 70: 40,
 71: 41,
 72: 42,
 73: 43,
 74: 44,
 80: 45,
 81: 46,
 82: 47,
 83: 48,
 84: 49,
 90: 50,
 91: 51,
 92: 52,
 93: 53,
 94: 54}

In [17]:
dummy_powerMask = np.zeros_like(dummy_symmatrix)
neg_mask = (dummy_symmatrix==-1).astype(int)
neg_mask

array([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])

In [104]:
tomo_shape = neg_mask.shape

global_to_neg_mapping = np.where(neg_mask.flatten()==1)[0]
_mapping = zip(global_to_neg_mapping, range(len(global_to_neg_mapping)))
global_to_neg_mapping = dict(_mapping)

neg_neighbors = []

def get_neighbors(inner_radius):
    # get masks
    ring_mask = make_shell(inner_radius, dr, tomo_shape)
    ring_neg_mask = ring_mask * neg_mask
    ring_uncorrupted_mask = (1 - dummy_powerMask) * ring_neg_mask
    
    # get neighbors
    ring_neg_nghbrs = np.nonzero(ring_neg_mask.flatten())[0]
    ring_uncorrupted_nghbrs = np.nonzero(ring_uncorrupted_mask.flatten())[0]

    k = min(len(ring_uncorrupted_nghbrs), n_neighbors)

#     def retrieve_mapped_random_neighbors(n, k):
#         from operator import itemgetter
#         np.random.seed(2022)
#         aux = np.random.choice(ring_uncorrupted_nghbrs[ring_uncorrupted_nghbrs!=n], k, replace=False)
#         return itemgetter(*aux)(global_to_neg_mapping)


#     # for each negative neighbor, get a random sample of size k from the uncorrupted neighbors. The first neighbor of a point is itself.
#     aux = pd.DataFrame([np.append(global_to_neg_mapping[n],
#                                   retrieve_mapped_random_neighbors(n, k))
#                         for n in ring_neg_nghbrs])

    def retrieve_mapped_random_neighbors(n, k):
        neighbor_pool = ring_uncorrupted_nghbrs[ring_uncorrupted_nghbrs!=n]
        neighbor_global_indices = np.random.choice(neighbor_pool, k, replace=False)
        own_neighbor = global_to_neg_mapping[n]
        neighbors = np.append(own_neighbor, itemgetter(*neighbor_global_indices)(global_to_neg_mapping))         
        return neighbors


    # for each corrupted negative neighbor, get a random sample of size k from the uncorrupted neighbors. 
    # The first neighbor of a point is itself.
    # Finally map global indices to negative indices in a way that is consistent with the GCN logic
    aux = pd.DataFrame([retrieve_mapped_random_neighbors(n, k) for n in ring_corrupted_nghbrs], index=ring_corrupted_nghbrs)

    return aux

In [109]:
ring_mask = make_shell(3, 2, tomo_shape=neg_mask.shape)
ring_neg_mask = ring_mask * neg_mask
ring_uncorrupted_mask = (1 - dummy_powerMask) * ring_neg_mask
ring_corrupted_mask = (dummy_powerMask) * ring_neg_mask


# get neighbors
ring_neg_nghbrs = np.nonzero(ring_neg_mask.flatten())[0]
ring_uncorrupted_nghbrs = np.nonzero(ring_uncorrupted_mask.flatten())[0]
ring_corrupted_nghbrs = np.nonzero(ring_corrupted_mask.flatten())[0]

In [114]:
ring_corrupted_nghbrs

array([], dtype=int64)

In [111]:
dr = 2
n_neighbors = 5

In [112]:
neg_neighbors = Parallel(n_jobs=8)(
    delayed(get_neighbors)(inner_radius)
    for inner_radius in np.arange(0, min(tomo_shape) // 2 - dr, dr)
)

In [113]:
neg_neighbors

[Empty DataFrame
 Columns: []
 Index: [],
 Empty DataFrame
 Columns: []
 Index: []]

In [41]:
N_neg = pd.concat(neg_neighbors)

# make a dummy dataframe with all indices corresponding to negative entries from the symmatrix
aux = len(np.nonzero(neg_mask.flatten())[0])
aux = pd.DataFrame(range(aux), index=range(aux))

# get the final data frame with all negative flattened indices
all_N_neg = N_neg.join(aux, on=0, how='right', rsuffix='_r')
all_N_neg['0'] = all_N_neg['0_r']
all_N_neg.index = all_N_neg['key_0'].values

all_N_neg.drop(['0_r', 'key_0'], axis=1, inplace=True)

all_N_neg.columns = range(6)
all_N_neg = all_N_neg.sort_index()

all_N_neg

Unnamed: 0,0,1,2,3,4,5
0,0,,,,,
1,1,,,,,
2,2,,,,,
3,3,,,,,
4,4,,,,,
5,5,,,,,
6,6,,,,,
7,7,,,,,
8,8,47.0,13.0,13.0,36.0,41.0
9,9,20.0,49.0,48.0,47.0,10.0


In [204]:
all_N_neg.head()

Unnamed: 0,0,1,2,3,4,5
0,0,,,,,
1,1,,,,,
2,2,49.0,12.0,35.0,9.0,47.0
3,3,30.0,9.0,11.0,30.0,35.0
4,4,53.0,52.0,41.0,18.0,48.0


In [120]:
def make_N_neg_matrix(dr, neg_mask, power_mask, n_neighbors=64):
    """
    Returns a pd.DataFrame of shape (N negative components, 1+n_neighbors) with the set of uncorrupted neighbors of 
    each negative frequency component indexed based on the flattened array of the negative part of the frequency spectrum.
    
    The columns correspond to the index value (own neighbor) and n_neighbors randomly choosen, "uncorrupted", neighbors (NaN whenever no neighbor is present) 
    within a ring of radius dr of each index value. 
    
    All values are given according to the flattened arrays. We assume that we use ZX slices of YZX images (missing wedge in ZX).
    
    - neg_mask: boolean array with 1 wherever the symmatrix equals -1. Shape: (Z,X)
    - power_mask: boolean array indicating low power coefficients. Shape: (Z,X)
    """
    tomo_shape = neg_mask.shape
    
    global_to_neg_mapping = np.where(neg_mask.flatten()==1)[0]
    _mapping = zip(global_to_neg_mapping, range(len(global_to_neg_mapping)))
    global_to_neg_mapping = dict(_mapping)

    neg_neighbors = []

    def get_neighbors(inner_radius):
        "Get uncorrupted neighbors within a ring"
        ################# get masks
        ring_mask = make_shell(inner_radius, dr, tomo_shape)
        ring_neg_mask = ring_mask * neg_mask
        # we only sample neighbors from the uncorrupted, negative part of the spectrum
        ring_uncorrupted_mask = (1 - power_mask) * ring_neg_mask
        ring_corrupted_mask = power_mask * ring_neg_mask

        ################# get neighbors
        # these sets are still based on the full image mapping
        # ring_neg_nghbrs = np.nonzero(ring_neg_mask.flatten())[0]
        ring_uncorrupted_nghbrs = np.nonzero(ring_uncorrupted_mask.flatten())[0]
        ring_corrupted_nghbrs = np.nonzero(ring_neg_mask.flatten())[0] #####################

        k = min(len(ring_uncorrupted_nghbrs)-1, n_neighbors)

        def retrieve_mapped_random_neighbors(n, k):
            neighbor_pool = ring_uncorrupted_nghbrs[ring_uncorrupted_nghbrs!=n]
            neighbor_global_indices = np.random.choice(neighbor_pool, k, replace=False)
            own_neighbor = global_to_neg_mapping[n]
            neighbors = np.append(own_neighbor, itemgetter(*neighbor_global_indices)(global_to_neg_mapping))         
            return neighbors


        # for each corrupted negative neighbor, get a random sample of size k from the uncorrupted neighbors. 
        # The first neighbor of a point is itself.
        # Finally map global indices to negative indices in a way that is consistent with the GCN logic
        aux = pd.DataFrame([retrieve_mapped_random_neighbors(n, k) for n in ring_corrupted_nghbrs], index=ring_corrupted_nghbrs)

        return aux

    neg_neighbors = Parallel(n_jobs=12)(
        delayed(get_neighbors)(inner_radius)
        for inner_radius in np.arange(0, min(tomo_shape) // 2 - dr, dr)
    )
    N_neg = pd.concat(neg_neighbors)

    # make a dummy dataframe with all indices corresponding to negative entries from the symmatrix
    aux = len(np.nonzero(neg_mask.flatten())[0])
    aux = pd.DataFrame(range(aux), index=range(aux))

    # get the final data frame with all negative flattened indices
    # Note: there are 2 reasons why an index might only be its own neighbor
    # 1) It is noncorrupted
    # 2) It has no uncorrupted neighbors
    N_neg = N_neg.join(aux, on=0, how='right', rsuffix='_r')
    N_neg['0'] = N_neg['0_r']
    N_neg.index = N_neg['key_0'].values

    N_neg.drop(['0_r', 'key_0'], axis=1, inplace=True)

    N_neg.columns = range(n_neighbors+1)
    N_neg = N_neg.sort_index()

    return N_neg

In [121]:
all_N_neg = make_N_neg_matrix(2, neg_mask, dummy_powerMask, 5)

In [122]:
all_N_neg.head()

Unnamed: 0,0,1,2,3,4,5
0,0,,,,,
1,1,,,,,
2,2,,,,,
3,3,,,,,
4,4,,,,,


In [133]:
dummy_x_neg = torch.rand(len(global_to_neg_mapping))
dummy_x_neg = torch.cat((dummy_x_neg, torch.tensor([0])))
dummy_x_neg

tensor([0.1074, 0.6619, 0.7425, 0.9371, 0.6984, 0.2325, 0.1373, 0.3756, 0.3253,
        0.2943, 0.2727, 0.4966, 0.8201, 0.1907, 0.7363, 0.3069, 0.4525, 0.4501,
        0.1470, 0.4819, 0.5082, 0.9958, 0.6989, 0.1234, 0.8892, 0.8084, 0.3648,
        0.2810, 0.2413, 0.0018, 0.8153, 0.9247, 0.0641, 0.5524, 0.0139, 0.7009,
        0.9528, 0.5692, 0.8098, 0.9850, 0.3136, 0.9310, 0.2921, 0.0617, 0.9734,
        0.0046, 0.5753, 0.6290, 0.7082, 0.0216, 0.2623, 0.9593, 0.4072, 0.2168,
        0.8892, 0.0000])

In [136]:
message_tensor = torch.tensor(all_N_neg.fillna(-1).to_numpy()).long()
message_tensor = dummy_x_neg[message_tensor]

final_x_neg = message_tensor.sum(1)

In [None]:
final_x_neg.shape

torch.Size([55])

In [None]:
_, mask_ind = torch.sort(
            torch.cat(
                [
                    torch.where(dummy_symmask.reshape(-1) == index)[0]
                    for index in [-1, 0, 1]
                ]
            )
        )

mask_ind

tensor([ 0,  1,  2,  3,  4,  5, 56, 57, 58, 59,  6,  7,  8,  9, 10, 11, 60, 61,
        62, 63, 12, 13, 14, 15, 16, 17, 64, 65, 66, 67, 18, 19, 20, 21, 22, 23,
        68, 69, 70, 71, 24, 25, 26, 27, 28, 29, 72, 73, 74, 75, 30, 31, 32, 33,
        34, 55, 76, 77, 78, 79, 35, 36, 37, 38, 39, 80, 81, 82, 83, 84, 40, 41,
        42, 43, 44, 85, 86, 87, 88, 89, 45, 46, 47, 48, 49, 90, 91, 92, 93, 94,
        50, 51, 52, 53, 54, 95, 96, 97, 98, 99])

In [None]:
final_x_neg

tensor([0.1074, 0.6619, 0.7425, 0.9371, 0.6984, 0.2325, 0.1373, 0.3756, 2.5719,
        3.9020, 2.8912, 3.7590, 0.8201, 3.1525, 2.4226, 2.6774, 3.5717, 2.2314,
        0.1470, 3.7014, 2.9701, 3.2478, 3.1834, 2.6148, 0.8892, 3.9609, 3.0724,
        2.1683, 2.6257, 2.2149, 0.8153, 3.0745, 3.6261, 2.4863, 2.6150, 0.7009,
        3.7412, 2.7562, 4.2832, 3.3289, 0.3136, 3.9927, 3.2993, 1.9825, 2.8375,
        0.0046, 0.5753, 2.8013, 2.1493, 2.5817, 0.2623, 0.9593, 0.4072, 0.2168,
        0.8892])

In [None]:
mask = torch.where(dummy_symmask==-1)

new_img = dummy_symmask.clone()
new_img[mask] = final_x_neg

new_img

tensor([[0.1074, 0.6619, 0.7425, 0.9371, 0.6984, 0.2325, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.1373, 0.3756, 2.5719, 3.9020, 2.8912, 3.7590, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.8201, 3.1525, 2.4226, 2.6774, 3.5717, 2.2314, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.1470, 3.7014, 2.9701, 3.2478, 3.1834, 2.6148, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.8892, 3.9609, 3.0724, 2.1683, 2.6257, 2.2149, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.8153, 3.0745, 3.6261, 2.4863, 2.6150, 0.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.7009, 3.7412, 2.7562, 4.2832, 3.3289, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.3136, 3.9927, 3.2993, 1.9825, 2.8375, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.0046, 0.5753, 2.8013, 2.1493, 2.5817, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000],
        [0.2623, 0.9593, 0.4072, 0.2168, 0.8892, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000]])

In [None]:
new_img.shape

torch.Size([10, 10])