In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch, sys, os, copy
import sigpy as sp
from scipy import signal
from scipy.stats import norm
import h5py
from os import listdir
from os.path import isfile, join
from tqdm import tqdm as tqdm_base

from ncsnv2.models import get_sigmas
from parameters import pairwise_dist
from parameters import step_size
from dotmap import DotMap

from tqdm import tqdm as tqdm_base
def tqdm(*args, **kwargs):
    if hasattr(tqdm_base, '_instances'):
        for instance in list(tqdm_base._instances):
            tqdm_base._decr_instances(instance)
    return tqdm_base(*args, **kwargs)

def sigma_rate(tqdm, shape):
    # Apply Song's Technique 2
    candidate_gamma = np.logspace(np.log10(0.9), np.log10(0.99999), 1000)
    gamma_criterion = np.zeros((len(candidate_gamma)))
    dataset_shape = np.prod(shape)

    for idx, gamma in enumerate(candidate_gamma):
        gamma_criterion[idx] = \
            norm.cdf(np.sqrt(2 * dataset_shape) * (gamma - 1) + 3*gamma) - \
            norm.cdf(np.sqrt(2 * dataset_shape) * (gamma - 1) - 3*gamma)
    
    best_idx = np.argmin(np.abs(gamma_criterion - 0.5))
    return candidate_gamma[best_idx]

In [2]:
def display(x, vmax=None):
    x = x[0] + 1j*x[1]
    x = np.array(x.cpu(), dtype=np.complex64)
    if vmax==None:
      plt.imshow(np.abs(x), cmap='gray')
    else:
      plt.imshow(np.abs(x), cmap='gray',vmax=vmax)
    plt.gca().invert_yaxis()
    plt.xlabel('doppler')
    plt.ylabel('range')

In [None]:
# path = '/csiNAS2/slow/mridata/skm_tea2/qdess/v1-release/files_recon_calib-24/'
# file_paths = [path+f for f in listdir(path) if isfile(join(path, f))]

# data = []

# for file in tqdm(file_paths):
#     with h5py.File(file, 'r') as F:
#         if F.keys():
#             target = np.squeeze(np.array(F['target'])[:, :, 70:135].transpose(-1, -3, -2, 0, 1))
#             for image in target:
#                 data.append(image)

# data = np.array(data)

In [None]:
# path = sys.path[0] + '/data/train-val-data/'
# path = path + 'train' + '-' + 'experiment-' + '1' + '.txt'
# filenames = np.loadtxt(path, dtype='str')
# data = []

# # ~850  files
# for ii in range(len(filenames)):
#     with np.load(filenames[ii], allow_pickle=True) as f:
#         data.append(f['data'])

# reshaped_data = np.array(data, dtype=np.complex64).transpose(0, 3, 1, 2)
# reshaped_data = np.reshape(reshaped_data, (-1, reshaped_data.shape[-2], reshaped_data.shape[-1]))
# channels = reshaped_data.copy()

# masks = []
# thr = 0.3
# for i in range(len(channels)):
#     z = np.max(np.abs(channels[i]))
#     mask = np.abs(channels[i]) > thr * z
#     masks.append(mask)
#     channels[i] = channels[i] * mask

# reshaped_data = np.stack((np.real(reshaped_data), np.imag(reshaped_data)), axis=1)
# channels = np.stack((np.real(channels), np.imag(channels)), axis=1)
# masks = np.stack((np.real(masks), np.imag(masks)), axis=1)

In [None]:
# X = []

# for i in range(0, 10000):
#     x0 = np.zeros((512*512), dtype=complex)
#     Npoints = np.random.randint(1,10)
#     idx = np.random.randint(512*512, size=Npoints)

#     x0[idx] = np.random.randn(*idx.shape) + 1j * np.random.randn(*idx.shape)

#     x0 = x0.reshape((512, 512))
#     N = 2
#     filt = np.outer(np.hamming(N), np.hamming(N))
#     x0 = signal.convolve(x0, filt, mode='same')
#     # y0 = sp.fft(x0, axes=(-1,-2))
#     x0 = x0 / np.max(np.abs(x0))
#     X.append(x0)

# torch.save({'X': X}, 'data/mri-data/skm-tea.pt')

# channels = torch.load(sys.path[0] + '/data/mri-data/skm-tea-128.pt')['X']
# channels = np.stack((np.real(channels), np.imag(channels)), axis=1)

In [3]:
# path = '/csiNAS/sidharth/T2_shuffling_data/checked_data'
# file_paths = [path + '/' + f for f in listdir(path) if isfile(join(path, f))]

# channels = np.zeros((4800, 240, 240), dtype=np.complex64)

# i = 0
# for file in tqdm(file_paths):
#     X = torch.load(file)['final_images']
#     for image in X:
#         channels[i] = image
#         i += 1

channels = torch.load(sys.path[0] + '/data/mri-data/knee-tea2.pt')['X']
channels = np.stack((np.real(channels), np.imag(channels)), axis=1)

In [69]:
# path = '/home/asad/score-based-channels/data/marius-data/'
# filenames = [path + f for f in os.listdir(path)]
# channels  = []

# contents = hdf5storage.loadmat(filenames[4])
# channel = np.asarray(contents['output_h'], dtype=np.complex64)

# channels.append(channel[:, 0])

# # Convert to array
# channels = np.asarray(channels)
# channels = np.reshape(channels, (-1, channels.shape[-2], channels.shape[-1]))

# channels = channels / np.max(np.abs(channels), axis=0)
# channels = np.stack((np.real(channels), np.imag(channels)), axis=1)

In [10]:
config = DotMap()
config.model.num_classes = 2311
config.model.sigma_dist = 'geometric'
config.device = 'cuda:0'

config.model.sigma_begin = np.loadtxt(sys.path[0] + '/parameters/knee-mri_max_pairwise_dist.txt')

# config.model.sigma_rate = sigma_rate(tqdm, channels[0].shape)
config.model.sigma_rate = 0.9954

config.model.sigma_end  = config.model.sigma_begin * config.model.sigma_rate ** (config.model.num_classes - 1)
config.model.step_size = step_size(config)

sigmas = get_sigmas(config)

In [11]:
print(config.model.sigma_begin)
print(config.model.sigma_rate)
print(config.model.sigma_end)

256.06842041015625
0.9954
0.006065912270894371


In [None]:
# plt.figure(figsize=(8,8))
# plt.axis('off')
# plt.title('Original Image')
# display(torch.tensor((channels[200])))

# plt.figure(figsize=(8,8))
# plt.title('Mask')
# display(torch.tensor(masks[1]))

In [12]:
images = []

for i in range(0, 10):
    samples = torch.tensor(channels[i]).cuda()
    all_labels = torch.tensor([x for x in range(0, len(sigmas), 200)])

    perturbed_images = {'original': [],
                        'noise_level': [],
                        'noise': [],
                        'perturbed': []}

    for labels in all_labels:
        used_sigmas = sigmas[labels].view(([1] * len(samples.shape[0:]))) 
        noise = torch.randn_like(samples) * used_sigmas
        perturbed_samples = samples + noise

        perturbed_images['original'] = samples
        perturbed_images['noise_level'].append(float(used_sigmas))
        perturbed_images['noise'].append(noise)
        perturbed_images['perturbed'].append(perturbed_samples)

    images.append(perturbed_images)

In [16]:
# sample = images[9]

# for i in range(len(sample['noise_level'])):
#     plt.subplots(1, 3, figsize=(16, 6))
#     plt.subplot(1,3,1)
#     plt.title('original')
#     display(sample['original'])
#     plt.axis('off')

#     plt.subplot(1,3,2)
#     plt.title('noise: ' + str(sample['noise_level'][i]))
#     display(sample['noise'][i])
#     plt.axis('off')

#     plt.subplot(1,3,3)
#     plt.title('image + noise')
#     display(sample['perturbed'][i])
#     plt.axis('off')

In [15]:
# plt.figure(figsize=(16, 6))
# # display(torch.tensor(sp.ifft(np.array(sample['original'].cpu()), axes=(-1,-2))))
# display(sample['original'])

# plt.figure(figsize=(16, 6))
# # display(torch.tensor(sp.ifft(np.array(sample['perturbed'][-1].cpu()), axes=(-1,-2))))
# display(sample['perturbed'][-1])

In [None]:
# filename = sys.path[0] + '/data/noisy-images/range_doppler_th0.3-noisy_samples.pt'
# torch.save({'images': images}, filename)