In [None]:
import torch
from torch.autograd import Variable
from torch.nn import RNN, GRU, LSTM
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle as pkl
import gc
import math

In [None]:
def oneStep(*params, model): 
    #Params is a tuple including h, x, and c (if LSTM)
    l = len(params)
    if l < 2:
        print('Params must be a tuple containing at least (x_t, h_t)')
        return None
    elif l>2:
        states = (params[1], params[2])
        return model(params[0], states)
    else:
        return model(*params)

def oneStepVarQR(J, Q):
    Z = torch.matmul(torch.transpose(J, 1, 2), Q) #Linear extrapolation of the network in many directions
    q, r = torch.qr(Z, some = True) #QR decomposition of new directions
    s = torch.diag_embed(torch.sign(torch.diagonal(r, dim1 = 1, dim2 = 2)))#extract sign of each leading r value
    return torch.matmul(q, s), torch.diagonal(torch.matmul(s, r), dim1 = 1, dim2 = 2) #return positive r values and corresponding vectors

def calc_LEs_an(*params, model, k_LE=100000, rec_layer= 0, kappa = 10, diff= 10, warmup = 10, T_ons = 1):
    cuda = next(model.parameters()).is_cuda
    if cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    bias = model.rnn_layer.bias
    x_in = params[0].to(device)
    x_in.requires_grad_(False)
    hc = params[1]
    h0 = hc.to(device)
    h0.requires_grad_(False)

    num_layers, batch_size, hidden_size = h0.shape
    _, feed_seq, input_size = x_in.shape
    L = num_layers*hidden_size
        
    k_LE = max(min(L, k_LE), 1)
    Q = torch.reshape(torch.eye(L), (1, L, L)).repeat(batch_size, 1, 1).to(device)
    Q = Q[:, :, :k_LE] #Choose how many exponents to track

    ht = h0
    states = (ht, ) #make tuple for easier generalization
    rvals = torch.ones(batch_size, feed_seq, k_LE).to(device) #storage
    #qvect = torch.zeros(batch_size, feed_seq, L, k_LE) #storage
    t = 0


    t_QR = t
    for xt in tqdm(x_in.transpose(0,1)):
        if (t - t_QR) >= T_ons or t==0 or t == feed_seq:
            QR = True
        else: 
            QR = False
        xt = torch.unsqueeze(xt, 1) #select t-th element in the fed sequence
        states = (ht, )
        
        if rec_layer=='rnn':
            J = rnn_jac(model.rnn_layer.all_weights, ht, xt, bias = bias)
        else:
            print("error, rec_layer is not rnn")
            J = None
        
        _, states = oneStep(xt, *states, model=model)
        if QR:
            Q, r = oneStepVarQR(J, Q)
            t_QR = t
            
        else:
            Q = torch.matmul(torch.transpose(J, 1, 2), Q)
            r = torch.ones((batch_size, hidden_size))
        ht = states
        rvals[:, t, :] = r
        #qvect[:, t, :, :] = Q

        t = t+1
    LEs = torch.sum(torch.log2(rvals.detach()), dim = 1)/feed_seq
    #     print(torch.log2(rvals.detach()).shape)
    return LEs, rvals#, qvect
    
def rnn_jac(params_array, h, x, bias):
    if bias:
        W, U, b_i, b_h = param_split(params_array, bias)
    else:
        W, U = param_split(params_array, bias)
    device = get_device(h)
    num_layers, batch_size, hidden_size = h.shape
    input_shape = x.shape[-1]
    h_in = h.transpose(1,2).detach()
    x_in = [x.squeeze(dim=1).t()]#input_shape, batch_size)]
    if bias:
        b = [b1 + b2 for (b1,b2) in zip(b_i, b_h)]
    else:
        b = [torch.zeros(W_i.shape[0],).to(device) for W_i in W]
    J = torch.zeros(batch_size, num_layers*hidden_size, num_layers*hidden_size).to(device)
    y = []
    h_out = []
    
    for layer in range(num_layers):
        if layer>0:
            x_l = h_out[layer-1]
            x_in.append(x_l)
        y.append((W[layer]@x_in[layer] + U[layer]@h_in[layer] + b[layer].repeat(batch_size,1).t()).t())
        h_out.append(torch.tanh(y[layer]).t())
        J_h = sech(y[layer])**2@U[layer]
        J[:, layer*hidden_size:(layer+1)*hidden_size, layer*hidden_size:(layer+1)*hidden_size] = J_h
        
        if layer>0:
            J_xt = sech(y[layer])**2@W[layer]
            for l in range(layer, 0, -1):
                J[:, layer*hidden_size:(layer+1)*hidden_size, (l-1)*hidden_size:l*hidden_size] = J_xt@J[:, (layer-1)*hidden_size:(layer)*hidden_size, (l-1)*hidden_size:l*hidden_size]
    return J
        
    
    
def param_split(model_params, bias):
#   model_params should be tuple of the form (W_i, W_h, b_i, b_h)
    hidden_size =int(model_params[0][0].shape[0])
    layers = len(model_params)
    W = []
    U = []
    b_i = []
    b_h = []
    if bias:
        param_list = (W, U, b_i, b_h)
    else:
        param_list = (W, U)
    grouped = []
    for idx, param in enumerate(param_list):
        for layer in range(layers):
#             if len(param.shape) == 1:
#                 param = param.squeeze(dim=1)
            param.append(model_params[layer][idx].detach())            
        grouped.append(param)
    return grouped
	
## Define Math Functions
def get_device(X):
    if X.is_cuda:
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def sig(X):
    device = get_device(X)
    return 1/(1+torch.exp(-X))
def sigmoid(X):
    device = get_device(X)
    return torch.diag_embed(1/(1+torch.exp(-X)))
def sigmoid_p(X):
    device = get_device(X)
    ones = torch.ones_like(X)
    return torch.diag_embed(sig(X)*(ones-sig(X)))
def sech(X):
    device = get_device(X)
    return torch.diag_embed(1/(torch.cosh(X)))
def tanh(X):
    device = get_device(X)
    return torch.diag_embed(torch.tanh(X))