In [2]:
#uploading google drive on colab
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#unzipping the files
!unzip "/content/drive/My Drive/Data.zip"

Archive:  /content/drive/My Drive/Data.zip
   creating: Data/
  inflating: Data/w_id.npy           
  inflating: Data/c_len.npy          
  inflating: Data/c.npy              
  inflating: Data/x_len.npy          
  inflating: Data/x.npy              


In [None]:
#data preprocessing
data_dir='/content/Data'
import data_frame
from data_frame import DataFrame
import numpy as np
import os
data_cols = ['x', 'x_len', 'c', 'c_len']
data = [np.load(os.path.join(data_dir, '{}.npy'.format(i))) for i in data_cols]

In [None]:
import jax.numpy as jnp
from jax import grad
import numpy as np
from jax import random
import jax

In [None]:
#Loading data
x=np.load('/content/drive/MyDrive/writemate data/x.npy')

In [None]:
data=x.transpose(2,0,1)
print(data.shape)

(3, 11911, 1200)


In [None]:
def sigmoid(z):
    # sigmoid activation for LSTM gates
    return 1.0/(1.0 + jnp.exp(-z))

def softmax(y):
    #computes softmax probabilities over characters
    return jnp.exp(y) / jnp.sum(jnp.exp(y),axis=0)

def relu(z):
  return jnp.maximum(0,z)

In [None]:
#Initializing weight matrices
def initialize_lstm_weights(key,n_h,n_x,l,params): #nh = 400, n_x = 3
    subkeys=random.split(key,4) # need to call random.noraml with new key each time : 9 why???
    params['Wc'+ str(l)] = random.normal(subkeys[1],(n_h, n_h+n_x))*0.01 # input to cell state
    params['Wi'+ str(l)] = random.normal(subkeys[2],(n_h, n_h+n_x))*0.01 # input to update
    params['Wf'+ str(l)] = random.normal(subkeys[3],(n_h, n_h+n_x))*0.01 # input to forget
    params['Wo'+ str(l)] = random.normal(subkeys[4],(n_h, n_h+n_x))*0.01 # input to output

    params['bc'+ str(l)] = jnp.zeros((n_h, 1)) # hidden bias
    params['bi'+ str(l)] = jnp.zeros((n_h, 1)) # forget bias
    params['bf'+ str(l)] = jnp.zeros((n_h, 1)) # update bias
    params['bo'+ str(l)] = jnp.zeros((n_h, 1)) # output bias
    return params

In [None]:
#Initialising mdn weights
def initialize_mdn_weights(n_h, key,mdn_params):
  mdn_in_size  = 3 * n_h # what about x
  mix_components = 20
  mdn_out_size = 1 + ((1 + 1 + 2 + 2) * mix_components) #mix components = 20

  mdn_params['Wy1'] = random.normal(key,(600, mdn_in_size ))*0.01
  mdn_params['by1'] = jnp.zeros((600, 1))
  mdn_params['Wy2'] = random.normal(key,(mdn_out_size, 600 ))*0.01
  mdn_params['by2'] = jnp.zeros((mdn_out_size, 1))

  return mdn_params

In [None]:
#defining lstm cell
def lstm_cell(xt,a_prev,c_prev,params, l):

  n_x, m = xt.shape
  Wc = params['Wc'+ str(l)]
  Wi = params['Wi'+ str(l)]
  Wf = params['Wf'+ str(l)]
  Wo = params['Wo'+ str(l)]

  bc = params['bc'+ str(l)]
  bi = params['bi'+ str(l)]
  bf = params['bf'+ str(l)]
  bo = params['bo'+ str(l)]

  concat = jnp.concatenate((a_prev, xt), axis=0)

  ft = sigmoid(jnp.dot(Wf, concat)+bf)
  it =  sigmoid(jnp.dot(Wi, concat)+bi)
  cct = jnp.tanh(jnp.dot(Wc, concat)+bc)
  c_next = ft* c_prev+  it*cct
  ot = sigmoid(jnp.dot(Wo, concat)+bo)
  a_next = ot * jnp.tanh(c_next)

 # cache = (a_next, c_next, a_prev, c_prev, ft, it, cct, ot, xt, params)

  return a_next, c_next#, cache

In [None]:
#LSTM forward propogation
def lstm_forward(x, params, l):
 # caches = [] # xt, a_next, c_next, params, l
  n_x, m, T_x = x.shape

  # Wy=parameters['Wy']
  n_a = 400
  #n_y, n_a = Wy.shape

  a = jnp.zeros((n_a, m, T_x))
  c = jnp.zeros((n_a, m, T_x))
  a_next = params['a0'+str(l)]
  c_next = jnp.zeros((n_a, m))
  for t in range(T_x):
        # Get the 2D slice 'xt' from the 3D input 'x' at time step 't'
        xt = x[:,:,t]
        a_next, c_next=  lstm_cell(xt, a_next, c_next, params, l)
        # a[:,:,t] = a_next  #x = x.at[idx].set(y)
        idx = (..., t)
        a=a.at[idx].set(a_next)
        c= c.at[idx].set(c_next)
        # caches.append[cache]
        # caches = [caches, x]
  return a, c#, caches

In [None]:
def mdn_linear_layer(x,mdn_params):
  Wy1 = mdn_params['Wy1']
  by1 = mdn_params['by1']
  Wy2 = mdn_params['Wy2']
  by2 = mdn_params['by2']
  Z1 = jnp.dot(Wy1,x)+by1 #600,m
  A1 = relu(Z1)
  Z2 = jnp.dot(Wy2,A1)+by2 #121,m
  A2 = relu(Z2)
  return A2

In [None]:
#Separating the mixture density parameters
def mixtureDensity_params(Z): # Z.shape=121,m

  last_index = Z.shape[0] # doubt in spitting
#  mdn_params = Z[0, 0: last_index]
  mdn_params = Z[1:]  # Select rows from the 2nd row to the 121st row

  pi_hat, mu1_hat, mu2_hat, sigma1_hat, sigma2_hat, rho_hat = jnp.split(mdn_params,6,0)
 # 20,m
  eos_hat = Z[0]
  eos = sigmoid(eos_hat)
  rho = jnp.tanh(rho_hat)
  pi = sigmoid(pi_hat)

  sigma1 = jnp.exp(sigma1_hat)
  sigma2 = jnp.exp(sigma2_hat)

  mu1 = mu1_hat
  mu2 = mu2_hat

  return mu1, mu2, sigma1, sigma2, pi, eos, rho

In [None]:
def mixtureDensity(mu1, mu2, sigma1, sigma2, pi, eos, rho, x1, x2):
   x_mu1 = x1-mu1
   x_mu2 = x2-mu2 #1,m
   Z_out = jnp.square(jnp.divide(x_mu1, sigma1)) + jnp.square(jnp.divide(x_mu2, sigma2)) - 2*jnp.divide(rho*x_mu1*x_mu2,sigma1*sigma2)
   rho_square_term = 1-jnp.square(rho)

   power_e = jnp.exp(jnp.divide(-Z_out, 2*rho_square_term))

   regularize_term = 2* jnp.pi*sigma1*sigma2*jnp.sqrt(rho_square_term)

   gaussian = jnp.divide(power_e, regularize_term)

   return jnp.sum(gaussian*pi)

In [None]:
#calculating probability
def get_prob(x1, x2, eos_true, Z):
  mu1, mu2, sigma1, sigma2, pi, eos, rho = mixtureDensity_params(Z)
  eps = jnp.finfo(float).eps
  prob=0.0
  for i in range(20):
    prob+=mixtureDensity(mu1[i], mu2[i], sigma1[i], sigma2[i], pi[i], eos, rho[i], x1, x2)
  return prob, jnp.sum(jnp.log(jnp.squeeze((eos * (eos_true + eps) + (1 - eos) * (1 - eos_true + eps))))) #might be used for log

In [None]:
#Calculating loss
def get_loss(x1, x2, e_true, Z):
  prob, stroke_prob = get_prob(x1, x2, e_true, Z)
  # loss = jnp.sum((jnp.log(prob) * mask)) + jnp.sum(stroke_prob * mask) / jnp.sum(mask)
  loss=prob+stroke_prob
  return (-1)*loss

In [None]:
def model(data,all_params,m):
  n_a = 400
  T_x=1200
#  cache_model = {} #a0 to be defined
  a1 , c1 = lstm_forward(data,all_params,1)
  x2 = jnp.concatenate((a1, data), axis=0)
  a2, c2= lstm_forward(x2,all_params,2)
  x3 = jnp.concatenate((a2, data), axis=0)
  a3, c3 = lstm_forward(x3,all_params,3)
  h = jnp.concatenate((a1, a2, a3), axis=0)
  total_loss=0.0
  for t in range(T_x):
   Z = mdn_linear_layer(h[:,:,t],all_params)
   total_loss+=get_loss(data[0,:,t],data[1,:,t],data[2,:,t],Z)
  return total_loss

In [None]:
#Gradient clipping
def gradient_clipping(W):
  return jnp.clip(W, -10, 10)

In [None]:
#updating parameters
def update_parameters(parameters, grads, learning_rate):
  #parameters = params.copy() : i have not created copy of parameters
  L =  3
  for l in range(L):
        parameters["Wc" + str(l+1)] =parameters["Wc" + str(l+1)] - learning_rate * gradient_clipping(grads["Wc" + str(l+1)])

        parameters["Wi" + str(l+1)] =parameters["Wi" + str(l+1)] - learning_rate * gradient_clipping(grads["Wi" + str(l+1)])

        parameters["Wf" + str(l+1)] =parameters["Wf" + str(l+1)] - learning_rate * gradient_clipping(grads["Wf" + str(l+1)])

        parameters["Wo" + str(l+1)] =parameters["Wo" + str(l+1)] - learning_rate * gradient_clipping(grads["Wo" + str(l+1)])

        parameters["bc" + str(l+1)] =parameters["bc" + str(l+1)] - learning_rate * gradient_clipping(grads["bc" + str(l+1)])

        parameters["bi" + str(l+1)] =parameters["bi" + str(l+1)] - learning_rate * gradient_clipping(grads["bi" + str(l+1)])

        parameters["bf" + str(l+1)] =parameters["bf" + str(l+1)] - learning_rate * gradient_clipping(grads["bf" + str(l+1)])

        parameters["bo" + str(l+1)] =parameters["bo" + str(l+1)] - learning_rate * gradient_clipping(grads["bo" + str(l+1)])
  parameters["Wy1"] =parameters["Wy1"] - learning_rate * gradient_clipping(grads["Wy1" + str(l+1)])
  parameters["by1"] =parameters["by1"] - learning_rate * gradient_clipping(grads["by1" + str(l+1)])
  parameters["Wy2"] =parameters["Wy2"] - learning_rate * gradient_clipping(grads["Wy2" + str(l+1)])
  parameters["by2"] =parameters["by2"] - learning_rate * gradient_clipping(grads["by2" + str(l+1)])

  return parameters

In [None]:
data.shape

(3, 11911, 1200)

In [None]:
pred_data=data[:,1:]

In [None]:
pred_data.shape

(3, 11910, 1200)

In [None]:
print(data[:,1,10])
print(pred_data[:,0,10])

[ 0.18620381 -1.815487    0.        ]
[ 0.18620381 -1.815487    0.        ]


In [None]:
def final_model(data,num_epochs,learning_rate):
  key = random.PRNGKey(0)
  hidden_size=400
  feature_dim=3
  all_params = {}
  key,subkey = random.split(key)
  all_params = initialize_lstm_weights(subkey,hidden_size, feature_dim, 1,all_params) #feature_dim = 3

  key,subkey = random.split(key)
  input_dim = feature_dim + hidden_size
  all_params = initialize_lstm_weights(subkey,hidden_size, input_dim, 2,all_params)

  key,subkey = random.split(key)
  all_params = initialize_lstm_weights(subkey,hidden_size, input_dim, 3,all_params)

  key, subkey = random.split(key)
  all_params = initialize_mdn_weights(hidden_size, key,all_params)
  batch_size=16
  batches=11910// batch_size
  index=0
  all_params['a01']=jnp.zeros((n_a,m))
  all_params['a02']=jnp.zeros((n_a, m))
  all_params['a03']=jnp.zeros((n_a, m))
  for epoch in range(num_epochs):  #num_epochs
    for batch in range(batches):
      x_temp=data[:,index:index+batch_size,:]
      loss = model(x_temp,all_params,16)
      print(loss)
      grads = grad(model)(x_temp,all_params,16)
      print(grads.keys())
      print(grads)
      index+=batch_size
      all_params = update_parameters(all_params, grads, learning_rate)


In [None]:
#final_model(pred_data,1000,0.001)
#printed first loss

-10720.922
