In [10]:
import jax
import jax.numpy as jnp
from jax import random

In [47]:

# Activation functions
def relu(input):
    return jnp.maximum(0, input)


def softmax(x, axis=-1):
    x_max = jnp.max(x, axis=axis, keepdims=True)
    x_shifted = x - x_max
    exp_x = jnp.exp(x_shifted)
    return exp_x / jnp.sum(exp_x, axis=axis, keepdims=True)

def sigmoid(x): 1 / (1 + jnp.exp(-x))

In [48]:
# Dense Layer

def initialize_dense_layer(key, input_dim, output_dim):
    w_key, b_key = random.split(key)
    # Xavier uniform limit for W and b
    limit = jnp.sqrt(6.0/(input_dim + output_dim))

    # Xavier uniform initialization for weights and biases
    w = random.uniform(w_key, (input_dim, output_dim), minval=-limit, maxval=limit)
    b = random.uniform(b_key, (output_dim,), minval=-limit, maxval=limit)
    return w, b

def dense_layer(params, x):
    w, b = params
    return jnp.dot(x, w) + b

In [49]:
# Test the dense layer
key = random.PRNGKey(0)
input_dim = 10
output_dim = 5

params = initialize_dense_layer(key, input_dim, output_dim)
x = jnp.ones((input_dim,))
y = dense_layer(params, x)
print(y)

[ 0.2890662   0.866141    1.8838954   0.77692413 -1.1997658 ]


In [50]:
# Layer normalization

def initialize_layer_norm(hidden_dim):
    gamma = jnp.ones(hidden_dim)
    beta = jnp.zeros(hidden_dim)
    return gamma, beta

def layer_norm(x, layernorm_params):
    # a simple layer norm
    gamma, beta = layernorm_params
    mean = jnp.mean(x, axis=-1, keepdims=True)
    var = jnp.var(x, axis=-1, keepdims=True)
    return gamma * (x - mean) / jnp.sqrt(var + 1e-6) + beta



In [51]:
# Test the layer norm
hidden_dim = 10
layernorm_params = initialize_layer_norm(hidden_dim)
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y = layer_norm(x, layernorm_params)
print(y)

[-1.5666988  -1.2185435  -0.8703882  -0.52223295 -0.17407764  0.17407764
  0.52223295  0.8703882   1.2185435   1.5666988 ]


In [52]:
# MLP
def initialize_mlp(hidden_dim, mlp_dim, key):
    w1_key, w2_key = random.split(key)

    # Xavier uniform limit for w1 and w2
    limit = jnp.sqrt(6.0 / (hidden_dim + mlp_dim))

    # Xavier uniform initialization for weights
    w1 = random.uniform(w1_key, (hidden_dim, mlp_dim), minval=-limit, maxval=limit)
    b1 = jnp.zeros(mlp_dim)

    w2 = random.uniform(w2_key, (mlp_dim, hidden_dim), minval=-limit, maxval=limit)
    b2 = jnp.zeros(hidden_dim)

    return w1, b1, w2, b2

def mlp(x, mlp_params):

    # unpack the parameters
    w1, b1, w2, b2 = mlp_params

    # out = (Relu(x*w1 + b1))*w2 + b2
    up_proj = relu(jnp.matmul(x, w1) + b1)
    down_proj = jnp.matmul(up_proj, w2) + b2

    return down_proj

In [53]:
# Test the MLP
hidden_dim = 10
mlp_dim = 5
key = random.PRNGKey(0)

params = initialize_mlp(hidden_dim, mlp_dim, key)
x = jnp.ones((hidden_dim,))
y = mlp(x, params)
print(y)

[-0.45543894 -0.27045643 -0.08839004  0.02056309 -1.2242546   0.6535751
 -1.611247   -0.48521277  1.2017834  -0.7943954 ]


In [54]:
# Self-attention
head_dim = 64
num_heads = 4

def initialize_attention(hidden_dim, num_heads,head_dim, key):
    q_key, k_key, v_key = random.split(key, 3)

    # Limit for Xavier uniform
    fan_in = hidden_dim
    fan_out = head_dim * num_heads
    limit = jnp.sqrt(6.0 / (fan_in + fan_out))

    # Random weights from uniform distribution
    q_w = random.uniform(q_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    q_b = jnp.zeros(fan_out)
    k_w = random.uniform(k_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    k_b = jnp.zeros(fan_out)
    v_w = random.uniform(v_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    v_b = jnp.zeros(fan_out)

    return q_w, k_w, v_w, q_b, k_b, v_b


def self_attention(x, attn_params):

    # unpack the parameters
    q_w, k_w, v_w, q_b, k_b, v_b = attn_params

    # n and d_k are the sequence length of the input and the hidden dimension
    n, d_k = x.shape

    # project the input into the query, key and value spaces
    q = jnp.matmul(x, q_w) + q_b
    k = jnp.matmul(x, k_w) + k_b
    v = jnp.matmul(x, v_w) + v_b


    # reshape to have heads
    # n, (num_heads head_dim) ->  (n, num_heads, headim) -> (num_heads, n, head_dim)
    q = q.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    k = k.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    v = v.reshape(n, num_heads, head_dim).swapaxes(0, 1)

    # perform multi-head attention
    attention_weights_heads = jnp.matmul(q, jnp.swapaxes(k, -1, -2)) / jnp.sqrt(head_dim)
    attention_weights_heads = jax.nn.softmax(attention_weights_heads, axis=-1)

    # output projection (num_heads, n, head_dim)
    output = jnp.matmul(attention_weights_heads, v)

    # reshape back (n, num_heads * heam_dim)
    output = output.swapaxes(0,1).reshape(n, d_k)

    return output


In [55]:
# Embedding
def initialize_embedding(key, vocab_size, hidden_dim):
    limit = jnp.sqrt(6.0 / (vocab_size + hidden_dim))
    w_key = random.split(key)[0]
    w = random.uniform(w_key, (vocab_size, hidden_dim), minval=-limit, maxval=limit)
    return w

def embedding(x, embedding_params):
    return embedding_params[x]

In [56]:
# Test the embedding
key = random.PRNGKey(0)
vocab_size = 10
hidden_dim = 5

params = initialize_embedding(key, vocab_size, hidden_dim)
x = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
y = embedding(x, params)
print(y)

[[ 0.43925422  0.09296543  0.01709316 -0.38812086  0.17173995]
 [ 0.58346367  0.42921904  0.4783264   0.46378067 -0.04329202]
 [-0.48244128  0.14456141  0.53700733  0.07577622 -0.523541  ]
 [-0.47330436 -0.08072075  0.02309517  0.5157007   0.04450391]
 [ 0.07960355  0.52857083 -0.52578     0.21598057 -0.6159274 ]
 [ 0.28318474  0.37144148  0.1625991   0.24843854  0.40760177]
 [-0.16007414  0.58533496  0.45814523 -0.27706602  0.41568798]
 [ 0.4229473  -0.5099719   0.15901892  0.36687768 -0.48541275]
 [ 0.43236148 -0.31739697  0.44400045 -0.58978945 -0.21325506]
 [-0.2950449   0.2288812  -0.47637233 -0.33364934  0.06641433]]


In [63]:
# Dropout
def dropout(key, x, rate, in_train_mode = True):
    if in_train_mode:
        mask = random.bernoulli(key, rate, x.shape)
        return x * mask / (1.0 - rate)
    return x

In [64]:
# Test the dropout
key = random.PRNGKey(0)
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
rate = 0.5
y = dropout(key, x, rate)
print(y)

[ 2.  0.  6.  8. 10. 12. 14. 16.  0.  0.]


In [61]:
# Batch Normalization Layer
def initialize_batch_norm(hidden_dim):
    gamma = jnp.ones(hidden_dim)
    beta = jnp.zeros(hidden_dim)

    running_mean = jnp.zeros(hidden_dim)
    running_var = jnp.ones(hidden_dim)
    return gamma, beta, running_mean, running_var

def batch_norm(params, inputs, train_mode=True, epsilon=1e-6, momentum=0.9):
    gamma, beta, running_mean, running_var = params
    if train_mode:
        mean = jnp.mean(inputs, axis=0)
        var = jnp.var(inputs, axis=0)
        running_mean = momentum * running_mean + (1.0 - momentum) * mean
        running_var = momentum * running_var + (1.0 - momentum) * var

        # Normalize the inputs
        x_hat = (inputs - mean) / jnp.sqrt(var + epsilon)
        return gamma * x_hat + beta
    else:
        x_hat = (inputs - running_mean) / jnp.sqrt(running_var + epsilon)
        return gamma * x_hat + beta



In [62]:
# Test the batch norm
hidden_dim = 10
batch_norm_params = initialize_batch_norm(hidden_dim)
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y = batch_norm(batch_norm_params, x)
print(y)

[-1.5666988  -1.2185435  -0.8703882  -0.52223295 -0.17407764  0.17407764
  0.52223295  0.8703882   1.2185435   1.5666988 ]
