In [1]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
from scipy.signal import ShortTimeFFT
from scipy.signal.windows import hann
import random
from scipy.stats import skew
from scipy.stats import kurtosis
import csv
import datetime

In [2]:
import data_corruption

In [3]:
def create_corrupted_fids(gt,std_base,std_var):
  tm = data_corruption.TransientMaker(gt_fids,t,transients=160)
  tm.add_random_amplitude_noise(std_base,std_var)
  corrupted_fids = tm.fids

  return corrupted_fids

In [4]:
def spect_noise_estimation(spect, qntty, ppm):
  std_array = np.empty(qntty)

  for i in range(qntty):
    idx_noise_1 = np.abs(ppm[i,:] - 8.5).argmin()
    idx_noise_2 = np.abs(ppm[i,:] - 9.5).argmin()
    idx_noise_3 = np.abs(ppm[i,:] - 10.5).argmin()

    #assumes ppm is inverted: smaller values in higher indexes
    ppm_array_1 = ppm[i,idx_noise_2:idx_noise_1]
    ppm_array_2 = ppm[i,idx_noise_3:idx_noise_2]
    spect_array_1 = np.real(spect[i,idx_noise_2:idx_noise_1])
    spect_array_2 = np.real(spect[i,idx_noise_3:idx_noise_2])

    estimate_1 = np.polyfit(ppm_array_1, spect_array_1, 2)
    estimate_2 = np.polyfit(ppm_array_2, spect_array_2, 2)
    aux_1 = (estimate_1[0]*(ppm_array_1**2)) + (estimate_1[1]*ppm_array_1) +  estimate_1[2]
    aux_2 = (estimate_2[0]*(ppm_array_2**2)) + (estimate_2[1]*ppm_array_2) +  estimate_2[2]
    detrending_1 = spect_array_1 - aux_1
    detrending_2 = spect_array_2 - aux_2
    std_1 = np.std(detrending_1)
    std_2 = np.std(detrending_2)

    if np.abs(std_1) < np.abs(std_2):
      std = std_1
    else:
      std = std_2

    std_array[i] = std

  return std_array

def spect_SNR_estimation(spect, qntty, ppm, ppm_min_peak,ppm_max_peak):

  noise_array = spect_noise_estimation(spect, qntty, ppm)
  SNR_array = np.empty(qntty)
  peak_array = np.empty(qntty)

  for i in range(qntty):
    idx_GABA_0 = np.abs(ppm[i,:] - ppm_min_peak).argmin()
    idx_GABA_1 = np.abs(ppm[i,:] - ppm_max_peak).argmin()
    peak_amplitude = np.max(np.abs(np.real(spect[i,idx_GABA_1:idx_GABA_0])))

    SNR = peak_amplitude/(2*noise_array[i])
    SNR_array[i] = SNR
    peak_array[i] = peak_amplitude

  return SNR_array, noise_array, peak_array

In [5]:
def normalize_vector_between_minus_one_and_one(complex_array):
    real_parts = complex_array.real
    imaginary_parts = complex_array.imag

    min_real = np.min(real_parts)
    max_real = np.max(real_parts)
    min_imaginary = np.min(imaginary_parts)
    max_imaginary = np.max(imaginary_parts)

    range_real = max_real - min_real
    range_imaginary = max_imaginary - min_imaginary

    normalized_real = (((real_parts - min_real)/range_real)*2)-1
    normalized_imaginary = (((imaginary_parts - min_imaginary)/range_imaginary)*2)-1

    normalized_complex_array = normalized_real + 1j*normalized_imaginary
    return normalized_complex_array

def get_normalized_spectrogram(fids,qntty,a,b,norm_abs):
    w = hann(256, sym=True)
    mfft_ = 446
    SFT = ShortTimeFFT(w, hop=10, fs=bandwidth, mfft=mfft_, scale_to='magnitude', fft_mode = 'centered')
    t_lo, t_hi, f_lo, f_hi = SFT.extent(fids.shape[1])
    spgram = []
    for i in range(qntty):
        aux = SFT.stft(fids[i,:])
        if norm_abs == True:
            spgram.append(aux/np.max(np.abs(aux)))
        else:
            spgram.append(normalize_vector_between_minus_one_and_one(aux))
    spgram = np.array(spgram)
    
    freq_spect = np.flip(np.linspace(f_lo,f_hi,mfft_))
    ppm_spect = a*freq_spect+b
    t_spect = np.linspace(t_lo,t_hi,spgram.shape[2])
    
    return spgram, freq_spect, ppm_spect, t_spect

In [6]:
def center_bins(bins):
    mean_bins = []
    for i in range(bins.shape[0]):
        mean_bins.append([])
        for j in range(bins.shape[1]-1):
            aux = (bins[i,j+1]+bins[i,j])/2
            mean_bins[i].append(aux)
    mean_bins = np.array(mean_bins)
    return mean_bins

In [7]:
def get_histogram(spgram,qntty):
    
    hist = []
    bins_hist = []
    for i in range(qntty):
        #switched from 200 to 8000, from density to absolute
        aux, bins = np.histogram(np.real(spgram[i,:,:]), 8000)
        #added this normalization
        aux = aux/aux.sum()
        hist.append(aux)
        bins_hist.append(bins)
    hist = np.array(hist)
    bins_hist = np.array(bins_hist)

    bins_ = center_bins(bins_hist)
    return hist, bins_

In [8]:
def calculate_TVs(spgram):
  aux_delta_l = np.empty(spgram.shape)
  for i in range(spgram.shape[1]-1):
    aux = np.real(spgram[:,i+1,:]-spgram[:,i,:])
    aux_delta_l[:,i,:]=aux
  aux_delta_l[:,-1,:] = np.zeros((spgram.shape[0],spgram.shape[2]))
  aux_delta_c = np.empty(spgram.shape)
  for i in range(spgram.shape[2]-1):
    aux = np.real(spgram[:,:,i+1]-spgram[:,:,i])
    aux_delta_c[:,:,i]=aux
  aux_delta_c[:,:,-1] = np.zeros((spgram.shape[0],spgram.shape[1]))
  TV_aniso = np.sum(np.abs(aux_delta_l)+np.abs(aux_delta_c), axis =(1,2))
  TV_iso = np.sum(np.sqrt((np.abs(aux_delta_l)**2)+(np.abs(aux_delta_c)**2)), axis=(1,2))

  return TV_aniso, TV_iso

In [9]:
def stats(seq_stats,names):

  metrics = {}
  for i,value in enumerate(seq_stats):
    metrics[names[i]] = {}
    metrics[names[i]]['mean'] = np.mean(value)
    metrics[names[i]]['std'] = np.std(value)

  return metrics

In [10]:
def get_histogram_metrics(hist,bins):
    #no median, no range, no cov, new skew (corrected), new kurt (corrected), added max, added LWHM
    argmax_hist = np.argmax(hist,axis=1)
    mode_ = []
    for i in range(argmax_hist.shape[0]):
        mode_.append(bins[i,argmax_hist[i]]) #pixel value that happens the most
    mode_ = np.array(mode_)
    max_ = np.max(hist,axis=1) #peak amplitude

    LWHM_ = []
    value_ref_larg = 1e-4
    for i in range(hist.shape[0]):
        aux_ans_min = 10000000
        aux_idx_min = 0
        aux_ans_max= 10000000
        aux_idx_max = 0
        for j in range(argmax_hist[i]):
            if np.abs(hist[i,j] - value_ref_larg) < aux_ans_min:
                aux_ans_min = np.abs(hist[i,j] - value_ref_larg)
                aux_idx_min = j
        for j in range(argmax_hist[i],hist.shape[1]):
            if np.abs(hist[i,j] - value_ref_larg) < aux_ans_max:
                aux_ans_max = np.abs(hist[i,j] - value_ref_larg)
                aux_idx_max = j
        LWHM_.append(np.abs(bins[i,aux_idx_max]-bins[i,aux_idx_min]))
    LWHM_ = np.array(LWHM_) #linewidth

    
    mean_ = np.sum(bins*hist,axis=1)
    std_ = np.sqrt(np.sum(((bins - mean_[:, np.newaxis])**2)*hist,axis=1))
    skewness_ = np.sum(((bins - mean_[:, np.newaxis])/std_[:, np.newaxis])**3*hist,axis=1)
    kurtosis_ = np.sum(((bins - mean_[:, np.newaxis])/std_[:, np.newaxis])**4*hist,axis=1)
    
    names = ['mode','max','width','skewness','kurtosis']
    seq_stats = (mode_,max_,LWHM_,skewness_,kurtosis_)
    metrics = stats(seq_stats,names)

    return metrics

In [11]:
def get_spgram_metrics(spgram,time,ppm):
  mean_ = np.mean(np.real(spgram),axis = (1,2))
  median_ = np.median(np.real(spgram),axis = (1,2))
  std_ = np.std(np.real(spgram),axis = (1,2))
  trace_ = []
  for i in range(spgram.shape[0]):
    trace_.append(np.trace(np.cov(np.real(spgram[i,:,:]))))
  trace_ = np.array(trace_)
  TV_aniso,TV_iso = calculate_TVs(spgram)

  idx_time_1 = np.abs(time - 0.4).argmin()
  idx_time_2 = np.abs(time - 0.6).argmin()
  idx_freq_1 = np.abs(np.flip(ppm) - 1).argmin()
  idx_freq_2 = np.abs(np.flip(ppm) - 8).argmin()
  sum_late = np.sum(np.abs(np.real(spgram[:,:,idx_time_2:])),axis=(1,2))
  #changed mean late and std late || before np.mean(np.abs(np.real(
  mean_late = np.mean(np.real(spgram[:,:,idx_time_2:]),axis=(1,2))
  std_late = np.std(np.real(spgram[:,:,idx_time_2:]),axis=(1,2))
  TV_aniso_late,TV_iso_late = calculate_TVs(spgram[:,idx_freq_1:idx_freq_2,idx_time_2:])
  #changed mean main and std main || before np.mean(np.abs(np.real(
  mean_main_sig = np.mean(np.real(spgram[:,idx_freq_1:idx_freq_2,:idx_time_1]),axis=(1,2))
  std_main_sig = np.std(np.real(spgram[:,idx_freq_1:idx_freq_2,:idx_time_1]),axis=(1,2))
  TV_aniso_main_sig,TV_iso_main_sig = calculate_TVs(spgram[:,idx_freq_1:idx_freq_2,:idx_time_1])

  names = ['MEAN_total','median_total','STD_total','trace_total','TV_aniso_total','TV_iso_total',
           'sum_late','MEAN_late','STD_late','TV_aniso_late','TV_iso_late',
           'MEAN_main_sig','STD_main_sig','TV_aniso_main_sig','TV_iso_main_sig']
  seq_stats = (mean_,median_,std_,trace_,TV_aniso,TV_iso,
               sum_late,mean_late,std_late,TV_aniso_late,TV_iso_late,
               mean_main_sig,std_main_sig,TV_aniso_main_sig,TV_iso_main_sig)
  metrics = stats(seq_stats,names)

  return metrics

In [12]:
def write_data(file_path,data):
  # Add new data to CSV file
  with open(file_path, mode='a', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(data)

In [13]:
def from_dict_to_list(metrics_dict,names):
  data_list = []
  for i in range(len(names)):
    data_list.append(metrics_dict[names[i]]['mean'])
    data_list.append(metrics_dict[names[i]]['std'])

  return data_list

In [14]:
path_gt_file = '../sample_data.h5'

In [15]:
qntty = 200

In [16]:
#Import data obtained on EditedMRS_Reconstruction_Challenge github -- Ground-truths
with h5py.File(path_gt_file) as hf:
  print(hf.keys())
  gt_fids = hf["ground_truth_fids"][()][:qntty]
  ppm = hf["ppm"][()][:qntty]
  t = hf["t"][()][:qntty]
  print(gt_fids.shape)
  print(ppm.shape)
  print(t.shape)

<KeysViewHDF5 ['ground_truth_fids', 'ppm', 't']>
(200, 2048, 2)
(200, 2048)
(200, 2048)


In [17]:
file_spectrum_path = 'data_Real_Norm_ABS_STFT_FID_spectrum.csv'
file_hist_path = 'data_Real_Norm_ABS_STFT_FID_hist.csv'
file_spgram_path = 'data_Real_Norm_ABS_STFT_FID_spgram.csv'

In [18]:
#general
dwelltime = t[0,1]-t[0,0]
bandwidth = 1/dwelltime
N = gt_fids.shape[1]

#gts
spectra_gt_fids = np.fft.fftshift(np.fft.ifft(gt_fids,n=N,axis = 1), axes = 1)
spectra_gt_diff = spectra_gt_fids[:,:,1] - spectra_gt_fids[:,:,0]
freq = np.flip(np.fft.fftshift(np.fft.fftfreq(N, d = dwelltime)))

#to get ppm axis
idx_min = np.real(spectra_gt_diff[0,:]).argmin()
idx_max = np.real(spectra_gt_diff[0,:]).argmax()
#p = a*f + b
a = (ppm[0,idx_max] - ppm[0,idx_min])/(freq[idx_max]-freq[idx_min])
b = ppm[0,idx_max] - a*freq[idx_max]
#ppm_aux = b + freq*a

names_stats_spectrum = ['SNR','STD','peak']
names_stats_hist = ['mode','max','width','skewness','kurtosis']
names_stats_spgram = ['MEAN_total','median_total','STD_total','trace_total','TV_aniso_total','TV_iso_total',
                      'sum_late','MEAN_late','STD_late','TV_aniso_late','TV_iso_late',
                      'MEAN_main_sig','STD_main_sig','TV_aniso_main_sig','TV_iso_main_sig']

std_basis = [1,2,3,4,5,6,7,8,9,10,12,15,17,19]
var_basis = [0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,1.5,1.5,1.5,1.5]

print('starting process...')

for level_idx in range(len(std_basis)):
    print('std:',std_basis[level_idx],'index:',level_idx)
    start = datetime.datetime.now()
    corrupted_fids = create_corrupted_fids(gt_fids,std_basis[level_idx],var_basis[level_idx])
    end = datetime.datetime.now()
    print('got corrupted fids',(end-start),'s')
    #spectrum
    spectra_corrupted_fids = np.fft.fftshift(np.fft.ifft(corrupted_fids,n=N,axis = 1), axes = 1)
    spectra_corrupted_diff_avgs = np.mean((spectra_corrupted_fids[:,:,1,:] - spectra_corrupted_fids[:,:,0,:]), axis = 2)

    start = datetime.datetime.now()
    SNR_corrupted_avgs, std_corrupted_avgs, peak_corrupted_avgs = spect_SNR_estimation(spectra_corrupted_diff_avgs, qntty, ppm, 2.79, 3.55)
    metrics_spectrum_corrupted_avgs = stats((SNR_corrupted_avgs, std_corrupted_avgs, peak_corrupted_avgs),names_stats_spectrum)
    end = datetime.datetime.now()
    print('got spectrum metrics',(end-start),'s')
    
    #spectrogram
    corrupted_fids_avgs = np.mean((corrupted_fids[:,:,1,:]-corrupted_fids[:,:,0,:]), axis = 2)
    start = datetime.datetime.now()
    spgram_corrupted_avgs, freq_spect, ppm_spect, t_spect = get_normalized_spectrogram(corrupted_fids_avgs,qntty,a,b,True)
    metrics_spgram_corrupted_avgs = get_spgram_metrics(spgram_corrupted_avgs,t_spect,ppm_spect)
    end = datetime.datetime.now()
    print('got spectrogram metrics',(end-start),'s')
    
    #histogram
    start = datetime.datetime.now()
    hist_corrupted_avgs, bins_hist = get_histogram(spgram_corrupted_avgs,qntty)
    metrics_hist_corrupted_avgs = get_histogram_metrics(hist_corrupted_avgs,bins_hist)
    end = datetime.datetime.now()
    print('got histogram metrics',(end-start),'s')
    
    data_spectrum = []
    if level_idx == 0:
        data_spectrum.append(['std_base','var','mean_SNR','std_SNR','mean_STD','std_STD','mean_peak','std_peak'])
        
    aux = from_dict_to_list(metrics_spectrum_corrupted_avgs,names_stats_spectrum)
    data_spectrum.append([std_basis[level_idx],var_basis[level_idx]]+aux)
    write_data(file_spectrum_path ,data_spectrum)
    print('saved spectrum file')

    data_spgram = []
    if level_idx == 0:
        data_spgram.append(['std_base','var','mean_MEAN_total','std_MEAN_total',
                            'mean_median_total', 'std_median_total', 'mean_STD_total','std_STD_total',
                            'mean_trace_total','std_trace_total','mean_TV_aniso_total','std_TV_aniso_total',
                            'mean_TV_iso_total','std_TV_iso_total','mean_sum_late','std_sum_late','mean_MEAN_late','std_MEAN_late',
                            'mean_STD_late','std_STD_late','mean_TV_aniso_late','std_TV_aniso_late','mean_TV_iso_late','std_TV_iso_late',
                            'mean_MEAN_main_sig', 'std_MEAN_main_sig','mean_STD_main_sig','std_STD_main_sig',
                            'mean_TV_aniso_main_sig','std_TV_aniso_main_sig','mean_TV_iso_main_sig','std_TV_iso_main_sig'])
    aux = from_dict_to_list(metrics_spgram_corrupted_avgs,names_stats_spgram)
    data_spgram.append([std_basis[level_idx],var_basis[level_idx]]+aux)
    write_data(file_spgram_path ,data_spgram)
    print('saved spgram file')

    data_hist = []
    if level_idx == 0:
        data_hist.append(['std_base','var','mean_mode','std_mode','mean_max','std_max','mean_width','std_width',
                          'mean_skewness','std_skewness','mean_kurtosis','std_kurtosis'])
    aux = from_dict_to_list(metrics_hist_corrupted_avgs,names_stats_hist)
    data_hist.append([std_basis[level_idx],var_basis[level_idx]]+aux)
    write_data(file_hist_path ,data_hist)
    print('saved hist file')

    print('done: '+str(level_idx+1)+'/'+str(len(std_basis)))

starting process...
std: 1 index: 0
got corrupted fids 0:02:30.421548 s
got spectrum metrics 0:00:00.312078 s
got spectrogram metrics 0:00:19.730835 s
got histogram metrics 0:00:09.913201 s
saved spectrum file
saved spgram file
saved hist file
done: 1/14
std: 2 index: 1
got corrupted fids 0:02:09.459835 s
got spectrum metrics 0:00:00.308410 s
got spectrogram metrics 0:00:21.567080 s
got histogram metrics 0:00:10.405911 s
saved spectrum file
saved spgram file
saved hist file
done: 2/14
std: 3 index: 2
got corrupted fids 0:02:33.867145 s
got spectrum metrics 0:00:00.359092 s
got spectrogram metrics 0:00:21.171854 s
got histogram metrics 0:00:12.174914 s
saved spectrum file
saved spgram file
saved hist file
done: 3/14
std: 4 index: 3
got corrupted fids 0:03:11.528545 s
got spectrum metrics 0:00:00.245355 s
got spectrogram metrics 0:00:19.962153 s
got histogram metrics 0:00:12.618511 s
saved spectrum file
saved spgram file
saved hist file
done: 4/14
std: 5 index: 4
got corrupted fids 0:02: