# Compute & Result Pattern

This notebook explains MolPy's **Compute & Result pattern** for defining reusable, composable computational operations.

The pattern separates:
- **Compute**: The operation/algorithm (how to calculate)
- **Result**: The output data structure (what you get)

This separation provides:
- Type safety with generic types
- Reusable computation logic
- Structured, self-documenting results
- Easy testing and composition

## When to Use This Pattern

| Use Case | Use Compute | Use Function | Use Method |
|----------|-------------|--------------|------------|
| Simple one-off calculation | ❌ | ✅ | ✅ |
| Reusable with configuration | ✅ | ⚠️ | ❌ |
| Complex multi-step algorithm | ✅ | ⚠️ | ❌ |
| Needs setup/cleanup | ✅ | ❌ | ❌ |
| Structured output | ✅ | ⚠️ | ⚠️ |
| Composable operations | ✅ | ⚠️ | ❌ |
| External dependencies | ✅ | ✅ | ❌ |

**Key principle**: Use `Compute` when you need **configuration + reusability + structure**. Use functions for simple calculations, methods for core data class operations.

## Quick Start

Here's a minimal example showing the pattern in action:

In [None]:
from dataclasses import dataclass
import numpy as np
from molpy.compute import Compute, Result
from molpy.core import Frame, Block

# 1. Define the result
@dataclass
class CenterOfMassResult(Result):
    """Result of center of mass calculation."""
    com: np.ndarray  # Center of mass coordinates
    total_mass: float  # Total mass

# 2. Define the compute operation
class ComputeCenterOfMass(Compute[Frame, CenterOfMassResult]):
    """Compute center of mass for a frame."""
    
    def compute(self, input: Frame) -> CenterOfMassResult:
        """Calculate center of mass."""
        atoms = input["atoms"]
        positions = np.column_stack([atoms["x"], atoms["y"], atoms["z"]])
        masses = np.array(atoms.get("mass", np.ones(len(atoms))))
        
        total_mass = masses.sum()
        com = (positions * masses[:, np.newaxis]).sum(axis=0) / total_mass
        
        return CenterOfMassResult(com=com, total_mass=total_mass)

# 3. Use it
frame = Frame()
frame["atoms"] = Block({
    "x": [0.0, 1.0, 2.0],
    "y": [0.0, 0.0, 0.0],
    "z": [0.0, 0.0, 0.0],
    "mass": [1.0, 1.0, 1.0]
})

compute_com = ComputeCenterOfMass()
result = compute_com(frame)

print(f"Center of mass: {result.com}")
print(f"Total mass: {result.total_mass}")

## The Pattern Architecture

### Compute Base Class

All compute operations inherit from `Compute[InT, OutT]`:

```python
class Compute[InT, OutT](ABC):
    """Abstract base class for compute operations."""
    
    def __call__(self, input: InT) -> OutT:
        """Execute the computation."""
        self.before(input)
        result = self.compute(input)
        self.after(input, result)
        return result
    
    @abstractmethod
    def compute(self, input: InT) -> OutT:
        """Core computation logic (must override)."""
        ...
    
    def before(self, input: InT) -> None:
        """Optional setup hook."""
        pass
    
    def after(self, input: InT, result: OutT) -> None:
        """Optional cleanup hook."""
        pass
```

### Result Base Class

Results are dataclasses that hold computation outputs:

```python
@dataclass
class Result:
    """Base class for computation results."""
    meta: dict[str, Any] = field(default_factory=dict)
    
    def to_dict(self) -> dict[str, Any]:
        """Convert result to dictionary."""
        return {k: v for k, v in self.__dict__.items()}
```

## Lifecycle Hooks

The `before()` and `after()` hooks enable setup and cleanup. This is useful for:
- Input validation
- Resource allocation
- Caching
- Logging
- Cleanup

In [None]:
from dataclasses import dataclass
import numpy as np
from molpy.compute import Compute, Result
from molpy.core import Frame, Block

@dataclass
class DistanceMatrixResult(Result):
    """Result of distance matrix calculation."""
    distances: np.ndarray  # Pairwise distance matrix
    n_atoms: int  # Number of atoms

class ComputeDistanceMatrix(Compute[Frame, DistanceMatrixResult]):
    """Compute pairwise distance matrix with validation and caching."""
    
    def before(self, input: Frame) -> None:
        """Validate input and allocate cache."""
        if "atoms" not in input.blocks():
            raise ValueError("Frame must have 'atoms' block")
        
        atoms = input["atoms"]
        if not all(k in atoms for k in ["x", "y", "z"]):
            raise ValueError("Atoms must have x, y, z coordinates")
        
        # Allocate cache for intermediate results
        n = len(atoms)
        self._cache = np.zeros((n, n))
        print(f"Allocated cache for {n} atoms")
    
    def compute(self, input: Frame) -> DistanceMatrixResult:
        """Calculate pairwise distances."""
        atoms = input["atoms"]
        positions = np.column_stack([atoms["x"], atoms["y"], atoms["z"]])
        
        # Compute pairwise distances
        n = len(positions)
        for i in range(n):
            for j in range(i + 1, n):
                dist = np.linalg.norm(positions[i] - positions[j])
                self._cache[i, j] = dist
                self._cache[j, i] = dist
        
        return DistanceMatrixResult(
            distances=self._cache.copy(),
            n_atoms=n
        )
    
    def after(self, input: Frame, result: DistanceMatrixResult) -> None:
        """Cleanup and logging."""
        del self._cache
        print(f"Computed distances for {result.n_atoms} atoms")

# Example usage
frame = Frame()
frame["atoms"] = Block({
    "x": [0.0, 1.0, 0.0],
    "y": [0.0, 0.0, 1.0],
    "z": [0.0, 0.0, 0.0]
})

compute_dist = ComputeDistanceMatrix()
result = compute_dist(frame)
print(f"Distance matrix:\n{result.distances}")

## Real-World Example: Trajectory Analysis

MolPy includes `MCDCompute` for Mean Displacement Correlation (diffusion analysis). This demonstrates:
- Complex computation with multiple parameters
- Structured result with time series data
- Processing trajectory data

### MCDCompute Overview

```python
class MCDCompute(Compute[Trajectory, MCDResult]):
    """Compute Mean Displacement Correlations (MSD) for diffusion analysis.
    
    Supports:
    - Self diffusion: MSD_i = <(r_i(t+dt) - r_i(t))²>
    - Distinct diffusion: <(r_i(t+dt) - r_i(t)) · (r_j(t+dt) - r_j(t))>
    
    Args:
        tags: List of atom type specifications
            - Single integer (e.g., "3"): Self-diffusion MSD of type 3
            - Two integers (e.g., "3,4"): Distinct diffusion between types
        max_dt: Maximum time lag in ps
        dt: Timestep in ps
        center_of_mass: Optional COM removal
    """
    
    def compute(self, input: Trajectory) -> MCDResult:
        # Extract coordinates and unwrap periodic boundaries
        # Apply center of mass correction if requested
        # Compute correlations for each tag
        return MCDResult(time=time_array, correlations=correlations)
```

### Usage Pattern

```python
from molpy.io import read_h5_trajectory
from molpy.compute import MCDCompute

# Load trajectory
trajectory = read_h5_trajectory("trajectory.h5")

# Compute self-diffusion MSD of atom type 3
mcd = MCDCompute(tags=["3"], max_dt=30.0, dt=0.01)
result = mcd(trajectory)

# Access results
print(result.time)  # Time lag values
print(result.correlations["3"])  # MSD values at each time lag

# Compute distinct diffusion between types 3 and 4
mcd_distinct = MCDCompute(tags=["3,4"], max_dt=30.0, dt=0.01)
result_distinct = mcd_distinct(trajectory)
print(result_distinct.correlations["3,4"])  # Correlation values
```

## Composability

Compute operations can be chained and composed. This is particularly useful with RDKit integration:

In [None]:
# Example structure (requires RDKit)
# from molpy.compute import Generate3D, OptimizeGeometry
# from molpy.adapter import RDKitAdapter
# from molpy.core.atomistic import Atomistic

# # Create molecule
# mol = Atomistic()
# # ... define atoms and bonds ...

# # Create adapter
# adapter = RDKitAdapter(internal=mol)

# # Chain operations
# generate_3d = Generate3D(add_hydrogens=True, optimize=True)
# optimize = OptimizeGeometry(max_opt_iters=500)

# # Apply sequentially
# adapter = generate_3d(adapter)  # Generate 3D coordinates
# adapter = optimize(adapter)     # Further optimize geometry

# # Extract result
# optimized_mol = adapter.internal

print("Composition example (requires RDKit to run)")
print("Shows how Compute operations can be chained:")
print("  adapter -> Generate3D -> OptimizeGeometry -> result")

## Benefits of the Pattern

### 1. Type Safety

Generic types ensure correct input/output:

```python
# Type checker knows:
compute: Compute[Frame, CenterOfMassResult]
result: CenterOfMassResult = compute(frame)  # ✓ Type safe
result: str = compute(frame)  # ✗ Type error
```

### 2. Structured Results

Results are self-documenting:

```python
@dataclass
class MCDResult(TimeSeriesResult):
    """Results from MCD calculation."""
    time: np.ndarray  # Time lag values
    correlations: dict[str, np.ndarray]  # MSD for each tag
```

### 3. Reusability

Compute operations are configurable and reusable:

```python
# Different configurations
mcd_short = MCDCompute(tags=["1"], max_dt=5.0, dt=0.1)
mcd_long = MCDCompute(tags=["1", "2"], max_dt=20.0, dt=0.2)

# Apply to different trajectories
result1 = mcd_short(trajectory1)
result2 = mcd_short(trajectory2)
```

### 4. Testability

Easy to test in isolation:

```python
def test_center_of_mass():
    # Create test frame
    frame = create_test_frame()
    
    # Compute
    compute = ComputeCenterOfMass()
    result = compute(frame)
    
    # Assert
    assert np.allclose(result.com, expected_com)
    assert result.total_mass == expected_mass
```

## Comparison with Other Patterns

### vs. Simple Functions

**Function:**
```python
def calculate_rdf(frame: Frame, r_max: float = 10.0) -> tuple[np.ndarray, np.ndarray]:
    # Calculate RDF
    return r, g_r
```

**Compute:**
```python
class ComputeRDF(Compute[Frame, RDFResult]):
    def __init__(self, r_max: float = 10.0, n_bins: int = 100):
        self.r_max = r_max
        self.n_bins = n_bins
    
    def compute(self, frame: Frame) -> RDFResult:
        # Calculate RDF
        return RDFResult(r=r, g_r=g_r, r_max=self.r_max)
```

**When to use Compute:**
- Multiple configuration parameters
- Need to reuse with different configs
- Complex setup/cleanup
- Want structured results

### vs. Methods on Data Classes

**Method:**
```python
class Frame:
    def center_of_mass(self) -> np.ndarray:
        # Calculate COM
        return com
```

**Compute:**
```python
class ComputeCenterOfMass(Compute[Frame, CenterOfMassResult]):
    def compute(self, frame: Frame) -> CenterOfMassResult:
        # Calculate COM
        return CenterOfMassResult(com=com, total_mass=mass)
```

**When to use Compute:**
- Operation is complex or configurable
- Want to keep data classes lightweight
- Operation is optional (requires external deps)
- Want to test computation separately

## Design Guidelines

### When to Use Compute

Use `Compute` for:
- ✅ Reusable calculations (RDF, MSD, COM)
- ✅ Operations with configuration (parameters, options)
- ✅ Complex multi-step algorithms
- ✅ Operations that need setup/cleanup

Don't use `Compute` for:
- ❌ Simple one-off calculations (use functions)
- ❌ Data transformations (use methods on data classes)
- ❌ IO operations (use readers/writers)

### When to Use Result

Use `Result` for:
- ✅ Structured computation outputs
- ✅ Multiple related values
- ✅ Results that need metadata
- ✅ Results that may be serialized

Don't use `Result` for:
- ❌ Simple scalar returns (use primitives)
- ❌ Temporary intermediate values

### Naming Conventions

- Compute classes: `Compute<Operation>` (e.g., `ComputeCenterOfMass`, `MCDCompute`)
- Result classes: `<Operation>Result` (e.g., `CenterOfMassResult`, `MCDResult`)
- Instances: Descriptive names (e.g., `compute_com`, `mcd`, `optimizer`)

## Real-World Examples in MolPy

### Trajectory Analysis

```python
# MCDCompute: Mean displacement correlation
class MCDCompute(Compute[Trajectory, MCDResult]):
    ...

# PMSDCompute: Polarization MSD
class PMSDCompute(Compute[Trajectory, PMSDResult]):
    ...
```

### RDKit Integration

```python
# Generate3D: 3D coordinate generation
class Generate3D(Compute[RDKitAdapter, RDKitAdapter]):
    ...

# OptimizeGeometry: Force field optimization
class OptimizeGeometry(Compute[RDKitAdapter, RDKitAdapter]):
    ...
```

## Testing Compute Operations

Best practices for testing:

In [None]:
# Example test structure
def test_compute_center_of_mass():
    """Test center of mass computation."""
    # Arrange: Create test data
    frame = Frame()
    frame["atoms"] = Block({
        "x": [0.0, 2.0],
        "y": [0.0, 0.0],
        "z": [0.0, 0.0],
        "mass": [1.0, 1.0]
    })
    
    # Act: Compute
    compute = ComputeCenterOfMass()
    result = compute(frame)
    
    # Assert: Verify results
    expected_com = np.array([1.0, 0.0, 0.0])
    assert np.allclose(result.com, expected_com)
    assert result.total_mass == 2.0
    
    print("✓ Test passed")

# Run test
test_compute_center_of_mass()

## Performance Considerations

### Caching

Use `before()` to allocate caches:

```python
def before(self, input: Frame) -> None:
    n = len(input["atoms"])
    self._cache = np.zeros((n, n))
```

### Vectorization

Prefer NumPy operations over loops:

```python
# Good: Vectorized
com = (positions * masses[:, np.newaxis]).sum(axis=0) / total_mass

# Bad: Loop
com = sum(pos * mass for pos, mass in zip(positions, masses)) / total_mass
```

### Memory Management

Clean up in `after()`:

```python
def after(self, input: Frame, result: Result) -> None:
    del self._cache
    del self._temporary_arrays
```

## Troubleshooting

### Common Issues

**Issue**: Type errors with generic types
```python
# Wrong: Missing type parameters
class MyCompute(Compute):
    ...

# Correct: Specify input and output types
class MyCompute(Compute[Frame, MyResult]):
    ...
```

**Issue**: Forgetting to override `compute()`
```python
# Wrong: No compute method
class MyCompute(Compute[Frame, Result]):
    pass

# Correct: Override compute
class MyCompute(Compute[Frame, Result]):
    def compute(self, input: Frame) -> Result:
        return Result()
```

**Issue**: Modifying input in `compute()`
```python
# Wrong: Modifies input
def compute(self, input: Frame) -> Result:
    input["atoms"]["x"] += 1.0  # Side effect!
    return Result()

# Correct: Don't modify input
def compute(self, input: Frame) -> Result:
    x = input["atoms"]["x"].copy()
    x += 1.0
    return Result()
```

## Summary

The Compute & Result pattern provides:

1. **Type Safety** - Generic types ensure correctness
2. **Reusability** - Configurable, composable operations
3. **Structure** - Self-documenting results
4. **Testability** - Easy to test in isolation
5. **Extensibility** - Simple to add new operations

Use this pattern for complex, reusable computations that benefit from configuration and structured outputs.

For simple calculations, plain functions are often sufficient.