In [None]:
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt


class DataPlot:

    def __init__(self):
        self.init_plotting()
        pass

    def init_plotting(self):
        plt.rcParams['figure.figsize'] = (6.5, 5.5)
        plt.rcParams['font.size'] = 15
        #plt.rcParams['font.family'] = 'Times New Roman'
        plt.rcParams['axes.labelsize'] = plt.rcParams['font.size']
        plt.rcParams['axes.titlesize'] = 20
        plt.rcParams['legend.fontsize'] = 13
        plt.rcParams['xtick.labelsize'] = plt.rcParams['font.size']
        plt.rcParams['ytick.labelsize'] = plt.rcParams['font.size']
        plt.rcParams['savefig.dpi'] = plt.rcParams['savefig.dpi']
        plt.rcParams['xtick.major.size'] = 3
        plt.rcParams['xtick.minor.size'] = 3
        plt.rcParams['xtick.major.width'] = 1
        plt.rcParams['xtick.minor.width'] = 1
        plt.rcParams['ytick.major.size'] = 3
        plt.rcParams['ytick.minor.size'] = 3
        plt.rcParams['ytick.major.width'] = 1
        plt.rcParams['ytick.minor.width'] = 1
        plt.rcParams['axes.linewidth'] = 2


    def tnse_plot(self, path, ds):

        dataset, nclass, n_train, D = ds
        # X = np.empty((0,200), float)
        # X_embedded = TSNE(n_components=2).fit_transform(X)
        # print(X_embedded.shape)
        # color_plate = ['black',
        #                'gold', 'chartreuse', 'deepskyblue',
        #                'purple', 'tomato', 'gainsboro']

        color_plate = ['black', 'red', 'rosybrown', 'tan', 'grey', 
                       'gold', 'olivedrab', 'chartreuse', 'darkgreen', 'deepskyblue',
                       'royalblue', 'navy', 'darkorchid', 'm', 'skyblue',
                       'slateblue', 'y', 'purple', 'tomato', 'gainsboro',
                       'royalblue', 'navy', 'darkorchid', 'm', 'skyblue']

        # load embedding data
        repr_file = path + dataset + '_embeddings.txt'

        X = np.empty((0, D), float)
        colors = []
        with open(repr_file, "r") as f:
            for line in f:
                #print(line)
                results = line.split(',')
                target = int(results[0])
                repr = [float(i) for i in results[1:]]

                colors.append(target)
                X = np.append(X, np.array([repr]), axis=0)

        X_embedded = TSNE(n_components=2).fit_transform(X)

        fig, ax = plt.subplots()
        for i, repr in enumerate(X_embedded):
            x = repr[0]
            y = repr[1]
            if i < nclass:
                scale = 150.0
                color = color_plate[colors[i]]
                edgecolors = 'black'
                zorder = 10
                marker = '*'
            elif i < nclass + n_train:
                scale = 20.0
                color = color_plate[colors[i]]
                edgecolors = 'none'
                zorder = 1
                marker = 'o'
            else:
                scale = 20.0
                color = color_plate[colors[i]]
                edgecolors = 'black'
                zorder = 1
                marker = 'x'
            ax.scatter(x, y, c=color, s=scale, marker=marker,
                       alpha=0.8, edgecolors=edgecolors, zorder=zorder)

        #ax.legend()
        ax.grid(True)
        print("drawing ...")
        plt.savefig(path + dataset + "_embed.pdf")
        #plt.show()



In [None]:
'''
if __name__ == '__main__':

    path = "./plot/"
    dataset = ('BasicMotions', 4, 40) # (name_dataset, nbr_class, train_size)
    data_plot = DataPlot()
    data_plot.tnse_plot(path, dataset)
    '''

In [None]:
proto_list = []
for i in range(n_classes):
    idx = np.where(y_train == i)[0]
    # compute the central point of each class
    class_repr = np.mean(h_train[idx], axis=0)  # 1 * L
    proto_list.append(class_repr) # n_classes * L
    
h_center = np.array(proto_list)

def dump_embedding(proto_embed, sample_embed, labels, dump_file='./plot/embeddings.txt'):
    embed = np.concatenate((proto_embed, sample_embed), axis=0)

    nclass = proto_embed.shape[0]
    labels = np.concatenate((np.asarray([i for i in range(nclass)]),
                             labels), axis=0)

    with open(dump_file, 'w') as f:
        for i in range(len(embed)):
            label = str(labels[i])
            line = str(int(float(label))) + "," + ",".join(["%.4f" % j for j in embed[i].tolist()])
            f.write(line + '\n')
            
    print (len(embed))
    
# t-SNE plot
def tSNE_plot(style="train", d_prime):
    if style == 'train':
        print('A')
        data = h_train[:x_sup.shape[0]]
        label = y_train[:x_sup.shape[0]]
    elif style == 'test': 
        print("B")
        data = h_test
        label = y_test
    else:
        print("C")
        data = np.concatenate([h_train, h_test], axis=0)
        label = np.concatenate([y_train, y_test], axis=0)
    dump_embedding(h_center, data, label, dump_file = './plot/'+ ds_name + '_embeddings.txt')
    path = "./plot/"
    dataset = (ds_name, n_classes, N, hidden_dim*d_prime) # (name_dataset, nbr_class, train_size)
    data_plot = DataPlot()
    data_plot.tnse_plot(path, dataset)
    
    