In [None]:
import mne
import nibabel as nib
from skimage import io
from matplotlib import pyplot as plt
from scipy import signal, stats

import numpy as np
import os

from IPython.display import clear_output
%matplotlib inline

In [None]:
checkpoint_name = './reports/low_net3D_1_all/'

In [None]:
def get_stats(man):
    base_results = np.load('{}/{}/base_losses.npy'.format(checkpoint_name, man))
    net_results = np.load('{}/{}/net_losses.npy'.format(checkpoint_name, man))
    frame_index_list = np.load('{}/{}/frame_index_list.npy'.format(checkpoint_name, man))
    slice_index_list = np.load('{}/{}/slice_index_list.npy'.format(checkpoint_name, man))
    total_sum_of_squares = np.load('{}/{}/base_test_mean_losses.npy'.format(checkpoint_name, man))
    residual_loss_list = np.load('{}/{}/residual_loss_list.npy'.format(checkpoint_name, man))
    
    net_predictions = np.load('{}/{}/net_predictions.npy'.format(checkpoint_name, man))
    gt_predictions = np.load('{}/{}/gt_predictions.npy'.format(checkpoint_name, man))
    
    
    base_loss = np.sum(base_results)
    net_loss = np.sum(net_results)
    r2 = 1 - net_loss/base_loss
    
    slice_losses = np.zeros([30, 64, 64])
    slice_base_losses = np.zeros([30, 64, 64])
    slice_total_squares = np.zeros([30, 64, 64])
    residual_loss_num = np.zeros([30, 64, 64])
    residual_loss_denum = np.zeros([30, 64, 64])
    for i, slice_indx in enumerate(slice_index_list):
        slice_losses[slice_indx] += net_results[i] 
        slice_base_losses[slice_indx] += base_results[i]
        slice_total_squares[slice_indx] += total_sum_of_squares[i]
        residual_loss_num[slice_indx] += (net_predictions[i] - gt_predictions[i])**2
        residual_loss_denum[slice_indx] += (gt_predictions[i])**2

    for i in range(30):
        slice_losses[i] /= np.sum(slice_index_list == i)
        slice_base_losses[i] /= np.sum(slice_index_list == i)
        slice_total_squares[i] /= np.sum(slice_index_list == i)
    residual_loss = np.sqrt(residual_loss_num/(residual_loss_denum + 1e-9))
    return r2, slice_losses, slice_base_losses, slice_total_squares, residual_loss

def get_voxel_predictions(man, slice, indx1, indx2, slice_index_list=None, net_predictions=None, gt_predictions=None):
    if slice_index_list is None:
        slice_index_list = np.load('{}/{}/slice_index_list.npy'.format(checkpoint_name, man))
    if net_predictions is None:
        net_predictions = np.load('{}/{}/net_predictions.npy'.format(checkpoint_name, man))
    if gt_predictions is None:
        gt_predictions = np.load('{}/{}/gt_predictions.npy'.format(checkpoint_name, man))
    
    net_voxel_preds = []
    gt_voxel_preds = []
    for i in range(len(slice_index_list)):
        if slice_index_list[i] != slice:
            continue
        
        net_voxel_preds.append(net_predictions[i, indx1, indx2])
        gt_voxel_preds.append(gt_predictions[i, indx1, indx2])
        
    return net_voxel_preds, gt_voxel_preds

def get_correlations(man):
    slice_index_list = np.load('{}/{}/slice_index_list.npy'.format(checkpoint_name, man))
    net_predictions = np.load('{}/{}/net_predictions.npy'.format(checkpoint_name, man))
    gt_predictions = np.load('{}/{}/gt_predictions.npy'.format(checkpoint_name, man))
    
    result = np.zeros([30, 64, 64])
    for s in range(30):
        for i in range(64):
            for j in range(64):
                p, q = get_voxel_predictions(man, s, i, j, slice_index_list, net_predictions, gt_predictions)
                result[s, i, j], _ = stats.pearsonr(p, q)
    return result

def get_grad(man):
    return np.load('{}/{}/grad_statistic.npy'.format(checkpoint_name, man))[:, 0].mean(0)

In [None]:
def get_statistics_full(i):
    fmri_path = '../../data/fMRI/'
    fmri = read_img(os.path.join(fmri_path, str(all_people[i])))
    net_predictions = np.load('{}/{}/net_predictions.npy'.format(checkpoint_name, str(all_people[i])))
    gt_predictions = np.load('{}/{}/gt_predictions.npy'.format(checkpoint_name, str(all_people[i])))
    slice_index_list = np.load('{}/{}/slice_index_list.npy'.format(checkpoint_name, str(all_people[i])))
    fmri_mean = fmri.mean(-1) / 4095 * 100
    num = np.zeros([30, 64, 64])
    denum = np.zeros([30, 64, 64])

    for i, indx in enumerate(slice_index_list):
        mean = fmri_mean[..., indx]
        num[indx] += (net_predictions[i] - gt_predictions[i]) ** 2
        denum[indx] += (gt_predictions[i] - mean) ** 2
    stats = np.sqrt(num / (denum + 1e-9))
    return stats

In [None]:
def get_statistics_train(i):
    fmri_path = '../../data/fMRI/'
    fmri = read_img(os.path.join(fmri_path, str(all_people[i])))
    net_predictions = np.load('{}/{}/net_predictions.npy'.format(checkpoint_name, str(all_people[i])))
    gt_predictions = np.load('{}/{}/gt_predictions.npy'.format(checkpoint_name, str(all_people[i])))
    slice_index_list = np.load('{}/{}/slice_index_list.npy'.format(checkpoint_name, str(all_people[i])))
    fmri_mean = fmri[..., :210].mean(-1) / 4095 * 100
    num = np.zeros([30, 64, 64])
    denum = np.zeros([30, 64, 64])

    for i, indx in enumerate(slice_index_list):
        mean = fmri_mean[..., indx]
        num[indx] += (net_predictions[i] - gt_predictions[i]) ** 2
        denum[indx] += (gt_predictions[i] - mean) ** 2
    stats = np.sqrt(num / (denum + 1e-9))
    return stats

In [None]:
def read_img(path):
    path = path[:-7]
    files = os.listdir(path)
    for file in files:
        if file[-12:] == 'cross.nii.gz':
            return nib.load(os.path.join(path, file)).get_data()

In [None]:
fmri_path = '../../data/fMRI/'
all_people = ['40/models', '49/models', '37/models', '36/models', '50/models', '47/models', '32/models', '48/models', '46/models', '35/models', '42/models', '43/models', '39/models', '44/models', '38/models', '41/models', '45/models']

## Tex tables

In [None]:
for i in range(2, 11):
    print(i/10, end=' & ')

In [None]:
for i in range(len(all_people)):
    print(all_people[i],  end=' & ')
    s = get_statistics_train(i)
    for i in [4, 6, 8, 10]:
        print("{0:.6f}".format((s < i/10).mean()), end=' & ')
    print("{0:.6f}".format((s < 1).mean()), end=' \\\\ ')
    print()
    print('\hline')

In [None]:
for i in range(len(all_people)):
    print(all_people[i],  end=' & ')
    s = get_statistics_full(i)
    for i in [4, 6, 8, 10]:
        print("{0:.6f}".format((s < i/10).mean()), end=' & ')
    print("{0:.6f}".format((s < 1).mean()), end=' \\\\ ')
    print()
    print('\hline')

## Preprocessing

In [None]:
grad_list = [get_grad(man) for man in all_people]

In [None]:
f = np.array([   0.        ,    3.96825397,    7.93650794,   11.9047619 ,
         15.87301587,   19.84126984,   23.80952381,   27.77777778,
         31.74603175,   35.71428571,   39.68253968,   43.65079365,
         47.61904762,   51.58730159,   55.55555556,   59.52380952,
         63.49206349,   67.46031746,   71.42857143,   75.3968254 ,
         79.36507937,   83.33333333,   87.3015873 ,   91.26984127,
         95.23809524,   99.20634921,  103.17460317,  107.14285714,
        111.11111111,  115.07936508,  119.04761905,  123.01587302])

t = np.array([  0.   ,   0.128,   0.256,   0.384,   0.512,   0.64 ,   0.768,
         0.896,   1.024,   1.152,   1.28 ,   1.408,   1.536,   1.664,
         1.792,   1.92 ,   2.048,   2.176,   2.304,   2.432,   2.56 ,
         2.688,   2.816,   2.944,   3.072,   3.2  ,   3.328,   3.456,
         3.584,   3.712,   3.84 ,   3.968,   4.096,   4.224,   4.352,
         4.48 ,   4.608,   4.736,   4.864,   4.992,   5.12 ,   5.248,
         5.376,   5.504,   5.632,   5.76 ,   5.888,   6.016,   6.144,
         6.272,   6.4  ,   6.528,   6.656,   6.784,   6.912,   7.04 ,
         7.168,   7.296,   7.424,   7.552,   7.68 ,   7.808,   7.936,
         8.064,   8.192,   8.32 ,   8.448,   8.576,   8.704,   8.832,
         8.96 ,   9.088,   9.216,   9.344,   9.472,   9.6  ,   9.728,
         9.856,   9.984,  10.112,  10.24 ,  10.368,  10.496,  10.624,
        10.752,  10.88 ,  11.008,  11.136,  11.264,  11.392,  11.52 ,
        11.648,  11.776,  11.904,  12.032,  12.16 ,  12.288,  12.416,
        12.544,  12.672,  12.8  ,  12.928,  13.056,  13.184,  13.312,
        13.44 ,  13.568,  13.696,  13.824,  13.952,  14.08 ,  14.208,
        14.336,  14.464,  14.592,  14.72 ,  14.848,  14.976,  15.104,
        15.232,  15.36 ,  15.488,  15.616,  15.744,  15.872,  16.   ,
        16.128,  16.256])

In [None]:
def get_distance(man, threshold):
    stats = get_statistics_train(man)
    center_points = [14.5, 31.5, 31.5]
    scores = {}
    for i in range(30):
        for j in range(64):
            for k in range(64):
                dist = (120/30*(i - center_points[0]))**2 \
                + (210/64*(j - center_points[1]))**2\
                + (210/64*(k - center_points[2]))**2
                if dist not in scores:
                    scores[dist] = []
                scores[dist].append(stats[i, j, k] < threshold)

    for key in scores:
        scores[key] = np.mean(scores[key])
    
    return scores

In [None]:
scores = get_distance(0, .6)

center_points = [14.5, 31.5, 31.5]
distances = np.zeros([30, 64, 64])
for i in range(30):
    for j in range(64):
        for k in range(64):
            dist = (120/30*(i - center_points[0]))**2 \
            + (210/64*(j - center_points[1]))**2\
            + (210/64*(k - center_points[2]))**2
            distances[i, j, k] = scores[dist]

            

In [None]:
dist_coll = np.concatenate([distances[i] for i in range(30)], axis=0)

In [None]:
plt.figure(figsize=[10, 100])
#plt.imshow((dist_coll < 70) & (dist_coll > 50))
plt.imshow(dist_coll)
plt.colorbar()

In [None]:
!rm -r plots

In [None]:
!mkdir ./plots
!mkdir ./plots/mean_grads
!mkdir ./plots/percentage_distance
!mkdir ./plots/percentage_distance/smoothed
!mkdir ./plots/percentage_distance/raw
!mkdir ./plots/treshold_visualization
!mkdir ./plots/distance_metric/
!mkdir ./plots/distance_metric/voxel_tresholds
!mkdir ./plots/distance_metric/smoothed
!mkdir ./plots/distance_metric/raw
!mkdir ./plots/distance_metric/voxel_masks
!mkdir ./plots/clipped_metric/

In [None]:
def get_sorted_people():
    stat_list = []
    for i in range(len(all_people)):
        s = get_statistics_train(i)
        stat_list.append(-(s < .6).mean())
    return np.argsort(stat_list)

In [None]:
subject_rate = get_sorted_people()

In [None]:
for i in range(len(all_people)):
    plt.pcolormesh(t, f, np.abs(grad_list[i]).mean(0), vmin=0, cmap='plasma')
    plt.title("mean abs gradient input man {}".format(all_people[i][:-7]))
    plt.colorbar()
    plt.ylabel('Frequency [Hz]')
    plt.xlabel('Time [sec]')
    plt.savefig('./plots/mean_grads/mean_gradient_subject_{}_rate_{}.png'.format(all_people[i][:-7], subject_rate[i]))
    plt.show()
    clear_output()

In [None]:
def plot_percentage_distance(man, threshold):
    stats = get_statistics_train(man)
    center_points = [14.5, 31.5, 31.5]
    scores = {}
    for i in range(30):
        for j in range(64):
            for k in range(64):
                dist = (120/30*(i - center_points[0]))**2 \
                + (210/64*(j - center_points[1]))**2\
                + (210/64*(k - center_points[2]))**2
                if dist not in scores:
                    scores[dist] = []
                scores[dist].append(stats[i, j, k] < threshold)

    for key in scores:
        scores[key] = np.mean(scores[key])

    keys, values = [], []
    for key in scores:
        keys.append(np.sqrt(key))
        values.append(scores[key])
    values = np.array(values)
    keys = np.array(keys)

    argsort = np.argsort(keys)

    plt.plot(keys[argsort], values[argsort])
    plt.title("Voxel percentage/Distance subject {} threshold {}".format(all_people[man][:-7], threshold))
    plt.ylabel('Voxel percentage')
    plt.xlabel('Distance [mm]')
    plt.savefig('./plots/percentage_distance/raw/percentage_distance_subject_{}_rate_{}.png'.format(all_people[man][:-7], subject_rate[man]))
    plt.show()

    plt.plot(keys[argsort], signal.savgol_filter(values[argsort], 53, 10))
    plt.title("Savgol Filter for voxel percentage/Distance subject {} threshold {}".format(all_people[man][:-7], threshold))
    plt.ylabel('Voxel percentage')
    plt.xlabel('Distance [mm]')
    
    plt.savefig('./plots/percentage_distance/smoothed/percentage_distance_subject_{}_rate_{}.png'.format(all_people[man][:-7], subject_rate[man]))
    plt.show()


In [None]:
def plot_tresholds(man):
    fmri_path = '../../data/fMRI/'
    fmri = read_img(os.path.join(fmri_path, str(all_people[man])))
    collage = np.concatenate([fmri.mean(-1)[..., i]/4095 for i in range(30)], axis = 0)
    stats = get_statistics_train(man)
    stats[stats > 2] = 2
    
    plt.figure(figsize=[20, 100])
    plt.subplot(171)
    plt.imshow(collage)
    plt.title('mean brain')
    plt.axis('off')
    #plt.suptitle('Treshold Visualisation subject {} rate {}'.format(all_people[man], subject_rate[man]), fontsize=1)
    for c, th in enumerate([.2, .4, .6, .8, 1]):
        plt.subplot(172 + c)
        collage_stats_tresh = np.concatenate([(stats[i]<th) for i in range(30)], axis = 0)
        plt.imshow(collage_stats_tresh)
        plt.title('treshold {}'.format(th))
        plt.axis('off')
    plt.savefig('./plots/treshold_visualization/Treshold Visualisation subject {} rate {}'.format(all_people[man][:-7], subject_rate[man]))
    plt.show()

In [None]:
for i in range(len(all_people)):
    plot_tresholds(i)
    clear_output()

In [None]:
for i in range(len(all_people)):
    plot_percentage_distance(i, 0.6)
    clear_output()

In [None]:
def plot_percentage_distance(man):
    fmri_path = '../../data/fMRI/'
    fmri = read_img(os.path.join(fmri_path, str(all_people[man])))
    mean_fmri = fmri.mean(-1)
    fmri_mask = mean_fmri > 66
    
    stats = get_statistics_train(man)
    center_points = [14.5, 31.5, 31.5]
    scores = {}
    for i in range(30):
        for j in range(64):
            for k in range(64):
                dist = (120/30*(i - center_points[0]))**2 \
                + (210/64*(j - center_points[1]))**2\
                + (210/64*(k - center_points[2]))**2
                
                
                if fmri_mask[j, k, i]:
                    if dist not in scores:
                        scores[dist] = []
                
                    scores[dist].append(stats[i, j, k])

    for key in scores:
        scores[key] = np.mean(scores[key])

    keys, values = [], []
    for key in scores:
        keys.append(np.sqrt(key))
        values.append(scores[key])
    values = np.array(values)
    keys = np.array(keys)

    argsort = np.argsort(keys)

    plt.plot(keys[argsort], values[argsort])
    plt.title("Voxel metric/Distance subject {}".format(all_people[man][:-7]))
    plt.ylabel('Voxel metric')
    plt.xlabel('Distance [mm]')
    plt.savefig('./plots/distance_metric/raw/metric_distance_subject_{}_rate_{}.png'.format(all_people[man][:-7], subject_rate[man]))
    plt.show()

    plt.plot(keys[argsort], signal.savgol_filter(values[argsort], 53, 10))
    plt.title("Savgol Filter for voxel metric/Distance subject {}".format(all_people[man][:-7]))
    plt.ylabel('Voxel metric')
    plt.xlabel('Distance [mm]')
    
    plt.savefig('./plots/distance_metric/smoothed/metric_distance_subject_{}_rate_{}.png'.format(all_people[man][:-7], subject_rate[man]))
    plt.show()



In [None]:
for i in range(len(all_people)):
    plot_percentage_distance(i)
    clear_output()

In [None]:
def plot_clipped_loss(man):
    stats = get_statistics_train(man)
    fig, axis = plt.subplots(6, 5, figsize=[40, 40])
    for slice_indx in range(30):
        img = stats[slice_indx]
        img[img > 2] = 2
        im = axis[slice_indx%6, slice_indx//6].imshow(img, cmap='plasma_r', vmin=0, vmax=2)
        plt.colorbar(im, ax=axis[slice_indx%6, slice_indx//6])
    plt.savefig('./plots/clipped_metric/clipped_metric_subject{}_rate_{}.png'.format(all_people[man][:-7], subject_rate[man]))
    plt.show()

In [None]:
def plot_filter_treshold(man):
    fmri_path = '../../data/fMRI/'
    fmri = read_img(os.path.join(fmri_path, str(all_people[man])))
    mean_fmri = fmri.mean(-1)

    th_list = np.linspace(0, mean_fmri.max(), 1000)
    acc_list = []
    for th in th_list:
        acc_list.append((mean_fmri > th).mean())
    acc_list = np.array(acc_list)
    #plt.figure(figsize=(20, 20))
    l = len(th_list[th_list < 66])
    plt.plot(th_list, acc_list, label='accepted voxels')
    plt.plot(th_list[:l], acc_list[:l], label='filtered voxels')
    plt.legend()
    plt.title("filtered voxels / tresholds subject {}".format(all_people[man]))
    plt.ylabel("voxel persentage")
    plt.xlabel("tresholds")
    plt.savefig('./plots/distance_metric/voxel_tresholds/voxel_persentage_subject{}_rate_{}.png'.format(all_people[man][:-7], subject_rate[man]))
    plt.show()

In [None]:
def plot_mask(man):
    fmri_path = '../../data/fMRI/'
    fmri = read_img(os.path.join(fmri_path, str(all_people[man])))
    mean_fmri = fmri.mean(-1)
    fmri_mask = mean_fmri > 66
    
    collage = np.concatenate([fmri.mean(-1)[..., i]/4095 for i in range(30)], axis = 0)
    collage_mask = np.concatenate([fmri_mask[..., i] for i in range(30)], axis = 0)
    
    plt.figure(figsize=[20, 200])
    plt.subplot(121)
    plt.imshow(collage)
    plt.title('mean brain')
    plt.axis('off')
    
    plt.subplot(122)
    plt.imshow(collage_mask)
    plt.title('mask')
    plt.axis('off')
    plt.savefig('./plots/distance_metric/voxel_masks/mask_subject{}_rate_{}.png'.format(all_people[man][:-7], subject_rate[man]))
    plt.show()

In [None]:
center_points = [14.5, 31.5, 31.5]
distances = np.zeros([30, 64, 64])
for i in range(30):
    for j in range(64):
        for k in range(64):
            dist = (120/30*(i - center_points[0]))**2 \
            + (210/64*(j - center_points[1]))**2\
            + (210/64*(k - center_points[2]))**2
            distances[i, j, k] = dist

            

In [None]:
for i in range(len(all_people)):
    plot_clipped_loss(i)
    clear_output()

In [None]:
for i in range(len(all_people)):
    plot_mask(i)
    clear_output()

In [None]:
for i in range(len(all_people)):
    plot_filter_treshold(i)
    clear_output()

In [None]:
!tar -cvf plots3d_1.tar ./plots