Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
908fa59
feat: add atomic add method and advisory lock to Channel interface
myui Apr 5, 2026
3569c8a
feat: rename add method to atomic_add for clarity in Channel interface
myui Apr 5, 2026
9c43073
feat: implement atomic_add method and advisory locking in MemoryChannel
myui Apr 5, 2026
7026671
feat: add atomic_add method for atomic arithmetic in RedisChannel
myui Apr 5, 2026
012662c
test: add thread safety tests for MemoryChannel and atomic_add method
myui Apr 5, 2026
850792c
docs: update README and add channel concurrency example for thread-sa…
myui Apr 5, 2026
1fdc256
refactor: simplify atomic_add method by removing counter key logic an…
myui Apr 5, 2026
2b5f855
test: add additional tests for atomic_add method in MemoryChannel
myui Apr 5, 2026
ee0f048
Update graflow/channels/memory_channel.py
myui Apr 5, 2026
2b841d7
Update examples/03_data_flow/channel_concurrency.py
myui Apr 5, 2026
8352daa
Update graflow/channels/base.py
myui Apr 5, 2026
bca8dff
docs: enhance warning for lock method in Channel class to clarify no …
myui Apr 5, 2026
9cdbcc7
docs: clarify atomic_add and lock method descriptions in README
myui Apr 5, 2026
3e541ef
feat: add distributed advisory lock method to RedisChannel
myui Apr 5, 2026
ccef489
docs: enhance lock method documentation to clarify mutual exclusion a…
myui Apr 5, 2026
cd999e6
adjust formatting in __setstate__ method for consistency
myui Apr 5, 2026
445614d
Update tests/channels/test_memory_channel_thread_safety.py
myui Apr 5, 2026
79d4cde
test: add timeout assertion for lock acquisition in TestAdvisoryLock
myui Apr 5, 2026
6621148
docs: clarify advisory lock documentation to emphasize safe usage and…
myui Apr 5, 2026
30eb9c1
Update examples/03_data_flow/channel_concurrency.py
myui Apr 5, 2026
d7dee2f
Update graflow/channels/memory_channel.py
myui Apr 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions examples/03_data_flow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ This section demonstrates **data flow and inter-task communication** in Graflow
## What You'll Learn

- 📡 Using channels for inter-task communication
- 🔒 Type-safe channels with TypedDict
- 🔒 Thread-safe channel operations (`atomic_add`, `lock`)
- 🏷️ Type-safe channels with TypedDict
- 💾 Storing and retrieving task results
- 🔄 Data flow patterns in workflows
- 📊 Sharing state across task boundaries
Expand All @@ -37,7 +38,25 @@ uv run python examples/03_data_flow/channels_basic.py

---

### 2. typed_channels.py
### 2. channel_concurrency.py

**Concept**: Thread-safe channel operations

Learn how to safely update shared channel data when tasks run in parallel using `ParallelGroup`.

```bash
uv run python examples/03_data_flow/channel_concurrency.py
```

**Key Concepts**:
- Race conditions with naive `get`/`set` under concurrency
- Atomic counter updates with `channel.atomic_add()`
- Advisory locking with `channel.lock()` for compound operations
- Threshold-based reset pattern with multi-key updates

---

### 3. typed_channels.py

**Concept**: Type-safe channels

Expand All @@ -56,7 +75,7 @@ uv run python examples/03_data_flow/typed_channels.py

---

### 3. results_storage.py
### 4. results_storage.py

**Concept**: Task results and dependency data

Expand Down Expand Up @@ -107,7 +126,32 @@ def process_batch(ctx: TaskExecutionContext):
channel.set("total_processed", total)
```

### Pattern 3: Type-Safe Messages
### Pattern 3: Thread-Safe Counter

```python
@task(inject_context=True)
def count_items(ctx: TaskExecutionContext):
channel = ctx.get_channel()
# Atomic — safe for parallel tasks
channel.atomic_add("processed_count", 1)
```

### Pattern 4: Compound Update with Lock

```python
@task(inject_context=True)
def check_and_reset(ctx: TaskExecutionContext):
channel = ctx.get_channel()
with channel.lock("counter"):
val = channel.get("counter")
if val >= threshold:
channel.set("counter", 0)
channel.atomic_add("overflow_count", 1)
else:
channel.set("counter", val + 1)
```

### Pattern 5: Type-Safe Messages

```python
from typing import TypedDict
Expand All @@ -129,7 +173,7 @@ def collect_metrics(ctx: TaskExecutionContext):
typed_channel.set("metrics", metrics)
```

### Pattern 4: Result Dependencies
### Pattern 6: Result Dependencies

```python
with workflow("pipeline") as ctx:
Expand Down Expand Up @@ -310,9 +354,8 @@ def use_config(ctx: TaskExecutionContext):
@task(inject_context=True)
def track_metrics(ctx: TaskExecutionContext):
channel = ctx.get_channel()
metrics = channel.get("metrics", {})
metrics["processed"] = metrics.get("processed", 0) + 1
channel.set("metrics", metrics)
# Thread-safe counter — works in parallel tasks
channel.atomic_add("processed", 1)
```

### Error Accumulation
Expand Down Expand Up @@ -374,6 +417,8 @@ After mastering data flow:
- `channel.set(key, value)` - Store value
- `channel.get(key, default=None)` - Retrieve value
- `channel.keys()` - List all keys
- `channel.atomic_add(key, amount=1)` - Atomic numeric add/subtract (thread-safe for `MemoryChannel`; single-command atomic for Redis)
- `channel.lock(key, timeout=10.0)` - Advisory per-key lock for compound read-modify-write operations (thread-safe via `threading.RLock`)

**TypedChannel**:
- `typed_channel = ctx.get_typed_channel(SchemaClass)` - Create typed channel
Expand Down
209 changes: 209 additions & 0 deletions examples/03_data_flow/channel_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""
Channel Concurrency Example
============================

This example demonstrates how to safely share and update channel data
when tasks run in parallel using ``ParallelGroup``.

Problem
-------
A naive read-modify-write (``get`` -> compute -> ``set``) is **not** atomic.
When multiple tasks execute in parallel threads, updates can be lost because
two tasks may read the same value and overwrite each other's writes.

Solutions
---------
1. **``channel.atomic_add(key, amount)``** — Atomic numeric add (inc/dec).
Backed by a per-key lock in MemoryChannel and ``INCRBYFLOAT`` in Redis.

2. **``channel.lock(key)``** — Advisory lock for arbitrary compound
operations. ``MemoryChannel`` uses a per-key ``threading.RLock`` for
in-process coordination; ``RedisChannel`` uses ``redis.lock.Lock``
(SET NX + Lua release) for cross-client distributed coordination.

Expected Output
---------------
=== Channel Concurrency Demo ===

--- Unsafe parallel increment (race condition) ---
Expected counter: 500, Actual: <less than 500>
Updates lost!

--- Safe parallel increment with atomic_add() ---
Expected counter: 500, Actual: 500

--- Safe compound update with lock() ---
Overflow events: 5
Counter after resets: 0

Done!
"""

from graflow.core.context import TaskExecutionContext
from graflow.core.decorators import task
from graflow.core.task import ParallelGroup
from graflow.core.workflow import workflow


def demo_unsafe_increment() -> None:
"""Show that naive get/set loses updates in parallel execution."""
import time

print("--- Unsafe parallel increment (race condition) ---")

num_workers = 5
increments_per_worker = 100
expected = num_workers * increments_per_worker

with workflow("unsafe_demo") as ctx:

@task(inject_context=True)
def init_counter(context: TaskExecutionContext):
context.get_channel().set("counter", 0)

# Create worker tasks that use naive get/set
workers = []
for i in range(num_workers):

@task(inject_context=True, id=f"unsafe_worker_{i}")
def unsafe_worker(context: TaskExecutionContext):
channel = context.get_channel()
for _ in range(increments_per_worker):
val = channel.get("counter")
time.sleep(0) # yield to trigger interleaving
channel.set("counter", val + 1)

workers.append(unsafe_worker)

@task(inject_context=True)
def report(context: TaskExecutionContext):
actual = context.get_channel().get("counter")
print(f" Expected counter: {expected}, Actual: {actual}")
if actual < expected:
print(" Updates lost!\n")
else:
print(" (Got lucky — no interleaving this run)\n")

parallel = ParallelGroup(workers, name="unsafe_group")
_ = init_counter >> parallel >> report
ctx.execute("init_counter")


def demo_atomic_add() -> None:
"""Show that atomic_add() is safe for parallel numeric updates."""
print("--- Safe parallel increment with atomic_add() ---")

num_workers = 5
increments_per_worker = 100
expected = num_workers * increments_per_worker

with workflow("add_demo") as ctx:

@task(inject_context=True)
def init_counter(context: TaskExecutionContext):
context.get_channel().set("counter", 0)

workers = []
for i in range(num_workers):

@task(inject_context=True, id=f"add_worker_{i}")
def add_worker(context: TaskExecutionContext):
channel = context.get_channel()
for _ in range(increments_per_worker):
channel.atomic_add("counter", 1)

workers.append(add_worker)

@task(inject_context=True)
def report(context: TaskExecutionContext):
actual = context.get_channel().get("counter")
print(f" Expected counter: {expected}, Actual: {actual}\n")

parallel = ParallelGroup(workers, name="add_group")
_ = init_counter >> parallel >> report
ctx.execute("init_counter")


def demo_advisory_lock() -> None:
"""Show lock() for compound read-modify-write that atomic_add() can't express."""
print("--- Safe compound update with lock() ---")

threshold = 10
num_workers = 5
increments_per_worker = 10

with workflow("lock_demo") as ctx:

@task(inject_context=True)
def init(context: TaskExecutionContext):
channel = context.get_channel()
channel.set("counter", 0)
channel.set("overflow_count", 0)

workers = []
for i in range(num_workers):

@task(inject_context=True, id=f"lock_worker_{i}")
def lock_worker(context: TaskExecutionContext):
channel = context.get_channel()
for _ in range(increments_per_worker):
# Advisory lock protects the entire read-modify-write block
with channel.lock("counter"):
val = channel.get("counter")
if val >= threshold:
channel.set("counter", 0)
channel.atomic_add("overflow_count", 1)
else:
channel.set("counter", val + 1)

workers.append(lock_worker)

@task(inject_context=True)
def report(context: TaskExecutionContext):
channel = context.get_channel()
overflows = channel.get("overflow_count")
counter = channel.get("counter")
print(f" Overflow events: {overflows}")
print(f" Counter after resets: {counter}\n")

parallel = ParallelGroup(workers, name="lock_group")
_ = init >> parallel >> report
ctx.execute("init")


def main():
print("=== Channel Concurrency Demo ===\n")
demo_unsafe_increment()
demo_atomic_add()
demo_advisory_lock()
print("Done!")


if __name__ == "__main__":
main()


# ============================================================================
# Key Takeaways:
# ============================================================================
#
# 1. **channel.atomic_add(key, amount)**
# - Atomic numeric add/subtract — no lost updates
# - Initialises missing keys to 0 automatically
# - MemoryChannel: per-key RLock; Redis: INCRBYFLOAT (server-side atomic)
# - Use for counters, metrics, scores
#
# 2. **channel.lock(key)**
# - Advisory lock for compound operations that atomic_add() can't express
# - Wrap with ``with channel.lock(key):`` context manager
# - MemoryChannel: per-key RLock; Redis: distributed lock for the same key
# - Use for conditional updates and other compound read-modify-write logic
#
# 3. **When to use which**
# - Simple counter? → channel.atomic_add("counter", 1)
# - Decrement? → channel.atomic_add("counter", -1)
# - Conditional update? → with channel.lock("key"): ...
# - Multi-key update? → with channel.lock("key"): ...
# - No concurrency concern? → channel.get() / channel.set() is fine
#
# ============================================================================
61 changes: 60 additions & 1 deletion graflow/channels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional
from contextlib import contextmanager
from typing import Any, Iterator, List, Optional, Union


class Channel(ABC):
Expand Down Expand Up @@ -69,3 +70,61 @@ def prepend(self, key: str, value: Any, ttl: Optional[int] = None) -> int:
Length of the list after prepend
"""
pass

@abstractmethod
def atomic_add(self, key: str, amount: Union[int, float] = 1) -> Union[int, float]:
"""Atomically add *amount* to the numeric value stored at *key*.

If *key* does not exist, it is initialised to 0 before the addition.
Negative *amount* is allowed (decrement).

Args:
key: The key identifying the numeric value.
amount: The value to add (default 1). May be negative.

Returns:
The value after the addition.

Raises:
TypeError: If the existing value is not numeric.
"""
pass

@contextmanager
def lock(self, key: str, timeout: float = 10.0) -> Iterator[None]:
"""Acquire an advisory lock scoped to *key* for compound operations.

Usage::

with channel.lock("counter"):
val = channel.get("counter")
channel.set("counter", val * 2 if val > 0 else 0)

The lock is *advisory* — regular ``get``/``set`` calls do **not**
acquire it automatically. It exists for task authors who need to
protect read-modify-write sequences that cannot be expressed with
``atomic_add()``.

Both ``MemoryChannel`` (threading lock) and ``RedisChannel``
(``redis.lock.Lock`` — distributed SET NX + Lua release) provide
real mutual exclusion.

.. warning::

The **base class** default is a **no-op** (yields immediately)
and provides **no mutual exclusion**. Custom subclasses that
need compound-operation safety **must** override this method.

Args:
key: Logical key to lock on (does not need to correspond to a
stored key).
timeout: Maximum seconds to wait for the lock.

Raises:
TimeoutError: If the lock cannot be acquired within *timeout*
(raised by ``MemoryChannel`` and ``RedisChannel``).

Yields:
None — the lock is held for the duration of the ``with`` block.
"""
yield
Loading
Loading