In [22]:
import os
import glob
import time
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from externel import seaborn as sns
from scipy.ndimage import zoom

def bin_CT(img, n_bins=1024):
    data_vector = img
    data_max = np.amax(data_vector)
    data_min = np.amin(data_vector)
    data_squeezed = (data_vector-data_min)/(data_max-data_min)
    data_extended = data_squeezed * (n_bins-1)
    data_discrete = data_extended // 1
    return np.asarray(list(data_discrete), dtype=np.int64)

train_dict = {}
train_dict["time_stamp"] = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
train_dict["project_name"] = "pixel_correlation"
train_dict["save_folder"] = "./project_dir/"+train_dict["project_name"]+"/"

train_dict["folder_X"] = "./data_dir/norm_MR/discrete/"
train_dict["folder_Y"] = "./data_dir/norm_CT/discrete/"

X_list = sorted(glob.glob(train_dict["folder_X"]+"*.nii.gz"))
Y_list = sorted(glob.glob(train_dict["folder_Y"]+"*.nii.gz"))
print(Y_list[:5])

n_bin = 128

['./data_dir/norm_CT/discrete/NORM_001.nii.gz', './data_dir/norm_CT/discrete/NORM_002.nii.gz', './data_dir/norm_CT/discrete/NORM_003.nii.gz', './data_dir/norm_CT/discrete/NORM_004.nii.gz', './data_dir/norm_CT/discrete/NORM_005.nii.gz']


In [19]:
t2f_file = nib.load("./data_dir/unknown/T2F_004.nii.gz")
t2f_data = t2f_file.get_fdata()
print(t2f_data.shape, np.amax(t2f_data), np.amin(t2f_data))
t2f_data_bin = bin_CT(t2f_data, n_bins=128)
pred_file = nib.Nifti1Image(t2f_data_bin, t2f_file.affine, t2f_file.header)
pred_name = "./data_dir/unknown/T2F_004_bin.nii.gz"
nib.save(pred_file, pred_name)

(512, 512, 33) 3091.0 0.0


In [92]:
t1b_file = nib.load("./data_dir/unknown/T1B_006.nii.gz")
t1b_data = zoom(t1b_file.get_fdata(), zoom=(1,1,33/140))
print(t1b_data.shape, np.amax(t1b_data), np.amin(t1b_data))
t1b_data_bin = bin_CT(t1b_data, n_bins=128)
pred_file = nib.Nifti1Image(t1b_data_bin, t2f_file.affine, t2f_file.header)
pred_name = "./data_dir/unknown/T1B_006_bin.nii.gz"
nib.save(pred_file, pred_name)

(512, 512, 33) 4689.630655389193 -39.61693493514443


In [13]:
from scipy.special import kl_div
from sklearn.metrics import mean_squared_error as mse

def kl_div_scalar(X, Y):
    return np.sum(kl_div(X, Y))

In [131]:
ix = 216 // 16
iy = 325 // 16
iz = 17 // 3

cube_x = t2f_data_bin[ix*16:(ix+1)*16, iy*16:(iy+1)*16, iz*3:(iz+1)*3]
cube_y = t1b_data_bin[ix*16:(ix+1)*16, iy*16:(iy+1)*16, iz*3:(iz+1)*3]
t2f_replace[ix*16:(ix+1)*16, iy*16:(iy+1)*16, iz*3:(iz+1)*3] = spatial_pred(cube_x, cube_y)

23 45
24 45
25 45
26 45
27 45
28 45
29 45
30 45
31 45
32 45
33 45
34 45
35 45


In [130]:
def spatial_pred(cube_x, cube_y):
#     print(cube_x.shape)
    dist_x = np.zeros((cube_x.shape[0]*cube_x.shape[1]*cube_x.shape[2], n_bin))
    dist_y = np.zeros((cube_y.shape[0]*cube_y.shape[1]*cube_y.shape[2], n_bin))
    pred = np.zeros((cube_x.shape))
    
    flat_x = np.ravel(cube_x)
    elem_x = np.unique(flat_x)
    for cnt_elem, elem in enumerate(elem_x):
        loc_elem = np.where(flat_x == elem)[0]
        for idx_elem in loc_elem:
            dist_x[idx_elem, int(elem)] += 1

    flat_y = np.ravel(cube_y)
    elem_y = np.unique(flat_y)
    for cnt_elem, elem in enumerate(elem_y):
        loc_elem = np.where(flat_y == elem)[0]
        for idx_elem in loc_elem:
            dist_y[idx_elem, int(elem)] += 1

    nonzero_x = []
    nonzero_y = []
    for idx in range(n_bin):
        
        row_sum = np.sum(dist_x[:, idx])
        if row_sum > 0:
            dist_x[idx, :] /= row_sum
            nonzero_x.append(idx)
        
        row_sum = np.sum(dist_y[:, idx])
        if row_sum > 0:
            dist_y[idx, :] /= row_sum
            nonzero_y.append(idx)
    
#     print(nonzero_x, nonzero_y)
#     print(nonzero_x, nonzero_y)
#     print(dist_x, dist_y)
    for cnt_x, elem_x in enumerate(nonzero_x):
        xy_corr = []
        for cnt_y, elem_y in enumerate(nonzero_y):
            xy_corr.append(kl_div_scalar(dist_y[elem_y, :], dist_x[elem_y, :]))
#         print(xy_corr)
        xy_corr = np.asarray(xy_corr)
        loc_pred_ix = np.where(cube_x == elem_x)[0]
        loc_pred_iy = np.where(cube_x == elem_x)[1]
        loc_pred_iz = np.where(cube_x == elem_x)[2]
        replace_x = nonzero_y[int(np.where(xy_corr == np.amin(xy_corr))[0][0])]
        for idx in range(len(loc_pred_ix)):
#             if ix == 7:
#             print(nonzero_y, loc_pred_ix[idx], loc_pred_iy[idx], loc_pred_iz[idx], xy_corr)
            pred[loc_pred_ix[idx], loc_pred_iy[idx], loc_pred_iz[idx]] = replace_x
#             print("-"*60)
        print(elem_x, replace_x)
    return pred

In [126]:
t2f_replace = np.zeros((t2f_data.shape))
for ix in range(256 // 16):
    print(ix)
    for iy in range(256 // 16):
        for iz in range(33 // 3):
            cube_x = t2f_data_bin[ix*16:(ix+1)*16, iy*16:(iy+1)*16, iz*3:(iz+1)*3]
            cube_y = t1b_data_bin[ix*16:(ix+1)*16, iy*16:(iy+1)*16, iz*3:(iz+1)*3]
            t2f_replace[ix*16:(ix+1)*16, iy*16:(iy+1)*16, iz*3:(iz+1)*3] = spatial_pred(cube_x, cube_y)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15


In [127]:
pred_file = nib.Nifti1Image(t2f_replace, t2f_file.affine, t2f_file.header)
pred_name = "./data_dir/unknown/T2F_004_pred.nii.gz"
nib.save(pred_file, pred_name)

In [15]:
test = np.asarray([[[1,2,3],[4,5,6],[7,8,9]], [[4,5,6],[7,8,9],[1,2,3]]])
print(np.where(test == 2))
print(np.where(test == 2)[0])
print(np.where(test == 2)[1])
print(np.where(test == 2)[2])

(array([0, 1]), array([0, 2]), array([1, 1]))
[0 1]
[0 2]
[1 1]
