In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
# get_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np
import pandas as pd
import pickle
import scipy.signal
import glob
import cv2
from skimage.exposure import rescale_intensity
import argparse
from itertools import groupby
from skimage import color, data, restoration
import h5py
import random
from tqdm import tqdm
import os
import time
from patchify import patchify, unpatchify

In [None]:
fs = 500000
spec_params={
'nperseg': 512, # default 1024
'noverlap': 256, # default: nperseg / 4
'fs': fs, # raw signal sample rate is 4MHz
'window': 'hamm',
'scaling': 'density', # {'density', 'spectrum'}
'detrend': 'linear', # {'linear', 'constant', False}
'eps': 1e-11}

# returns spectrogram
def specgr (fname,ecen,spec_params,cut_shot):
    ece_data = pickle.load(open(fname,'rb'))
    ece_num = 'besfu{:02d}'.format(ecen)
    sig_in = ece_data[ece_num]['data.BES'][:np.int_(cut_shot*spec_params['fs'])]
    f, t, Sxx = scipy.signal.spectrogram(sig_in, 
                                         nperseg=spec_params['nperseg'], 
                                         noverlap=spec_params['noverlap'],
                                         fs=spec_params['fs'], 
                                         window=spec_params['window'],
                                         scaling=spec_params['scaling'], 
                                         detrend=spec_params['detrend'])
    Sxx = np.log(Sxx + spec_params['eps'])
    Sxx=(Sxx-np.min(Sxx))/(np.max(Sxx)-np.min(Sxx))
    Sxx = Sxx[:-1,:];f=f[:-1]
    return Sxx,f,t

def norm(data):
    mn = data.mean()
    std = data.std()
    return((data-mn)/std)

def rescale(data):
    return (data-data.min())/(data.max()-data.min())

def quantfilt(src,thr=0.9):
    filt = np.quantile(src,thr,axis=0)
    out = np.where(src<filt,0,src)
    return out

# gaussian filtering
def gaussblr(src,filt=(31, 3)):
    src = (rescale(src)*255).astype('uint8')
    out = cv2.GaussianBlur(src,filt,0)
    return rescale(out)

# mean filtering
def meansub(src):
    mn = np.mean(src,axis=1)[:,np.newaxis]
    out = np.absolute(src - mn)
    return rescale(out)

# morphological filtering
def morph(src):
    src = (rescale(src)*255).astype('uint8')
    
    se1 = cv2.getStructuringElement(cv2.MORPH_RECT, (4,4))
    se2 = cv2.getStructuringElement(cv2.MORPH_RECT, (3,1))
    mask = cv2.morphologyEx(src, cv2.MORPH_CLOSE, se1)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, se2)
        
    return rescale(mask)

#reshapes the data
def reshape(arr):
    arr = np.reshape(arr, (len(arr), 256, 128, 1))
    return arr

# patches all the strips together to 1 spectrogram
def patch(arr):
    all_patches = np.empty((len(arr)* 30, 256, 128))
    for i in range(len(arr)):
        patches = patchify(arr[i], (256, 128), step=128)
        
        for x in range(30):
            all_patches[(x + 30 * i)] = patches[0][x]
    
    return all_patches

# splits spectrogram into strips
def unpatch(arr):
    all_spectrograms = []
    for i in range(int(len(arr) / 30)):
        Sxx = []
        for x in range(30):
            Sxx.append(arr[x + 30 * i])
        
        y=[Sxx]
        reconstructed = unpatchify(np.array(y), (256, 3840))
        all_spectrograms.append(reconstructed)
    return np.array(all_spectrograms)

# displays Sxx and final
def display(Sxx, final):
    n = 5
    
    idx = np.random.randint(len(Sxx), size=n)
    plots1 = Sxx[idx, :]
    plots2 = final[idx, :]
    
    fig = plt.figure(figsize=(8,12))
    grd = gridspec.GridSpec(ncols=1, nrows=(2 * n), figure=fig)
    ax=[None] * (2 * n)
    
    for i, (plot1, plot2) in enumerate(zip(plots1, plots2)):
        ax[2*i] = fig.add_subplot(grd[2*i])
        ax[2*i].pcolormesh(t[:3840],(f/1000)+1,plot1,cmap='hot',shading='gouraud')
        _=plt.ylabel('Original')
        
        ax[2*i+1] = fig.add_subplot(grd[2*i+1])
        ax[2*i+1].pcolormesh(t[:3840],(f/1000)+1,plot2,cmap='hot',shading='gouraud')
        _=plt.ylabel('Final')
    
    plt.show()
    
    
#### modifications by finn ####
def omega(beta):
    """ follows http://www.pyrunner.com/weblog/2016/08/01/optimal-svht/ """
    coef = [0.56, -0.95, 1.82, 1.43]
    poly = [beta ** (3 - n) for n in range(4)]
    return sum([c*p for c,p in zip(coef, poly)]);

def computeSignal(matrix):
    """
    Computes the signal part of the SVD.
    
    Parameters
    ----------
    matrix : ndarray
        The data matrix.
        
    Returns
    -------
    ndarray shaped like test
    """
    u, s, vh = np.linalg.svd(matrix, full_matrices = False)
    beta = np.min(matrix.shape) / np.max(matrix.shape)
    med_sv = np.median(s)
    med_idx = np.argpartition(s, len(s) // 2)[len(s) // 2]
    t_star = omega(beta) * med_sv
    out = np.zeros_like(matrix, dtype=float)
    num_sing = (s > t_star).sum()
    for idx in range(1, 2 * num_sing):
        u1 = u[:, idx]
        v1 = vh[idx, :]
        s1 = s[idx]
        out += s1 * np.outer(u1, v1)
    return out

def denoiseSignal(matrix, start=None, stop=None, use_optimal=False):
    """
    Computes the signal part of the SVD.
    
    Parameters
    ----------
    matrix : ndarray
        The data matrix.
    start : int
        The first singular value to keep.
    stop : int
        The last singular value to keep.  Uses the python 
        exclusion convention.
    use_optimal : bool
        Whether to use the optimal values for the signal assuming
        Gaussian noise.
        
    Returns
    -------
    ndarray shaped like matrix
    """
    u, s, vh = np.linalg.svd(matrix, full_matrices = False)
    if use_optimal:
        beta = np.min(matrix.shape) / np.max(matrix.shape)
        med_sv = np.median(s)
        med_idx = np.argpartition(s, len(s) // 2)[len(s) // 2]
        t_star = omega(beta) * med_sv
        num_sing = (s > t_star).sum()
        start = 0
        stop = num_sing - 1
    else:  # look to load defaults
        if start is None:
            start = 1
        if stop is None:
            stop = len(s)
    # prevent some bad values from being used
    if start < 0:
        start = 0
    if stop > len(s):
        stop = len(s)
    out = u[:, start:stop] @ np.diag(s[start:stop]) @ vh[start:stop, :]
    return out

In [None]:
# testing with file 122117BES
shotn = '122117'

data_path = "/Users/foshea/Documents/Projects/Anomaly Detection/Plasma/spectrogram-enhancement/files"
fname = os.path.join(data_path, shotn + 'BES')

Sxx = []
processed = []
svded = []
thr = 0.9

for chn in range(30):  # limited to 30 due to a stupid error
    s,f,t = specgr(fname, chn+1, spec_params, 2)
    
    out_quant = quantfilt(s,thr)
    out_gauss = np.empty(s.shape)
    out_mean = np.empty(s.shape)
    out_morph = np.empty(s.shape)
    out_final = np.empty(s.shape)

    out_gauss =  gaussblr(out_quant,(31, 3))
    out_mean = meansub(out_gauss)    
    out_morph = morph(out_mean)
    out_final = meansub(out_morph)
    
    svd = denoiseSignal(s)
    
    Sxx.append(s)
    processed.append(out_final)
    svded.append(svd)

In [None]:
channel = 0

hacked = svded[channel].copy()
hacked[hacked < 0.0] = 0.0

datas = [Sxx[channel],
         processed[channel],
         svded[channel],
         hacked
        ]

titles = ["spectrogram", "processed", "SVD'd", "SVD'd > 0"]

fig, axs = plt.subplots(4, 2, sharex='col', figsize=(16,12), gridspec_kw={'width_ratios': [3, 1]})
plt.suptitle('BES, shot number: {:s}, channel: {:02d}'.format(shotn, channel))
for ax, d, title in zip(axs, datas, titles):
    n, edges = np.histogram(d.flatten(), bins=50, density=True)
    ax[1].bar(x=edges[:-1], height=n, width=(edges[1] - edges[0]), align='edge')
    ax[1].set_yscale('log')
    ax[0].imshow(d, origin='lower', aspect='auto', cmap='hot')
#     ax[0].pcolormesh(t ,(f/1000)+1 , d, cmap='hot', shading='gouraud')
    ax[0].set_ylabel('f (kHz)')
    ax[0].set_title(title)
_ = ax[0].set_xlabel('time (ms)')
