In [None]:
import time
from tqdm import tqdm

In [None]:
import torch
import math
import torch.nn.functional as F
import lltm_cpp

In [None]:
class LLTM(torch.nn.Module):
    def __init__(self, input_features, state_size):
        super(LLTM, self).__init__()
        self.input_features = input_features
        self.state_size = state_size
        # 3 * state_size for input gate, output gate and candidate cell gate.
        # input_features + state_size because we will multiply with [input, h].
        self.weights = torch.nn.Parameter(
            torch.empty(3 * state_size, input_features + state_size))
        self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.state_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)

    def forward(self, input, state):
        old_h, old_cell = state
        X = torch.cat([old_h, input], dim=1)

        # Compute the input, output and candidate cell gates with one MM.
        gate_weights = F.linear(X, self.weights, self.bias)
        # Split the combined gate weight matrix into its components.
        gates = gate_weights.chunk(3, dim=1)

        input_gate = torch.sigmoid(gates[0])
        output_gate = torch.sigmoid(gates[1])
        # Here we use an ELU instead of the usual tanh.
        candidate_cell = F.elu(gates[2])

        # Compute the new cell state.
        new_cell = old_cell + candidate_cell * input_gate
        # Compute the new hidden state and output.
        new_h = torch.tanh(new_cell) * output_gate

        return new_h, new_cell

In [None]:
batch_size = 32
input_features = 10
state_size = 4

In [None]:
X = torch.randn(batch_size, input_features).to('cuda')
h = torch.randn(batch_size, state_size).to('cuda')
C = torch.randn(batch_size, state_size).to('cuda')

rnn = LLTM(input_features, state_size).to('cuda')

In [None]:
new_h, new_C = rnn(X, (h, C))

In [None]:
for _ in tqdm(range(100000)):
    new_h, new_C = rnn(X, (h, C))

In [None]:
import math
import torch

# Our module!
import lltm_cpp

class LLTMFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weights, bias, old_h, old_cell):
        outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell)
        new_h, new_cell = outputs[:2]
        variables = outputs[1:] + [weights]
        # ctx.save_for_backward(*variables)

        return new_h, new_cell

In [None]:
class LLTM(torch.nn.Module):
    def __init__(self, input_features, state_size):
        super(LLTM, self).__init__()
        self.input_features = input_features
        self.state_size = state_size
        self.weights = torch.nn.Parameter(
            torch.empty(3 * state_size, input_features + state_size))
        self.bias = torch.nn.Parameter(torch.empty(3 * state_size))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.state_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)

    def forward(self, input, state):
        return LLTMFunction.apply(input, self.weights, self.bias, *state)

In [None]:
import torch

batch_size = 16
input_features = 32
state_size = 128

X = torch.randn(batch_size, input_features).to('cuda')
h = torch.randn(batch_size, state_size).to('cuda')
C = torch.randn(batch_size, state_size).to('cuda')

rnn = LLTM(input_features, state_size).to('cuda')

In [None]:
for _ in tqdm(range(100000)):
    new_h, new_C = rnn(X, (h, C))

In [2]:
from torch.utils.cpp_extension import load

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
lltm_cpp = load(name='lltm_scpp', sources=['lltm.cpp', 'lltm_cuda_kernel.cu'])

RuntimeError: Error building extension 'lltm_scpp': [1/3] /home/archangel/miniconda3/envs/jax_wav2vec2/bin/nvcc  -DTORCH_EXTENSION_NAME=lltm_scpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/TH -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/THC -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -std=c++14 -c /home/archangel/DeepLearning/research/lltm_cuda_kernel.cu -o lltm_cuda_kernel.cuda.o 
FAILED: lltm_cuda_kernel.cuda.o 
/home/archangel/miniconda3/envs/jax_wav2vec2/bin/nvcc  -DTORCH_EXTENSION_NAME=lltm_scpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/TH -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/THC -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -std=c++14 -c /home/archangel/DeepLearning/research/lltm_cuda_kernel.cu -o lltm_cuda_kernel.cuda.o 
/home/archangel/DeepLearning/research/lltm_cuda_kernel.cu(59): error: identifier "lltm_cuda_forward_kernel" is undefined

/home/archangel/DeepLearning/research/lltm_cuda_kernel.cu(59): error: type name is not allowed

/home/archangel/DeepLearning/research/lltm_cuda_kernel.cu(59): error: expected an expression

/home/archangel/DeepLearning/research/lltm_cuda_kernel.cu(59): error: identifier "lltm_cuda_forward_kernel" is undefined

/home/archangel/DeepLearning/research/lltm_cuda_kernel.cu(59): error: type name is not allowed

/home/archangel/DeepLearning/research/lltm_cuda_kernel.cu(59): error: expected an expression

6 errors detected in the compilation of "/home/archangel/DeepLearning/research/lltm_cuda_kernel.cu".
[2/3] c++ -MMD -MF lltm.o.d -DTORCH_EXTENSION_NAME=lltm_scpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/TH -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/lib/python3.8/site-packages/torch/include/THC -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/include -isystem /home/archangel/miniconda3/envs/jax_wav2vec2/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /home/archangel/DeepLearning/research/lltm.cpp -o lltm.o 
ninja: build stopped: subcommand failed.
