In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import datetime
import copy
from IPython.display import clear_output
import youtokentome as yttm

import simcube
from simcube.data import tokenize_corpus, build_vocabulary, \
    save_texts_to_file, LanguageModelDataset, load_war_and_piece_chunks, \
    GreedyGenerator, BeamGenerator
from simcube.pipeline import train_eval_loop, init_random_seed
from simcube.base import get_params_number

init_random_seed()

plt.rcParams["figure.figsize"] = (15,10)

In [None]:
# Init signature:
# nn.Embedding(
#     num_embeddings,
#     embedding_dim,
#     padding_idx=None,
#     max_norm=None,
#     norm_type=2.0,
#     scale_grad_by_freq=False,
#     sparse=False,
#     _weight=None,
# )
# Docstring:     
# A simple lookup table that stores embeddings of a fixed dictionary and size.

# This module is often used to store word embeddings and retrieve them using indices.
# The input to the module is a list of indices, and the output is the corresponding
# word embeddings.

nn.Embedding()

In [2]:
import numpy as np

In [4]:
b = np.array([0.1, 0.5, 0.3, 0.1])

In [5]:
a = np.array([[1, 0, 1, 0], [0, 1, 3, 0], [2, 3, 0, 0]])

In [6]:
a.dot(b)

array([0.4, 1.4, 1.7])

In [8]:
a = np.array([[1, 0, 1, 0], [0, 1, 3, 0], [2, 3, 0, 0]])

In [9]:
q = np.array([0, 0, 1])

In [10]:
q.dot(a)

array([2, 3, 0, 0])

In [11]:
from scipy.special import softmax

In [12]:
softmax(q.dot(a))

array([0.25069239, 0.68145256, 0.03392753, 0.03392753])

In [13]:
a.dot(softmax(q.dot(a)))

array([0.28461991, 0.78323514, 2.54574246])

In [16]:
a

array([[1, 0, 1, 0],
       [0, 1, 3, 0],
       [2, 3, 0, 0]])

In [18]:
a.dot(a.T)

array([[ 2,  3,  2],
       [ 3, 10,  3],
       [ 2,  3, 13]])

In [25]:
softmax(a.dot(a.T), axis=1)

array([[2.11941558e-01, 5.76116885e-01, 2.11941558e-01],
       [9.10221936e-04, 9.98179556e-01, 9.10221936e-04],
       [1.67006637e-05, 4.53971105e-05, 9.99937902e-01]])

In [23]:
a.dot(softmax(a.dot(a.T), axis=0))

array([[2.11974959e-01, 5.76207679e-01, 2.21181736e+00],
       [9.60323927e-04, 9.98315747e-01, 3.00072393e+00],
       [2.14672223e-01, 3.57065555e+00, 2.14672223e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00]])

In [26]:
b = np.array([[1, 0, 1, 0], [0, 1, 1, 0]])

In [30]:
logits = b.T.dot(b)
logits

array([[1, 0, 1, 0],
       [0, 1, 1, 0],
       [1, 1, 2, 0],
       [0, 0, 0, 0]])

In [34]:
b.dot(softmax(logits, axis=0)).T

array([[0.73105858, 0.5       ],
       [0.5       , 0.73105858],
       [0.73105858, 0.73105858],
       [0.5       , 0.5       ]])

In [37]:
projk = np.array([[1, 0], [0, 0]])
projq = np.array([[0, 1], [0, 0]])
projv = np.array([[1, 0], [0, 1]])

In [39]:
bais = np.array([[0, ], [0, ]])

In [83]:
inp = np.array([[1, 0, 1, 0], [0, 1, 1, 0]])
inp.T

array([[1, 0],
       [0, 1],
       [1, 1],
       [0, 0]])

In [57]:
K = inp.T.dot(projk.T)
Q = inp.T.dot(projq.T)
V = inp.T.dot(projv.T)

In [58]:
K

array([[1, 0],
       [0, 0],
       [1, 0],
       [0, 0]])

In [59]:
Q

array([[0, 0],
       [1, 0],
       [1, 0],
       [0, 0]])

In [60]:
V

array([[1, 0],
       [0, 1],
       [1, 1],
       [0, 0]])

In [62]:
logits = Q.dot(K.T)
logits

array([[0, 0, 0, 0],
       [1, 0, 1, 0],
       [1, 0, 1, 0],
       [0, 0, 0, 0]])

In [80]:
attscore = softmax(logits, axis=1)
attscore

array([[0.25      , 0.25      , 0.25      , 0.25      ],
       [0.36552929, 0.13447071, 0.36552929, 0.13447071],
       [0.36552929, 0.13447071, 0.36552929, 0.13447071],
       [0.25      , 0.25      , 0.25      , 0.25      ]])

In [81]:
attscore.dot(V)

array([[0.5       , 0.5       ],
       [0.73105858, 0.5       ],
       [0.73105858, 0.5       ],
       [0.5       , 0.5       ]])

In [88]:
inp = np.array([[1, 0, 1, 0], [0, 1, 1, 0]])
inp[0], inp[1]

(array([1, 0, 1, 0]), array([0, 1, 1, 0]))

In [93]:
emb1 = np.expand_dims(inp[0], axis=0)
emb2 = np.expand_dims(inp[1], axis=0)
emb1, emb2

(array([[1, 0, 1, 0]]), array([[0, 1, 1, 0]]))

In [86]:
projk1 = np.array([[1, 0], [0, 0]])
projk2 = np.array([[0, 1], [0, 0]])

projq1 = np.array([[0, 1], [1, 0]])
projq2 = np.array([[1, 1], [1, 1]])

projv1 = np.array([[1,], [0,]])
projv2 = np.array([[0,], [1,]])

In [98]:
K1 = inp.T.dot(projk1.T)
Q1 = inp.T.dot(projq1.T)

In [99]:
K2 = inp.T.dot(projk2.T)
Q2 = inp.T.dot(projq2.T)

In [101]:
V1 = emb1.T.dot(projv1.T)
V2 = emb2.T.dot(projv2.T)

In [102]:
logits1 = Q1.dot(K1.T)
logits2 = Q2.dot(K2.T)

In [103]:
logits1, logits2

(array([[0, 0, 0, 0],
        [1, 0, 1, 0],
        [1, 0, 1, 0],
        [0, 0, 0, 0]]),
 array([[0, 1, 1, 0],
        [0, 1, 1, 0],
        [0, 2, 2, 0],
        [0, 0, 0, 0]]))

In [104]:
attscore1 = softmax(logits1, axis=1)
attscore2 = softmax(logits2, axis=1)

In [105]:
attscore1.dot(V1)

array([[0.5       , 0.        ],
       [0.73105858, 0.        ],
       [0.73105858, 0.        ],
       [0.5       , 0.        ]])

In [106]:
attscore2.dot(V2)

array([[0.        , 0.73105858],
       [0.        , 0.73105858],
       [0.        , 0.88079708],
       [0.        , 0.5       ]])

In [107]:
attscore1.dot(V1) + attscore2.dot(V2)

array([[0.5       , 0.73105858],
       [0.73105858, 0.73105858],
       [0.73105858, 0.88079708],
       [0.5       , 0.5       ]])