In [49]:
from nltk import word_tokenize
import gensim.downloader as api
import torch

import plotly.graph_objects as go
import plotly.express as px

In [2]:
model = api.load("glove-twitter-25")

In [3]:
sentence = "Your journey starts with one step"
tokens = word_tokenize(sentence.lower())
tokens

['your', 'journey', 'starts', 'with', 'one', 'step']

In [47]:
embeddings = torch.tensor(model[tokens])
embeddings.shape

torch.Size([6, 25])

In [None]:
attention_scores = torch.zeros(6,6)

for i in range(len(embeddings)):
    for j in range(len(embeddings)):
        attention_scores[i,j] = torch.dot(embeddings[i],embeddings[j])

attention_weights = torch.softmax(attention_scores, dim=1)
# context vectors will be attention weights times the query

In [90]:
context_vectors = torch.matmul(attention_weights,embeddings)
context_vectors.shape

torch.Size([6, 25])

In [96]:
from_ = 2

x_points,y_points,z_points = embeddings[:, from_:from_+3].T
xc_points,yc_points,zc_points = context_vectors[:, from_:from_+3].T
points_trace = go.Scatter3d(
    x=x_points,
    y=y_points,
    z=z_points,
    mode='markers+text',
    marker={
        'size': 8,
        'opacity': 0.8,
        'color': "blue",
    },
    text=tokens
)

context_trace = go.Scatter3d(
    x=xc_points,
    y=yc_points,
    z=zc_points,
    mode='markers+text',
    marker={
        'size': 8,
        'opacity': 0.8,
        'color': "red",
    },
    text=tokens
)

layout = go.Layout( template='plotly_dark',
                    title="Embedding Vector",
                    margin = {"pad":0, "l":0, "r":0, "t":30, "b":10},
                    scene = dict(camera=dict(center=dict(x=0, y=0, z=0)))
                   )
arrow_traces = []
for i in range(len(embeddings)):
    arrow_traces.append(go.Scatter3d(
        x=[0, x_points[i]],
        y=[0, y_points[i]],
        z=[0, z_points[i]],
        mode='lines',
        showlegend=False,
        marker=dict(color = "blue")
    ))

fig = go.Figure([points_trace,context_trace] + arrow_traces, layout=layout)

fig.show()


In [79]:

fig = px.imshow(attention_weights,
                labels = dict(color="Similarity"),
                x = tokens,
                y = tokens)
fig.show()