In [1]:
from dipy.io.image import load_nifti, save_nifti
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from warnings import warn
import time
from dipy.utils.optpkg import optional_package
from sklearn.utils.extmath import randomized_svd
import dipy.core.optimize as opt
import math

sklearn, has_sklearn, _ = optional_package('sklearn')
linear_model, _, _ = optional_package('sklearn.linear_model')

if not has_sklearn:
    w = "Scikit-Learn is required to denoise the data via Patch2Self."
    warn(w)
    

def _vol_split(train, vol_idx):
    """ Split the 3D volumes into the train and test set.

    Parameters
    ----------
    train : ndarray
        Array of all 3D patches flattened out to be 2D.

    vol_idx: int
        The volume number that needs to be held out for training.

    Returns
    --------
    cur_x : 2D-array (nvolumes*patch_size) x (nvoxels)
        Array of patches corresponding to all the volumes except for the
        held-out volume.

    y : 1D-array
        Array of patches corresponding to the volume that is used a target for
        denoising.
    """
    # Hold-out the target volume
    mask = np.zeros(train.shape[0])
    mask[vol_idx] = 1
    cur_x = train[mask == 0]
    cur_x = cur_x.reshape(((train.shape[0]-1)*train.shape[1],
                           train.shape[2]))

    # Center voxel of the selected block
    y = train[vol_idx, train.shape[1]//2, :]
    return cur_x, y

def count_sketch(matrixA, s):
    m, n = matrixA.shape
    matrixC = np.zeros([s, n])
    hashedIndices = np.random.choice(s, m, replace=True)
    # a m-by-1 {+1, -1} vector
    randSigns = np.random.choice(2, m, replace=True) * 2 - 1  
    
    # flip the signs of 50% rows of A
    matrixA = matrixA * randSigns.reshape(m, 1)  
    
    # this loop directly computes matrixC= S * matrixA
    for i in range(s):  
        idx = (hashedIndices == i)
        matrixC[i] = np.sum(matrixA[idx], 0)
    
    return matrixC[:, np.newaxis, :]


def _real_fft(matrixA):
    
    n_int = matrixA.shape[0]
    fft_mat = np.fft.fft(matrixA, n=None, axis=0) / np.sqrt(n_int)
    if n_int % 2 == 1:
        cutoff_int = int((n_int+1) / 2)
        idx_real_vec = list(range(1, cutoff_int))
        idx_imag_vec = list(range(cutoff_int, n_int))
    else:
        cutoff_int = int(n_int/2)
        idx_real_vec = list(range(1, cutoff_int))
        idx_imag_vec = list(range(cutoff_int+1, n_int))
    matrixC = fft_mat.real
    matrixC[idx_real_vec] *= np.sqrt(2)
    matrixC[idx_imag_vec] = fft_mat[idx_imag_vec].imag * np.sqrt(2)
    return matrixC[:, np.newaxis, :]

def deter_row_sample(matrixA, s):
    lev_scores = lev_exact(matrixA)
    idx_vec = np.argsort(lev_scores, axis=0)[::-1][range(s)]  # [lev_scores<0.003]
    matrixC = matrixA[idx_vec, :]
    return matrixC[:, np.newaxis, :]

def srft(matrixA, s):

    n_int = matrixA.shape[0]
    sign_vec = np.random.choice(2, n_int) * 2 - 1
    idx_vec = np.random.choice(n_int, s, replace=False)
    a_mat = sign_vec.reshape(n_int,1) * matrixA
    a_mat = _real_fft(matrixA)
    matrixC = matrixA[idx_vec] * np.sqrt(n_int / s)

    return matrixC[:, np.newaxis, :]

def lev_approx(matrixA, lev_sketch_type, lev_sketch_size=5):
    
    m, n = matrixA.shape
    s = int(n * lev_sketch_size)
    
    if lev_sketch_type == 'countsketch':
        matrixB = np.squeeze(count_sketch(matrixA, s))
        
    elif lev_sketch_type == 'srft':
        matrixB = np.squeeze(srft(matrixA, s))
        
    elif lev_sketch_type == 'uniform':
        idx_vec = np.random.choice(m, s, replace=False)
        matrixB = matrixA[idx_vec] * (m / s)
        print(matrixB.shape)
    
    elif lev_sketch_type == 'exact':
        matrixB = matrixA
    
    _, S, V = np.linalg.svd(matrixB, full_matrices=False)
    
    matrixT = V.T / S
    matrixY = np.dot(matrixA, matrixT)
    
    lev_vec = np.sum(matrixY ** 2, axis=1)
    return lev_vec

def ridge_lev_approx(matrixA, alpha):
    matrixA_alpha= np.concatenate((matrixA, np.sqrt(alpha)*np.identity(matrixA.shape[1])), axis=0)
    ridge_vec=lev_approx(matrixA_alpha)
    return ridge_vec[0:matrixA.shape[0]]

def row_sample(matrixA, s, prob_vec):
    m = matrixA.shape[0]
    prob_vec /= sum(prob_vec)
    idx_vec = np.random.choice(m, s, replace=False, p=prob_vec)
    scaling_vec = np.sqrt(s * prob_vec[idx_vec]) + 1e-10
    matrixC = matrixA[idx_vec] / scaling_vec.reshape(len(scaling_vec),1)
    return matrixC[:, np.newaxis, :]

def lev_exact(a_mat, low_rank=False):
    n_int = a_mat.shape[0]
    _ , _, v_mat = np.linalg.svd(a_mat.T, full_matrices=False)
    
    if low_rank:
        _, _, v_mat = randomized_svd(a_mat.T, 
                                     n_components=20,
                                     n_iter=5,
                                     random_state=None)
    lev_vec = np.sum(v_mat ** 2, axis=0)
    return lev_vec

def uniform_sampling(matrixA, s):
    m, n = matrixA.shape
    idx_vec = np.random.choice(m, s, replace=True)
    matrixC = matrixA[idx_vec]
    return matrixC[:, np.newaxis, :]

def sketch_data(matrixA, s, sketching_method, lev_sketch_type):
    
    if sketching_method == 'srft':
        return srft(matrixA, s)
    
    if sketching_method == 'uniform':
        return uniform_sampling(matrixA, s)
    
    if sketching_method == 'countsketch':
        return count_sketch(matrixA, s)
    
    if sketching_method == 'lev_deterministic':
        return deter_row_sample(matrixA, s)
    
    if sketching_method == 'leverage_scores':
        if lev_sketch_type == 'uniform':
            leverage_scores = lev_approx(matrixA, lev_sketch_type='uniform')
            return row_sample(matrixA, s, leverage_scores)
        
        elif lev_sketch_type == 'countsketch':
            leverage_scores = lev_approx(matrixA, lev_sketch_type='countsketch')
            return row_sample(matrixA, s, leverage_scores)
        
        elif lev_sketch_type == 'srft':
            leverage_scores = lev_approx(matrixA, lev_sketch_type='srft')
            return row_sample(matrixA, s, leverage_scores)
        
        elif lev_sketch_type == 'exact':
            leverage_scores = lev_exact(matrixA)
            return row_sample(matrixA, s, leverage_scores)

def _vol_denoise(train, sketched_train, vol_idx, 
                 model, data_shape, alpha):
    """ Denoise a single 3D volume using a train and test phase.

    Parameters
    ----------
    train : ndarray
        Array of all 3D patches flattened out to be 2D.

    vol_idx : int
        The volume number that needs to be held out for training.

    model : string, or initialized linear model object.
            This will determine the algorithm used to solve the set of linear
            equations underlying this model. If it is a string it needs to be
            one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
            it can be an object that inherits from
            `dipy.optimize.SKLearnLinearSolver` or an object with a similar
            interface from Scikit-Learn:
            `sklearn.linear_model.LinearRegression`,
            `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
            and other objects that inherit from `sklearn.base.RegressorMixin`.
            Default: 'ridge'.

    data_shape : ndarray
        The 4D shape of noisy DWI data to be denoised.

    alpha : float, optional
        Regularization parameter only for ridge and lasso regression models.
        default: 1.0

    Returns
    --------
    model prediction : ndarray
        Denoised array of all 3D patches flattened out to be 2D corresponding
        to the held out volume `vol_idx`.

    """
    # To add a new model, use the following API
    # We adhere to the following options as they are used for comparisons
    if model.lower() == 'ols':
        model = linear_model.LinearRegression(copy_X=False)

    elif model.lower() == 'ridge':
        model = linear_model.Ridge(copy_X=False, alpha=alpha, solver='lsqr')

    elif model.lower() == 'lasso':
        model = linear_model.Lasso(copy_X=False, max_iter=50, alpha=alpha)

    elif (isinstance(model, opt.SKLearnLinearSolver) or
          has_sklearn and isinstance(model, sklearn.base.RegressorMixin)):
        model = model

    else:
        e_s = "The `solver` key-word argument needs to be: "
        e_s += "'ols', 'ridge', 'lasso' or a "
        e_s += "`dipy.optimize.SKLearnLinearSolver` object"
        raise ValueError(e_s)

    cur_x, y = _vol_split(train, vol_idx)
    r_cur_x, r_cur_y = _vol_split(sketched_train.T, vol_idx)
    
    model.fit(r_cur_x.T, r_cur_y.T)
    coefs = model.coef_

    return model.predict(cur_x.T).reshape(data_shape[0], data_shape[1],
                                          data_shape[2]), coefs

def _extract_3d_patches(arr, patch_radius):
    """ Extract 3D patches from 4D DWI data.

    Parameters
    ----------
    arr : ndarray
        The 4D noisy DWI data to be denoised.

    patch_radius : int or 1D array
        The radius of the local patch to be taken around each voxel (in
        voxels).

    Returns
    --------
    all_patches : ndarray
        All 3D patches flattened out to be 2D corresponding to the each 3D
        volume of the 4D DWI data.

    """
    if isinstance(patch_radius, int):
        patch_radius = np.ones(3, dtype=int) * patch_radius
    if len(patch_radius) != 3:
        raise ValueError("patch_radius should have length 3")
    else:
        patch_radius = np.asarray(patch_radius, dtype=int)
    patch_size = 2 * patch_radius + 1

    dim = arr.shape[-1]

    all_patches = []

    # loop around and find the 3D patch for each direction
    for i in range(patch_radius[0], arr.shape[0] -
                   patch_radius[0], 1):
        for j in range(patch_radius[1], arr.shape[1] -
                       patch_radius[1], 1):
            for k in range(patch_radius[2], arr.shape[2] -
                           patch_radius[2], 1):

                ix1 = i - patch_radius[0]
                ix2 = i + patch_radius[0] + 1
                jx1 = j - patch_radius[1]
                jx2 = j + patch_radius[1] + 1
                kx1 = k - patch_radius[2]
                kx2 = k + patch_radius[2] + 1

                X = arr[ix1:ix2, jx1:jx2,
                        kx1:kx2].reshape(np.prod(patch_size), dim)
                
                # Random projection on the X
                all_patches.append(X)

    return np.array(all_patches).T


def patch2self(data, bvals, patch_radius=[0, 0, 0], model='ridge',
               b0_threshold=50, out_dtype=None, alpha=1.0, 
               verbose=False, sketching_method='srft', sketch_size=2000, 
               lev_sketch_type='uniform'):
    
    """ Patch2Self Denoiser

    Parameters
    ----------
    data : ndarray
        The 4D noisy DWI data to be denoised.

    bvals : 1D array
        Array of the bvals from the DWI acquisition

    patch_radius : int or 1D array, optional
        The radius of the local patch to be taken around each voxel (in
        voxels). Default: 0 (denoise in blocks of 1x1x1 voxels).

    model : string, or initialized linear model object.
            This will determine the algorithm used to solve the set of linear
            equations underlying this model. If it is a string it needs to be
            one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
            it can be an object that inherits from
            `dipy.optimize.SKLearnLinearSolver` or an object with a similar
            interface from Scikit-Learn:
            `sklearn.linear_model.LinearRegression`,
            `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
            and other objects that inherit from `sklearn.base.RegressorMixin`.
            Default: 'ridge'.

    b0_threshold : int, optional
        Threshold for considering volumes as b0.

    out_dtype : str or dtype, optional
        The dtype for the output array. Default: output has the same dtype as
        the input.

    alpha : float, optional
        Regularization parameter only for ridge regression model.
        default: 1.0

    verbose : bool, optional
        Show progress of Patch2Self and time taken.

    Returns
    --------
    denoised array : ndarray
        This is the denoised array of the same size as that of the input data,
        clipped to non-negative values.

    References
    ----------

    [Fadnavis20] S. Fadnavis, J. Batson, E. Garyfallidis, Patch2Self:
                    Denoising Diffusion MRI with Self-supervised Learning,
                    Advances in Neural Information Processing Systems 33 (2020)
    """
    patch_radius = np.asarray(patch_radius, dtype=np.int)

    if not data.ndim == 4:
        raise ValueError("Patch2Self can only denoise on 4D arrays.",
                         data.shape)

    if data.shape[3] < 10:
        warn("The intput data has less than 10 3D volumes. Patch2Self may not",
             "give denoising performance.")

    if out_dtype is None:
        out_dtype = data.dtype

    # We retain float64 precision, iff the input is in this precision:
    if data.dtype == np.float64:
        calc_dtype = np.float64

    # Otherwise, we'll calculate things in float32 (saving memory)
    else:
        calc_dtype = np.float32

    # Segregates volumes by b0 threshold
    b0_idx = np.argwhere(bvals <= b0_threshold)
    dwi_idx = np.argwhere(bvals > b0_threshold)

    data_b0s = np.squeeze(np.take(data, b0_idx, axis=3))
    data_dwi = np.squeeze(np.take(data, dwi_idx, axis=3))

    # create empty arrays
    denoised_b0s = np.empty((data_b0s.shape), dtype=calc_dtype)
    denoised_dwi = np.empty((data_dwi.shape), dtype=calc_dtype)

    denoised_arr = np.empty((data.shape), dtype=calc_dtype)

    if verbose is True:
        t1 = time.time()

    # Separate denoising for DWI volumes
    train_dwi = _extract_3d_patches(np.pad(data_dwi, ((patch_radius[0],
                                                       patch_radius[0]),
                                                      (patch_radius[1],
                                                       patch_radius[1]),
                                                      (patch_radius[2],
                                                       patch_radius[2]),
                                                      (0, 0)), mode='constant'),
                                    patch_radius=patch_radius)
    
    levs = lev_exact(np.squeeze(train_dwi).T)
    sketched_train_dwi = sketch_data(np.squeeze(train_dwi).T, 
                                     sketching_method=sketching_method, 
                                     s=sketch_size, lev_sketch_type=lev_sketch_type)
    
    print(sketched_train_dwi.shape)

    # Insert the separately denoised arrays into the respective empty arrays
    coef_list = []
    for vol_idx in range(60, 61):
        denoised_dwi[..., vol_idx] = _vol_denoise(train_dwi,
                                                  sketched_train_dwi,
                                                  vol_idx, model,
                                                  data_dwi.shape,
                                                  alpha=alpha)[0]
        coefs = _vol_denoise(train_dwi,
                             sketched_train_dwi,
                             vol_idx, model,
                             data_dwi.shape,
                             alpha=alpha)[1]
        coef_list.append(coefs)

        if verbose is True:
            print("Denoised DWI Volume: ", vol_idx)

    if verbose is True:
        t2 = time.time()
        print('Total time taken for Patch2Self: ', t2-t1, " seconds")

    if data_b0s.ndim == 3:
        denoised_arr[:, :, :, b0_idx[0][0]] = denoised_b0s
    else:
        for i, idx in enumerate(b0_idx):
            denoised_arr[:, :, :, idx[0]] = np.squeeze(denoised_b0s[..., i])

    for i, idx in enumerate(dwi_idx):
        denoised_arr[:, :, :, idx[0]] = np.squeeze(denoised_dwi[..., i])

    # clip out the negative values from the denoised output
    denoised_arr.clip(min=0, out=denoised_arr)

    return np.array(denoised_arr, dtype=out_dtype), coef_list

In [2]:
data, affine = load_nifti('dwi_corr.nii.gz')
bval = np.loadtxt('corr.bval')

In [8]:
from numpy import mean

list_sketch_size = [500, 1000, 5000, 10000, 20000, 30000, 40000, 50000, 
                    60000, 70000, 80000, 100000]#, 200000, 300000]
#                     , 400000,
#                     500000, 600000, 700000, 800000, 900000]
coef_list = []
sketch_name_list = []

for j in range(0, 10):
    print('This is ITERATION # ', j)

    for i in list_sketch_size:

        _, coefs_deter = patch2self(data, bval, verbose=True, model='ols', patch_radius=[0, 0, 0],
                                  b0_threshold=50, sketch_size=i, sketching_method='lev_deterministic')
        coef_list.append(coefs_deter)
        print('Deterministic Leverage - ', i)
        sketch_name_list.append('Deterministic Leverage')

        _, coefs_uni = patch2self(data, bval, verbose=True, model='ols', patch_radius=[0, 0, 0],
                                  b0_threshold=50, sketch_size=i, sketching_method='uniform')
        coef_list.append(coefs_uni)
        print('Uniform - ', i)
        sketch_name_list.append('Uniform')

        _, coefs_lev_exact = patch2self(data, bval, verbose=True, model='ols', patch_radius=[0, 0, 0],
                                        b0_threshold=50, sketch_size=i, 
                                        sketching_method='leverage_scores', 
                                        lev_sketch_type='exact')
        coef_list.append(coefs_lev_exact)
        print('Leverage Scores: Exact - ', i)
        sketch_name_list.append('Leverage Scores: Exact')

        _, coefs_cs = patch2self(data, bval, verbose=True, model='ols', patch_radius=[0, 0, 0],
                                 b0_threshold=50, sketch_size=i, sketching_method='countsketch')
        coef_list.append(coefs_cs)
        print('CS - ', i)
        sketch_name_list.append('Count Sketching')

        _, coefs_srft = patch2self(data, bval, verbose=True, model='ols', patch_radius=[0, 0, 0],
                                   b0_threshold=50, sketch_size=i, sketching_method='srft')
        coef_list.append(coefs_srft)
        print('SRFT - ', i)
        sketch_name_list.append('Subsampled Randomized Fourier Transform')

This is ITERATION #  0
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  49.5635712146759  seconds
Deterministic Leverage -  500
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  28.839262008666992  seconds
Uniform -  500
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  49.57374167442322  seconds
Leverage Scores: Exact -  500
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  30.380274772644043  seconds
CS -  500
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  37.19846820831299  seconds
SRFT -  500
(1000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  49.95524621009827  seconds
Deterministic Leverage -  1000
(1000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  29.973685026168823  seconds
Uniform -  1000
(1000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.12030100822449  seconds
Leverage Scores: Exact -  1000
(1000

Uniform -  1000
(1000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.207077980041504  seconds
Leverage Scores: Exact -  1000
(1000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  31.841331481933594  seconds
CS -  1000
(1000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.11418318748474  seconds
SRFT -  1000
(5000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  47.41963171958923  seconds
Deterministic Leverage -  5000
(5000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  29.54510498046875  seconds
Uniform -  5000
(5000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  43.47941303253174  seconds
Leverage Scores: Exact -  5000
(5000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  38.9629340171814  seconds
CS -  5000
(5000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.376044511795044  seconds
SRFT -  5000
(10000, 1, 64)
Denoised D

CS -  5000
(5000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.40615630149841  seconds
SRFT -  5000
(10000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  49.698490858078  seconds
Deterministic Leverage -  10000
(10000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  29.663562536239624  seconds
Uniform -  10000
(10000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  49.72949028015137  seconds
Leverage Scores: Exact -  10000
(10000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  47.641270875930786  seconds
CS -  10000
(10000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  35.67872381210327  seconds
SRFT -  10000
(20000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  47.67682433128357  seconds
Deterministic Leverage -  20000
(20000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  29.869622945785522  seconds
Uniform -  20000
(20000, 1, 6

(20000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  28.13305425643921  seconds
Uniform -  20000
(20000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.56255388259888  seconds
Leverage Scores: Exact -  20000
(20000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  64.33806204795837  seconds
CS -  20000
(20000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.276287317276  seconds
SRFT -  20000
(30000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.2779905796051  seconds
Deterministic Leverage -  30000
(30000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  29.652067184448242  seconds
Uniform -  30000
(30000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.2645468711853  seconds
Leverage Scores: Exact -  30000
(30000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  81.48917984962463  seconds
CS -  30000
(30000, 1, 64)
Denoised DWI

(30000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  82.08690786361694  seconds
CS -  30000
(30000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.11773109436035  seconds
SRFT -  30000
(40000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  47.902870178222656  seconds
Deterministic Leverage -  40000
(40000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  30.039796590805054  seconds
Uniform -  40000
(40000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.82097554206848  seconds
Leverage Scores: Exact -  40000
(40000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  99.19624209403992  seconds
CS -  40000
(40000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  37.17706751823425  seconds
SRFT -  40000
(50000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.00001907348633  seconds
Deterministic Leverage -  50000
(50000, 1, 64)
Denoised D

(50000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.47942638397217  seconds
Deterministic Leverage -  50000
(50000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  31.203718662261963  seconds
Uniform -  50000
(50000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  49.914127826690674  seconds
Leverage Scores: Exact -  50000
(50000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  117.1313681602478  seconds
CS -  50000
(50000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.16454362869263  seconds
SRFT -  50000
(60000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  52.20329213142395  seconds
Deterministic Leverage -  60000
(60000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  30.105817556381226  seconds
Uniform -  60000
(60000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  51.344783544540405  seconds
Leverage Scores: Exact -  60000


Uniform -  60000
(60000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  51.12195801734924  seconds
Leverage Scores: Exact -  60000
(60000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  133.8637306690216  seconds
CS -  60000
(60000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  37.06312155723572  seconds
SRFT -  60000
(70000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.22718262672424  seconds
Deterministic Leverage -  70000
(70000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  30.798282384872437  seconds
Uniform -  70000
(70000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.45554566383362  seconds
Leverage Scores: Exact -  70000
(70000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  151.39943099021912  seconds
CS -  70000
(70000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.99989104270935  seconds
SRFT -  70000
(80000, 

CS -  70000
(70000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.474218130111694  seconds
SRFT -  70000
(80000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.83392524719238  seconds
Deterministic Leverage -  80000
(80000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  30.380988836288452  seconds
Uniform -  80000
(80000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.67492079734802  seconds
Leverage Scores: Exact -  80000
(80000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  168.4823911190033  seconds
CS -  80000
(80000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  36.68785262107849  seconds
SRFT -  80000
(100000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.69064950942993  seconds
Deterministic Leverage -  100000
(100000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  30.786225080490112  seconds
Uniform -  100000
(10

Deterministic Leverage -  100000
(100000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  31.16933298110962  seconds
Uniform -  100000
(100000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.47606611251831  seconds
Leverage Scores: Exact -  100000
(100000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  204.11389803886414  seconds
CS -  100000
(100000, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  37.70287251472473  seconds
SRFT -  100000
This is ITERATION #  9
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  50.13383603096008  seconds
Deterministic Leverage -  500
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  30.506011962890625  seconds
Uniform -  500
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  47.93022441864014  seconds
Leverage Scores: Exact -  500
(500, 1, 64)
Denoised DWI Volume:  60
Total time taken for Patch2Self:  30.99089288

In [9]:
from dipy.io.image import load_nifti, save_nifti
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from warnings import warn
import time
from dipy.utils.optpkg import optional_package
import dipy.core.optimize as opt
import math

sklearn, has_sklearn, _ = optional_package('sklearn')
linear_model, _, _ = optional_package('sklearn.linear_model')

if not has_sklearn:
    w = "Scikit-Learn is required to denoise the data via Patch2Self."
    warn(w)
    

def _vol_split(train, vol_idx):
    """ Split the 3D volumes into the train and test set.

    Parameters
    ----------
    train : ndarray
        Array of all 3D patches flattened out to be 2D.

    vol_idx: int
        The volume number that needs to be held out for training.

    Returns
    --------
    cur_x : 2D-array (nvolumes*patch_size) x (nvoxels)
        Array of patches corresponding to all the volumes except for the
        held-out volume.

    y : 1D-array
        Array of patches corresponding to the volume that is used a target for
        denoising.
    """
    # Hold-out the target volume
    mask = np.zeros(train.shape[0])
    mask[vol_idx] = 1
    cur_x = train[mask == 0]
    cur_x = cur_x.reshape(((train.shape[0]-1)*train.shape[1],
                           train.shape[2]))

    # Center voxel of the selected block
    y = train[vol_idx, train.shape[1]//2, :]
    return cur_x, y

def count_sketch(matrixA, s):
    m, n = matrixA.shape
    matrixC = np.zeros([s, n])
    hashedIndices = np.random.choice(s, m, replace=True)
    # a m-by-1 {+1, -1} vector
    randSigns = np.random.choice(2, m, replace=True) * 2 - 1  
    
    # flip the signs of 50% rows of A
    matrixA = matrixA * randSigns.reshape(m, 1)  
    
    # this loop directly computes matrixC= S * matrixA
    for i in range(s):  
        idx = (hashedIndices == i)
        matrixC[i] = np.sum(matrixA[idx], 0)
    
    return matrixC[:, np.newaxis, :]


def _real_fft(matrixA):
    
    n_int = matrixA.shape[0]
    fft_mat = np.fft.fft(matrixA, n=None, axis=0) / np.sqrt(n_int)
    if n_int % 2 == 1:
        cutoff_int = int((n_int+1) / 2)
        idx_real_vec = list(range(1, cutoff_int))
        idx_imag_vec = list(range(cutoff_int, n_int))
    else:
        cutoff_int = int(n_int/2)
        idx_real_vec = list(range(1, cutoff_int))
        idx_imag_vec = list(range(cutoff_int+1, n_int))
    matrixC = fft_mat.real
    matrixC[idx_real_vec] *= np.sqrt(2)
    matrixC[idx_imag_vec] = fft_mat[idx_imag_vec].imag * np.sqrt(2)
    return matrixC[:, np.newaxis, :]


def srft(matrixA, s):

    n_int = matrixA.shape[0]
    sign_vec = np.random.choice(2, n_int) * 2 - 1
    idx_vec = np.random.choice(n_int, s, replace=False)
    a_mat = sign_vec.reshape(n_int,1) * matrixA
    a_mat = _real_fft(matrixA)
    matrixC = matrixA[idx_vec] * np.sqrt(n_int / s)

    return matrixC[:, np.newaxis, :]

def lev_approx(matrixA, lev_sketch_size=5, lev_sketch_type='uniform'):
    m, n = matrixA.shape
    s = int(n * lev_sketch_size)
    
    if lev_sketch_type == 'countsketch':
        matrixB = count_sketch(matrixA, s)
        
    elif lev_sketch_type == 'srft':
        matrixB = srft(matrixA, s)
        
    elif lev_sketch_type == 'uniform':
        idx_vec = np.random.choice(m, s, replace=False)
        matrixB = matrixA[idx_vec] * (m / s)
        
    _, S, V = np.linalg.svd(matrixB, full_matrices=False)
    
    matrixT = V.T / S
    matrixY = np.dot(matrixA, matrixT)
    
    lev_vec = np.sum(matrixY ** 2, axis=1)
    return lev_vec

def ridge_lev_approx(matrixA, alpha):
    matrixA_alpha= np.concatenate((matrixA, np.sqrt(alpha)*np.identity(matrixA.shape[1])), axis=0)
    ridge_vec=lev_approx(matrixA_alpha)
    return ridge_vec[0:matrixA.shape[0]]

def row_sample(matrixA, s, prob_vec):
    m = matrixA.shape[0]
    prob_vec /= sum(prob_vec)
    idx_vec = np.random.choice(m, s, replace=False, p=prob_vec)
    scaling_vec = np.sqrt(s * prob_vec[idx_vec]) + 1e-10
    matrixC = matrixA[idx_vec] / scaling_vec.reshape(len(scaling_vec),1)
    return matrixC[:, np.newaxis, :]

def sketch_data(matrixA, s, sketching_method):
    
    if sketching_method == 'srft':
        return srft(matrixA, s)
    
    if sketching_method == 'countsketch':
        print(count_sketch(matrixA, s).shape)
        return count_sketch(matrixA, s)
    
    if sketching_method == 'leverage_scores':
        leverage_scores = lev_approx(matrixA)
        print(row_sample(matrixA, s, leverage_scores).shape)
        return row_sample(matrixA, s, leverage_scores)


def _vol_denoise(train, sketched_train, vol_idx, 
                 model, data_shape, alpha):
    """ Denoise a single 3D volume using a train and test phase.

    Parameters
    ----------
    train : ndarray
        Array of all 3D patches flattened out to be 2D.

    vol_idx : int
        The volume number that needs to be held out for training.

    model : string, or initialized linear model object.
            This will determine the algorithm used to solve the set of linear
            equations underlying this model. If it is a string it needs to be
            one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
            it can be an object that inherits from
            `dipy.optimize.SKLearnLinearSolver` or an object with a similar
            interface from Scikit-Learn:
            `sklearn.linear_model.LinearRegression`,
            `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
            and other objects that inherit from `sklearn.base.RegressorMixin`.
            Default: 'ridge'.

    data_shape : ndarray
        The 4D shape of noisy DWI data to be denoised.

    alpha : float, optional
        Regularization parameter only for ridge and lasso regression models.
        default: 1.0

    Returns
    --------
    model prediction : ndarray
        Denoised array of all 3D patches flattened out to be 2D corresponding
        to the held out volume `vol_idx`.

    """
    # To add a new model, use the following API
    # We adhere to the following options as they are used for comparisons
    if model.lower() == 'ols':
        model = linear_model.LinearRegression(copy_X=False)

    elif model.lower() == 'ridge':
        model = linear_model.Ridge(copy_X=False, alpha=alpha, solver='lsqr')

    elif model.lower() == 'lasso':
        model = linear_model.Lasso(copy_X=False, max_iter=50, alpha=alpha)

    elif (isinstance(model, opt.SKLearnLinearSolver) or
          has_sklearn and isinstance(model, sklearn.base.RegressorMixin)):
        model = model

    else:
        e_s = "The `solver` key-word argument needs to be: "
        e_s += "'ols', 'ridge', 'lasso' or a "
        e_s += "`dipy.optimize.SKLearnLinearSolver` object"
        raise ValueError(e_s)

    cur_x, y = _vol_split(train, vol_idx)
    r_cur_x, r_cur_y = _vol_split(sketched_train.T, vol_idx)
    
    model.fit(cur_x.T, y.T)
    coefs = model.coef_

    return model.predict(cur_x.T).reshape(data_shape[0], data_shape[1],
                                          data_shape[2]), coefs

def _extract_3d_patches(arr, patch_radius):
    """ Extract 3D patches from 4D DWI data.

    Parameters
    ----------
    arr : ndarray
        The 4D noisy DWI data to be denoised.

    patch_radius : int or 1D array
        The radius of the local patch to be taken around each voxel (in
        voxels).

    Returns
    --------
    all_patches : ndarray
        All 3D patches flattened out to be 2D corresponding to the each 3D
        volume of the 4D DWI data.

    """
    if isinstance(patch_radius, int):
        patch_radius = np.ones(3, dtype=int) * patch_radius
    if len(patch_radius) != 3:
        raise ValueError("patch_radius should have length 3")
    else:
        patch_radius = np.asarray(patch_radius, dtype=int)
    patch_size = 2 * patch_radius + 1

    dim = arr.shape[-1]

    all_patches = []

    # loop around and find the 3D patch for each direction
    for i in range(patch_radius[0], arr.shape[0] -
                   patch_radius[0], 1):
        for j in range(patch_radius[1], arr.shape[1] -
                       patch_radius[1], 1):
            for k in range(patch_radius[2], arr.shape[2] -
                           patch_radius[2], 1):

                ix1 = i - patch_radius[0]
                ix2 = i + patch_radius[0] + 1
                jx1 = j - patch_radius[1]
                jx2 = j + patch_radius[1] + 1
                kx1 = k - patch_radius[2]
                kx2 = k + patch_radius[2] + 1

                X = arr[ix1:ix2, jx1:jx2,
                        kx1:kx2].reshape(np.prod(patch_size), dim)
                
                # Random projection on the X
                all_patches.append(X)

    return np.array(all_patches).T


def patch2self(data, bvals, patch_radius=[0, 0, 0], model='ridge',
               b0_threshold=50, out_dtype=None, alpha=1.0, 
               verbose=False, sketching_method='srft', sketch_size=2000):
    """ Patch2Self Denoiser

    Parameters
    ----------
    data : ndarray
        The 4D noisy DWI data to be denoised.

    bvals : 1D array
        Array of the bvals from the DWI acquisition

    patch_radius : int or 1D array, optional
        The radius of the local patch to be taken around each voxel (in
        voxels). Default: 0 (denoise in blocks of 1x1x1 voxels).

    model : string, or initialized linear model object.
            This will determine the algorithm used to solve the set of linear
            equations underlying this model. If it is a string it needs to be
            one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
            it can be an object that inherits from
            `dipy.optimize.SKLearnLinearSolver` or an object with a similar
            interface from Scikit-Learn:
            `sklearn.linear_model.LinearRegression`,
            `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
            and other objects that inherit from `sklearn.base.RegressorMixin`.
            Default: 'ridge'.

    b0_threshold : int, optional
        Threshold for considering volumes as b0.

    out_dtype : str or dtype, optional
        The dtype for the output array. Default: output has the same dtype as
        the input.

    alpha : float, optional
        Regularization parameter only for ridge regression model.
        default: 1.0

    verbose : bool, optional
        Show progress of Patch2Self and time taken.

    Returns
    --------
    denoised array : ndarray
        This is the denoised array of the same size as that of the input data,
        clipped to non-negative values.

    References
    ----------

    [Fadnavis20] S. Fadnavis, J. Batson, E. Garyfallidis, Patch2Self:
                    Denoising Diffusion MRI with Self-supervised Learning,
                    Advances in Neural Information Processing Systems 33 (2020)
    """
    patch_radius = np.asarray(patch_radius, dtype=np.int)

    if not data.ndim == 4:
        raise ValueError("Patch2Self can only denoise on 4D arrays.",
                         data.shape)

    if data.shape[3] < 10:
        warn("The intput data has less than 10 3D volumes. Patch2Self may not",
             "give denoising performance.")

    if out_dtype is None:
        out_dtype = data.dtype

    # We retain float64 precision, iff the input is in this precision:
    if data.dtype == np.float64:
        calc_dtype = np.float64

    # Otherwise, we'll calculate things in float32 (saving memory)
    else:
        calc_dtype = np.float32

    # Segregates volumes by b0 threshold
    b0_idx = np.argwhere(bvals <= b0_threshold)
    dwi_idx = np.argwhere(bvals > b0_threshold)

    data_b0s = np.squeeze(np.take(data, b0_idx, axis=3))
    data_dwi = np.squeeze(np.take(data, dwi_idx, axis=3))

    # create empty arrays
    denoised_b0s = np.empty((data_b0s.shape), dtype=calc_dtype)
    denoised_dwi = np.empty((data_dwi.shape), dtype=calc_dtype)

    denoised_arr = np.empty((data.shape), dtype=calc_dtype)

    if verbose is True:
        t1 = time.time()

    # Separate denoising for DWI volumes
    train_dwi = _extract_3d_patches(np.pad(data_dwi, ((patch_radius[0],
                                                       patch_radius[0]),
                                                      (patch_radius[1],
                                                       patch_radius[1]),
                                                      (patch_radius[2],
                                                       patch_radius[2]),
                                                      (0, 0)), mode='constant'),
                                    patch_radius=patch_radius)
    
    sketched_train_dwi = sketch_data(np.squeeze(train_dwi).T, 
                                     sketching_method=sketching_method, 
                                     s=sketch_size)
    

    # Insert the separately denoised arrays into the respective empty arrays
    coef_list = []
    for vol_idx in range(60, 61):
        denoised_dwi[..., vol_idx] = _vol_denoise(train_dwi,
                                                  sketched_train_dwi,
                                                  vol_idx, model,
                                                  data_dwi.shape,
                                                  alpha=alpha)[0]
        coefs = _vol_denoise(train_dwi,
                             sketched_train_dwi,
                             vol_idx, model,
                             data_dwi.shape,
                             alpha=alpha)[1]
        coef_list.append(coefs)

        if verbose is True:
            print("Denoised DWI Volume: ", vol_idx)

    if verbose is True:
        t2 = time.time()
        print('Total time taken for Patch2Self: ', t2-t1, " seconds")

    if data_b0s.ndim == 3:
        denoised_arr[:, :, :, b0_idx[0][0]] = denoised_b0s
    else:
        for i, idx in enumerate(b0_idx):
            denoised_arr[:, :, :, idx[0]] = np.squeeze(denoised_b0s[..., i])

    for i, idx in enumerate(dwi_idx):
        denoised_arr[:, :, :, idx[0]] = np.squeeze(denoised_dwi[..., i])

    # clip out the negative values from the denoised output
    denoised_arr.clip(min=0, out=denoised_arr)

    return np.array(denoised_arr, dtype=out_dtype), coef_list

In [10]:
from numpy import mean

_, coefs_ols_list = patch2self(data, bval, verbose=True, model='ols', 
                               patch_radius=[0, 0, 0],
                               b0_threshold=50, sketch_size=i)
# coefs_ols = mean(coefs_ols_list)

Denoised DWI Volume:  60
Total time taken for Patch2Self:  24.77943468093872  seconds


In [15]:
import scipy
from numpy import mean

list_acc = []
for i in range(0, len(list_sketch_size)*50):
    diff = scipy.linalg.norm(np.array(coefs_ols_list)-
                             np.array(coef_list[i]), ord=2) / scipy.linalg.norm(np.array(coefs_ols_list), ord=2)
    list_acc.append(diff)
    
from sklearn.preprocessing import minmax_scale

rel_err = minmax_scale(np.array(list_acc), feature_range=(0,1))

In [16]:
sketch_name_list = ['Randomized Leverage Scores' if x=='Leverage Scores: Exact' else x for x in sketch_name_list]

In [22]:
list_sketch_size_str = list(np.concatenate((np.repeat('0.5k', 5), np.repeat('1K', 5), np.repeat('5K', 5), 
                            np.repeat('10K', 5), np.repeat('20K', 5), np.repeat('30K', 5), np.repeat('40K', 5),
                            np.repeat('50K', 5), np.repeat('60K', 5), np.repeat('70K', 5), np.repeat('80K', 5),
                            np.repeat('100K', 5))))

list_sketch_size_str = list_sketch_size_str*10

In [23]:
len(list_sketch_size_str)

600

In [24]:
import pandas as pd

df_diff = pd.DataFrame({'Error':rel_err,
                        'Sketching Dimension':list_sketch_size_str,
                        'Sketching Method':sketch_name_list})

In [None]:
%matplotlib qt
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="whitegrid")

ax = sns.pointplot(data=df_diff, x='Sketching Dimension', y='Error', hue='Sketching Method', 
                   palette="Set2", ci=95, markers=['*', 's', 'o', 'x', 'D', '^', 'X', 'd'], 
                   linestyles=['-', '-', '-', '-', '-', '-', '-', '-', '-'], scale=.8, join=True, )

In [40]:
%matplotlib qt
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style='white')

ax = sns.lineplot(data=df_diff, x='Sketching Dimension', y='Error', hue='Sketching Method', 
                  ci='sd', size_order='Error', sort=False, markers=True, style='Sketching Method')
figure = ax.get_figure()    
figure.savefig('errors_sketching.png', dpi=1000)

In [37]:
import pickle

f = open("dict_p2s_err_plot.pkl","wb")
pickle.dump(df_diff,f)
f.close()

In [38]:
import pandas as pd

object = pd.read_pickle(r'dict_p2s_err_plot.pkl')