Skip to content

Commit

Permalink
lazy load CPU kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
taolei87 committed Nov 7, 2018
1 parent 4517e2a commit 01eef88
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions sru/sru_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,34 @@
import torch.nn as nn
from torch.autograd import Variable

# load C++ implementation for CPU computation
try:
from torch.utils.cpp_extension import load
cpu_source = os.path.join(os.path.dirname(__file__), "sru_cpu_impl.cpp")
sru_cpu_impl = load(name="sru_cpu_impl", sources=[cpu_source])
except:
sru_cpu_impl = None
SRU_CPU_kernel = None
SRU_GPU_kernel = None

SRU_GPU_class = None
# load C++ implementation for CPU computation
def _lazy_load_cpu_kernel():
global SRU_CPU_kernel
if SRU_CPU_kernel is not None:
return SRU_CPU_kernel
try:
from torch.utils.cpp_extension import load
cpu_source = os.path.join(os.path.dirname(__file__), "sru_cpu_impl.cpp")
SRU_CPU_kernel = load(name="sru_cpu_impl", sources=[cpu_source])
except:
# use Python version instead
SRU_CPU_kernel = False

# load C++ implementation for GPU computation
def _lazy_load_cuda_class():
global SRU_GPU_class
if SRU_GPU_class is not None:
return SRU_GPU_class
def _lazy_load_cuda_kernel():
global SRU_GPU_kernel
if SRU_GPU_kernel is not None:
return SRU_GPU_kernel
try:
from .cuda_functional import SRU_Compute_GPU
SRU_GPU_class = SRU_Compute_GPU
SRU_GPU_kernel = SRU_Compute_GPU
except:
from cuda_functional import SRU_Compute_GPU
SRU_GPU_class = SRU_Compute_GPU
return SRU_GPU_class
SRU_GPU_kernel = SRU_Compute_GPU
return SRU_GPU_kernel


def SRU_CPU_class(activation_type,
Expand Down Expand Up @@ -79,7 +85,8 @@ def sru_compute_cpu(u, x, weight_c, bias, init=None, mask_c=None):
batch = x.size(-2)
k = u.size(-1) // d // bidir

if sru_cpu_impl is not None:
sru_cpu_impl = _lazy_load_cpu_kernel()
if (sru_cpu_impl is not None) and (sru_cpu_impl != False):
if not torch.is_grad_enabled():
assert mask_c is None
cpu_forward = sru_cpu_impl.cpu_bi_forward if bidirectional else \
Expand Down Expand Up @@ -377,7 +384,7 @@ def forward(self, input, c0=None, mask_pad=None, return_proj=False):

# Pytorch Function() doesn't accept NoneType in forward() call.
# So we put mask_pad as class attribute as a work around
SRU_Compute_Class = _lazy_load_cuda_class() if input.is_cuda else SRU_CPU_class
SRU_Compute_Class = _lazy_load_cuda_kernel() if input.is_cuda else SRU_CPU_class
SRU_Compute = SRU_Compute_Class(
self.activation_type, n_out, self.bidirectional, self.has_skip_term,
scale_val, mask_pad
Expand Down

0 comments on commit 01eef88

Please sign in to comment.