**Imports**

In [5]:
# Install a widget that will allow you to build up interactive plots
# !conda install -c conda-forge ipympl -y

# If using JupyterLab
# !conda install -c conda-forge nodejs -y
# !jupyter labextension install @jupyter-widgets/jupyterlab-manager jupyter-matplotlib

# After run this lines, close the jupyter session and restore it

In [1]:
import pickle as pickle
from tensorflow.keras.models import load_model
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d, Axes3D
import pandas as pd
%matplotlib widget

**Load the model and training data**

In [2]:
model = load_model("model.h5")
labels = pickle.load(open("training_data.dat", "rb")).get("labels")
data = pickle.load(open("training_data.dat", "rb")).get("data")

**Performs PCA over the word embedding model layer in order to get 3D**

In [3]:
# Define the PCA
pca = PCA(n_components=3)

# Fit and transform the PCA model
embedding_weights_pca = pca.fit_transform(pd.DataFrame(model.get_layer("Embedding").get_weights()[0]))
embedding_weights = pd.DataFrame(data = embedding_weights_pca, columns = ['PC1', 'PC2', 'PC3'])
embedding_weights.head()

Unnamed: 0,PC1,PC2,PC3
0,0.024804,-0.135827,0.137367
1,0.352741,0.211429,0.036154
2,-0.004227,0.112113,-0.104535
3,-0.0197,0.050421,-0.016598
4,-0.063853,-0.225383,-0.074789


**Plot the word embedding model layer 3D representation**

In [4]:
# Create the figure
fig = plt.figure(num=None, figsize=(14, 12), dpi=80, 
                 facecolor='w', edgecolor='k')
ax = plt.axes(projection='3d')

for index, (x, y, z) in enumerate(zip(embedding_weights['PC1'], 
                                      embedding_weights['PC2'], 
                                      embedding_weights['PC3'])):
    # Get the label
    label = labels[index]
    
    # Plot the weight in 3D
    if label == "good":
        color = "g"
    elif label == "bad":
        color = "r"
    else:
        color = "b"
    ax.scatter(x, y, z, color=color, s=12)
    
    # Set the text for the point
    ax.text(x, y, z, data[index], size=12, 
            zorder=2.5, color='k')
    
ax.set_title("Word Embedding", fontsize=20)
ax.set_xlabel("PC1", fontsize=20)
ax.set_ylabel("PC2", fontsize=20)
ax.set_zlabel("PC3", fontsize=20)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …