In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from AMPpred_MFA.lib.Encoding import easy_encoding
from AMPpred_MFA.lib.Vocab import load_vocab
from AMPpred_MFA.lib.Data import load_dataset
from AMPpred_MFA.models.Model import load_model
from AMPpred_MFA.models.AMPpred_MFA import Model, Config
import matplotlib as mpl

mpl.rcParams.update({'font.size': 16})  
mpl.rcParams['font.family'] = 'Times New Roman'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
save_path_feature_visual = '../figures/feature visualization/1_trial' # 保存路径
os.makedirs(save_path_feature_visual, exist_ok=True)

In [2]:
vocab_path = './trained_model/vocab.json'
model_path = './trained_model/model.pth'
vocab = load_vocab(vocab_path)
config = Config()
config.k_mer = 1
config.batch_size = 32
config.embed_padding_idx = vocab[config.padding_token]
config.feature_dim = 400
config.vocab_size = len(vocab)
model = Model(config)
load_model(model, model_path)
model.eval()
dataset = load_dataset('../dataset/train/1_trial/train.fasta')[:3000]
fastas = dataset[:, :-1]
labels = dataset[:, -1].astype(np.int64)
data_x = easy_encoding(fastas, 'mixed', vocab,
                        config.k_mer, config.padding_size)
print('Predicting dataset...')
out = model(data_x)
print('Dataset has been predicted.')

[0;36mLoading vocabulary...[0m
[1;32mVocabulary has been loaded.[0m
[0;36mLoading model...[0m
[1;32mModel has been loaded.[0m
Predicting dataset...
Dataset has been predicted.


In [3]:
s_size = 26
alpha = 0.8
figsize = (6, 6)
tsne = TSNE(n_components=2, random_state=0)

fig_before_attention2, ax_before_attention2 = plt.subplots(1, 1, figsize=figsize)
before_attention2 = model.attention_wight2_inputs.cpu().detach().numpy().reshape(fastas.shape[0], -1)
tsne_before_attention2 = tsne.fit_transform(before_attention2)
pos_before_attention2 = tsne_before_attention2[np.where(labels == 1)]
neg_before_attention2 = tsne_before_attention2[np.where(labels == 0)]
ax_before_attention2.scatter(pos_before_attention2[:, 0], 
                             pos_before_attention2[:, 1], 
                             c='b', 
                             label='AMPs', 
                             alpha=alpha, 
                             s=s_size,
                             edgecolor="none")
ax_before_attention2.scatter(neg_before_attention2[:, 0], 
                             neg_before_attention2[:, 1], 
                             c='r', 
                             label='Non-AMPs', 
                             alpha=alpha, 
                             s=s_size, 
                             edgecolor="none")
ax_before_attention2.legend()
fig_before_attention2.tight_layout()
save_path_before_attention2 = os.path.join(save_path_feature_visual, 'Before attention of fragment-level.png')
fig_before_attention2.savefig(save_path_before_attention2, dpi=300, bbox_inches='tight')
print('已保存至：', save_path_before_attention2)
plt.close()


fig_after_attention2, ax_after_attention2 = plt.subplots(1, 1, figsize=figsize)
after_attention2 = model.attention_wight2_outputs.cpu().detach().numpy().reshape(fastas.shape[0], -1)
tsne_after_attention2 = tsne.fit_transform(after_attention2)
pos_after_attention2 = tsne_after_attention2[np.where(labels == 1)]
neg_after_attention2 = tsne_after_attention2[np.where(labels == 0)]
ax_after_attention2.scatter(pos_after_attention2[:, 0], 
                            pos_after_attention2[:, 1], 
                            c='b', 
                            label='AMPs', 
                            alpha=alpha, 
                            s=s_size,
                            edgecolor="none")
ax_after_attention2.scatter(neg_after_attention2[:, 0], 
                            neg_after_attention2[:, 1], 
                            c='r', 
                            label='Non-AMPs', 
                            alpha=alpha, 
                            s=s_size,
                            edgecolor="none")
ax_after_attention2.legend()
fig_after_attention2.tight_layout()
save_path_after_attention2 = os.path.join(save_path_feature_visual, 'After attention of fragment-level.png')
fig_after_attention2.savefig(save_path_after_attention2, dpi=300, bbox_inches='tight')
print('已保存至：', save_path_after_attention2)
plt.close()


fig_before_attention1, ax_before_attention1 = plt.subplots(1, 1, figsize=figsize)
before_attention1 = model.attention_wight1_inputs.cpu().detach().numpy().reshape(fastas.shape[0], -1)
tsne_before_attention1 = tsne.fit_transform(before_attention1)
pos_before_attention1 = tsne_before_attention1[np.where(labels == 1)]
neg_before_attention1 = tsne_before_attention1[np.where(labels == 0)]
ax_before_attention1.scatter(pos_before_attention1[:, 0], 
                             pos_before_attention1[:, 1], 
                             c='b', 
                             label='AMPs', 
                             alpha=alpha, 
                             s=s_size,
                             edgecolor="none")
ax_before_attention1.scatter(neg_before_attention1[:, 0], 
                             neg_before_attention1[:, 1], 
                             c='r', 
                             label='Non-AMPs', 
                             alpha=alpha, 
                             s=s_size,
                             edgecolor="none")
ax_before_attention1.legend()
fig_before_attention1.tight_layout()
save_path_before_attention1 = os.path.join(save_path_feature_visual, 'Before attention of dipeptide-level.png')
fig_before_attention1.savefig(save_path_before_attention1, dpi=300, bbox_inches='tight')
print('已保存至：', save_path_before_attention1)
plt.close()


fig_after_attention1, ax_after_attention1 = plt.subplots(1, 1, figsize=figsize)
after_attention1 = model.attention_wight1_outputs.cpu().detach().numpy().reshape(fastas.shape[0], -1)
tsne_after_attention1 = tsne.fit_transform(after_attention1)
pos_after_attention1 = tsne_after_attention1[np.where(labels == 1)]
neg_after_attention1 = tsne_after_attention1[np.where(labels == 0)]
ax_after_attention1.scatter(pos_after_attention1[:, 0], 
                            pos_after_attention1[:, 1], 
                            c='b', 
                            label='AMPs', 
                            alpha=alpha, 
                            s=s_size,
                            edgecolor="none")
ax_after_attention1.scatter(neg_after_attention1[:, 0], 
                            neg_after_attention1[:, 1],
                            c='r', 
                            label='Non-AMPs', 
                            alpha=alpha, 
                            s=s_size,
                            edgecolor="none")
ax_after_attention1.legend()
fig_after_attention1.tight_layout()
save_path_after_attention1 = os.path.join(save_path_feature_visual, 'After attention of dipeptide-level.png')
fig_after_attention1.savefig(save_path_after_attention1, dpi=300, bbox_inches='tight')
print('已保存至：', save_path_after_attention1)
plt.close()


fig_after_concat, ax_after_concat = plt.subplots(1, 1, figsize=figsize)
after_concat = model.last_feature.cpu().detach().numpy()
tsne_after_concat = tsne.fit_transform(after_concat)
pos_after_concat = tsne_after_concat[np.where(labels == 1)]
neg_after_concat = tsne_after_concat[np.where(labels == 0)]
ax_after_concat.scatter(pos_after_concat[:, 0], 
                        pos_after_concat[:, 1], 
                        c='b', 
                        label='AMPs', 
                        alpha=alpha, 
                        s=s_size,
                        edgecolor="none")
ax_after_concat.scatter(neg_after_concat[:, 0], 
                        neg_after_concat[:, 1], 
                        c='r', 
                        label='Non-AMPs', 
                        alpha=alpha, 
                        s=s_size,
                        edgecolor="none")
ax_after_concat.legend()
fig_after_concat.tight_layout()
save_path_after_concat = os.path.join(save_path_feature_visual, 'After concat.png')
fig_after_concat.savefig(save_path_after_concat, dpi=300, bbox_inches='tight')
print('已保存至：', save_path_after_concat)
plt.close()


fig_after_dense, ax_after_dense = plt.subplots(1, 1, figsize=figsize)
after_dense = out.cpu().detach().numpy()
tsne_after_dense = tsne.fit_transform(after_dense)
pos_after_dense = tsne_after_dense[np.where(labels == 1)]
neg_after_dense = tsne_after_dense[np.where(labels == 0)]
ax_after_dense.scatter(pos_after_dense[:, 0], 
                       pos_after_dense[:, 1], 
                       c='b', 
                       label='AMPs', 
                       alpha=alpha, 
                       s=s_size,
                       edgecolor="none")
ax_after_dense.scatter(neg_after_dense[:, 0], 
                       neg_after_dense[:, 1], 
                       c='r', 
                       label='Non-AMPs', 
                       alpha=alpha, 
                       s=s_size, 
                       edgecolor="none")
ax_after_dense.legend()
fig_after_dense.tight_layout()
save_path_after_dense = os.path.join(save_path_feature_visual, 'After dense.png')
fig_after_dense.savefig(save_path_after_dense, dpi=300, bbox_inches='tight')
print('已保存至：', save_path_after_dense)
plt.close()


已保存至： ../figures/feature visualization/1_trial\Before attention of fragment-level.png
已保存至： ../figures/feature visualization/1_trial\After attention of fragment-level.png
已保存至： ../figures/feature visualization/1_trial\Before attention of dipeptide-level.png
已保存至： ../figures/feature visualization/1_trial\After attention of dipeptide-level.png
已保存至： ../figures/feature visualization/1_trial\After concat.png
已保存至： ../figures/feature visualization/1_trial\After dense.png
