# Compute Module

The `compute` module provides a unified abstraction for defining and executing computational operations on molecular structures.

## Core Concepts

- **Result**: Base class for computation outputs
- **Compute**: Abstract base class for computation operations
- **ComputeContext**: Optional context for sharing expensive intermediates


In [None]:
from dataclasses import dataclass

from molpy.compute import Compute, ComputeContext, Result
from molpy.core.frame import Frame


# Define a custom result
@dataclass
class RadiusOfGyrationResult(Result):
    """Result from radius of gyration calculation."""

    rg: float = 0.0


# Define a custom compute
class RadiusOfGyrationCompute(Compute[Frame, RadiusOfGyrationResult]):
    """Compute radius of gyration for a frame."""

    def compute(self, input: Frame) -> RadiusOfGyrationResult:
        if "atoms" not in input:
            return RadiusOfGyrationResult(name="radius_of_gyration", rg=0.0)

        atoms = input["atoms"]
        import numpy as np

        # Extract coordinates
        x = atoms["x"][:]
        y = atoms["y"][:]
        z = atoms["z"][:]
        coords = np.column_stack([x, y, z])

        # Calculate center of mass
        com = np.mean(coords, axis=0)

        # Calculate radius of gyration
        r_squared = np.sum((coords - com) ** 2, axis=1)
        rg = np.sqrt(np.mean(r_squared))

        return RadiusOfGyrationResult(name="radius_of_gyration", rg=float(rg))


# Use the compute
compute = RadiusOfGyrationCompute()
# result = compute(frame)  # Uncomment when you have a frame

## Design and Architecture

The `compute` subsystem provides a lightweight, composable API for running deterministic computations on MolPy data structures.
Design goals:
- **Separation of concerns**: a `Compute` encapsulates *what* to compute; `ComputeContext` holds shared, expensive intermediates.
- **Small, serializable outputs**: `Result` objects are compact and suitable for caching and pipeline transfer.
- **Composable and testable**: computes are plain objects with a single `compute` method, making them easy to unit-test and reuse.


## Core Concepts

**Result**:
- `Result` is a small dataclass-like container for outputs. It typically carries a `name` and typed fields for the computed values. Results should be immutable after creation when possible so they are safe to cache.

**Compute**:
- `Compute` is an abstract generic class parameterized by its input and `Result` type. Implementations override the `compute(self, input)` method which returns a `Result` instance.
- `Compute` instances are lightweight and may optionally accept a `ComputeContext` in their constructor to share resources.

**ComputeContext**:
- `ComputeContext` provides a place to store expensive-to-build intermediates (e.g., neighbor lists, basis sets, or precomputed descriptors).
- Use `ComputeContext` when multiple computes in a pipeline need access to the same costly data rather than recomputing it.


## How to Customize a Compute

When implementing a custom compute, follow these recommended steps:
1. Create a small `Result` dataclass describing the output fields (keep it focused).
2. Subclass `Compute` and implement the `compute(self, input)` method. Keep the method deterministic and side-effect free whenever possible.
3. If your compute benefits from cached intermediates (e.g. neighbor lists), accept a `ComputeContext` in `__init__` and read/write context keys.
4. Write unit tests that exercise the compute with minimal fixtures (small frames/atomistic objects).

Best practices:
- Return small, explicit `Result` objects rather than raw dicts for clearer typing and easier downstream consumption.
- Avoid mutating the input; if you must, clearly document it.
- Keep `compute` fast and let heavy precomputation live in the context or a separate compute that can be cached.


In [None]:
# Example: custom compute that caches a neighbor list in the context
from dataclasses import dataclass

from molpy.compute import Compute, ComputeContext, Result
from molpy.core.frame import Frame
import numpy as np

@dataclass
class PairCountResult(Result):
    """Count of neighbor pairs within cutoff."""
    count: int = 0

class PairCounterCompute(Compute[Frame, PairCountResult]):
    def __init__(self, cutoff: float = 3.0, context: ComputeContext | None = None):
        super().__init__(context=context)
        self.cutoff = float(cutoff)

    def compute(self, frame: Frame) -> PairCountResult:
        # Try to reuse precomputed neighbor list from context
        ctx = self.context or ComputeContext()
        nlist_key = f'neighbor_list_{self.cutoff}'
        nlist = ctx.data.get(nlist_key)
        if nlist is None:
            # Build simple O(N^2) neighbor list for demo purposes
            coords = np.column_stack([frame['atoms']['x'], frame['atoms']['y'], frame['atoms']['z']])
            N = len(coords)
            pairs = []
            for i in range(N):
                for j in range(i+1, N):
                    if np.linalg.norm(coords[i] - coords[j]) <= self.cutoff:
                        pairs.append((i, j))
            nlist = pairs
            ctx.data[nlist_key] = nlist

        return PairCountResult(name='pair_count', count=len(nlist))

# Usage:
# ctx = ComputeContext()
# compute = PairCounterCompute(cutoff=3.0, context=ctx)
# result = compute(frame)
# print(result.count)

### Sharing Context Between Computes

You can share expensive intermediate computations using `ComputeContext`:


In [None]:
# Create shared context
context = ComputeContext()
# context.data["neighbor_list"] = compute_neighbor_list(frame)

# Use in multiple computes
# compute1 = SomeCompute(context=context)
# compute2 = AnotherCompute(context=context)
# Both can access the neighbor list from context

## Result: 深入理解

`Result` 通常是一个轻量的数据容器（可使用 `dataclass` 或继承自 `Result` 的类型）。关键要点：
- **语义化字段**：将输出命名为明确的字段（例如 `rg`, `count`, `energy`），避免使用杂乱的 dict。
- **包含 `name` 字段**：可用作流水线中的标签或缓存键的一部分。
- **可序列化**：为了方便缓存与传输，保持字段为 JSON 友好的类型（数值、字符串、简单列表）。
- **不可变性**：计算完成后不要修改返回的 `Result`，这能避免缓存和并行执行时的竞态条件。

示例：
```python
@dataclass
class EnergyResult(Result):
    energy: float = 0.0
    name: str = 'energy'
```


## Compute API 参考

核心契约很简单：实现 `compute(self, input)` 并返回 `Result` 实例。常见约定：
- `__init__(self, *, context: ComputeContext | None = None)`: 接受可选的 `ComputeContext`。
- `compute(self, input) -> Result`: 执行计算并返回结果。
- 有些实现可能实现 `__call__` 以便像函数一样被调用（检查具体实现）。

示例最小实现：
```python
class MyCompute(Compute[Frame, MyResult]):
    def compute(self, frame: Frame) -> MyResult:
        # do work\n
        return MyResult(...)
```


## ComputeContext API 与缓存策略

`ComputeContext` 是一个简单的容器（通常暴露 `.data` 字典）用于在多个 computes 之间共享中间结果。实用建议：
- 使用带版本或参数的键名（例如 `f'neighbor_list_{cutoff}_{version}'`）以避免不兼容缓存被误用。
- 在上下文中存放可重用、昂贵的对象；避免把大量临时数据随意放入上下文，从而占用内存。
- 如果需要，可在上下文中存放缓存元数据（构建时间、来源 hash 等），用于缓存失效策略。

示例：
```python
ctx = ComputeContext()
if 'nl_1.5' not in ctx.data:
    ctx.data['nl_1.5'] = build_neighbor_list(frame, cutoff=1.5)
```


## 组合与流水线

计算通常可以拆成多个小 compute 串联执行（例如：几何标准化 -> 描述子计算 -> 统计聚合）。建议：
- 将职责拆分到小的、可测试的 computes。
- 使用 `ComputeContext` 在阶段之间传递缓存（如 neighbor lists、分子图等）。
- 在流水线外部维护版本/参数信息以决定何时使缓存失效。

示例管道：
1. `NormalizeGeometryCompute` (输入：Frame -> 输出：Frame')
2. `DescriptorCompute` (使用 Frame' 和 context 中的 neighbor list) -> 输出：DescriptorResult
3. `AggregateCompute` (聚合多个 DescriptorResult) -> 输出：AggregateResult


## 测试和调试

测试 compute 时的实践：
- 使用极小的 fixture（例如 3-6 个原子）来保持测试快速且可读。
- 把上下文作为可注入依赖，以便在测试时提供替代或模拟数据。
- 为边界条件（空输入、单原子、重心重合）写单独测试。

Pytest 示例：
```python
def test_pair_counter_simple():
    frame = {'atoms': {'x': np.array([0.0, 0.0]), 'y': np.array([0.0, 0.0]), 'z': np.array([0.0, 0.0])}}  # minimal
    ctx = ComputeContext()
    compute = PairCounterCompute(cutoff=1.0, context=ctx)
    result = compute(frame)
    assert result.count == 1
```


In [None]:
# Runnable demo: build a tiny frame and run PairCounterCompute
import numpy as np
from molpy.compute import ComputeContext
# Reuse the PairCounterCompute defined earlier in this notebook

# Minimal frame with 3 atoms in a line:
frame = { 'atoms': { 'x': np.array([0.0, 1.0, 3.0]), 'y': np.array([0.0, 0.0, 0.0]), 'z': np.array([0.0, 0.0, 0.0]) } }
ctx = ComputeContext()
compute = PairCounterCompute(cutoff=1.5, context=ctx)
result = compute(frame)
print('Pair count:', result.count)
