In [1]:
## Import Packages
from __future__ import print_function

import numpy as np
import pandas as pd
from itertools import product

#Astro Software
import astropy.units as units
from astropy.coordinates import SkyCoord
from astropy.io import fits

#Plotting Packages
import matplotlib as mpl
import matplotlib.cm as cmplt
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.colors import LogNorm
from matplotlib.colors import ListedColormap

from PIL import Image, ImageDraw

from yt.config import ytcfg
import yt
import yt.units as u

#Scattering NN
import torch
import torch.nn.functional as F
from torch import optim
from kymatio.torch import Scattering2D
device = "cpu"

#Machine Learning
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.decomposition import PCA, FastICA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error

import skimage
from skimage import filters as skfilters
from skimage.filters import window

import cv2

from scipy.optimize import curve_fit
from scipy import linalg
from scipy import stats
from scipy.signal import general_gaussian
from scipy.ndimage import map_coordinates


#I/O
import h5py
import pickle
import glob
import copy
import time
import os
import scipy.io as sio

#Plotting Style
%matplotlib inline
plt.style.use('dark_background')
rcParams['text.usetex'] = True
rcParams['axes.titlesize'] = 20
rcParams['xtick.labelsize'] = 16
rcParams['ytick.labelsize'] = 16
rcParams['legend.fontsize'] = 12
rcParams['axes.labelsize'] = 20
rcParams['font.family'] = 'sans-serif'

#Threading
torch.set_num_threads=32

import ntpath
def path_leaf(path):
    head, tail = ntpath.split(path)
    out = os.path.splitext(tail)[0]
    return out

def hd5_open(file_name,name):
    f=h5py.File(file_name,'r', swmr=True)
    data = f[name][:]
    f.close()
    return data

from matplotlib.colors import LinearSegmentedColormap
cdict1 = {'red':   ((0.0, 0.0, 0.0),
                   (0.5, 0.0, 0.0),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 0.0, 0.0),
                   (1.0, 0.0, 0.0)),

         'blue':  ((0.0, 0.0, 1.0),
                   (0.5, 0.0, 0.0),
                   (1.0, 0.0, 0.0))
        }
blue_red1 = LinearSegmentedColormap('BlueRed1', cdict1,N=5000)

  self[key]


In [None]:
# PreCalc the WST Network
J = 8
L = 8
m = 2
scattering = Scattering2D(J=J, shape=(256,256), L=L, max_order=m)

In [None]:
def WST_torch(src_img,scattering):
    src_img = src_img.astype(np.float32)
    src_img_tensor = torch.from_numpy(src_img).to(device).contiguous()
    return scattering(src_img_tensor).numpy()

In [None]:
def placement(src_img,loc,roll):
    if loc == 0:
        out_img = src_img
    if loc == 1:
        out_img = np.zeros((256,256))
        out_img[64:192,64:192] = skimage.transform.resize(src_img,(2**7,2**7),order=3)
    if loc == 2:
        out_img = np.zeros((256,256))
        out_img[96:160,96:160] = skimage.transform.resize(src_img,(2**6,2**6),order=3)
    
    rollx = np.roll(out_img,roll*32,axis=0)
    rolly = np.roll(rollx,roll*32,axis=1)
    return rolly

In [None]:
def MHD_process_dmock(file_name,cum_sum,axis,slice_ind,dist_scale,sig,AR,roll,rot):
    print(file_name,cum_sum,axis,slice_ind,dist_scale,sig,AR,roll,rot)

    with fits.open(file_name) as hdul:
        src_img = hdul[0].data

    if cum_sum == 0:
        if axis == 0:
            slc_img = src_img[slice_ind*8,:,:]
        if axis == 1:
            slc_img = src_img[:,slice_ind*8,:]
        if axis == 2:
            slc_img = src_img[:,:,slice_ind*8]
    if cum_sum == 1:
        if axis == 0:
            slc_img = np.cumsum(src_img,axis=0)[slice_ind*8,:,:]
        if axis == 1:
            slc_img = np.cumsum(src_img,axis=1)[:,slice_ind*8,:]
        if axis == 2:
            slc_img = np.cumsum(src_img,axis=2)[:,:,slice_ind*8]
    

    apod_img = apodize(slc_img)
    out_img = placement(slc_img,dist_scale,roll)
    #WST_img = sine_MHD_psf(out_img,sig,AR)
    if rot != 0:
        WST_img = rotate_image(out_img,rot)
    else:
        WST_img = out_img
        
    sc = StandardScaler()
    inputData = sc.fit_transform(WST_img)

    return WST_torch(inputData,scattering).flatten()

In [None]:
from torch.multiprocessing import Pool
if __name__ == '__main__':
    pool = Pool(30)
    WST_MHD_rinvar_cumsum = pool.starmap(MHD_process_dmock,
                                           product(file_list,
                                                   iter([1]),
                                                   range(0,3),
                                                   range(0,32),
                                                   iter([0]),
                                                   iter([0]),
                                                   iter([1]),
                                                   iter([0]),
                                                   iter([0])
                                                  ))
    pool.close()
    pool.join()

with open('WST_MHD_rinvar_cumsum_apod.p', 'wb') as output_file:
    pickle.dump(WST_MHD_rinvar_cumsum, output_file)

In [None]:
J = 8
L = 8
m = 2

def RWST_from_WST_nof(scattering_coefficients,L,J,m):
    scattering_coefficients_0 = scattering_coefficients[0]
    scattering_coefficients_1 = np.log2(scattering_coefficients[1:L*J+1])
    
    rep_template = [(J-np.floor_divide(i,L)-1)*L for i in range(0,L*J)]
    scattering_coefficients_2 = np.log2(scattering_coefficients[L*J+1:]) - np.repeat(scattering_coefficients_1,rep_template, axis=0)
    
    def func(x, a, b, c):
        return b * np.cos(2*np.pi/L*(x-c)) + a

    def func_2(X, a, b, c, d, e):
        x,y = X
        return a + b * np.cos(2*np.pi/L*(x-y)) + c * np.cos(2*np.pi/L*(x-e)) + d * np.cos(2*np.pi/L*(y-e)) 
    
    order_1_fits = np.zeros([3,J])
    xdata = np.linspace(1,L,L)
    for j in range(0,J):
        ydata = scattering_coefficients_1[j*L:(j+1)*L]
        popt, pcov = curve_fit(func, xdata, ydata, bounds=([-np.inf, -np.inf, 0],[np.inf, np.inf, L-1]),max_nfev=2000)
        order_1_fits[:,j] = popt
                
    indx_coeff = []
    for j in range(0,J):
        for the1 in range(0,L):
            for k in range(j+1,J):
                for the2 in range(0,L):
                    indx_coeff.append([j,the1,k,the2])
    indx_coeff = np.asarray(indx_coeff)
    
    order_2_fits = np.empty([5,J,J])
    order_2_fits[:] = np.NaN

    for j1 in range(0,J):
        for j2 in range(j1+1,J):
            x_data =[indx_coeff[np.logical_and(indx_coeff[:,0]==j1,indx_coeff[:,2]==j2)][:,1],indx_coeff[np.logical_and(indx_coeff[:,0]==j1,indx_coeff[:,2]==j2)][:,3]]
            y_data = scattering_coefficients_2[np.logical_and(indx_coeff[:,0]==j1,indx_coeff[:,2]==j2)]
            popt, pcov = curve_fit(func_2, 
                                   x_data, 
                                   y_data,
                                   bounds=([-np.inf, -np.inf, -np.inf, -np.inf, 0],[np.inf, np.inf, np.inf, np.inf, L-1]),
                                   max_nfev=2000)
            order_2_fits[:,j1,j2] = popt
    out1 = order_1_fits.flatten()
    out2 = [x for x in order_2_fits.flatten() if not np.isnan(x)]
    return np.concatenate((out1,out2))

In [None]:
def RWST_from_WST_MHD_rinvar_cumsum(i):
    return RWST_from_WST_nof(WST_MHD_rinvar_cumsum[i,:],8,8,2)

In [None]:
WST_MHD_rinvar_cumsum = np.array(WST_MHD_rinvar_cumsum)

In [None]:
if __name__ == '__main__':
    pool = Pool(30)
    RWST_MHD_rinvar_cumsum = pool.map(RWST_from_WST_MHD_rinvar_cumsum,range(0,13824))
    pool.close()
    pool.join()
    
with open('RWST_MHD_rinvar_cumsum_apod.p', 'wb') as output_file:
    pickle.dump(RWST_MHD_rinvar_cumsum, output_file)

Simpler Code

In [None]:
test_data = hd5_open("MHD_2dcs.h5","data")

image_list = []
for i in range(6912):
    image_list.append(test_data[:,:,i])

In [None]:
def MHD_process_dmock(src_img):
    apod_img = apodize(src_img)
        
    sc = StandardScaler()
    inputData = sc.fit_transform(apod_img)

    return WST_torch(inputData,scattering).flatten()

In [None]:
from torch.multiprocessing import Pool
if __name__ == '__main__':
    pool = Pool(30)
    WST_MHD_rinvar_cumsum = pool.map(MHD_process_dmock,image_list)
    pool.close()
    pool.join()

with open('WST_MHD_rinvar_cumsum_apod.p', 'wb') as output_file:
    pickle.dump(WST_MHD_rinvar_cumsum, output_file)

In [None]:
J = 8
L = 8
m = 2

def RWST_from_WST_nof(scattering_coefficients,L,J,m):
    scattering_coefficients_0 = scattering_coefficients[0]
    scattering_coefficients_1 = np.log2(scattering_coefficients[1:L*J+1])
    
    rep_template = [(J-np.floor_divide(i,L)-1)*L for i in range(0,L*J)]
    scattering_coefficients_2 = np.log2(scattering_coefficients[L*J+1:]) - np.repeat(scattering_coefficients_1,rep_template, axis=0)
    
    def func(x, a, b, c):
        return b * np.cos(2*np.pi/L*(x-c)) + a

    def func_2(X, a, b, c, d, e):
        x,y = X
        return a + b * np.cos(2*np.pi/L*(x-y)) + c * np.cos(2*np.pi/L*(x-e)) + d * np.cos(2*np.pi/L*(y-e)) 
    
    order_1_fits = np.zeros([3,J])
    xdata = np.linspace(1,L,L)
    for j in range(0,J):
        ydata = scattering_coefficients_1[j*L:(j+1)*L]
        popt, pcov = curve_fit(func, xdata, ydata, bounds=([-np.inf, -np.inf, 0],[np.inf, np.inf, L-1]),max_nfev=2000)
        order_1_fits[:,j] = popt
                
    indx_coeff = []
    for j in range(0,J):
        for the1 in range(0,L):
            for k in range(j+1,J):
                for the2 in range(0,L):
                    indx_coeff.append([j,the1,k,the2])
    indx_coeff = np.asarray(indx_coeff)
    
    order_2_fits = np.empty([5,J,J])
    order_2_fits[:] = np.NaN

    for j1 in range(0,J):
        for j2 in range(j1+1,J):
            x_data =[indx_coeff[np.logical_and(indx_coeff[:,0]==j1,indx_coeff[:,2]==j2)][:,1],indx_coeff[np.logical_and(indx_coeff[:,0]==j1,indx_coeff[:,2]==j2)][:,3]]
            y_data = scattering_coefficients_2[np.logical_and(indx_coeff[:,0]==j1,indx_coeff[:,2]==j2)]
            popt, pcov = curve_fit(func_2, 
                                   x_data, 
                                   y_data,
                                   bounds=([-np.inf, -np.inf, -np.inf, -np.inf, 0],[np.inf, np.inf, np.inf, np.inf, L-1]),
                                   max_nfev=2000)
            order_2_fits[:,j1,j2] = popt
    out1 = order_1_fits.flatten()
    out2 = [x for x in order_2_fits.flatten() if not np.isnan(x)]
    return np.concatenate((out1,out2))

In [None]:
def RWST_from_WST_MHD_rinvar_cumsum(i):
    return RWST_from_WST_nof(WST_MHD_rinvar_cumsum[i,:],8,8,2)

WST_MHD_rinvar_cumsum = np.array(WST_MHD_rinvar_cumsum)

In [None]:
if __name__ == '__main__':
    pool = Pool(30)
    RWST_MHD_rinvar_cumsum = pool.map(RWST_from_WST_MHD_rinvar_cumsum,range(0,6912))
    pool.close()
    pool.join()
    
with open('RWST_MHD_rinvar_cumsum_apod.p', 'wb') as output_file:
    pickle.dump(RWST_MHD_rinvar_cumsum, output_file)

And No Apod for Clarity

In [None]:
def MHD_process_dmock(src_img):

    sc = StandardScaler()
    inputData = sc.fit_transform(src_img)

    return WST_torch(inputData,scattering).flatten()

In [None]:
from torch.multiprocessing import Pool
if __name__ == '__main__':
    pool = Pool(30)
    WST_MHD_rinvar_cumsum = pool.map(MHD_process_dmock,image_list)
    pool.close()
    pool.join()

with open('WST_MHD_rinvar_cumsum_noapod.p', 'wb') as output_file:
    pickle.dump(WST_MHD_rinvar_cumsum, output_file)

In [None]:
WST_MHD_rinvar_cumsum = np.array(WST_MHD_rinvar_cumsum)

In [None]:
if __name__ == '__main__':
    pool = Pool(30)
    RWST_MHD_rinvar_cumsum = pool.map(RWST_from_WST_MHD_rinvar_cumsum,range(0,6912))
    pool.close()
    pool.join()
    
with open('RWST_MHD_rinvar_cumsum_npapod.p', 'wb') as output_file:
    pickle.dump(RWST_MHD_rinvar_cumsum, output_file)