In [2]:
import gensim.downloader as api
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [3]:
wv = api.load('word2vec-google-news-300')

In [4]:
d_model = 300
n_head = 4
gen_results = 100

In [5]:
def tokenize(x):
    return x.split()

def embed(x):
    embeddings = wv[x]
    return embeddings

def posEncode(x):
    n_seq, _ = x.shape
    i = np.arange(0, d_model, 2, dtype='float32')
    denominator = np.power(10000, i/d_model)
    position = np.arange(0, n_seq, dtype='float32').reshape(-1,1)
    even_PE = np.sin(position / denominator)
    odd_PE  = np.cos(position / denominator)
    return x + np.ravel([even_PE.T, odd_PE.T],'F').reshape(n_seq, d_model)

def masking(x):
    mask = np.tril(np.ones((x.shape)))
    mask[mask==0] = -np.infty
    mask[mask==1] = 0
    x = x + mask
    return x

def init_QKV():
    #used once only
    Q, K, V = [],[],[]
    for h in range(n_head):
        Q.append(np.random.rand(d_model, d_model//n_head))
        K.append(np.random.rand(d_model, d_model//n_head))
        V.append(np.random.rand(d_model, d_model//n_head))
    return Q, K, V

def context(input, Q,K,V, mask=False):    
    context = []
    for h in range(n_head):
        raw_attention = (input @ Q[h]) @ (input @ K[h]).T
        if (mask == True):
            raw_attention = masking(raw_attention)
        score = tf.nn.softmax(raw_attention / (d_model)** .5)
        context.append(score @ (input @ V[h]))
    concat = np.concatenate([c for c in context], axis=1)    
    return concat

def add_norm(context, prev_input, gamma, beta):
    context = context + prev_input
    mean = context.mean(axis=1).reshape(-1,1)
    sigma = context.std(axis=1).reshape(-1,1)
    context = (context - mean) / sigma
    context = context * gamma + beta
    return context

def feed_forward(context, w, b):
    #Assume no hidden layer
    context = context @ w + b
    context = tf.nn.relu(context).numpy()
    return context


def cross_context(input_e, input_d, Q_d,K_e,V_e):    
    context = []
    for h in range(n_head):
        raw_attention = (input_d @ Q_d[h]) @ (input_e @ K_e[h]).T
        score = tf.nn.softmax(raw_attention / (d_model)** .5)
        context.append(score @ (input_e @ V_e[h]))

    concat = np.concatenate([c for c in context], axis=1)    
    return concat

def Encode(sentence):
    tokens = tokenize(sentence)
    embeddings = embed(tokens)
    input = posEncode(embeddings)
    
    Q,K,V = init_QKV()

    context1 = context(input, Q,K,V)

    gamma1 = np.random.rand(d_model) - 0.5
    beta1 = np.random.rand(d_model) - 0.5

    gamma2 = np.random.rand(d_model) - 0.5
    beta2 = np.random.rand(d_model) - 0.5

    w1 = np.random.rand(d_model, d_model) -0.5
    b1 = np.random.rand(d_model) -0.5

    context2 = add_norm(context1, input, gamma1, beta1)
    context3 = feed_forward(context2, w1, b1)
    context4 = add_norm(context3, context2, gamma2, beta2)

    return context4


Q_d, K_d, V_d = init_QKV()
Q_d, K_e, V_e = init_QKV()
gamma_d_1 = np.random.rand(d_model) - 0.5
beta_d_1 = np.random.rand(d_model)  - 0.5
gamma_d_2 = np.random.rand(d_model) - 0.5
beta_d_2 = np.random.rand(d_model) - 0.5
w2 = np.random.rand(d_model, d_model) - 0.5
b2 = np.random.rand(d_model) - 0.5
gamma_d_3 = np.random.rand(d_model) - 0.5
beta_d_3 = np.random.rand(d_model) - 0.5
w_linear = np.random.rand(d_model, gen_results) - 0.5
b_linear = np.random.rand(gen_results) - 0.5

def Decode(sentence2, context4, max_seq):

    input2 = posEncode(embed(tokenize(sentence2))) #residual for the next addNorm
    context_d_1 = context(input2, Q_d, K_d, V_d, mask=True)
    context_d_2 = add_norm(context_d_1, input2, gamma_d_1, beta_d_1) #residual for the next addNorm
    cross = cross_context(context4, context_d_2, Q_d, K_e, V_e)
    context_d_3 = add_norm(context_d_2, cross, gamma_d_2, beta_d_2) #residual for the next addNorm
    context_d_4 = feed_forward(context_d_3, w2, b2)
    context_d_5 = add_norm(context_d_4, context_d_3, gamma_d_3, beta_d_3)
    raw_prediction = context_d_5 @ w_linear + b_linear
    prediction = tf.nn.softmax(raw_prediction)
    out = wv.index_to_key[np.argmax(prediction)]
    sentence2 = sentence2 + f" {out}"

    if max_seq == 0:
        print(sentence2)
        return

    Decode(sentence2, context4, max_seq - 1)

def Generate(sentence, max_seq):
    context4 = Encode(sentence)
    Decode('</s>', context4, max_seq)

In [6]:
Generate('wassup', max_seq=5)

</s> by by by by by by
