In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # set this before importing torchimport torch 
import torch 
import time

In [None]:
from myrtlespeech.model.hard_lstm import HardLSTM as HardLSTM_ver2
from hard import HardLSTM as HardLSTM_ver1

from deepspeech_int import HardLSTM as HardLSTM_dsi
from matplotlib import pyplot as plt
from typing import List
from copy import copy

In [None]:
def gen_args(in_size, hidden, seq_len, num_layers, bidirectional, batch, gpu=False):
    x = torch.randn(seq_len, batch, in_size)
    num_directions = 2 if bidirectional else 1
    zeros = torch.zeros(
        num_layers * num_directions,
        batch,
        hidden,
        dtype=x.dtype,
    )
    if gpu:
        x = x.cuda()
        zeros = zeros.cuda()
    return (x, (zeros, zeros))

### Check everything runs

In [None]:
in_size = 100
hidden = 128
seq_len = 35
num_layers = 1
bidirectional = False
batch = 3

args = gen_args(in_size, hidden, seq_len, num_layers, bidirectional, batch)

lstm_v1 = HardLSTM_ver1(in_size=in_size, hidden_size=hidden, batch_first=False, bidirectional=bidirectional)
lstm_v2 = HardLSTM_ver2(in_size=in_size, hidden_size=hidden, batch_first=False, bidirectional=bidirectional)
lstm_v2_script = torch.jit.script(HardLSTM_ver2(in_size=in_size, hidden_size=hidden, batch_first=False, bidirectional=bidirectional))
lstm = torch.nn.LSTM(input_size=in_size, hidden_size=hidden, batch_first=False, bidirectional=bidirectional)

outputs_v1 = lstm_v1(*args)
outputs_v2 = lstm_v2(*args)
outputs_v2_script = lstm_v2_script(*args)
outputs_n = lstm(*args)

In [None]:
idx_to_name = {
    0: "Input Size",
    1: "Hidden Size",
    2: "Sequence Length",
    3: "Number Layers",
    4: "bidirectional",
    5: "Batch Size",
}
def profile_and_plot(models, dims, construct_each_time = False, batch_first=False, gpu=False):
    """One and only one of dims is a List. All others are constants.
    
    dims = (in_size, hidden, seq_len, num_layers, bidirectional, batch)
    
    """
    list_seen = False
    for idx, dim in enumerate(dims):
        if isinstance(dim, List):
            assert list_seen == False, "Only one List can be present"
            list_seen = True
            list_idx = idx
    assert list_seen == True, "There must be a List present"
    
    values = dims[list_idx]
    results = {k : [] for k in models.keys()}
    
    if not construct_each_time:
        lstms = {}
        for name, lstm_constr in models.items():
            lstm = lstm_constr(dims[0], dims[1], batch_first=batch_first, bidirectional=dims[4])
            if gpu:
                lstm.cuda()
            # warmup
            dims_in = copy(dims)
            dims_in[list_idx] = values[0]
            args = gen_args(*dims_in, gpu=gpu)
            lstm(*args)
            # add to dict
            lstms[name] = lstm 
    
    for value in values:
        dims_in = copy(dims)
        dims_in[list_idx] = value
        args = gen_args(*dims_in, gpu=gpu)

        for name, lstm_constr in models.items():
            if construct_each_time:
                lstm = lstm_constr(dims_in[0], dims_in[1], batch_first=batch_first, bidirectional=dims_in[4])
                if gpu:
                    lstm.cuda()
                # warmup
                outputs = lstm(*args)
            else:
                lstm = lstms[name]

            # time
            t0 = time.perf_counter() 
            lstm(*args)
            tend = time.perf_counter() 
            results[name].append((value, tend-t0))
            if construct_each_time:
                del lstm 
            
    
    # plot
    for k, res in results.items():
        res_ = list(zip(*res))
        plt.plot(res_[0], res_[1], label=k)
        plt.xlabel(f"{idx_to_name[list_idx]}")
        plt.ylabel("Time /s")
    plt.legend()
    plt.show()
    return results

In [None]:
def get_script_constructor(constructor):
    
    def cstor(*args, **kwargs):
        model = constructor(*args, **kwargs)
        return torch.jit.script(model)
    return cstor

## Variation with batch

In [None]:
in_size = 100
hidden = 256
seq_len = 100
num_layers = 1
bidirectional = False
batch = list(range(2, 256, 8))

models =  {"1": HardLSTM_ver1, 
           "2": HardLSTM_ver2, 
           "2_scripted": get_script_constructor(HardLSTM_ver2),
           "PyTorch": torch.nn.LSTM}

dims = [in_size, hidden, seq_len, num_layers, bidirectional, batch]

results = profile_and_plot(models, dims, construct_each_time=False)


# Seq length

In [None]:
in_size = 100
hidden = 256
seq_len = list(range(2, 800, 30))
num_layers = 1
bidirectional = True
batch = 16

dims = [in_size, hidden, seq_len, num_layers, bidirectional, batch]
models =  {"V1": HardLSTM_ver1, 
           "V2": HardLSTM_ver2, 
           "V2_scripted": get_script_constructor(HardLSTM_ver2),
           "PyTorch": torch.nn.LSTM}

results = profile_and_plot(models, dims, construct_each_time=False)



# In size

In [None]:
in_size = list(range(1, 512, 32))
hidden = 256
seq_len = 100
num_layers = 1
bidirectional = True
batch = 16

dims = [in_size, hidden, seq_len, num_layers, bidirectional, batch]
models =  {"V1": HardLSTM_ver1, 
           "V2": HardLSTM_ver2, 
           "V2_scripted": get_script_constructor(HardLSTM_ver2),
           "PyTorch": torch.nn.LSTM}

results = profile_and_plot(models, dims, construct_each_time=True)


# Hidden

In [None]:
in_size = 100
hidden = list(range(2, 1024, 32))
num_layers = 1
bidirectional = True
batch = 16
seq_len = 100

dims = [in_size, hidden, seq_len, num_layers, bidirectional, batch]

results = profile_and_plot(models, dims, construct_each_time=True)


In [None]:
for k, res in results.items():
    res_ = list(zip(*res))
    plt.plot(res_[0], res_[1], label=k)
    plt.xlabel("Hidden")
    plt.ylabel("time")
    plt.xlim(xmax=500)
    plt.ylim(ymax=0.1)
plt.legend()
plt.show()

# Profile hard1 and hard2 diff

In [None]:
import cProfile
from myrtlespeech.model.hard_lstm import HardLSTM as HardLSTM_ver2
from hard import HardLSTM as HardLSTM_ver1
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # set this before importing torchimport torch 
import torch 
import time
def gen_args(in_size, hidden, seq_len, num_layers, bidirectional, batch):
    x = torch.randn(seq_len, batch, in_size)
    num_directions = 2 if bidirectional else 1
    zeros = torch.zeros(
        num_layers * num_directions,
        batch,
        hidden,
        dtype=x.dtype,
    )
    return (x, (zeros, zeros))

In [None]:
in_size = 100
hidden = 512
seq_len = 100
num_layers = 1
bidirectional = False
batch = 300
dims = in_size, hidden, seq_len, num_layers, bidirectional, batch
args = gen_args(*dims)

lstm_v1 = HardLSTM_ver1(dims[0], dims[1], batch_first=False, bidirectional=dims[4])
lstm_v2 = HardLSTM_ver2(dims[0], dims[1], batch_first=False, bidirectional=dims[4])
lstm_v2_script = torch.jit.script(HardLSTM_ver2(dims[0], dims[1], batch_first=False, bidirectional=dims[4]))

In [None]:
lstm_v2

In [None]:
cProfile.run('lstm_v1(*args)')

In [None]:
cProfile.run('lstm_v2(*args)')

In [None]:
cProfile.run('lstm_v2_script(*args)')