In [1]:
from tomoSegmentPipeline.utils.common import read_array, write_array
from tomoSegmentPipeline.utils import setup

import numpy as np
import matplotlib.pyplot as plt
import random
import mrcfile
import pandas as pd
import torch
from torch.utils.data import Dataset
import os
from glob import glob
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import normalized_mutual_information as nmi
from scipy import ndimage
from joblib import Parallel, delayed

PARENT_PATH = setup.PARENT_PATH
ISONET_PATH = os.path.join(PARENT_PATH, 'data/isoNet/')

%matplotlib inline
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [2]:
def clip_and_standardize(X, low=0.005, high=0.995, quantiles=True):
    if quantiles:
        X_clp = np.clip(X, np.quantile(X, low), np.quantile(X, high))
    else:
        X_clp = np.clip(X, low, high)
    X_stdz = (X_clp-np.mean(X_clp))/np.std(X_clp)
    
    return X_stdz

class destripeDataSet(Dataset):
    def __init__(self, path, l, normalize=True, logTransform=False):
        
        # TODO: add this to standalone destripe
        from tomoSegmentPipeline.utils.common import read_array
        
        # data is originally in ZYX form
        image_data = read_array(path)
        # we set data in YZX form
        image_data = image_data.transpose(1, 0, 2)
        
        if logTransform:
            image_data = np.log10(image_data)  
            
        if normalize:
            image_data = Parallel(n_jobs=1)(delayed(clip_and_standardize)(xz_plane) for xz_plane in image_data)
            image_data = np.array(image_data)
            # image_data = image_data - image_data.min()
            
        fft_img = np.array([np.fft.fftshift(np.fft.fft2(xz_plane)) for xz_plane in image_data]) # just run fft on the image information
        
        logPower_data = np.log(np.abs(fft_img)**2)
        logPower_data = Parallel(n_jobs=1)(delayed(clip_and_standardize)(xz_plane, low=0.001, high=0.999) for xz_plane in logPower_data)
        logPower_data = np.array(logPower_data)
        
        # all image data is in the form: YZX
        # data[:, 0, :, :] = image
        # data[:, 1, :, :] = outlier mask
        # data[:, 2, :, :] = weights
        data = image_data[:, np.newaxis, :, :]
        
        # data = io.loadmat('Data/simu-small-constant.mat')['datas'].transpose(2, 0, 1)[:, np.newaxis, :, :]
        #data = np.repeat(io.loadmat('Data/oneround.mat')['data'][np.newaxis, np.newaxis, :, :], 3, axis=0)          

        # get outliers according to the power criterion 
        outlier_mask = np.array([(xz_plane<np.quantile(xz_plane, 0.95)).astype(int) for xz_plane in logPower_data])
        outlier_mask = outlier_mask[:, np.newaxis, :, :] # set a consistent shape of the data
        
        # get proper shape of the weight data
        weight_matrix = logPower_data[:, np.newaxis, :, :]

        # concatenate stuff
        data = np.concatenate((data, outlier_mask), 1)  
        data = np.concatenate((data, weight_matrix), 1)   
        
        # why? Only for testing?
        data = data[l[0]:l[1], :, :, :]
        
        self.x_data = torch.from_numpy(data).float()
        self.y_data = torch.from_numpy(np.zeros((3, 1))).float()
        
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.x_data.size(0)

In [3]:
%%timeit
cet_path = os.path.join(ISONET_PATH, 'RAW_dataset/RAW_allTomos_deconv/tomo10.mrc')    

tmp = destripeDataSet(path = cet_path, l = [0, 1])

# print(tmp.x_data.size())
# print(tmp.y_data.size())
# print(tmp.y_data)
# plt.imshow(tmp.x_data[-1, 0, :, :])

1min 13s ± 789 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
