In [1]:
## Import Packages
from __future__ import print_function

import numpy as np
import pandas as pd
from itertools import product

#Astro Software
import astropy.units as units
from astropy.coordinates import SkyCoord
from astropy.io import fits

#Plotting Packages
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib import rcParams

import seaborn as sns

from PIL import Image

from yt.config import ytcfg
import yt
from yt.analysis_modules.ppv_cube.api import PPVCube
import yt.units as u

#Scattering NN
import torch
import torch.nn.functional as F
from torch import optim
from kymatio.torch import Scattering2D
device = "cpu"

#Machine Learning
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.decomposition import PCA, FastICA

import skimage
from skimage import filters

from scipy.optimize import curve_fit
from scipy import linalg
from scipy import stats
from scipy.signal import general_gaussian

#I/O
import h5py
import pickle
import glob
import copy
import time

#Plotting Style
%matplotlib inline
plt.style.use('dark_background')
rcParams['text.usetex'] = False
rcParams['axes.titlesize'] = 20
rcParams['xtick.labelsize'] = 16
rcParams['ytick.labelsize'] = 16
rcParams['legend.fontsize'] = 12
rcParams['axes.labelsize'] = 20
rcParams['font.family'] = 'sans-serif'

#Threading
torch.set_num_threads=32



In [2]:
import torch
import torch.nn as nn
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.datasets as ds

In [3]:
import cv2

In [4]:
train_ds = ds.MNIST(root='./data_cache',train=True,download=True, transform=None)
test_ds = ds.MNIST(root='./data_cache',train=False,download=True, transform=None)

In [5]:
# PreCalc the WST Network
J = 7
L = 8
m = 2
scattering = Scattering2D(J=J, shape=(128,128), L=L, max_order=m)

In [6]:
lst = []
train_temp = train_ds.data.detach().cpu().numpy()
for i in range(60000):
    lst.append(train_temp[i,:,:])

lst_test = []
test_temp = test_ds.data.detach().cpu().numpy()
for i in range(10000):
    lst_test.append(test_temp[i,:,:])

In [7]:
import multiprocessing
from torch.multiprocessing import Pool
for p in multiprocessing.active_children():
    p.terminate()

In [8]:
M = 100
angle_array = [i for i in np.linspace(2*180/M,360,M)]
train_angles = [2*180/M,180/3,2*180/3,3*180/3,4*180/3,5*180/3];

In [9]:
def mnist_WST(theta, x):
    image    = mnist_pad(x, theta=theta)
    print(theta)
    Nx, Ny = image.shape
    S0 = np.mean(image)
    norm_im = image - S0
    S1   = np.sum(np.square(norm_im))/(Nx*Ny)
    norm_im /= np.sqrt(Nx*Ny*S1)
    WST = WST_torch(norm_im,scattering).flatten()
    return np.append([S0,S1],WST)

def mnist_pad(im, theta=0):
    impad = np.zeros((64,64))
    impad[18:46,18:46] = im
    imbig = np.array(Image.fromarray(impad).resize((128,128)))
    if theta != 0.0:
        imrot = rotate_image(imbig, theta)
    else:
        imrot = imbig
    return imrot

def rotate_image(image, angle):
    image_center = tuple(np.array(image.shape[1::-1]) / 2)
    rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
    result_real = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LANCZOS4)
    return result_real
    
def WST_torch(src_img,scattering):
    src_img = src_img.astype(np.float32)
    src_img_tensor = torch.from_numpy(src_img).to(device).contiguous()
    scattering_coefficients = scattering(src_img_tensor)
    return scattering_coefficients

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[0:10],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_0.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

3.6
3.6
3.6
3.6
3.6
3.6
3.6
3.6
3.6
3.6
3.6
3.6
3.6
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
10.799999999999999
7.199999999999999
10.799999999999999
10.799999999999999
10.799999999999999
10.799999999999999
10.799999999999999
7.199999999999999
10.799999999999999
7.199999999999999
7.199999999999999
10.799999999999999
10.799999999999999
3.6
3.6
3.6
10.799999999999999
7.199999999999999
10.799999999999999
3.6
7.199999999999999
3.6
3.6
10.799999999999999
7.199999999999999
7.199999999999999
7.199999999999999
7.199999999999999
3.6
3.6
3.6
7.199999999999999
3.6
3.6
7.199999999999999
3.6
3.6
7.199999999999999
7.199999999999999
10.799999999999999
10.799999999999999
10.799999999999999
10.799999999999999
3.6
10.799999999999999
7.199999999999999
7.199999999999999
7.199999999999999
3.6
10.799999999999999
7.199999999999999
7.199999999999999
3.6
3.

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[10:20],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_1.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

In [10]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[20:30],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_2.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

82.79999999999998
82.79999999999998
79.19999999999999
79.19999999999999
75.6
79.19999999999999
79.19999999999999
75.6
79.19999999999999
82.79999999999998
75.6
75.6
79.19999999999999
75.6
79.19999999999999
75.6
82.79999999999998
79.19999999999999
75.6
79.19999999999999
75.6
75.6
75.6
79.19999999999999
79.19999999999999
75.6
75.6
82.79999999999998
82.79999999999998
79.19999999999999
75.6
82.79999999999998
82.79999999999998
79.19999999999999
82.79999999999998
82.79999999999998
79.19999999999999
79.19999999999999
79.19999999999999
82.79999999999998
79.19999999999999
75.6
79.19999999999999
79.19999999999999
79.19999999999999
82.79999999999998
75.6
79.19999999999999
75.6
75.6
75.6
75.6
75.6
79.19999999999999
75.6
75.6
75.6
79.19999999999999
75.6
75.6
79.19999999999999
75.6
79.19999999999999
82.79999999999998
82.79999999999998
82.79999999999998
82.79999999999998
82.79999999999998
79.19999999999999
82.79999999999998
79.19999999999999
79.19999999999999
79.19999999999999
79.19999999999999
75.6
7

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[30:40],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_3.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
111.59999999999998
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
115.19999999999999
118.79999999999998
118.79999999999998
118.79999999999998
118.79999999999998
118.79999999999998
118.79999999999998
111.59999999999998
111.59999999999998
115.19999999999999
115.19999999999999
111.59999999999998
115.19999999999999
111.59999999999998
118.79999999999998
115.19999999999999
115.19999999999999
111.59999999999998
118.79999999999998
115.19999999999999
115.19999999999999
115.19999999999999
111.59999999999998
115.19999999999999
115.19999999999999
111.59999999999998
111.59999999999998
118.79999999999998
111.59999999

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[40:50],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_4.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[50:60],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_5.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[60:70],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_6.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[70:80],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_7.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
259.2
255.59999999999997
255.59999999999997
259.2
259.2
259.2
259.2
259.2
259.2
259.2
259.2
259.2
259.2
259.2
262.8
262.8
262.8
262.8
262.8
262.8
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
259.2
255.59999999999997
255.59999999999997
259.2
259.2
255.59999999999997
262.8
259.2
262.8
259.2
259.2
259.2
262.8
262.8
262.8
259.2
259.2
262.8
259.2
255.59999999999997
259.2
255.59999999999997
259.2
255.59999999999997
255.59999999999997
255.59999999999997
259.2
255.59999999999997
255.59999999999997
259.2
262.8
255.59999999999997
255.59999999999997
259.2
262.8
259.2
262.8
259.2
262.8
262.8
262.8
259.2
259.2
259.2
259.2
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
255.59999999999997
259.2
259.

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[80:90],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_8.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

In [None]:
if __name__ == '__main__':
    pool = Pool(31)
    WST_MNIST_test = pool.starmap(mnist_WST,product(angle_array[90:100],lst_test))
    pool.close()
    pool.join()
    
hf = h5py.File('WST_MNIST_test_9.h5', 'w')
hf.create_dataset('WST_MNIST_test', data=WST_MNIST_test)
hf.close()

In [None]:
1+1