Skip to content

Commit

Permalink
lazy load GPU kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
taolei87 committed Nov 6, 2018
1 parent 72c68eb commit 25d25db
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions sru/sru_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
if torch.cuda.device_count() > 0:
try:
from .cuda_functional import SRU_Compute_GPU
except:
from cuda_functional import SRU_Compute_GPU

# load C++ implementation for CPU computation
try:
from torch.utils.cpp_extension import load
cpu_source = os.path.join(os.path.dirname(__file__), "sru_cpu_impl.cpp")
Expand All @@ -19,8 +15,23 @@
sru_cpu_impl = None
warnings.warn("Failed to load the C++ implementation of SRU for CPU inference.")

SRU_GPU_class = None

# load C++ implementation for GPU computation
def _lazy_load_cuda_class():
global SRU_GPU_class
if SRU_GPU_class is not None:
return SRU_GPU_class
try:
from .cuda_functional import SRU_Compute_GPU
SRU_GPU_class = SRU_Compute_GPU
except:
from cuda_functional import SRU_Compute_GPU
SRU_GPU_class = SRU_Compute_GPU
return SRU_GPU_class


def SRU_Compute_CPU(activation_type,
def SRU_CPU_class(activation_type,
d,
bidirectional=False,
has_skip_term=True,
Expand Down Expand Up @@ -363,7 +374,7 @@ def forward(self, input, c0=None, mask_pad=None, return_proj=False):

# Pytorch Function() doesn't accept NoneType in forward() call.
# So we put mask_pad as class attribute as a work around
SRU_Compute_Class = SRU_Compute_GPU if input.is_cuda else SRU_Compute_CPU
SRU_Compute_Class = _lazy_load_cuda_class() if input.is_cuda else SRU_CPU_class
SRU_Compute = SRU_Compute_Class(
self.activation_type, n_out, self.bidirectional, self.has_skip_term,
scale_val, mask_pad
Expand Down

0 comments on commit 25d25db

Please sign in to comment.