μ΄ νν 리μΌμμλ, TorchScriptμμ λμ inter-op λ³λ ¬μ²λ¦¬ λ₯Ό νλ ꡬ문(syntax)μ μκ°ν©λλ€. μ΄ λ³λ ¬μ²λ¦¬μλ λ€μκ³Ό κ°μ μμ±μ΄ μμ΅λλ€:
- λμ (dynamic) - μμ±λ λ³λ ¬ μμ μ μμ μμ λΆνλ νλ‘κ·Έλ¨μ μ μ΄ νλ¦μ λ°λΌ λ¬λΌμ§ μ μμ΅λλ€.
- inter-op - λ³λ ¬ μ²λ¦¬λ TorchScript νλ‘κ·Έλ¨ μ‘°κ°μ λ³λ ¬λ‘ μ€ννλ κ²κ³Ό κ΄λ ¨μ΄ μμ΅λλ€. μ΄λ κ°λ³ μ°μ°μλ₯Ό λΆν νκ³ μ°μ°μ μμ μ νμ μ§ν©μ λ³λ ¬λ‘ μ€ννλ λ°©μμΈ intra-op parallelism μλ ꡬλ³λ©λλ€.
λμ λ³λ ¬ μ²λ¦¬λ₯Ό μν λ κ°μ§ μ€μν APIλ λ€μκ³Ό κ°μ΅λλ€:
torch.jit.fork(fn : Callable[..., T], *args, **kwargs) -> torch.jit.Future[T]
torch.jit.wait(fut : torch.jit.Future[T]) -> T
μ΄λ¬ν μλ λ°©μμ λ€μ μμ μμ μ μ΄ν΄ν μ μμ΅λλ€:
import torch
def foo(x):
return torch.neg(x)
@torch.jit.script
def example(x):
# λ³λ ¬μ μΌλ‘ `foo` λ₯Ό νΈμΆν©λλ€.
# λ¨Όμ , μμ
μ "fork" ν©λλ€. μ΄ μμ
μ `x` μΈμ(argument)μ ν¨κ» `foo` λ₯Ό μ€νν©λλ€.
future = torch.jit.fork(foo, x)
# μΌλ°μ μΌλ‘ `foo`λ₯Ό νΈμΆν©λλ€.
x_normal = foo(x)
# λμ§Έ, μμ
μ "κΈ°λ€λ¦½λλ€".
# μμ
μ΄ λ³λ ¬λ‘ μ€ν μ€μΌ μ μμΌλ―λ‘ κ²°κ³Όλ₯Ό μ¬μ©ν μ μμ λκΉμ§ "λκΈ°" ν΄μΌν©λλ€.
# κ³μ°μ λ³λ ¬λ‘ μννκΈ° μν΄μ
# "fork()" μ "wait()" μ¬μ΄μμ
# Futureλ₯Ό νΈμΆνλ μ μ μ μνμΈμ.
x_parallel = torch.jit.wait(future)
return x_normal, x_parallel
print(example(torch.ones(1))) # (-1., -1.)
fork()
λ νΈμΆ κ°λ₯ν(callable) fn
,κ·Έμ λν νΈμΆ κ°λ₯ν μΈμ args
λ° kwargs
λ₯Ό μ·¨νκ³ fn
μ€νμ μν λΉλκΈ°(asynchronous) μμ
μ μμ±ν©λλ€.
fn
μ ν¨μ, λ©μλ, λλ λͺ¨λ μΈμ€ν΄μ€μΌ μ μμ΅λλ€.
fork()
λ Future
λΌκ³ λΆλ¦¬λ μ΄ μ€ν κ²°κ³Όμ κ°μ λν μ°Έμ‘°(reference)λ₯Ό λ°νν©λλ€.
fork
λ λΉλκΈ° μμ
μ μμ±ν μ§νμ λ°νλκΈ° λλ¬Έμ,
fork()
νΈμΆ ν μ½λ λΌμΈμ΄ μ€νλ λκΉμ§ fn
μ΄ μ€νλμ§ μμ μ μμ΅λλ€.
λ°λΌμ, wait()
μ λΉλκΈ° μμ
μ΄ μλ£ λ λκΉμ§ λκΈ°νκ³ κ°μ
λ°ννλλ° μ¬μ©λ©λλ€.
μ΄λ¬ν ꡬ쑰λ ν¨μ λ΄μμ λͺ λ Ήλ¬Έ μ€νμ μ€μ²©νκ±°λ μμ λ (μμ μΉμ μ νμλ¨) 루νμ κ°μ λ€λ₯Έ μΈμ΄ κ΅¬μ‘°λ‘ κ΅¬μ± λ μ μμ΅λλ€:
import torch
from typing import List
def foo(x):
return torch.neg(x)
@torch.jit.script
def example(x):
futures : List[torch.jit.Future[torch.Tensor]] = []
for _ in range(100):
futures.append(torch.jit.fork(foo, x))
results = []
for future in futures:
results.append(torch.jit.wait(future))
return torch.sum(torch.stack(results))
print(example(torch.ones([])))
Note
Futureμ λΉ λ¦¬μ€νΈ(list)λ₯Ό μ΄κΈ°νν λ, λͺ
μμ μΈ μ ν μ£Όμμ futures
μ μΆκ°ν΄μΌ νμ΅λλ€.
TorchScriptμμ λΉ μ»¨ν
μ΄λ(container)λ κΈ°λ³Έμ μΌλ‘ tensor κ°μ ν¬ν¨νλ€κ³ κ°μ νλ―λ‘
리μ€νΈ μμ±μ(constructor) #μ
List[torch.jit.Future[torch.Tensor]]
μ νμ μ£Όμμ λ¬μμ΅λλ€.
μ΄ μμ λ fork()
λ₯Ό μ¬μ©νμ¬ ν¨μ foo
μ μΈμ€ν΄μ€ 100κ°λ₯Ό μμνκ³ , 100κ°μ μμ
μ΄ μλ£ λ λκΉμ§
λκΈ°ν λ€μ, κ²°κ³Όλ₯Ό ν©μ°νμ¬ -100.0
μ λ°νν©λλ€.
λ³΄λ€ νμ€μ μΈ μμμ λ³λ ¬νλ₯Ό μ μ©νκ³ μ΄λ€ μ±λ₯μ μ»μ μ μλμ§ μ΄ν΄λ΄ μλ€. λ¨Όμ , μλ°©ν₯ LSTM κ³μΈ΅μ μμλΈμΈ κΈ°μ€ λͺ¨λΈμ μ μν©μλ€.
import torch, time
# RNN μ©μ΄μμλ μ°λ¦¬κ° κ΄μ¬ κ°λ μ°¨μλ€μ μλμ κ°μ΄ λΆλ¦
λλ€:
# λ¨μμκ°μ κ°―μ (T)
# λ°°μΉ ν¬κΈ° (B)
# "channels"μ μ¨κ²¨μ§ ν¬κΈ°/μ«μ (C)
T, B, C = 50, 50, 1024
# λ¨μΌ "μλ°©ν₯ LSTM"μ μ μνλ λͺ¨λμ
λλ€.
# μ΄λ λ¨μν λμΌν μνμ€μ μ μ©λ λ κ°μ LSTMμ΄μ§λ§ νλλ λ°λλ‘ μ μ©λ©λλ€.
class BidirectionalRecurrentLSTM(torch.nn.Module):
def __init__(self):
super().__init__()
self.cell_f = torch.nn.LSTM(input_size=C, hidden_size=C)
self.cell_b = torch.nn.LSTM(input_size=C, hidden_size=C)
def forward(self, x : torch.Tensor) -> torch.Tensor:
# Forward κ³μΈ΅
output_f, _ = self.cell_f(x)
# Backward κ³μΈ΅. μκ° μ°¨μ(time dimension)(dim 0)μμ μ
λ ₯μ flip (dim 0),
# κ³μΈ΅ μ μ©νκ³ , μκ° μ°¨μμμ μΆλ ₯μ flip ν©λλ€.
x_rev = torch.flip(x, dims=[0])
output_b, _ = self.cell_b(torch.flip(x, dims=[0]))
output_b_rev = torch.flip(output_b, dims=[0])
return torch.cat((output_f, output_b_rev), dim=2)
# `BidirectionalRecurrentLSTM` λͺ¨λμ "ensemble"μ
λλ€.
# μμλΈμ λͺ¨λμ κ°μ μ
λ ₯μΌλ‘ νλνλμ© μ€νλκ³ ,
# λμ λκ³ ν©μ°λ κ²°κ³Όλ₯Ό λ°νν©λλ€.
class LSTMEnsemble(torch.nn.Module):
def __init__(self, n_models):
super().__init__()
self.n_models = n_models
self.models = torch.nn.ModuleList([
BidirectionalRecurrentLSTM() for _ in range(self.n_models)])
def forward(self, x : torch.Tensor) -> torch.Tensor:
results = []
for model in self.models:
results.append(model(x))
return torch.stack(results).sum(dim=0)
# fork/waitμΌλ‘ μ€νν κ²λ€μ μ§μ λΉκ΅λ₯Ό μν΄
# λͺ¨λμ μΈμ€ν΄μ€ννκ³ TorchScriptλ₯Ό ν΅ν΄ μ»΄νμΌν΄ λ΄
μλ€.
ens = torch.jit.script(LSTMEnsemble(n_models=4))
# μΌλ°μ μΌλ‘ μλ² λ© ν
μ΄λΈ(embedding table)μμ μ
λ ₯μ κ°μ Έμ€μ§λ§,
# λ°λͺ¨λ₯Ό μν΄ μ¬κΈ°μλ 무μμ λ°μ΄ν°λ₯Ό μ¬μ©νκ² μ΅λλ€.
x = torch.rand(T, B, C)
# λ©λͺ¨λ¦¬ ν λΉμ(memory allocator) λ±μ μ€λΉμν€κΈ° μν΄ λͺ¨λΈμ λ¨Όμ νλ² μ€νν©λλ€.
ens(x)
x = torch.rand(T, B, C)
# μΌλ§λ λΉ λ₯΄κ² μ€νλλμ§ λ΄
μλ€!
s = time.time()
ens(x)
print('Inference took', time.time() - s, ' seconds')
μ μ»΄ν¨ν°μμλ λ€νΈμν¬κ° 2.05
μ΄ λ§μ μ€νλμμ΅λλ€. ν¨μ¬ λ λΉ λ₯΄κ² ν μ μμ΅λλ€!
κ°λ¨νκ² ν μ μλ μΌλ‘λ BidirectionalRecurrentLSTM
λ΄μμ forward, backward κ³μΈ΅λ€μ λ³λ ¬ννλ κ²μ΄ μμ΅λλ€.
μ΄ λ, κ³μ° ꡬ쑰λ κ³ μ λμ΄ μμΌλ―λ‘ μ°λ¦¬λ μ΄λ€ 루νλ νμλ‘ νμ§ μμ΅λλ€.
BidirectionalRecurrentLSTM
μ forward
λ©μλλ₯Ό λ€μκ³Ό κ°μ΄ μ¬μμ±ν΄λ΄
μλ€:
def forward(self, x : torch.Tensor) -> torch.Tensor:
# Backward κ³μΈ΅κ³Ό λ³λ ¬λ‘ μ€νμν€κΈ° μν΄ forward layerλ₯Ό fork()λ₯Ό νλ€.
future_f = torch.jit.fork(self.cell_f, x)
# Backward κ³μΈ΅. μκ° μ°¨μ(time dimension)(dim 0)μμ μ
λ ₯μ flip (dim 0),
# κ³μΈ΅μ μ μ©νκ³ , κ·Έλ¦¬κ³ μκ° μ°¨μμμ μΆλ ₯μ flip ν©λλ€.
x_rev = torch.flip(x, dims=[0])
output_b, _ = self.cell_b(torch.flip(x, dims=[0]))
output_b_rev = torch.flip(output_b, dims=[0])
# Forward κ³μΈ΅μμ μΆλ ₯μ λ°μμ΅λλ€.
# μ΄λ μ°λ¦¬κ° λ³λ ¬ννλ €λ μμ
*μ΄ν*μ μΌμ΄λμΌ ν¨μ μ£Όμν΄μΌ ν©λλ€.
output_f, _ = torch.jit.wait(future_f)
return torch.cat((output_f, output_b_rev), dim=2)
μ΄ μμμμ, forward()
λ cell_b
μ μ€νμ κ³μνλ λμ
cell_f
λ₯Ό λ€λ₯Έ μ€λ λλ‘ μμν©λλ€.
μ΄λ‘ μΈν΄ λ μ
μ μ€νμ΄ μλ‘ κ²ΉμΉ©λλ€.
μ΄ κ°λ¨ν μμ νμ μ€ν¬λ¦½νΈλ₯Ό λ€μ μ€ννλ©΄
17%
ν₯μλ 1.71
μ΄μ λ°νμμ΄ λμ΅λλ€!
μμ§ λͺ¨λΈ μ΅μ νκ° λλμ§ μμμ§λ§ μ΄μ―€μμ μ±λ₯ μκ°νλ₯Ό μν λꡬλ₯Ό λμ ν΄λ΄ μλ€. ν κ°μ§ μ€μν λꡬλ PyTorch νλ‘νμΌλ¬(profiler) μ λλ€.
Chromeμ μΆμ λ΄λ³΄λ΄κΈ° κΈ°λ₯(trace export functionality)κ³Ό ν¨κ» νλ‘νμΌλ¬λ₯Ό μ¬μ©ν΄ λ³λ ¬νλ λͺ¨λΈμ μ±λ₯μ μκ°νν΄λ΄ μλ€:
with torch.autograd.profiler.profile() as prof:
ens(x)
prof.export_chrome_trace('parallel.json')
μ΄ μμ μ½λ μ‘°κ°μ parallel.json
νμΌμ μμ±ν©λλ€.
Google Chromeμμ chrome://tracing
μΌλ‘ μ΄λνμ¬ Load
λ²νΌμ ν΄λ¦νκ³
JSON νμΌμ λ‘λνλ©΄ λ€μκ³Ό κ°μ νμλΌμΈμ λ³΄κ² λ κ²λλ€:
νμλΌμΈμ κ°λ‘μΆμ μκ°μ, μΈλ‘μΆμ μ€ν μ€λ λλ₯Ό λνλ
λλ€.
보λ€μνΌ ν λ²μ λ κ°μ lstm
μ μ€ννκ³ μμ΅λλ€.
μ΄κ²μ μλ°©ν₯(forward, backward) κ³μΈ΅μ λ³λ ¬ννκΈ° μν΄
λ
Έλ ₯ν κ²°κ³Όμ
λλ€!
μ΄ μ½λμ λ λ§μ λ³λ ¬ν κΈ°νκ° μλ€λ κ²μ λμΉμ±μμ§λ λͺ¨λ¦
λλ€:
LSTMEnsemble
μ ν¬ν¨λ λͺ¨λΈλ€μ μλ‘ λ³λ ¬λ‘ μ€νν μλ μμ΅λλ€.
μ΄λ κ² νκΈ° μν λ°©λ²μ μμ£Ό κ°λ¨ν©λλ€.
λ°λ‘ LSTMEnsemble
μ forward
λ©μλλ₯Ό λ³κ²½νλ λ°©λ²μ
λλ€:
def forward(self, x : torch.Tensor) -> torch.Tensor:
# κ° λͺ¨λΈμ μν μμ
μ€νν©λλ€.
futures : List[torch.jit.Future[torch.Tensor]] = []
for model in self.models:
futures.append(torch.jit.fork(model, x))
# μ€νλ μμ
λ€μμ κ²°κ³Ό μμ§ν©λλ€.
results : List[torch.Tensor] = []
for future in futures:
results.append(torch.jit.wait(future))
return torch.stack(results).sum(dim=0)
λλ, λ§μ½ κ°κ²°ν¨μ μ€μνκ² μκ°νλ€λ©΄ 리μ€νΈ μ»΄ν리ν¨μ (list comprehension)μ μ¬μ©ν μ μμ΅λλ€.
def forward(self, x : torch.Tensor) -> torch.Tensor:
futures = [torch.jit.fork(model, x) for model in self.models]
results = [torch.jit.wait(fut) for fut in futures]
return torch.stack(results).sum(dim=0)
μλμμ μ€λͺ νλ―μ΄, μ°λ¦¬λ 루νλ₯Ό μ¬μ©ν΄ μμλΈμ κ° λͺ¨λΈλ€μ λν μμ μ λλ΄μ΅λλ€. κ·Έλ¦¬κ³ λͺ¨λ μμ μ΄ μλ£λ λκΉμ§ κΈ°λ€λ¦΄ λ€λ₯Έ 루νλ₯Ό μ¬μ©νμ΅λλ€. μ΄λ λ λ§μ κ³μ°μ μ€λ²λ©μ μ 곡ν©λλ€.
μ΄ μμ μ
λ°μ΄νΈλ‘ μ€ν¬λ¦½νΈλ 1.4
μ΄μ μ€νλμ΄ μ΄ 32%
λ§νΌ μλκ° ν₯μλμμ΅λλ€!
λ¨ λ μ€λ§μ μ’μ ν¨κ³Όλ₯Ό 보μμ΅λλ€.
λν Chrome μΆμ κΈ°(tracer)λ₯Ό λ€μ μ¬μ©ν΄ μ§ν μν©μ λ³Ό μ μμ΅λλ€:
μ΄μ λͺ¨λ LSTM
μΈμ€ν΄μ€κ° μμ ν λ³λ ¬λ‘ μ€νλλ κ²μ λ³Ό μ μμ΅λλ€.
μ΄ νν 리μΌμμ μ°λ¦¬λ TorchScriptμμ λμ (dynamic), inter-op λ³λ ¬ μ²λ¦¬λ₯Ό μννκΈ° μν κΈ°λ³Έ APIμΈ
fork()
μ wait()
μ λν΄ λ°°μ μ΅λλ€.
μ΄λ¬ν ν¨μλ€μ μ¬μ©ν΄ TorchScript μ½λμμ ν¨μ, λ©μλ, λλ
Modules
μ μ€νμ λ³λ ¬ννλ λͺ κ°μ§ μΌλ°μ μΈ μ¬μ© ν¨ν΄λ 보μμ΅λλ€.
λ§μ§λ§μΌλ‘, μ΄ κΈ°μ μ μ¬μ©ν΄ λͺ¨λΈμ μ΅μ ννλ μλ₯Ό νμ΄λ³΄κ³ , PyTorchμμ μ¬μ© κ°λ₯ν
μ±λ₯ μΈ‘μ λ° μκ°ν λꡬλ₯Ό μ΄ν΄λ³΄μμ΅λλ€.