In [1]:
## Import Packages
from __future__ import print_function

import numpy as np
import pandas as pd
from itertools import product

#Astro Software
import astropy.units as units
from astropy.coordinates import SkyCoord
from astropy.io import fits

#Plotting Packages
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib import rcParams

import seaborn as sns

from PIL import Image

from yt.config import ytcfg
import yt
import yt.units as u

#Scattering NN
import torch
import torch.nn.functional as F
from torch import optim
from kymatio.torch import Scattering2D
device = "cpu"

#Machine Learning
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.decomposition import PCA, FastICA

import skimage
from skimage import filters

from scipy.optimize import curve_fit
from scipy import linalg
from scipy import stats
from scipy.signal import general_gaussian

#I/O
import h5py
import pickle
import glob
import copy
import time

#Plotting Style
%matplotlib inline
plt.style.use('dark_background')
rcParams['text.usetex'] = False
rcParams['axes.titlesize'] = 20
rcParams['xtick.labelsize'] = 16
rcParams['ytick.labelsize'] = 16
rcParams['legend.fontsize'] = 12
rcParams['axes.labelsize'] = 20
rcParams['font.family'] = 'sans-serif'

#Threading
torch.set_num_threads=2
from multiprocessing import Pool

import ntpath
def path_leaf(path):
    head, tail = ntpath.split(path)
    out = os.path.splitext(tail)[0]
    return out

def hd5_open(file_name,name):
    f=h5py.File(file_name,'r', swmr=True)
    data = f[name][:]
    f.close()
    return data

from matplotlib.colors import LinearSegmentedColormap
cdict1 = {'red':   ((0.0, 0.0, 0.0),
                   (0.5, 0.0, 0.0),
                   (1.0, 1.0, 1.0)),

         'green': ((0.0, 0.0, 0.0),
                   (1.0, 0.0, 0.0)),

         'blue':  ((0.0, 0.0, 1.0),
                   (0.5, 0.0, 0.0),
                   (1.0, 0.0, 0.0))
        }
blue_red1 = LinearSegmentedColormap('BlueRed1', cdict1,N=5000)

from sklearn.preprocessing import StandardScaler

  self[key]


In [2]:
mnist_train_y = hd5_open('../scratch_AKS/data/mnist_train_y.h5','main/data')
mnist_test_y = hd5_open('../scratch_AKS/data/mnist_test_y.h5','main/data')
test_angles = hd5_open('../scratch_AKS/data/angles_listjl.h5','main/test_angles')
train_angles = hd5_open('../scratch_AKS/data/angles_listjl.h5','main/train_angles')

In [3]:
def prec_LDA(lda,X_test,y_test):
    y_pred = lda.predict(X_test)
    cm = confusion_matrix(y_test, y_pred)
    prec = precision_score(y_test, y_pred,average='micro')
    return prec

def DHC_iso(wst,J,L):
    (nk, Nd) = np.shape(wst)
    S0 = wst[:,0:2]
    S1 = wst[:,2:J*L+2]
    S2 = np.reshape(wst[:,J*L+3:],(nk,(J*L+1),(J*L+1)))
    
    S1iso = np.zeros((nk,J))
    for j1 in range(J):
        for l1 in range(L):
            S1iso[:,j1] += S1[:,l1*J+j1]
    
    S2iso = np.zeros((nk,J,J,L))
    for j1 in range(J):
        for j2 in range(J):
            for l1 in range(L):
                for l2 in range(L):
                    deltaL = np.mod(l1-l2,L)
                    S2iso[:,j1,j2,deltaL] += S2[:,l1*J+j1,l2*J+j2]
                    
    Sphi1 = np.zeros((nk,J))
    for j1 in range(J):
        for l1 in range(L):
            Sphi1[:,j1] += S2[:,l1*J+j1,L*J]
            
    Sphi2 = np.zeros((nk,J))
    for j1 in range(J):
        for l1 in range(L):
            Sphi2[:,j1] += S2[:,L*J,l1*J+j1]
            
    return np.hstack((S0,S1iso,wst[:,J*L+2].reshape(nk,1),S2iso.reshape(nk,J*J*L),Sphi1,Sphi2,S2[:,L*J,L*J].reshape(nk,1)))

In [4]:
mnist_DHC_train_LanRot = hd5_open('../from_cannon/2021_03_28/mnist_DHC_train_ang_LanRotResize.h5','data')
mnist_DHC_test_LanRot = hd5_open('../from_cannon/2021_03_28/mnist_DHC_test_ang_LanRotResize.h5','data')

In [None]:
mnist_DHC_outR = hd5_open('../from_cannon/2021_04_04/mnist_DHC_train_RR_wd2.h5','main/data')
mnist_DHC_outR_test = hd5_open('../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2.h5','main/data')

In [None]:
angle_3_train = np.vstack([DHC_iso(mnist_DHC_train_LanRot[1::6,:],6,8),DHC_iso(mnist_DHC_train_LanRot[2::6,:],6,8),DHC_iso(mnist_DHC_train_LanRot[3::6,:],6,8)])
angle_3_test = np.hstack([mnist_train_y,mnist_train_y,mnist_train_y])

In [None]:
N = 10
ldaAug = LDA(n_components=N-1)
ldaAug.fit(DHC_iso(angle_3_train,6,8), angle_3_test)

In [None]:
N = 10
ldaNR = LDA(n_components=N-1)
ldaNR.fit(DHC_iso(mnist_DHC_train_LanRot[1::6,:],6,8), mnist_train_y)

In [None]:
N = 10
ldaR = LDA(n_components=N-1)
ldaR.fit(DHC_iso(mnist_DHC_outR,6,8), mnist_train_y)

In [None]:
prec_LDA(ldaNR,DHC_iso(mnist_DHC_out_test,6,8),mnist_test_y)

In [None]:
prec_LDA(ldaR,DHC_iso(mnist_DHC_outR_test,6,8),mnist_test_y)

In [None]:
prec_LDA(ldaNR,DHC_iso(mnist_DHC_outR_test,6,8),mnist_test_y)

In [None]:
prec_LDA(ldaAug,DHC_iso(mnist_DHC_out_test,6,8),mnist_test_y)

In [None]:
prec_LDA(ldaAug,DHC_iso(mnist_DHC_outR_test,6,8),mnist_test_y)

In [None]:
N = 10
ldaAugREG = LDA(n_components=N-1)
ldaAugREG.fit(angle_3_train, angle_3_test)

In [None]:
N = 10
ldaNRREG = LDA(n_components=N-1)
ldaNRREG.fit(mnist_DHC_train_LanRot[1::6,:], mnist_train_y)

In [None]:
N = 10
ldaRREG = LDA(n_components=N-1)
ldaRREG.fit(mnist_DHC_outR, mnist_train_y)

In [None]:
prec_LDA(ldaNRREG,mnist_DHC_out_test,mnist_test_y)

In [None]:
prec_LDA(ldaRREG,mnist_DHC_outR_test,mnist_test_y)

In [None]:
prec_LDA(ldaNRREG,mnist_DHC_outR_test,mnist_test_y)

In [None]:
prec_LDA(ldaAugREG,mnist_DHC_out_test,mnist_test_y)

In [None]:
prec_LDA(ldaAugREG,mnist_DHC_outR_test,mnist_test_y)

In [None]:
test_list = [
    '../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2.h5',
    '../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2_0.h5',
    '../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2_1.h5',
    '../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2_2.h5',
    '../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2_3.h5',
    '../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2_4.h5',
    '../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2_5.h5',
    '../from_cannon/2021_04_04/mnist_DHC_test_RR_wd2_6.h5'
]

In [None]:
prec_R_R = []
for file in test_list:
    mnist_DHC_outR_test = hd5_open(file,'main/data')
    prec_R_R.append(prec_LDA(ldaR,DHC_iso(mnist_DHC_outR_test,6,8),mnist_test_y))

In [None]:
np.mean(prec_R_R),np.std(prec_R_R), prec_R_R

In [None]:
prec_NR_R = []
for file in test_list:
    mnist_DHC_outR_test = hd5_open(file,'main/data')
    prec_NR_R.append(prec_LDA(ldaNR,DHC_iso(mnist_DHC_outR_test,6,8),mnist_test_y))

In [None]:
np.mean(prec_NR_R),np.std(prec_NR_R),prec_NR_R

In [None]:
prec_NRAug_R = []
for file in test_list:
    mnist_DHC_outR_test = hd5_open(file,'main/data')
    prec_NRAug_R.append(prec_LDA(ldaAug,DHC_iso(mnist_DHC_outR_test,6,8),mnist_test_y))

In [None]:
np.mean(prec_NRAug_R),np.std(prec_NRAug_R),prec_NRAug_R

In [None]:
prec_R_R_REG = []
for file in test_list:
    mnist_DHC_outR_test = hd5_open(file,'main/data')
    prec_R_R_REG.append(prec_LDA(ldaRREG,mnist_DHC_outR_test,mnist_test_y))
np.mean(prec_R_R_REG),np.std(prec_R_R_REG),prec_R_R_REG

In [None]:
prec_NR_R_REG = []
for file in test_list:
    mnist_DHC_outR_test = hd5_open(file,'main/data')
    prec_NR_R_REG.append(prec_LDA(ldaNRREG,mnist_DHC_outR_test,mnist_test_y))
np.mean(prec_NR_R_REG),np.std(prec_NR_R_REG),prec_NR_R_REG

In [None]:
prec_AUG_R_REG = []
for file in test_list:
    mnist_DHC_outR_test = hd5_open(file,'main/data')
    prec_AUG_R_REG.append(prec_LDA(ldaAugREG,mnist_DHC_outR_test,mnist_test_y))
np.mean(prec_AUG_R_REG),np.std(prec_AUG_R_REG),prec_AUG_R_REG

RWST and WST_log WU

In [None]:
with open('../../IWST/FromCannon/2021_04_04/RWST_MHD_rinvar_cumsum.p', 'rb') as input_file:
    RWST_MHD_rinvar_cumsum = np.array(pickle.load(input_file))

In [None]:
offset_1 = np.repeat([True, False, False],8)
amp_1 = np.repeat([False, True, False],8)
angle_1 = np.repeat([False, False, True],8)
offset_2 = np.repeat([True, False, False, False, False],8*7/2)
amp_2 = np.repeat([False, True, True, True, False],8*7/2)
angle_2 = np.repeat([False, False, False, False, True],8*7/2)

In [None]:
J = 8
L = 8
def WST_log_iso(scattering_coefficients):
    scattering_coefficients_0 = scattering_coefficients[0]
    scattering_coefficients_1 = np.log2(scattering_coefficients[1:L*J+1])
    
    rep_template = [(J-np.floor_divide(i,L)-1)*L for i in range(0,L*J)]
    scattering_coefficients_2 = np.log2(scattering_coefficients[L*J+1:]) - np.repeat(scattering_coefficients_1,rep_template, axis=0)
    
    scattering_coefficients_1 = scattering_coefficients_1.reshape(J,L)
    data1_iso = np.sum(scattering_coefficients_1,axis=1)
    
    indx_coeff = []
    for j in range(0,J):
        for the1 in range(0,L):
            for k in range(j+1,J):
                for the2 in range(0,L):
                    indx_coeff.append([j,the1,k,the2])
    indx_coeff = np.asarray(indx_coeff)

    scat2 = np.zeros((J,J,L))
    scat2[:] = np.NaN
        
    for j1 in range(0,J):
        for j2 in range(j1+1,J):
            scattering_coefficients_2_temp = scattering_coefficients_2[np.logical_and(indx_coeff[:,0]==j1,indx_coeff[:,2]==j2)].reshape(L,L)
            data2_iso_temp = np.array([np.sum([scattering_coefficients_2_temp[l,np.remainder(l+d,8)] for d in range(0,8)],axis=0) for l in range(0,8)])
            scat2[j1,j2,:] = data2_iso_temp
    
    data2 = scat2.flatten()
    out2 = [x for x in data2 if not np.isnan(x)]
    
    return np.append(
        np.reshape(
            scattering_coefficients_0,
            1)
        ,np.append(data1_iso,out2,axis=0)
        ,axis=0)

In [None]:
def LDA_AKS_testman(X_train,y_train,X_test,y_test,n_components,label_list):
    lda = LDA(n_components=n_components)
    X_train = lda.fit_transform(X_train, y_train)
    y_pred = lda.predict(X_test)
    X_test = lda.transform(X_test)
    cm = confusion_matrix(y_test, y_pred)
    prec = precision_score(y_test, y_pred,average='micro')
    print(cm)
    print('Accuracy' + str(prec))
    cmap_normal = (cm.T/cm.sum(axis=1)).T
    
    fig = plt.figure(figsize=(10,10),dpi=150)

    ax = fig.add_subplot(2,2,1)
    ax.imshow(cmap_normal,cmap='gray',vmin=0,vmax=1)

    ax.set_xticks(np.arange(cm.shape[0]))
    ax.set_yticks(np.arange(cm.shape[1]))

    ax.set_xticklabels(label_list)
    ax.set_yticklabels(label_list)

    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')

    plt.setp(ax.get_xticklabels(), rotation=90, ha="right",va='center',
             rotation_mode="anchor")
    plt.title('Test-Train Fidelity ({:.0f}\%)'.format(100*prec))

    # Loop over data dimensions and create text annotations.
    textcolors=["black", "white"]
    threshold = 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if cm[i, j] != 0:
                text = ax.text(j, i, cm[i, j],
                               ha="center", va="center", color=textcolors[int(cmap_normal[i, j] < threshold)])
    if n_components == 1:
        ax = fig.add_subplot(2,2,2)
        #ax.scatter(X_train, y_train, s=2, marker='o', zorder=10,c=y_train, cmap = 'bwr',alpha=0.5)
        #ax.scatter(X_test, y_test, s=2, marker='^', zorder=10,c=y_test, cmap = 'bwr',alpha=0.5)
        sns.distplot(X_test[y_test==0],ax=ax)
        sns.distplot(X_test[y_test==1],ax=ax)
        plt.xlabel('$LDA_{}$'.format(0))
        plt.ylabel('Probability Density')
        plt.title('LDA Projection')
        
    if n_components == 2:
        ax = fig.add_subplot(2,2,2)
        #ax.scatter(X_train[:,0], X_train[:,1], s=2, marker='o', zorder=10,c=y_train, cmap = 'bwr',alpha=0.5)
        ax.scatter(X_test[:,0], X_test[:,1], s=2, marker='^', zorder=10,c=y_test, cmap = 'bwr',alpha=0.5)
        plt.xlabel('$LDA_{}$'.format(i-2))
        plt.ylabel('$LDA_{}$'.format(i-1))
        plt.title('LDA Projection')
    elif n_components > 2:
        for i in range(2,5):
            ax = fig.add_subplot(2,2,i)
            #ax.scatter(X_train[:,i-2], X_train[:,i-1], s=2, marker='o', zorder=10,c=y_train, cmap = 'bwr',alpha=0.5)
            ax.scatter(X_test[:,i-2], X_test[:,i-1], s=2, marker='^', zorder=10,c=y_test, cmap = 'bwr',alpha=0.5)
            plt.xlabel('$LDA_{}$'.format(i-2))
            plt.ylabel('$LDA_{}$'.format(i-1))
            plt.title('LDA Projection')

    fig.subplots_adjust(wspace=0.6, hspace=0.6)
    plt.show()
    
    return (lda,cm,X_train,X_test,y_train,y_test,y_pred)

In [None]:
with open('../../IWST/FromCannon/2020_11_09/WST_MHD_rinvar_cumsum.p', 'rb') as input_file:
    WST_MHD_rinvar_cumsum = np.array(pickle.load(input_file))

In [None]:
WST_MHD_rinvar_cumsum_log_iso = np.array([WST_log_iso(WST_MHD_rinvar_cumsum[i,:]) for i in range(6912)])

In [None]:
wph_2dcs_labels = np.transpose(hd5_open("../from_cannon/2021_03_30/MHD_2dcs.h5","labels"))

lbl = {
    "Ms"   : 0,
    "Ma"   : 1,
    "t"    : 2,
    "ax"   : 3,
    "pos"  : 4,
    "class": 5
}

In [None]:
#
data_in = RWST_MHD_rinvar_cumsum
label = wph_2dcs_labels

train = [~np.isin(ele[lbl["t"]],[600,900]) for ele in label]
test = [np.isin(ele[lbl["t"]],[600,900]) for ele in label]
Y = label[:,lbl["class"]]

sc = StandardScaler()
X_train = sc.fit_transform(np.arcsinh(data[train]/(1e-20)))
X_test = sc.transform(np.arcsinh(data[test]/(1e-20)))
y_train = Y[train]
y_test = Y[test]

lda = LDA(n_components=7)
X_train = lda.fit_transform(X_train, Y_train)
y_predR = lda.predict(X_test)
X_testR = lda.transform(X_test)
cmR = confusion_matrix(Y_testR, y_predR)
precR = precision_score(Y_testR, y_predR,average='micro')
cmap_normalR = (cmR.T/cmR.sum(axis=1)).T

data_in = RWST_MHD_rinvar_cumsum[:,np.concatenate((offset_1,offset_2))]
label = wph_2dcs_labels

train = [~np.isin(ele[lbl["t"]],[600,900]) for ele in label]
test = [np.isin(ele[lbl["t"]],[600,900]) for ele in label]
Y = label[:,lbl["class"]]

sc = StandardScaler()
X_train = sc.fit_transform(np.arcsinh(data[train]/(1e-20)))
X_test = sc.transform(np.arcsinh(data[test]/(1e-20)))
y_train = Y[train]
y_test = Y[test]

lda = LDA(n_components=7)
X_train = lda.fit_transform(X_train, Y_train)
y_predRR = lda.predict(X_test)
X_testRR = lda.transform(X_test)
cmRR = confusion_matrix(Y_testRR, y_predRR)
precRR = precision_score(Y_testRR, y_predRR,average='micro')
cmap_normalRR = (cmRR.T/cmRR.sum(axis=1)).T

data_in = WST_MHD_rinvar_cumsum_log_iso
label = wph_2dcs_labels

train = [~np.isin(ele[lbl["t"]],[600,900]) for ele in label]
test = [np.isin(ele[lbl["t"]],[600,900]) for ele in label]
Y = label[:,lbl["class"]]

sc = StandardScaler()
X_train = sc.fit_transform(np.arcsinh(data[train]/(1e-20)))
X_test = sc.transform(np.arcsinh(data[test]/(1e-20)))
y_train = Y[train]
y_test = Y[test]

lda = LDA(n_components=7)
X_train = lda.fit_transform(X_train, Y_train)
y_predW = lda.predict(X_test)
X_testW = lda.transform(X_test)
cmW = confusion_matrix(Y_testW, y_predW)
precW = precision_score(Y_testW, y_predW,average='micro')
cmap_normalW = (cmW.T/cmW.sum(axis=1)).T

In [None]:
#
mpl.rcParams.update(mpl.rcParamsDefault)
rcParams['text.usetex'] = False
rcParams['axes.titlesize'] = 24
rcParams['xtick.labelsize'] = 18
rcParams['ytick.labelsize'] = 18
rcParams['legend.fontsize'] = 12
rcParams['axes.labelsize'] = 24
rcParams['font.family'] = 'sans-serif'
plt.style.use('seaborn-white')

vmin=-7
vmax=np.abs(vmin)
marker_size = 10
label_list = labels

from matplotlib import cm as cmplt
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
rgba_colors = [cmplt.twilight(norm(i)) for i in [-4,-3,-2,-1,1,2,3,4]]
color_order = ["white","black"]
color_map = "binary"

fig = plt.figure(figsize=(12,17),dpi=150)

colormap='twilight'

ax = fig.add_subplot(3,2,1)
ax.imshow(cmap_normalR,
          interpolation='nearest',
          cmap=color_map,
          aspect='equal',
          vmin=0,vmax=1
)

ax.set_xticks(np.arange(cmR.shape[0]))
ax.set_yticks(np.arange(cmR.shape[1]))

ax.set_xticklabels(label_list,size=16)
ax.set_yticklabels(label_list,size=16)

for ytick, color in zip(ax.get_yticklabels(), rgba_colors):
    ytick.set_color(color)

#plt.xlabel('Predicted Label')
plt.ylabel('True Label')

plt.setp(ax.get_xticklabels(), rotation=90, ha="right",va='center',
         rotation_mode="anchor")
plt.title('Accuracy ({:.0f}%)'.format(100*precR))

# Loop over data dimensions and create text annotations.
textcolors=color_order
threshold = 0.5
for i in range(cmR.shape[0]):
    for j in range(cmR.shape[1]):
        text = ax.text(j, i, cmR[i, j],
                       ha="center", va="center",
                       size=19,
                       color=textcolors[int(cmap_normalR[i, j] < threshold)])
for i in range(2,3):
    ax = fig.add_subplot(3,2,i)
    #ax.scatter(X_train[:,i-2], X_train[:,i-1], s=marker_size/2, marker='o', zorder=10,c=y_train, cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)
    sc_ref = ax.scatter(X_testR[Y_testR==y_predR,i-2], X_testR[Y_testR==y_predR,i-1], s=marker_size/2, marker='o', zorder=10,c=Y_testR[Y_testR==y_predR], cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)
    ax.scatter(X_testR[Y_testR!=y_predR,i-2], X_testR[Y_testR!=y_predR,i-1], s=marker_size*2, marker='x', zorder=10,c=Y_testR[Y_testR!=y_predR], cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)

    plt.xlabel('$LD_{}$'.format(i-2))
    plt.ylabel('$LD_{}$'.format(i-1))
    plt.title('RWST-LDA Projection')

ax = fig.add_subplot(3,2,3)
ax.imshow(cmap_normalRR,
          interpolation='nearest',
          cmap=color_map,
          aspect='equal',
          vmin=0,vmax=1
)

ax.set_xticks(np.arange(cmRR.shape[0]))
ax.set_yticks(np.arange(cmRR.shape[1]))

ax.set_xticklabels(label_list,size=16)
ax.set_yticklabels(label_list,size=16)

for ytick, color in zip(ax.get_yticklabels(), rgba_colors):
    ytick.set_color(color)

#plt.xlabel('Predicted Label')
plt.ylabel('True Label')

plt.setp(ax.get_xticklabels(), rotation=90, ha="right",va='center',
         rotation_mode="anchor")
plt.title('Accuracy ({:.0f}%)'.format(100*precRR))

# Loop over data dimensions and create text annotations.
textcolors=color_order
threshold = 0.5
for i in range(cmRR.shape[0]):
    for j in range(cmRR
                   .shape[1]):
        text = ax.text(j, i, cmRR[i, j],
                       ha="center", va="center",
                       size=19,
                       color=textcolors[int(cmap_normalRR[i, j] < threshold)])
for i in range(2,3):
    ax = fig.add_subplot(3,2,2+i)
    #ax.scatter(X_train_1[:,i-2], X_train_1[:,i-1], s=marker_size/2, marker='o', zorder=10,c=y_train_1, cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)
    sc_ref = ax.scatter(X_testRR[Y_testRR==y_predRR,i-2], X_testRR[Y_testRR==y_predRR,i-1], s=marker_size/2, marker='o', zorder=10,c=Y_testRR[Y_testRR==y_predRR], cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)
    ax.scatter(X_testRR[Y_testRR!=y_predRR,i-2], X_testRR[Y_testRR!=y_predRR,i-1], s=marker_size*2, marker='x', zorder=10,c=Y_testRR[Y_testRR!=y_predRR], cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)

    plt.xlabel('$LD_{}$'.format(i-2))
    plt.ylabel('$LD_{}$'.format(i-1))
    plt.title('R-RWST-LDA Projection')
    
#     if i == 2:
#         plt.xlim(-12,12)
#         plt.ylim(-12,12)
#     else:
#         plt.xlim(-12,12)
#         plt.ylim(-6,6)

ax = fig.add_subplot(3,2,5)
ax.imshow(cmap_normalW,
          interpolation='nearest',
          cmap=color_map,
          aspect='equal',
          vmin=0,vmax=1
)

ax.set_xticks(np.arange(cmW.shape[0]))
ax.set_yticks(np.arange(cmW.shape[1]))

ax.set_xticklabels(label_list,size=16)
ax.set_yticklabels(label_list,size=16)

for ytick, color in zip(ax.get_yticklabels(), rgba_colors):
    ytick.set_color(color)

plt.xlabel('Predicted Label')
plt.ylabel('True Label')

plt.setp(ax.get_xticklabels(), rotation=90, ha="right",va='center',
         rotation_mode="anchor")
plt.title('Accuracy ({:.0f}%)'.format(100*precW))

# Loop over data dimensions and create text annotations.
textcolors=color_order
threshold = 0.5
for i in range(cmW.shape[0]):
    for j in range(cmW.shape[1]):
        text = ax.text(j, i, cmW[i, j],
                       ha="center", va="center",
                       size=19,
                       color=textcolors[int(cmap_normalW[i, j] < threshold)])
for i in range(2,3):
    ax = fig.add_subplot(3,2,4+i)
    #ax.scatter(X_train_2[:,i-2], X_train_2[:,i-1], s=marker_size/2, marker='o', zorder=10,c=y_train_2, cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)
    sc_ref = ax.scatter(X_testW[Y_testW==y_predW,i-2], X_testW[Y_testW==y_predW,i-1], s=marker_size/2, marker='o', zorder=10,c=Y_testW[Y_testW==y_predW], cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)
    ax.scatter(X_testW[Y_testW!=y_predW,i-2], X_testW[Y_testW!=y_predW,i-1], s=marker_size*2, marker='x', zorder=10,c=Y_testW[Y_testW!=y_predW], cmap = colormap, vmin=vmin,vmax=vmax,alpha=1)

    plt.xlabel('$LD_{}$'.format(i-2))
    plt.ylabel('$LD_{}$'.format(i-1))
    plt.title('WST-LOG-ISO-LDA Projection')
    
fig.subplots_adjust(wspace=0.35, hspace=0.45)
#plt.savefig('../figures/RWSTCompare.png', dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.show()