In [1]:
# Load Tox21 Data
import numpy as np

tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD','NR-PPAR-gamma', 
         'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']

ids = np.load('./tox21_embed_data/'+'id'+'.npz', allow_pickle=True)
label = np.load('./tox21_embed_data/'+'label'+'.npz', allow_pickle=True)

single_list = []
multi_list = []
for i in range(12):
    single_list.append(np.load('./tox21_embed_data/'+'single'+str(i)+'.npz', allow_pickle=True))
    multi_list.append(np.load('./tox21_embed_data/'+'multi'+str(i)+'.npz', allow_pickle=True))
NR_list = []
SR_list = []
for i in range(7):
    NR_list.append(np.load('./tox21_embed_data/'+'NR'+str(i)+'.npz', allow_pickle=True))
for i in range(5):
    SR_list.append(np.load('./tox21_embed_data/'+'SR'+str(i)+'.npz', allow_pickle=True))

In [2]:
# PCA
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

def get_pca3d(data):
    x = StandardScaler().fit_transform(data)
    pca = PCA(n_components=3)
    out = pca.fit_transform(x)
    print(pca.explained_variance_ratio_)
    return out[:,0], out[:,1], out[:,2]

In [3]:
# Data Frame
import pandas as pd

def make_PCA_dataframe(cid, feature, label, domain_name):
    pc1, pc2, pc3 = get_pca3d(feature)
    
    domain_names = [domain_name] * len(cid)
    
    df = pd.DataFrame({'id':cid, 'pc1':pc1, 'pc2':pc2, 'pc3':pc3, 'label':label, 'domain':domain_names})
    return df

In [4]:
# Visualize
# %matplotlib inline
%matplotlib notebook
import matplotlib.pyplot as plt
from matplotlib import animation

def draw_anim(ids, feature, label, idx):
    train = make_PCA_dataframe(ids['train'], feature['train'], label['train'][:,idx], 'train')
    valid = make_PCA_dataframe(ids['valid'], feature['valid'], label['valid'][:,idx], 'valid')
    test = make_PCA_dataframe(ids['test'], feature['test'], label['test'][:,idx], 'test')
    df = pd.concat([train, valid, test])
    
    titles = ['train', 'valid', 'test']
    fig, axs = plt.subplots(figsize=(5,15),
                            dpi=300,
                            nrows=3,
                            subplot_kw={"projection":"3d"})

    for idx, ax in enumerate(axs):
        ax.set_title(titles[idx], fontsize=15)
        ax.xaxis.set_tick_params(labelsize=10)
        ax.yaxis.set_tick_params(labelsize=10)
        ax.zaxis.set_tick_params(labelsize=10)
        ax.set_xlim(df['pc1'].min(), df['pc1'].max())
        ax.set_ylim(df['pc2'].min(), df['pc2'].max())
        ax.set_zlim(df['pc3'].min(), df['pc3'].max())
    
    def drawScatter(ax, data, d_name): 
        colors = ['r', 'b']
        for i in [0, 1]: 
            x = data[(data['domain']==d_name) & (data['label']==i)]['pc1']
            y = data[(data['domain']==d_name) & (data['label']==i)]['pc2']
            z = data[(data['domain']==d_name) & (data['label']==i)]['pc3']
            ax.scatter(x, y, z, c=colors[i], s=1, marker=".")
    
    def init():
        for idx, ax in enumerate(axs):
            drawScatter(ax, df, titles[idx])
        return fig,
    
    def animate(i):
        axs[0].view_init(elev=30., azim=i)
        axs[1].view_init(elev=30., azim=i)
        axs[2].view_init(elev=30., azim=i)
        return fig,
    
    return animation.FuncAnimation(fig, animate, init_func=init, frames=360, interval=20, blit=True)

In [5]:
# Task idx
task_idx = 0

In [8]:
# Single
# ani_single = draw_anim(ids, single_list[task_idx], label, task_idx)
# ani_single.save('./result/single_task_'+str(task_idx)+'.gif', fps=30)

In [10]:
# Multi
# ani_multi = draw_anim(ids, multi_list[task_idx], label, task_idx)
# ani_multi.save('./result/multi_task_'+str(task_idx)+'.gif', fps=30)

In [12]:
# NR / SR
# if task_idx < 7:
#     ani_NR = draw_anim(ids, NR_list[task_idx], label, task_idx)
#     ani_NR.save('./result/NR_task_'+str(task_idx)+'.gif', fps=30)
# else:
#     task_idx -= 7
#     ani_SR = draw_anim(ids, SR_list[task_idx], label, task_idx)
#     ani_SR.save('./result/SR_task_'+str(task_idx)+'.gif', fps=30)

In [14]:
# Visualize
# %matplotlib inline
%matplotlib notebook
import matplotlib.pyplot as plt
from matplotlib import animation

def draw(ids, feature, label, idx, angle):
    train = make_PCA_dataframe(ids['train'], feature['train'], label['train'][:,idx], 'train')
    valid = make_PCA_dataframe(ids['valid'], feature['valid'], label['valid'][:,idx], 'valid')
    test = make_PCA_dataframe(ids['test'], feature['test'], label['test'][:,idx], 'test')
    df = pd.concat([train, valid, test])
    
    titles = ['train', 'valid', 'test']
    fig, axs = plt.subplots(figsize=(5,15),
                            dpi=300,
                            nrows=3,
                            subplot_kw={"projection":"3d"})

    for idx, ax in enumerate(axs):
        ax.set_title(titles[idx], fontsize=15)
        ax.xaxis.set_tick_params(labelsize=10)
        ax.yaxis.set_tick_params(labelsize=10)
        ax.zaxis.set_tick_params(labelsize=10)
        ax.set_xlim(df['pc1'].min(), df['pc1'].max())
        ax.set_ylim(df['pc2'].min(), df['pc2'].max())
        ax.set_zlim(df['pc3'].min(), df['pc3'].max())
        ax.view_init(elev=30., azim=angle)
    
    def drawScatter(ax, data, d_name): 
        colors = ['r', 'b']
        for i in [0, 1]: 
            x = data[(data['domain']==d_name) & (data['label']==i)]['pc1']
            y = data[(data['domain']==d_name) & (data['label']==i)]['pc2']
            z = data[(data['domain']==d_name) & (data['label']==i)]['pc3']
            ax.scatter(x, y, z, c=colors[i], s=4, marker=".")
            
    for idx, ax in enumerate(axs):
        drawScatter(ax, df, titles[idx])

In [19]:
# draw(ids, single_list[task_idx], label, task_idx, 120)
# plt.savefig('./result/single_task_'+str(task_idx)+'.png', dpi=300)

In [21]:
# draw(ids, multi_list[task_idx], label, task_idx, 120)
# plt.savefig('./result/multi_task_'+str(task_idx)+'.png', dpi=300)

In [24]:
# NR / SR
# if task_idx < 7:
#     draw(ids, NR_list[task_idx], label, task_idx, 120)
#     plt.savefig('./result/NR_task_'+str(task_idx)+'.png', dpi=300)
# else:
#     task_idx -= 7
#     draw(ids, SR_list[task_idx], label, task_idx, 120)
#     plt.savefig('./result/SR_task_'+str(task_idx)+'.png', dpi=300)

In [25]:
exit()