In [None]:
# this code is used for GPU
# train_set is in R^{m_train x d} and test_set is in R^{m_test x d} , m_train is the size of traing set, m_test is the size of test set, d is dimension of features.





# data setup
X_train=train_set
y_train=train_label
X_test=test_set
y_test=test_label
num_classes=


import numpy as np
import cupy as cp
class Architecture:
    def __init__(self, blocks, model_dir, num_classes, batch_size=100):
        self.blocks = blocks
        self.model_dir = model_dir
        self.num_classes = num_classes
        self.batch_size = batch_size


    def __call__(self, Z, y=None):
        for b, block in enumerate(self.blocks):
            block.load_arch(self, b)
            self.init_loss()

            Z = block.preprocess(Z)
            Z = block(Z, y)
            Z = block.postprocess(Z)
        return Z

    def __getitem__(self, i):
        return self.blocks[i]

    def init_loss(self):
        self.loss_dict = {"loss_total": [],
                          "loss_expd": [],
                          "loss_comp": []}

    def update_loss(self, layer, loss_total, loss_expd, loss_comp):
        self.loss_dict["loss_total"].append(loss_total)
        self.loss_dict["loss_expd"].append(loss_expd)
        self.loss_dict["loss_comp"].append(loss_comp)
        print(f"layer: {layer} | loss_total: {loss_total:5f} | loss_expd: {loss_expd:5f} | loss_comp: {loss_comp:5f}")




##utils
import os
import logging
import json
import numpy as np
import pandas as pd
import torch


from torch.nn.functional import normalize


def sort_dataset(data, labels, classes, stack=False):
    """Sort dataset based on classes.

    Parameters:
        data (np.ndarray): data array
        labels (np.ndarray): one dimensional array of class labels
        classes (int): number of classes
        stack (bol): combine sorted data into one numpy array

    Return:
        sorted data (np.ndarray), sorted_labels (np.ndarray)

    """
    if type(classes) == int:
        classes = cp.arange(classes)
    sorted_data = []
    sorted_labels = []
    for c in classes:
        idx = (labels == c)
        data_c = data[idx]
        labels_c = labels[idx]
        sorted_data.append(data_c)
        sorted_labels.append(labels_c)
    if stack:
        sorted_data = cp.vstack(sorted_data)
        sorted_labels = cp.hstack(sorted_labels)
    return sorted_data, sorted_labels

def save_params(model_dir, params, name='params.json'):
    """Save params to a .json file. Params is a dictionary of parameters."""
    path = os.path.join(model_dir, name)
    with open(path, 'w') as f:
        json.dump(params, f, indent=2, sort_keys=True)

def load_params(model_dir):
    """Load params.json file in model directory and return dictionary."""
    _path = os.path.join(model_dir, "params.json")
    with open(_path, 'r') as f:
        _dict = json.load(f)
    return _dict

def create_csv(model_dir, filename, headers):
    """Create .csv file with filename in model_dir, with headers as the first line
    of the csv. """
    csv_path = os.path.join(model_dir, filename)
    if os.path.exists(csv_path):
        os.remove(csv_path)
    with open(csv_path, 'w+') as f:
        f.write(','.join(map(str, headers)))
    return csv_path

def save_loss(loss_dict, model_dir, name):
    save_dir = os.path.join(model_dir, "loss")
    os.makedirs(save_dir, exist_ok=True)
    file_path = os.path.join(save_dir, "{}.csv".format(name))
    pd.DataFrame(loss_dict).to_csv(file_path)

def save_features(model_dir, name, features, labels, layer=None):
    save_dir = os.path.join(model_dir, "features")
    os.makedirs(save_dir, exist_ok=True)
    cp.save(os.path.join(save_dir, f"{name}_features.npy"), features)
    cp.save(os.path.join(save_dir, f"{name}_labels.npy"), labels)








#functionals

import os
from tqdm import tqdm
import numpy as np
import scipy
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


def get_n_each(X, y, n, b=0):
    classes = np.unique(y)
    _X, _y = [], []
    for c in classes:
        idx = y==c
        X_class = X[idx][b*n:(b+1)*n]
        y_class = y[idx][:n]
        _X.append(X_class)
        _y.append(y_class)
    return np.vstack(_X), np.hstack(_y)

def translate1d(data, labels, n=None, stride=1):
    n_samples, _, n_dim = data.shape
    data_new = []
    if n is None:
        shifts = np.arange(0, n_dim, stride)
    else:
        shifts = np.arange(-n*stride, (n+1)*stride, stride)
    for r in shifts:
        data_new.append(np.roll(data, r, axis=2))
    return (np.vstack(data_new),
            np.tile(labels, len(shifts)))

def translate2d(data, labels, n=None, stride=1):
    n_samples, _, H, W = data.shape
    data_new = []
    if n is None:
        vshifts = np.arange(0, H, stride)
        hshifts = np.arange(0, W, stride)
    else:
        hshifts = np.arange(-n*stride, (n+1)*stride, stride)
        vshifts = np.arange(-n*stride, (n+1)*stride, stride)
    for h in vshifts:
        for w in hshifts:
            data_new.append(np.roll(data, (h, w), axis=(2, 3)))
    return (np.vstack(data_new),
            np.tile(labels, len(vshifts)*len(hshifts)))

def shuffle(data, labels, seed=10):
    np.random.seed(seed)
    num_samples = data.shape[0]
    idx = np.random.choice(np.arange(num_samples), num_samples, replace=False)
    return data[idx], labels[idx]

def filter_class(data, labels, classes, n=None, b=0):
    if type(classes) == int:
        classes = np.arange(classes)
    data_filter = []
    labels_filter = []
    for _class in classes:
        idx = labels == _class
        data_filter.append(data[idx][b*n:(b+1)*n])
        labels_filter.append(labels[idx][b*n:(b+1)*n])
    data_new = np.vstack(data_filter)
    labels_new = np.unique(np.hstack(labels_filter), return_inverse=True)[1]
    return data_new, labels_new

def normalize(X, p=2):
    axes = tuple(np.arange(1, len(X.shape)).tolist())
    norm = cp.linalg.norm(X.reshape(X.shape[0], -1), axis=1, ord=p)
    norm = cp.clip(norm, 1e-8, np.inf)
    return X / cp.expand_dims(norm, axes)

def batch_cov(V, bs):
    m = V.shape[0]
    return np.sum([np.einsum('ji...,jk...->ik...', V[i:i+bs], V[i:i+bs].conj(), optimize=True) \
                     for i in np.arange(0, m, bs)], axis=0)

def generate_kernel(mode, size, seed=10):
    np.random.seed(seed)
    if mode == 'gaussian':
        return np.random.normal(0., 1., size=size)
    elif mode == 'ones':
        return np.ones(size=size)

def convert2polar(images, channels, timesteps):
    mid_pt = images.shape[1] // 2
    r = np.linspace(0, mid_pt, channels).astype(np.int32)
    angles = np.linspace(0, 360, timesteps)
    polar_imgs = []
    for angle in angles:
        X_rot = scipy.ndimage.rotate(images, angle, axes=(1, 2), reshape=False)
        polar_imgs.append(X_rot[:, mid_pt, r])
    polar_imgs = np.stack(polar_imgs).transpose(1, 2, 0)
    return polar_imgs


##evaluate
import argparse
import os
import torch
import numpy as np

from sklearn.svm import LinearSVC
from sklearn.decomposition import PCA
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import SGDClassifier
from sklearn.svm import LinearSVC, SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier

def svm(train_features, train_labels, test_features, test_labels):
    svm = LinearSVC(verbose=0, random_state=10)
    svm.fit(train_features, train_labels)
    acc_train = svm.score(train_features, train_labels)
    acc_test = svm.score(test_features, test_labels)
    print("SVM: {}".format(acc_test))
    return acc_train, acc_test

def knn(train_features, train_labels, test_features, test_labels, k=5):
    """Perform k-Nearest Neighbor classification using cosine similaristy as metric.
    Options:
        k (int): top k features for kNN

    """
    sim_mat = train_features @ test_features.T
    topk = torch.from_numpy(sim_mat).topk(k=k, dim=0)
    topk_pred = train_labels[topk.indices]
    test_pred = torch.tensor(topk_pred).mode(0).values.detach()

    #print("knn",test_pred[0])

    acc = compute_accuracy(test_pred.numpy(), test_labels)
    print("kNN: {}".format(acc))
    return acc

def nearsub(train_features, train_labels, test_features, test_labels, n_comp=10):
    """Perform nearest subspace classification.

    Options:
        n_comp (int): number of components for PCA or SVD

    """
    scores_svd = []
    classes = np.unique(test_labels)
    features_sort, _ = sort_dataset(train_features, train_labels,
                                          classes=classes, stack=False)
    fd = features_sort[0].shape[1]
    if n_comp >= fd:
        n_comp = fd - 1
    for j in np.arange(len(classes)):
        svd = TruncatedSVD(n_components=n_comp).fit(features_sort[j])
        svd_subspace = svd.components_.T
        svd_j = (np.eye(fd) - svd_subspace @ svd_subspace.T) \
                        @ (test_features).T
        score_svd_j = np.linalg.norm(svd_j, ord=2, axis=0)
        scores_svd.append(score_svd_j)
    test_predict_svd = np.argmin(scores_svd, axis=0)
    ###
    #print('predict_svd',test_predict_svd[0])
    ###
    acc_svd = compute_accuracy(classes[test_predict_svd], test_labels)
    print('SVD: {}'.format(acc_svd))
    return acc_svd

def nearsub_pca(train_features, train_labels, test_features, test_labels, n_comp=10):
    """Perform nearest subspace classification.

    Options:
        n_comp (int): number of components for PCA or SVD

    """
    scores_pca = []
    classes = np.unique(test_labels)
    features_sort, _ = sort_dataset(train_features, train_labels,
                                          classes=classes, stack=False)
    fd = features_sort[0].shape[1]
    if n_comp >= fd:
        n_comp = fd - 1
    for j in np.arange(len(classes)):
        pca = PCA(n_components=n_comp).fit(features_sort[j])
        pca_subspace = pca.components_.T
        mean = np.mean(features_sort[j], axis=0)
        pca_j = (np.eye(fd) - pca_subspace @ pca_subspace.T) \
                        @ (test_features - mean).T
        score_pca_j = np.linalg.norm(pca_j, ord=2, axis=0)
        scores_pca.append(score_pca_j)
    test_predict_pca = np.argmin(scores_pca, axis=0)
    acc_pca = compute_accuracy(classes[test_predict_pca], test_labels)
    print('PCA: {}'.format(acc_pca))
    return acc_svd

def compute_accuracy(y_pred, y_true):
    """Compute accuracy by counting correct classification. """
    assert y_pred.shape == y_true.shape
    return 1 - np.count_nonzero(y_pred - y_true) / y_true.size

def baseline(train_features, train_labels, test_features, test_labels):
    test_models = {'log_l2': SGDClassifier(loss='log', max_iter=10000, random_state=42),
                   'SVM_linear': LinearSVC(max_iter=10000, random_state=42),
                   'SVM_RBF': SVC(kernel='rbf', random_state=42),
                   'DecisionTree': DecisionTreeClassifier(),
                   'RandomForrest': RandomForestClassifier()}
    for model_name in test_models:
        test_model = test_models[model_name]
        test_model.fit(train_features, train_labels)
        score = test_model.score(test_features, test_labels)
        print(f"{model_name}: {score}")




#Vector
import os
import numpy as np


from scipy.special import softmax

class Vector:
    def __init__(self, layers, eta, eps, lmbda=500):
        self.layers = layers
        self.eta = eta
        self.eps = eps
        self.lmbda = lmbda

    def __call__(self, Z, y=None):
        for layer in range(self.layers):
            Z, y_approx = self.forward(layer, Z, y)
            self.arch.update_loss(layer, *self.compute_loss(Z, y_approx))
        return Z

    def forward(self, layer, Z, y=None):
        if y is not None:
            self.feature=Z
            self.label=y
            self.init(Z, y)
            self.save_weights(layer)
            self.save_gam(layer)

            m,d=Z.shape
            c=d/(m*self.eps)
            for j in range(self.num_classes):
                Z_j=Z[y==j]
                m_j=Z_j.shape[0]
                c_j=d/(m_j*self.eps)
                pre_Ej=cp.linalg.inv(cp.eye(m_j)+c * Z_j @ Z_j.T)
                pre_Cj=cp.linalg.inv(cp.eye(m_j)+c_j* Z_j@ Z_j.T)
                E_j=c*(cp.eye(d)-c * Z_j.T @ pre_Ej @ Z_j)
                C_j=c_j*(cp.eye(d)-c_j* Z_j.T @ pre_Cj @ Z_j)
                expd=Z_j @ E_j.T
                clus=self.gam[j] * Z_j @ C_j.T
                Z_j=Z_j+self.eta*(expd-clus)
                if j==0:
                    out=Z_j
                else:
                    out=cp.vstack((out,Z_j))
            if layer==self.layers-1:
                out=normalize(out)
            return out,y


        else:
            self.load_weights(layer)
            self.load_gam(layer)

            m, d =self.feature.shape
            c=d / (m * self.eps)
            for j in range(self.num_classes):
                m_j=self.feature[self.label==j].shape[0]
                c_j=d/(m_j * self.eps)
                pre_Cj=cp.linalg.inv(cp.eye(m_j)+ c_j * self.feature[self.label==j] @ self.feature[self.label==j].T)
                C_j=c_j*(cp.eye(d)-c_j* self.feature[self.label==j].T @ pre_Cj @ self.feature[self.label==j])
                if j==0:
                    comp= Z @ C_j.T
                else:
                    comp=cp.vstack((comp, Z @ C_j.T))
            m_1=Z.shape[0]
            comp=comp.reshape((self.num_classes,m_1,d))
            pred_pi, y_approx = self.nonlinear(comp)
            for j in range(self.num_classes):
                m_j=self.feature[self.label==j].shape[0]
                pre_Ej=cp.linalg.inv(cp.eye(m_j)+ c * self.feature[self.label==j] @ self.feature[self.label==j].T)
                E_j=c*(cp.eye(d)-c* self.feature[self.label==j].T @ pre_Ej @ self.feature[self.label==j])
                Z_j=Z+ self.eta * (Z@ E_j.T-self.gam[j]* comp[j])
                Z_j=pred_pi[j]*Z_j
                if j==0:
                    out=Z_j
                else:
                    out=out+Z_j
            if layer==self.layers-1:
                out=normalize(out)
            return out,y_approx


    def first_ortho(self,Z,y):
        y_1=np.array([])
        Z_0=Z[y==0]
        Z_0=Z_0.T
        y_1=np.concatenate((y_1, np.array([int(0)]*Z_0.shape[1])), axis=0)
        U_0,R_0=cp.linalg.qr(Z_0, mode='reduced')
        A=U_0
        output=Z_0
        for j in np.arange(1,self.num_classes):
            if j==1:
                A=U_0
            else:
                A=cp.concatenate((A,B ), axis=1)
            U=A @ A.T
            Z_j=Z[y==j].T
            y_1=np.concatenate((y_1, np.array([int(j)]*Z_j.shape[1])), axis=0)
            Z_j=(cp.eye(U.shape[0])-U)@ Z_j
            B,R_j=cp.linalg.qr(Z_j, mode='reduced')
            output=cp.hstack((output, Z_j))
        output=output.T
        return output,y_1


    def load_arch(self, arch, block_id):
        self.arch = arch
        self.block_id = block_id
        self.num_classes = self.arch.num_classes

    def init(self, Z, y):
        self.compute_gam(y)


    def compute_gam(self, y):
        m_j = [(y==j).nonzero()[0].size for j in range(self.num_classes)]
        self.gam = np.array(m_j) / y.size



    def compute_loss(self, Z, y):
        m, d = Z.shape
        I = cp.eye(d)

        c = d / (m * self.eps)
        logdet = cp.linalg.slogdet(I + c * Z.T @ Z)[1]
        loss_expd = logdet / 2.

        loss_comp = 0.
        for j in np.arange(self.num_classes):
            idx = (y == int(j))
            Z_j = Z[idx, :]
            m_j = Z_j.shape[0]
            if m_j == 0:
                continue
            c_j = d / (m_j * self.eps)
            logdet_j = cp.linalg.slogdet(I + c_j * Z_j.T @ Z_j)[1]
            loss_comp += self.gam[j] * logdet_j / 2.
        loss_expd=cp.asnumpy(loss_expd)
        loss_comp=cp.asnumpy(loss_comp)
        return loss_expd - loss_comp, loss_expd, loss_comp

    def preprocess(self, X):
        m = X.shape[0]
        X = X.reshape(m, -1)
        return normalize(X)

    def postprocess(self, X):
        return normalize(X)


    def nonlinear(self, Bz):
        axes = tuple(np.arange(2, len(Bz.shape)))
        norm = cp.linalg.norm(Bz.reshape(Bz.shape[0], Bz.shape[1], -1), axis=2)
        norm = cp.clip(norm, 1e-8, norm)
        norm=cp.asnumpy(norm)
        pred = softmax(-self.lmbda * norm, axis=0)
        pred=cp.asarray(pred)
        #print(pred[:,0])
        y = cp.argmax(pred, axis=0)

        return cp.expand_dims(pred, axes), y

    def save_weights(self, layer):
        weight_dir = os.path.join(self.arch.model_dir, "weights")
        os.makedirs(weight_dir, exist_ok=True)
        save_path = os.path.join(weight_dir, f"{self.block_id}_{layer}.npz")
        cp.savez(save_path, array1=self.feature,array2=self.label)

    def load_weights(self, layer):
        weight_dir = os.path.join(self.arch.model_dir, "weights")
        save_path = os.path.join(weight_dir, f"{self.block_id}_{layer}.npz")
        weights = cp.load(save_path)
        self.feature=weights['array1']
        self.label=weights['array2']
        return self.feature, self.label

    def save_gam(self, layer):
        weight_dir = os.path.join(self.arch.model_dir, "weights")
        os.makedirs(weight_dir, exist_ok=True)
        save_path = os.path.join(weight_dir, f"{self.block_id}_{layer}_gam.npy")
        np.save(save_path, self.gam)

    def load_gam(self, layer):
        weight_dir = os.path.join(self.arch.model_dir, "weights")
        save_path = os.path.join(weight_dir, f"{self.block_id}_{layer}_gam.npy")
        self.gam = np.load(save_path)
        return self.gam




#######
#layers=10
import argparse
import os
import time

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt



# hyperparameters
parser = argparse.ArgumentParser()
parser.add_argument('--layers', type=int, default=10, help="number of layers")
parser.add_argument('--eta', type=float, default=0.2, help='learning rate')
parser.add_argument('--eps', type=float, default=0.1, help='eps squared')
parser.add_argument('--tail', type=str, default='',
                    help='extra information to add to folder name')
parser.add_argument('--save_dir', type=str, default='./saved_models/',
                    help='base directory for saving PyTorch model. (default: ./saved_models/)')
parser.add_argument('--data_dir', type=str, default='./data/',
                    help='base directory for saving PyTorch model. (default: ./saved_models/)')
args = parser.parse_args(args=[])

# pipeline setup
model_dir = os.path.join("./saved_models", "multi-ReduNet-LastNorm",
                         "layers{}_eps{}_eta{}"
                         "".format(args.layers, args.eps, args.eta)
                         )
os.makedirs(model_dir, exist_ok=True)
save_params(model_dir, vars(args))
print(model_dir)



# model setup
layers = [Vector(args.layers, eta=args.eta, eps=args.eps)]
model = Architecture(layers, model_dir, num_classes)


# train/test pass
print("Forward pass - train features")
start_time=time.time()

Z_train = model(X_train, y_train)

end_time=time.time()

save_loss(model.loss_dict, model_dir, "train")
print("Forward pass - test features")
Z_test = model(X_test)
save_loss(model.loss_dict, model_dir, "test")

# save features
save_features(model_dir, "X_train", X_train, y_train)
save_features(model_dir, "X_test", X_test, y_test)
save_features(model_dir, "Z_train", Z_train, y_train)
save_features(model_dir, "Z_test", Z_test, y_test)

# evaluation
Z_train=cp.asnumpy(Z_train)
y_train=cp.asnumpy(y_train)
Z_test=cp.asnumpy(Z_test)
y_test=cp.asnumpy(y_test)

_, acc_svm = svm(Z_train, y_train, Z_test, y_test)
acc_knn = knn(Z_train, y_train, Z_test, y_test, k=5)
acc_svd = nearsub(Z_train, y_train, Z_test, y_test, n_comp=5)
acc = {"svm": acc_svm, "knn": acc_knn, "nearsub-svd": acc_svd}
save_params(model_dir, acc, name="acc_test.json")


elapsed_time = end_time - start_time
print(f"Model execution time: {elapsed_time:.4f} seconds")


In [None]:
import argparse
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.nn.functional import normalize
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1 import make_axes_locatable
from sklearn.decomposition import TruncatedSVD, PCA


def plot_heatmap(features, labels, title, model_dir):
    """Plot heatmap of cosine simliarity for all features. """
    num_samples = features.shape[0]
    classes = np.arange(np.unique(labels).size)
    features_sort_, _ = sort_dataset(features, labels,
                            classes=classes, stack=True)
    sim_mat = np.abs(features_sort_ @ features_sort_.T)
    print(sim_mat.min(), sim_mat.max())

#    plt.rc('text', usetex=True)
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman']
    fig, ax = plt.subplots(figsize=(8, 7), sharey=True, sharex=True)
    im = ax.imshow(sim_mat, cmap='Blues')
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cbar = fig.colorbar(im, cax=cax, drawedges=0, ticks=[0, 0.5, 1])
    cbar.ax.tick_params(labelsize=18)
    # fig.colorbar(im, pad=0.02, drawedges=0, ticks=[0, 0.5, 1])
    ax.set_xticks(np.linspace(0, num_samples, len(classes)+1))
    ax.set_yticks(np.linspace(0, num_samples, len(classes)+1))
    [tick.label.set_fontsize(24) for tick in ax.xaxis.get_major_ticks()]
    [tick.label.set_fontsize(24) for tick in ax.yaxis.get_major_ticks()]
    fig.tight_layout()

    save_dir = os.path.join(model_dir, "figures", "heatmaps")
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, f"heatmap-{title}.pdf"))
    plt.close()

def plot_combined_loss(model_dir, update=None):
    """Plot theoretical loss and empirical loss.

    Figure 3: gaussian2d, gaussian3d, fontsize 24

    """
#    plt.rc('text', usetex=True)
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman']
    fig, ax = plt.subplots(1, 1, figsize=(7, 5), sharey=True, sharex=True)
    models = ['train', 'test']
    linestyles = ['solid', 'dashed']
    markers = ['o', 'D']
    markersizes = [4.5, 3]
    alphas = [0.5, 0.9]
    names = ['$\Delta R$ ', '$R$', '$R_c$']
    colors = ['green', 'royalblue', 'coral']
    for model, linestyle, marker, alpha, markersize in zip(models, linestyles, markers, alphas, markersizes):
        filename = os.path.join(model_dir, "loss", f'{model}.csv')
        data = pd.read_csv(filename)
        losses = [data['loss_total'].ravel(), data['loss_expd'].ravel(), data['loss_comp'].ravel()]
        for loss, name, color in zip(losses, names, colors):
            num_iter = np.arange(loss.size)
            ax.plot(num_iter, loss, label=r'{} ({})'.format(name, model),
                color=color, linewidth=1.5, alpha=alpha, linestyle=linestyle,
                marker=marker, markersize=markersize, markevery=5, markeredgecolor='black')
    ax.set_ylabel('Loss', fontsize=40)
    ax.set_xlabel('Layers', fontsize=40)
    # ax.set_ylim((-0.05, 2.8)) # gaussian2d
    # ax.set_yticks(np.linspace(0, 2.5, 6)) # gaussian2d
    # ax.set_ylim((-0.05, 2.5)) # gaussian2d
    # ax.set_yticks(np.linspace(0, 2.5, 6)) # gaussian2d
    # ax.set_ylim((0, 4.0)) # gaussian3d
    # ax.set_yticks(np.linspace(0, 4.0, 9)) # gaussian3d
    # ax.set_ylim((-0.005, 0.075)) # mnist_rotation_classes01
    # ax.set_yticks(np.linspace(0, 0.075, 6)) # mnist_rotation_classes01
    # ax.set_ylim((-0.02, 0.1)) # sinusoid
    # ax.set_yticks(np.linspace(0, 0.1, 5)) # sinusoid
    handles, labels = ax.get_legend_handles_labels()
    handles = [handles[i] for i in [0, 3, 1, 4, 2, 5]]
    labels = [labels[i] for i in [0, 3, 1, 4, 2, 5]]
    ax.legend(handles, labels, loc='lower right', prop={"size": 13}, ncol=3, framealpha=0.5)
    [tick.label.set_fontsize(22) for tick in ax.xaxis.get_major_ticks()]
    [tick.label.set_fontsize(22) for tick in ax.yaxis.get_major_ticks()]
    fig.tight_layout()

    save_dir = os.path.join(model_dir, 'figures', 'loss')
    os.makedirs(save_dir, exist_ok=True)
    file_name = os.path.join(save_dir, f'loss-traintest.pdf')
    plt.savefig(file_name, dpi=200)
    plt.close()

def plot_2d(Z, y, name, model_dir):
    plot_dir = os.path.join(model_dir, "figures", "2dscatter")
    colors = np.array(['forestgreen', 'red', 'royalblue', 'purple', 'darkblue', 'orange'])
    os.makedirs(plot_dir, exist_ok=True)
#    plt.rc('text', usetex=True)
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman']
    # colors = np.array(['royalblue', 'forestgreen', 'red'])
    fig, ax = plt.subplots(figsize=(6, 5), dpi=200)
    ax.scatter(Z[:, 0], Z[:, 1], c=colors[y], alpha=0.5)
    ax.scatter(0.0, 0.0, c='black', alpha=0.8, marker='s')
    # ax.arrow(0.0, 0.0, Z[:, 0], Z[:, 1])
    ax.set_ylim(-1.2, 1.2)
    ax.set_xlim(-1.2, 1.2)
    ax.set_xticks([-1.0, -0.5, 0.0, 0.5, 1.0])
    ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0])
    ax.grid(linestyle=':')
    Z, _ = F.get_n_each(Z, y, 1)
    for c in np.unique(y):
        ax.arrow(0, 0, Z[c, 0], Z[c, 1], head_width=0.03, head_length=0.05, fc='k', ec='k', length_includes_head=True)
    [tick.label.set_fontsize(24) for tick in ax.xaxis.get_major_ticks()]
    [tick.label.set_fontsize(24) for tick in ax.yaxis.get_major_ticks()]
    plt.savefig(os.path.join(plot_dir, "scatter2d-"+name+".pdf"), dpi=200)
    plt.close()

def plot_3d(Z, y, name, model_dir):
    colors = np.array(['green', 'blue', 'red'])
    savedir = os.path.join(model_dir, 'figures', '3d')
    os.makedirs(savedir, exist_ok=True)
#    plt.rc('text', usetex=True)
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman']
    colors = np.array(['forestgreen', 'royalblue', 'brown'])
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(Z[:, 0], Z[:, 1], Z[:, 2], c=colors[y], cmap=plt.cm.Spectral, s=200.0)
    Z, _ = F.get_n_each(Z, y, 1)
    for c in np.unique(y):
        ax.quiver(0.0, 0.0, 0.0, Z[c, 0], Z[c, 1], Z[c, 2], length=1.0, normalize=True, arrow_length_ratio=0.05, color='black')
    u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
    x = np.cos(u)*np.sin(v)
    y = np.sin(u)*np.sin(v)
    z = np.cos(v)
    ax.plot_wireframe(x, y, z, color="gray", alpha=0.5)
    ax.xaxis._axinfo["grid"]['color'] =  (0,0,0,0.1)
    ax.yaxis._axinfo["grid"]['color'] =  (0,0,0,0.1)
    ax.zaxis._axinfo["grid"]['color'] =  (0,0,0,0.1)
    [tick.label.set_fontsize(24) for tick in ax.xaxis.get_major_ticks()]
    [tick.label.set_fontsize(24) for tick in ax.yaxis.get_major_ticks()]
    [tick.label.set_fontsize(24) for tick in ax.zaxis.get_major_ticks()]
    ax.view_init(20, 15)
    plt.tight_layout()
    fig.savefig(os.path.join(savedir, f"scatter3d-{name}.jpg"), dpi=200)
    plt.close()

def plot_sample_angle_combined(train_features, train_labels, test_features, test_labels, model_dir, title1, title2, tail=""):
    save_dir = os.path.join(model_dir, "figures", "sample_angle_combined")
    os.makedirs(save_dir, exist_ok=True)

    colors = ['blue', 'red', 'green']
    _bins = np.linspace(-0.05, 1.05, 21)

    classes = np.unique(y_train)
    fs_train, _ = sort_dataset(train_features, train_labels,
                        classes=classes, stack=False)
    fs_test, _ = sort_dataset(test_features, test_labels,
                            classes=classes, stack=False)
    angles = []
    for class_train in classes:
        for class_test in classes:
            if class_train == class_test:
                continue
            angles.append((fs_train[class_train] @ fs_test[class_test].T).reshape(-1))

#    plt.rc('text', usetex=True)
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman']
    fig, ax = plt.subplots(figsize=(7, 5))
    ax.hist(np.hstack(angles), bins=_bins, alpha=0.5,   color='red', #colors[class_test],
                edgecolor='black')#, label=f'Class {class_test}')
    ax.set_xlabel('Similarity', fontsize=38)
    ax.set_ylabel('Count', fontsize=38)
    ax.ticklabel_format(style='sci', scilimits=(0, 3))
    [tick.label.set_fontsize(22) for tick in ax.xaxis.get_major_ticks()]
    [tick.label.set_fontsize(22) for tick in ax.yaxis.get_major_ticks()]
    # ax.legend(loc='upper center', prop={"size": 13}, ncol=1, framealpha=0.5)
    fig.tight_layout()
    fig.savefig(os.path.join(save_dir, f'sample_angle_combined-{title1}-vs-{title2}{tail}.pdf'))
    plt.close()

plot_combined_loss(model_dir)