Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement zip lookaside in Python interpreter (enables e.g. thunder.jit with zip from LitGPT LLaMAMoE) #284

Open
IvanYashchuk opened this issue Apr 26, 2024 · 6 comments
Labels
enhancement New feature or request good first issue Good for newcomers jit

Comments

@IvanYashchuk
Copy link
Collaborator

馃悰 Bug

Here's a simplified version of LitGPT's LLaMAMoE without data-dependent shapes and it fails somewhere in the general jit:

NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> ProvenanceRecord(

To reproduce:

import torch
import thunder
from torch import nn

class Test(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.n_expert = 8
        self.n_expert_per_token = 2
        self.C = 2
        self.gate = nn.Linear(self.C, self.n_expert, bias=False)
        self.experts = nn.ModuleList(nn.Linear(2, 2, bias=False) for _ in range(self.n_expert))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        x = x.view(-1, C)  # (B*T, C)
        router = self.gate(x)  # (B*T, n_expert)
        probs, indices = torch.topk(router, self.n_expert_per_token)  # (B*T, n_expert_per_token)
        probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
        masks = indices.unsqueeze(-1) == torch.arange(self.n_expert, device=x.device)
        masks = masks.permute(2, 0, 1)  # (n_expert, B*T, n_expert_per_token)
        y = torch.zeros_like(x)  # (B*T, C)
        for (mask, expert) in zip(masks, self.experts):
            token_idx, expert_idx = torch.arange(B*T, device=x.device), torch.arange(B*T, device=x.device)
            pprobs = probs[token_idx, expert_idx]
            pprobs = pprobs.unsqueeze(-1)
            eexpert = expert(x[token_idx])
            y = torch.index_add(y, 0, token_idx, pprobs * eexpert)
        return y.view(B, T, C)

model = Test()
model = thunder.jit(model)

x = torch.randn(2, 3, 2)
y = model(x)

raises:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1272, in unpack_inputs.<locals>.unpack(v)
   1271 try:
-> 1272     from_provenance(p.history)
   1273 except Exception as e:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1178, in unpack_inputs.<locals>.unpack.<locals>.from_load_attr(provenance, new_output)
   1177 is_pure = False
-> 1178 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1179 if new_output:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1178, in <listcomp>(.0)
   1177 is_pure = False
-> 1178 inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1179 if new_output:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in unpack_inputs.<locals>.unpack.<locals>.from_binary_subscr(provenance, new_output)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1195, in <listcomp>(.0)
   1194 def from_binary_subscr(provenance, *, new_output=False):
-> 1195     inputs = [from_provenance(i, new_output=True) for i in provenance.inputs]
   1196     obj, idx = inputs

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1264, in unpack_inputs.<locals>.unpack.<locals>.from_provenance(provenance, new_output)
   1263     raise NotImplementedError(f"Unpacking from {inst} {provenance}")
-> 1264 res = unpack_fn(provenance, new_output=new_output)
   1265 provenance.proxy = res

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1241, in unpack_inputs.<locals>.unpack.<locals>.from_opaque(provenance, new_output)
   1232     return from_provenance(
   1233         ProvenanceRecord(
   1234             PseudoInst.LOAD_ATTR,
   (...)
   1239         )
   1240     )
-> 1241 raise NotImplementedError(f"unpacking from OPAQUE {fn.value} {provenance}")

NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> ProvenanceRecord(
  i1 = INPUT_FN()
  i2 = LOAD_ATTR(i1, '__dict__')
  i3 = BINARY_SUBSCR(i2, '_modules')
  i4 = BINARY_SUBSCR(i3, 'experts')
  i5 = INPUT_ARGS()
  i6 = BINARY_SUBSCR(i5, 0)
  i7 = LOAD_ATTR(i6, '__getattr__')
  i8 = LOAD_ATTR(i7, '__func__')
  i9 = Instruction(opname='CALL_FUNCTION_KW', opcode=141, arg=2, argval=2, argrepr='', offset=102, starts_line=None, is_jump_target=False)()
  i10 = LOAD_ATTR(i1, 'n_expert_per_token')
  i11 = BINARY_SUBSCR(i3, 'gate')
  i12 = LOAD_ATTR(i11, '__dict__')
  i13 = BINARY_SUBSCR(i12, '_parameters')
  i14 = BINARY_SUBSCR(i13, 'bias')
  i15 = BINARY_SUBSCR(i13, 'weight')
  i16 = BUILD_TUPLE('view', i6)
  i17 = OPAQUE(i8, i16, CONSTANT({}))
  i18 = LOAD_ATTR(i17, 'func')
  i19 = Instruction(opname='BINARY_ADD', opcode=23, arg=None, argval=None, argrepr='', offset=10, starts_line=None, is_jump_target=False)()
  i20 = BINARY_SUBSCR(i19, 2)
  i21 = BINARY_SUBSCR(i19, 1)
  i22 = LOAD_ATTR(i17, 'args')
  i23 = BINARY_SUBSCR(i22, 0)
  i24 = BUILD_TUPLE(i20, i21, i23)
  i25 = OPAQUE(i18, i24, CONSTANT({}))
  i26 = BUILD_TUPLE(i14, i15, i25)
  i27 = OPAQUE(CONSTANT([Symbol name=linear]), i26, CONSTANT({}))
  i28 = BUILD_TUPLE(i10, i27)
  i29 = OPAQUE(CONSTANT([Symbol name=topk]), i28, CONSTANT({}))
  i30 = BINARY_SUBSCR(i29, 1)
  i31 = BUILD_TUPLE('unsqueeze', i30)
  i32 = OPAQUE(i8, i31, CONSTANT({}))
  i33 = LOAD_ATTR(i32, 'func')
  i34 = BUILD_TUPLE(i21, i30)
  i35 = OPAQUE(i33, i34, CONSTANT({}))
  i36 = Instruction(opname='COMPARE_OP', opcode=107, arg=2, argval='==', argrepr='==', offset=104, starts_line=None, is_jump_target=False)(i9, i35)
  i37 = BUILD_TUPLE('permute', i36)
  i38 = OPAQUE(i8, i37, CONSTANT({}))
  i39 = LOAD_ATTR(i38, 'func')
  i40 = BINARY_SUBSCR(i19, 3)
  i41 = BUILD_TUPLE(i40, i20, i21, i36)
  i42 = OPAQUE(i39, i41, CONSTANT({}))
  i43 = LOAD_ATTR(i1, 'forward')
  i44 = LOAD_ATTR(i43, '__func__')
  i45 = LOAD_ATTR(i44, '__globals__')
  i46 = BINARY_SUBSCR(i45, '__builtins__')
  i47 = LOAD_ATTR(i46, 'zip')
  i48 = BUILD_TUPLE(i4, i42, i47)
  i49 = OPAQUE(CONSTANT(<built-in method __new__ of type object at 0x55c1c13de340>), i48, CONSTANT({}))
  i50 = BUILD_TUPLE(i49)
  i51 = OPAQUE(CONSTANT(<slot wrapper '__next__' of 'zip' objects>), i50, CONSTANT({}))
)

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 35
     32 model = thunder.jit(model)
     34 x = torch.randn(2, 3, 2)
---> 35 y = model(x)

File ~/miniforge3/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniforge3/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/dev/lightning-thunder/thunder/__init__.py:209, in ThunderModule.forward(self, *args, **kwargs)
    208 def forward(self, *args, **kwargs):
--> 209     res = self._forward_fn(*args, **kwargs)
    210     return res

File ~/dev/lightning-thunder/thunder/__init__.py:661, in jit.<locals>.fn_(*args, **kwargs)
    658 cs.last_trace_host_start = time.time_ns()
    659 cs.calls += 1
--> 661 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    662 cs.last_trace_host_execution_start = time.time_ns()
    664 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:277, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    275 tok = _cache_info_ctx.set({})
    276 try:
--> 277     res = fn(*args, **kwargs)
    278 finally:
    279     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:538, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    536 prologue_trc: TraceCtx
    537 computation_trc: TraceCtx
--> 538 jit_results: TraceResults = interpreter(
    539     fn, args, kwargs, record_history=record_history, sharp_edges=cd.sharp_edges
    540 )
    541 prologue_trc = jit_results.prologue_trace
    542 computation_trc = jit_results.computation_trace

File ~/dev/lightning-thunder/thunder/__init__.py:190, in _general_frontend(fn, args, kwargs, record_history, sharp_edges)
    181 def _general_frontend(
    182     fn: Callable,
    183     args: tuple[Any, ...],
   (...)
    188     sharp_edges: SHARP_EDGES_OPTIONS,
    189 ) -> TraceResults:
--> 190     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1481, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges)
   1478 else:
   1479     epilogue_trace = None
-> 1481 pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(
   1482     ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None
   1483 )
   1485 proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
   1486 pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1301, in unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, has_epilogue)
   1298             pro_kwargs_proxy = output
   1300 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1301 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
   1303 with tracectx(prologue_trace):
   1304     for prim, *args in ctx._constraints:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1301, in <genexpr>(.0)
   1298             pro_kwargs_proxy = output
   1300 pro_to_epi = tuple(sorted((unpack(v) for v in pro_to_epi_inps), key=lambda x: param_ordering[id(x)][1]))
-> 1301 pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
   1303 with tracectx(prologue_trace):
   1304     for prim, *args in ctx._constraints:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1274, in unpack_inputs.<locals>.unpack(v)
   1272         from_provenance(p.history)
   1273     except Exception as e:
-> 1274         raise NotImplementedError(f"Exception occured unpacking object from {p.history}") from e
   1276 already_unpacked[id(p)] = p
   1278 # Adds cache constraints
   1279 # TODO Consider refactoring these contraints
   1280 # TODO Constrain on rank, device, and dtype

NotImplementedError: Exception occured unpacking object from ProvenanceRecord(
  i1 = INPUT_FN()
  i2 = LOAD_ATTR(i1, '__dict__')
  i3 = BINARY_SUBSCR(i2, '_modules')
  i4 = BINARY_SUBSCR(i3, 'experts')
  i5 = INPUT_ARGS()
  i6 = BINARY_SUBSCR(i5, 0)
  i7 = LOAD_ATTR(i6, '__getattr__')
  i8 = LOAD_ATTR(i7, '__func__')
  i9 = Instruction(opname='CALL_FUNCTION_KW', opcode=141, arg=2, argval=2, argrepr='', offset=102, starts_line=None, is_jump_target=False)()
  i10 = LOAD_ATTR(i1, 'n_expert_per_token')
  i11 = BINARY_SUBSCR(i3, 'gate')
  i12 = LOAD_ATTR(i11, '__dict__')
  i13 = BINARY_SUBSCR(i12, '_parameters')
  i14 = BINARY_SUBSCR(i13, 'bias')
  i15 = BINARY_SUBSCR(i13, 'weight')
  i16 = BUILD_TUPLE('view', i6)
  i17 = OPAQUE(i8, i16, CONSTANT({}))
  i18 = LOAD_ATTR(i17, 'func')
  i19 = Instruction(opname='BINARY_ADD', opcode=23, arg=None, argval=None, argrepr='', offset=10, starts_line=None, is_jump_target=False)()
  i20 = BINARY_SUBSCR(i19, 2)
  i21 = BINARY_SUBSCR(i19, 1)
  i22 = LOAD_ATTR(i17, 'args')
  i23 = BINARY_SUBSCR(i22, 0)
  i24 = BUILD_TUPLE(i20, i21, i23)
  i25 = OPAQUE(i18, i24, CONSTANT({}))
  i26 = BUILD_TUPLE(i14, i15, i25)
  i27 = OPAQUE(CONSTANT([Symbol name=linear]), i26, CONSTANT({}))
  i28 = BUILD_TUPLE(i10, i27)
  i29 = OPAQUE(CONSTANT([Symbol name=topk]), i28, CONSTANT({}))
  i30 = BINARY_SUBSCR(i29, 1)
  i31 = BUILD_TUPLE('unsqueeze', i30)
  i32 = OPAQUE(i8, i31, CONSTANT({}))
  i33 = LOAD_ATTR(i32, 'func')
  i34 = BUILD_TUPLE(i21, i30)
  i35 = OPAQUE(i33, i34, CONSTANT({}))
  i36 = Instruction(opname='COMPARE_OP', opcode=107, arg=2, argval='==', argrepr='==', offset=104, starts_line=None, is_jump_target=False)(i9, i35)
  i37 = BUILD_TUPLE('permute', i36)
  i38 = OPAQUE(i8, i37, CONSTANT({}))
  i39 = LOAD_ATTR(i38, 'func')
  i40 = BINARY_SUBSCR(i19, 3)
  i41 = BUILD_TUPLE(i40, i20, i21, i36)
  i42 = OPAQUE(i39, i41, CONSTANT({}))
  i43 = LOAD_ATTR(i1, 'forward')
  i44 = LOAD_ATTR(i43, '__func__')
  i45 = LOAD_ATTR(i44, '__globals__')
  i46 = BINARY_SUBSCR(i45, '__builtins__')
  i47 = LOAD_ATTR(i46, 'zip')
  i48 = BUILD_TUPLE(i4, i42, i47)
  i49 = OPAQUE(CONSTANT(<built-in method __new__ of type object at 0x55c1c13de340>), i48, CONSTANT({}))
  i50 = BUILD_TUPLE(i49)
  i51 = OPAQUE(CONSTANT(<slot wrapper '__next__' of 'zip' objects>), i50, CONSTANT({}))
  i52 = BINARY_SUBSCR(i51, 1)
  i53 = LOAD_ATTR(i52, '__dict__')
  i54 = BINARY_SUBSCR(i53, '_parameters')
  i55 = BINARY_SUBSCR(i54, 'weight')
)
@IvanYashchuk IvanYashchuk added bug Something isn't working jit labels Apr 26, 2024
@t-vi
Copy link
Collaborator

t-vi commented Apr 26, 2024

Thank you @IvanYashchuk The underlying issue is "need lookaside for zip in interpreter".

@t-vi t-vi added enhancement New feature or request good first issue Good for newcomers and removed bug Something isn't working labels Apr 26, 2024
@t-vi t-vi changed the title thunder.jit with zip from LitGPT LLaMAMoE implement zip lookaside in Python interpreter (enables e.g. thunder.jit with zip from LitGPT LLaMAMoE) Apr 26, 2024
@t-vi
Copy link
Collaborator

t-vi commented Apr 26, 2024

Seems that this is a good issue for someone who wants to take a look at our great Python interpreter (thunder/core/interpreter.py), it's not trivial but should be relatively self-contained.

@lantiga
Copy link
Collaborator

lantiga commented Apr 26, 2024

@t-vi what would be a similar lookaside to start from for anyone wanting to approach this?

@nikitaved
Copy link
Contributor

I would I assume that functools.reduce is not that bad to look at, specifically because of the test coverage (generic iterables, custom iterables, etc.).

@IvanYashchuk
Copy link
Collaborator Author

Doesn't the error message NotImplementedError: unpacking from OPAQUE <slot wrapper '__next__' of 'zip' objects> mean that the problem is in the next interpretation and not in zip?

@riccardofelluga
Copy link
Collaborator

@t-vi to me it looks like there some errors in the unpacking more than the zip, might it be the opaque ModuleList container?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers jit
Projects
None yet
Development

No branches or pull requests

5 participants