Skip to content

[Feat] Support dataclass in magi_register_custom_op#32

Merged
jiahy0825 merged 7 commits into
SandAI-org:mainfrom
themistbeforedawn:feat/magi-register-op-with-dataclass-input
May 19, 2026
Merged

[Feat] Support dataclass in magi_register_custom_op#32
jiahy0825 merged 7 commits into
SandAI-org:mainfrom
themistbeforedawn:feat/magi-register-op-with-dataclass-input

Conversation

@themistbeforedawn
Copy link
Copy Markdown
Collaborator

@themistbeforedawn themistbeforedawn commented May 13, 2026

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

What's new

@magi_register_custom_op now accepts frozen-dataclass parameters (recursively nested), so users can group config / flags as a single @dataclass(frozen=True) while torch.library's schema continues to see only primitives.

Internally, the user signature is lowered to a flat form at registration time, and a small runtime adapter flattens / unflattens dataclass values around every call.

Quick start

@dataclasses.dataclass(frozen=True)
class AttnCfg:
    scale: float
    causal: bool = False

@magi_register_custom_op()
def attn(q: torch.Tensor, k: torch.Tensor, cfg: AttnCfg) -> torch.Tensor:
    ...

Full-support example

@dataclasses.dataclass(frozen=True)
class Inner:
    w: torch.Tensor
    b: torch.Tensor


@dataclasses.dataclass(frozen=True)
class WeightCfg:
    inner: Inner
    scale: float


def setup(ctx, inputs, output):
    x, cfg = inputs
    ctx.save_for_backward(x, cfg.inner.w)
    ctx.scale = cfg.scale


def bwd(ctx, gy):
    x, w = ctx.saved_tensors
    s = ctx.scale
    # Slot order matches the ORIGINAL signature of `op` below.
    return (
        gy * w * s,                                              # grad for x   (Tensor)
        {"inner": {"w": gy * x * s, "b": None}, "scale": None},  # grad for cfg (nested dict)
        # equivalently: WeightCfg(inner=Inner(w=gy * x * s, b=None), scale=None)
    )


@magi_register_custom_op(setup_context_fn=setup, backward_fn=bwd)
def op(x: torch.Tensor, cfg: WeightCfg) -> torch.Tensor:
    return (x * cfg.inner.w + cfg.inner.b) * cfg.scale

Note

  1. Tensor fields inside the dataclass are now differentiable. Returning None — for the whole slot, an individual field, or by simply omitting the key — marks them non-differentiable, same as PyTorch.
  2. The bridge matches grads by field name, so dict is the recommended form for convenience.

The same lower-signature pass also handles (transparent to users):

  • Literal[...] / string-Enum annotations → auto-downgraded to str
  • Unsupported defaults (mutable, dataclass instance, …) are scrubbed from the lowered signature only; user-facing defaults are preserved
  • mutates_args accepts either the dataclass-level name (expands to all Tensor leaves) or any lowered leaf name
  • backward_fn returns one grad per original parameter (not per leaf); a whole non-differentiable dataclass arg collapses to a single None
  • Variadic args / missing annotations are rejected up-front at registration time with actionable errors.

Architecture — 4-slot pipeline

Each registration owns up to 4 named objects:

Slot Object Created by Presence
0 fn user source Always
1 lowered_fn this PR only when the signature needs lowering
2 torch_registered_op torch.library Always
3 magi_exposed_op this PR only when dataclass flattening is needed

The naming is deliberately dual: torch_registered_op is registered into torch.library's dispatcher; magi_exposed_op is exposed out of Magi to the user.

Tests

86 new tests in tests/api_tests/test_register_custom_op.py cover all three runtime paths, autograd bridging through dataclass inputs, nested dataclasses, Optional / Literal / Enum / dtype / device fields, torch.compile integration, and full error-path coverage.

Comment thread magi_compiler/_magi_register_custom_op.py Outdated
Comment thread magi_compiler/_magi_register_custom_op.py
Copy link
Copy Markdown
Collaborator

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiahy0825 jiahy0825 merged commit e68e0b0 into SandAI-org:main May 19, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants