# AlignJuice - 自定义算子开发指南

本 Notebook 演示如何创建自定义算子来扩展 AlignJuice 框架。

In [None]:
from alignjuice import DataContainer, AlignmentSample
from alignjuice.operators.base import Operator, FilterOperator, TransformOperator
from alignjuice.core.registry import register_operator
from typing import Any

## 1. 创建自定义过滤算子

In [None]:
@register_operator("length_filter")
class LengthFilter(FilterOperator):
    """按输出长度过滤样本。"""
    
    name = "length_filter"
    
    def __init__(self, min_length: int = 50, max_length: int = 5000, **kwargs: Any):
        super().__init__(min_length=min_length, max_length=max_length, **kwargs)
        self.min_length = min_length
        self.max_length = max_length
    
    def should_keep(self, sample: AlignmentSample) -> bool:
        length = len(sample.output)
        return self.min_length <= length <= self.max_length

# 测试
data = DataContainer.from_list([
    {"id": "1", "instruction": "Short", "input": "", "output": "Hi", "category": "daily"},
    {"id": "2", "instruction": "Medium", "input": "", "output": "This is a medium length response that should pass the filter.", "category": "daily"},
])

length_filter = LengthFilter(min_length=10)
filtered = length_filter(data)
print(f"过滤结果: {len(data)} -> {len(filtered)}")

## 2. 创建自定义转换算子

In [None]:
@register_operator("prefix_adder")
class PrefixAdder(TransformOperator):
    """为输出添加前缀。"""
    
    name = "prefix_adder"
    
    def __init__(self, prefix: str = "Answer: ", **kwargs: Any):
        super().__init__(prefix=prefix, **kwargs)
        self.prefix = prefix
    
    def transform(self, sample: AlignmentSample) -> AlignmentSample:
        return AlignmentSample(
            id=sample.id,
            instruction=sample.instruction,
            input=sample.input,
            output=self.prefix + sample.output,
            category=sample.category,
            metadata={**sample.metadata, "prefix_added": True}
        )

# 测试
adder = PrefixAdder(prefix="[Enhanced] ")
transformed = adder(filtered)
for s in transformed:
    print(f"{s.id}: {s.output[:50]}...")

## 3. 组合多个算子

In [None]:
# 创建算子链
operators = [
    LengthFilter(min_length=10),
    PrefixAdder(prefix="[Processed] ")
]

# 依次应用
result = data
for op in operators:
    result = op(result)
    print(f"{op.name}: {op.metrics}")

print(f"\n最终结果: {len(result)} 条")
result.show()