In [1]:
## Import Packages
from __future__ import print_function

import numpy as np
import pandas as pd
from itertools import product

#Astro Software
import astropy.units as units
from astropy.coordinates import SkyCoord
from astropy.io import fits

#Plotting Packages
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib import rcParams

import seaborn as sns

from PIL import Image

from yt.config import ytcfg
import yt
from yt.analysis_modules.ppv_cube.api import PPVCube
import yt.units as u

#Scattering NN
import torch
import torch.nn.functional as F
from torch import optim
from kymatio.torch import Scattering2D
device = "cpu"

#Machine Learning
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.decomposition import PCA, FastICA

import skimage
from skimage import filters

from scipy.optimize import curve_fit
from scipy import linalg
from scipy import stats
from scipy.signal import general_gaussian

#I/O
import h5py
import pickle
import glob
import copy
import time

#Plotting Style
%matplotlib inline
plt.style.use('dark_background')
rcParams['text.usetex'] = False
rcParams['axes.titlesize'] = 20
rcParams['xtick.labelsize'] = 16
rcParams['ytick.labelsize'] = 16
rcParams['legend.fontsize'] = 12
rcParams['axes.labelsize'] = 20
rcParams['font.family'] = 'sans-serif'

#Threading
torch.set_num_threads=32



In [2]:
import torch
import torch.nn as nn
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.datasets as ds

In [3]:
import cv2

In [4]:
train_ds = ds.MNIST(root='./data_cache',train=True,download=True, transform=None)
test_ds = ds.MNIST(root='./data_cache',train=False,download=True, transform=None)

In [5]:
# PreCalc the WST Network
J = 7
L = 8
m = 2
scattering = Scattering2D(J=J, shape=(128,128), L=L, max_order=m)

In [6]:
def mnist_pad(im, theta=0):
    impad = np.zeros((64,64))
    impad[18:46,18:46] = im
    imbig = np.array(Image.fromarray(impad).resize((128,128)))
    if theta != 0.0:
        imrot = rotate_image(imbig, theta)
    else:
        imrot = imbig
    return imrot

def rotate_image(image, angle):
    image_center = tuple(np.array(image.shape[1::-1]) / 2)
    rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
    result_real = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LANCZOS4)
    return result_real
    
def mnist_WST(params):
    theta, x = params
    image    = mnist_pad(x, theta=theta)
    print(theta)
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)
    
def WST_torch(src_img,scattering):
    src_img = src_img.astype(np.float32)
    src_img_tensor = torch.from_numpy(src_img).to(device).contiguous()
    scattering_coefficients = scattering(src_img_tensor)
    return scattering_coefficients

In [7]:
lst = []
train_temp = train_ds.data.detach().cpu().numpy()
for i in range(60000):
    lst.append(train_temp[i,:,:])

lst_test = []
test_temp = test_ds.data.detach().cpu().numpy()
for i in range(10000):
    lst_test.append(test_temp[i,:,:])

In [8]:
%timeit mnist_WST((180/3,lst[2]))

60.0
60.0
60.0
60.0
60.0
60.0
60.0
60.0
1.71 s ± 9.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
import multiprocessing
from torch.multiprocessing import Pool
for p in multiprocessing.active_children():
    p.terminate()

In [9]:
M = 100
angle_array = [i for i in np.linspace(2*180/M,360,M)]
train_angles = [2*180/M,180/3,2*180/3,3*180/3,4*180/3,5*180/3];

In [11]:
def mnist_WST(x):
    image    = mnist_pad(x, theta=2*180/M)
    print(x[14,14])
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)
    
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_train_0 = pool.map(mnist_WST,lst)
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_train_0.h5', 'w')
hf.create_dataset('WST_MNIST_train_0', data=WST_MNIST_train_0)
hf.close()

0
240
254
254
252
255
0
0
128
0
0
253
0
0
0
121
53
251
0
254
185
0
29
22
254
0
253
0
155
0
0
0
254
0
0
172
0
79
240
0
254
147
253
252
156
244
0
197
253
170
10
198
0
2
190
253
0
218
254
177
240
247
254
0
253
0
249
0
254
0
220
2
192
0
161
0
237
197
210
30
7
253
253
254
168
241
0
255
0
239
234
118
253
0
253
0
254
208
229
107
197
255
254
253
0
0
0
186
0
254
0
0
238
79
0
251
253
253
0
246
0
0
150
254
0
254
0
254
253
253
180
140
254
0
175
253
121
0
252
18
253
254
253
0
0
0
253
253
253
57
0
66
253
0
254
253
0
248
175
0
254
0
253
0
0
0
157
228
255
0
57
44
32
84
252
0
84
0
254
249
253
0
0
58
0
229
254
250
253
45
0
0
0
0
255
0
0
101
194
253
0
0
253
254
206
255
254
95
254
253
200
63
252
253
253
70
253
252
254
254
151
0
178
20
0
253
180
0
255
128
0
0
255
0
255
226
254
40
254
246
254
0
0
253
254
253
0
0
254
0
174
200
0
111
255
176
40
112
255
32
252
0
253
253
63
0
5
0
0
253
0
16
60
0
254
0
105
203
253
166
62
253
204
252
61
253
181
252
240
139
253
0
0
255
0
254
156
226
0
0
0
0
253
253
224
251
7
253
2

In [None]:
def mnist_WST(x):
    image    = mnist_pad(x, theta=180/3)
    print(x[14,14])
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)
    
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_train = pool.map(mnist_WST,lst)
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_train_pi_3.h5', 'w')
hf.create_dataset('WST_MNIST_train_pi_3', data=WST_MNIST_train)
hf.close()

240
252
254
254
255
0
0
0
251
0
0
253
0
0
0
128
121
53
185
29
254
0
0
0
22
253
254
155
0
0
0
0
247
253
240
172
156
254
170
79
10
252
244
0
254
0
0
190
0
147
197
254
0
253
0
253
2
198
218
0
240
177
7
220
253
0
254
161
254
0
0
0
0
210
253
253
2
197
118
192
239
254
237
249
30
255
168
0
253
234
0
241
254
0
238
0
229
254
253
107
255
0
0
186
197
0
0
208
0
253
0
254
254
0
0
0
253
0
0
79
150
246
251
253
254
254
253
253
254
180
0
253
140
252
0
0
0
253
254
0
0
0
175
0
18
253
253
121
253
0
253
254
66
253
254
0
0
255
0
248
57
44
253
57
0
0
253
0
0
252
84
228
32
0
175
229
84
157
0
254
58
0
0
249
0
0
254
194
206
250
0
253
45
101
200
255
0
0
0
95
253
253
253
0
255
253
0
254
0
63
253
178
254
70
254
128
151
254
253
40
253
0
254
0
0
255
253
20
0
252
0
226
255
252
254
0
0
255
0
180
254
0
254
111
255
246
32
0
0
0
253
0
254
174
112
0
0
200
176
252
255
253
5
254
253
253
181
253
0
63
40
60
16
0
252
252
0
166
240
62
0
0
253
105
203
253
0
204
0
255
253
156
254
0
139
226
61
254
0
253
253
224
0
253
0
0
201
253
0

In [None]:
def mnist_WST(x):
    image    = mnist_pad(x, theta=2*180/3)
    print(x[14,14])
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)
    
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_train = pool.map(mnist_WST,lst)
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_train_2pi_3.h5', 'w')
hf.create_dataset('WST_MNIST_train_2pi_3', data=WST_MNIST_train)
hf.close()

In [None]:
def mnist_WST(x):
    image    = mnist_pad(x, theta=3*180/3)
    print(x[14,14])
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)
    
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_train = pool.map(mnist_WST,lst)
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_train_3pi_3.h5', 'w')
hf.create_dataset('WST_MNIST_train_3pi_3', data=WST_MNIST_train)
hf.close()

In [None]:
def mnist_WST(x):
    image    = mnist_pad(x, theta=4*180/3)
    print(x[14,14])
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)
    
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_train = pool.map(mnist_WST,lst)
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_train_4pi_3.h5', 'w')
hf.create_dataset('WST_MNIST_train_4pi_3', data=WST_MNIST_train)
hf.close()

In [None]:
def mnist_WST(x):
    image    = mnist_pad(x, theta=5*180/3)
    print(x[14,14])
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)
    
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_train = pool.map(mnist_WST,lst)
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_train_5pi_3.h5', 'w')
hf.create_dataset('WST_MNIST_train_5pi_3', data=WST_MNIST_train)
hf.close()

In [11]:
def mnist_WST(theta, x):
    image    = mnist_pad(x, theta=theta)
    print(theta)
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[0:25],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_0.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

3.6
3.6
3.6
3.6
3.6
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
10.799999999999999
10.799999999999999
10.799999999999999
10.799999999999999
10.799999999999999
14.399999999999999
14.399999999999999
14.399999999999999
14.399999999999999
14.399999999999999
18.0
18.0
18.0
18.0
18.0
21.6
21.6
21.6
21.6
21.6
25.2
21.6
21.6
3.6
21.6
21.6
21.6
25.2
14.399999999999999
18.0
14.399999999999999
18.0
18.0
18.0
14.399999999999999
7.199999999999999
14.399999999999999
14.399999999999999
10.799999999999999
10.799999999999999
10.799999999999999
7.199999999999999
7.199999999999999
7.199999999999999
10.799999999999999
18.0
10.799999999999999
3.6
7.199999999999999
3.6
3.6
3.6
21.6
21.6
21.6
14.399999999999999
7.199999999999999
14.399999999999999
10.799999999999999
18.0
14.399999999999999
14.399999999999999
18.0
10.799999999999999
14.399999999999999
21.6
18.0
18.0
21.6
18.0
10.799999999999999
7.199999999999999
25.2
10.799999999999999
3.6
7.199999999999999
7.1999

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[25:50],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_1.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[50:75],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_2.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[75:],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_3.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

Failed attempts at pbar

In [16]:
import parmap

In [54]:
def my_function(x,y):
    return x+y

In [55]:
list1 = [1 , 2, 3]
list2 = [10, 11, 12]

In [56]:
out = parmap.starmap(my_function,list(zip(list1,list2)),pm_pbar=True)

100%|██████████| 3/3 [00:00<00:00, 6879.67it/s]


In [45]:
test = [(x[0],x[1]) for x in product(list1,list2)]

In [69]:
if __name__ == '__main__':
    WST_MNIST_test = parmap.starmap(mnist_WST,list(product(angle_array,lst_test)),pm_pbar=True,pm_processes=4)

  0%|          | 0/1000000 [04:12<?, ?it/s]Process ForkPoolWorker-1837:
Process ForkPoolWorker-1834:
Process ForkPoolWorker-1835:
Process ForkPoolWorker-1836:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/n/home12/saydjari/.conda/envs/RWST/lib/

KeyboardInterrupt: 

  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
    return list(map(*args))
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/site-packages/parmap/parmap.py", line 117, in _func_star_many
    **func_items_args[3])
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "<ipython-input-68-7efd87553f7e>", line 24, in mnist_WST
    WST = WST_torch(norm_im,scattering).flatten()
  File "/n/home12/sa

In [62]:
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=3):
    """
        A parallel version of the map function with a progress bar. 

        Args:
            array (array-like): An array to iterate over.
            function (function): A python function to apply to the elements of array
            n_jobs (int, default=16): The number of cores to use
            use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of 
                keyword arguments to function 
            front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. 
                Useful for catching bugs
        Returns:
            [function(array[0]), function(array[1]), ...]
    """
    #We run the first few iterations serially to catch bugs
    if front_num > 0:
        front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
    #If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
    if n_jobs==1:
        return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
    #Assemble the workers
    with ProcessPoolExecutor(max_workers=n_jobs) as pool:
        #Pass the elements of array into function
        if use_kwargs:
            futures = [pool.submit(function, **a) for a in array[front_num:]]
        else:
            futures = [pool.submit(function, a) for a in array[front_num:]]
        kwargs = {
            'total': len(futures),
            'unit': 'it',
            'unit_scale': True,
            'leave': True
        }
        #Print out the progress as tasks complete
        for f in tqdm(as_completed(futures), **kwargs):
            pass
    out = []
    #Get the results from the futures. 
    for i, future in tqdm(enumerate(futures)):
        try:
            out.append(future.result())
        except Exception as e:
            out.append(e)
    return front + out

In [66]:
if __name__ == '__main__':
    parallel_process(list(product(angle_array,lst_test)),mnist_WST,n_jobs=4)

Process Process-1827:
Process Process-1829:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/concurrent/futures/process.py", line 181, in _process_worker
    result=r))
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/n/home12/saydjari/.conda/envs/RWST/lib/python3.6/multiprocessing/queues.py", line 347, 