# Module 06: State Management and Checkpointing

**Estimated Time:** 75 minutes

## Learning Objectives

By the end of this module, you will:
- Understand Flink's state backends and storage options
- Configure and tune checkpointing for fault tolerance
- Use savepoints for job migration and upgrades
- Handle state schema evolution
- Monitor and optimize state size
- Implement queryable state patterns

---

## 1. Understanding State in Flink

### What is State?

**State** = Data maintained across events for stateful operations

```
Example: Running Count

Event 1: value=5  → state=5   (count: 1)
Event 2: value=3  → state=8   (count: 2)
Event 3: value=7  → state=15  (count: 3)
         ↑                ↑
      Input           State (remembered)
```

### Types of State

**1. Keyed State** (most common):
```
After key_by(), each key has its own state

key='user1' → state_1
key='user2' → state_2
key='user3' → state_3

Isolated: Changes to one key don't affect others
```

**2. Operator State**:
```
State scoped to operator instance (not keyed)

Use cases:
- Kafka source offsets
- Buffering elements
```

### State Primitives

| Type | Description | Use Case |
|------|-------------|----------|
| ValueState | Single value | Last seen value |
| ListState | List of values | Event history |
| MapState | Key-value map | Counts by category |
| ReducingState | Reduced values | Running sum |
| AggregatingState | Aggregated result | Custom aggregation |

In [None]:
# Setup
import json
from datetime import datetime, timedelta
from collections import defaultdict
import random
import pickle
import time

print("[OK] Ready for state management examples")

In [None]:
# Simulate different state types
class StatefulProcessor:
    """Demonstrates different types of state"""

    def __init__(self):
        # ValueState: single value per key
        self.value_state = {}  # key -> value

        # ListState: list of values per key
        self.list_state = defaultdict(list)  # key -> [values]

        # MapState: nested map per key
        self.map_state = defaultdict(dict)  # key -> {subkey: value}

        # ReducingState: accumulated value
        self.reducing_state = defaultdict(int)  # key -> sum

    def update_value_state(self, key, value):
        """Update single value"""
        self.value_state[key] = value

    def append_list_state(self, key, value):
        """Append to list (keep last 5)"""
        self.list_state[key].append(value)
        # Keep only last 5
        if len(self.list_state[key]) > 5:
            self.list_state[key] = self.list_state[key][-5:]

    def update_map_state(self, key, subkey, value):
        """Update nested map"""
        self.map_state[key][subkey] = value

    def add_reducing_state(self, key, value):
        """Add to running sum"""
        self.reducing_state[key] += value

    def get_state_size(self):
        """Estimate state size in bytes"""
        total_size = 0
        total_size += len(pickle.dumps(self.value_state))
        total_size += len(pickle.dumps(self.list_state))
        total_size += len(pickle.dumps(self.map_state))
        total_size += len(pickle.dumps(self.reducing_state))
        return total_size

    def print_state(self):
        """Print current state"""
        print("\n[DATA] Current State:\n")
        print(f"ValueState (last value):")
        for key, value in list(self.value_state.items())[:3]:
            print(f"  {key}: {value}")

        print(f"\nListState (recent values):")
        for key, values in list(self.list_state.items())[:2]:
            print(f"  {key}: {values}")

        print(f"\nMapState (nested):")
        for key, nested in list(self.map_state.items())[:2]:
            print(f"  {key}: {nested}")

        print(f"\nReducingState (sum):")
        for key, total in list(self.reducing_state.items())[:3]:
            print(f"  {key}: {total}")

        print(f"\nState size: {self.get_state_size()} bytes")


# Test state types
processor = StatefulProcessor()

print("[OK] Processing events with different state types...\n")

for i in range(20):
    user_id = f"user_{random.randint(1, 3)}"
    value = random.randint(1, 10)
    action = random.choice(["view", "click", "purchase"])

    # Update different state types
    processor.update_value_state(user_id, value)
    processor.append_list_state(user_id, value)
    processor.update_map_state(user_id, action, value)
    processor.add_reducing_state(user_id, value)

processor.print_state()
print("\n[SUCCESS] State management demonstration complete!")

---

## 2. State Backends

### Types of State Backends

**1. MemoryStateBackend** (Development only):
```
Storage: JVM heap
Checkpoints: JobManager memory

Pros:
- Fast (in-memory)
- Simple setup

Cons:
- Limited by heap size
- Lost on failure
- Not for production!
```

**2. FsStateBackend**:
```
Storage: JVM heap
Checkpoints: Distributed file system (HDFS, S3)

Pros:
- Fast access
- Durable checkpoints

Cons:
- Limited by heap size
- State < few GB per operator
```

**3. RocksDBStateBackend** (Production recommended):
```
Storage: RocksDB (embedded KV store)
Checkpoints: Distributed file system

Pros:
- Scales to TB of state
- Incremental checkpoints
- Spills to disk

Cons:
- Slower than heap
- Serialization overhead
```

### Configuration

```python
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.datastream.state_backend import RocksDBStateBackend

env = StreamExecutionEnvironment.get_execution_environment()

# Configure RocksDB backend
env.set_state_backend(
    RocksDBStateBackend(
        checkpoint_path='file:///tmp/checkpoints',
        enable_incremental_checkpointing=True
    )
)
```

---

## 3. Checkpointing

### What is Checkpointing?

**Periodic snapshots** of all operator state for fault tolerance:

```
Processing Timeline:

10:00 ── Events ── 10:01 ── Events ── 10:02 ── Events ── 10:03
          ↓                   ↓                   ↓
      Checkpoint 1       Checkpoint 2       Checkpoint 3
      (state saved)      (state saved)      (state saved)

Failure at 10:02:30:
  ↓
Restore from Checkpoint 2 (10:02)
Replay events from 10:02 to 10:02:30
  ↓
Resume normal processing
```

### Checkpoint Process

```
1. JobManager triggers checkpoint
2. Barrier injected into stream
3. Operators align barriers
4. Take state snapshot
5. Acknowledge to JobManager
6. Checkpoint complete!
```

### Configuration

```python
# Enable checkpointing
env.enable_checkpointing(10000)  # Every 10 seconds

# Configure checkpoint behavior
checkpoint_config = env.get_checkpoint_config()
checkpoint_config.set_min_pause_between_checkpoints(5000)
checkpoint_config.set_checkpoint_timeout(60000)
checkpoint_config.set_max_concurrent_checkpoints(1)

# Cleanup policy
from pyflink.datastream import ExternalizedCheckpointCleanup
checkpoint_config.enable_externalized_checkpoints(
    ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION
)
```

In [None]:
# Simulate checkpointing
class CheckpointManager:
    """Manages checkpoints for fault tolerance"""

    def __init__(self, checkpoint_interval_seconds=5):
        self.interval = checkpoint_interval_seconds
        self.checkpoints = []  # List of checkpoints
        self.last_checkpoint_time = None

    def should_checkpoint(self, current_time):
        """Check if it's time to checkpoint"""
        if self.last_checkpoint_time is None:
            return True

        elapsed = (current_time - self.last_checkpoint_time).total_seconds()
        return elapsed >= self.interval

    def create_checkpoint(self, state, timestamp):
        """Create checkpoint snapshot"""
        checkpoint = {
            "id": len(self.checkpoints) + 1,
            "timestamp": timestamp,
            "state": pickle.dumps(state),  # Serialize state
            "size_bytes": len(pickle.dumps(state)),
        }

        self.checkpoints.append(checkpoint)
        self.last_checkpoint_time = timestamp

        return checkpoint["id"]

    def restore_from_checkpoint(self, checkpoint_id):
        """Restore state from checkpoint"""
        for checkpoint in self.checkpoints:
            if checkpoint["id"] == checkpoint_id:
                return pickle.loads(checkpoint["state"])
        return None

    def get_latest_checkpoint(self):
        """Get most recent checkpoint"""
        if self.checkpoints:
            return self.checkpoints[-1]
        return None

    def print_checkpoints(self):
        """Print checkpoint history"""
        print("\n[DATA] Checkpoint History:\n")
        for cp in self.checkpoints:
            print(f"Checkpoint {cp['id']}:")
            print(f"  Time: {cp['timestamp'].strftime('%H:%M:%S')}")
            print(f"  Size: {cp['size_bytes']} bytes")


# Simulate processing with checkpoints
class ProcessorWithCheckpoints:
    """Processor that creates periodic checkpoints"""

    def __init__(self):
        self.state = defaultdict(int)  # user_id -> count
        self.checkpoint_manager = CheckpointManager(checkpoint_interval_seconds=3)
        self.events_processed = 0

    def process_event(self, event, timestamp):
        """Process event and checkpoint if needed"""
        # Update state
        self.state[event["user_id"]] += 1
        self.events_processed += 1

        # Check if checkpoint needed
        if self.checkpoint_manager.should_checkpoint(timestamp):
            cp_id = self.checkpoint_manager.create_checkpoint(dict(self.state), timestamp)
            print(f"  [Checkpoint {cp_id} created at {timestamp.strftime('%H:%M:%S')}]")

    def simulate_failure_and_recovery(self):
        """Simulate failure and restore from checkpoint"""
        print("\n[WARNING] Simulating failure...")
        print(f"State before failure: {dict(list(self.state.items())[:3])}")

        # Get latest checkpoint
        latest_cp = self.checkpoint_manager.get_latest_checkpoint()

        if latest_cp:
            print(f"\n[OK] Restoring from checkpoint {latest_cp['id']}...")
            restored_state = self.checkpoint_manager.restore_from_checkpoint(latest_cp["id"])
            self.state = defaultdict(int, restored_state)
            print(f"State after recovery: {dict(list(self.state.items())[:3])}")


# Test checkpointing
processor = ProcessorWithCheckpoints()

print("[OK] Processing events with periodic checkpoints...\n")

base_time = datetime.now()
for i in range(20):
    event = {"user_id": f"user_{random.randint(1, 5)}", "action": "click"}
    timestamp = base_time + timedelta(seconds=i * 0.5)

    processor.process_event(event, timestamp)

    if i % 5 == 0:
        print(f"Processed {i} events...")

processor.checkpoint_manager.print_checkpoints()
processor.simulate_failure_and_recovery()

print("\n[SUCCESS] Checkpointing enables fault tolerance!")

---

## 4. Savepoints

### Savepoints vs Checkpoints

| Feature | Checkpoint | Savepoint |
|---------|-----------|------------|
| Purpose | Auto recovery | Manual snapshot |
| Triggered by | Flink | User |
| Lifetime | Temporary | Permanent |
| Use case | Fault tolerance | Upgrade, migration |

### Savepoint Use Cases

**1. Job Upgrades:**
```
1. Take savepoint of running job
2. Cancel job
3. Deploy new version
4. Start from savepoint
   → State preserved!
```

**2. Cluster Migration:**
```
1. Savepoint on Cluster A
2. Copy savepoint to Cluster B
3. Start job on Cluster B
   → Seamless migration!
```

**3. A/B Testing:**
```
1. Savepoint from production
2. Start test job from savepoint
3. Compare results
```

### Creating Savepoints

```bash
# Trigger savepoint
flink savepoint <jobId> [targetDirectory]

# Start from savepoint
flink run -s <savepointPath> <jobJar>

# Dispose savepoint
flink savepoint -d <savepointPath>
```

---

## 5. State Schema Evolution

### The Problem

```
Version 1:              Version 2:
class UserState {       class UserState {
  String name;            String name;
  int age;                int age;
}                         String email;  ← New field!
                        }

How to upgrade without losing state?
```

### Solutions

**1. POJO Evolution:**
```java
// Adding fields: OK (default values)
// Removing fields: OK (ignored)
// Changing types: NOT OK!
```

**2. Avro Schema Evolution:**
```
Version 1: {name: string, age: int}
Version 2: {name: string, age: int, email: string (default="")}

Avro handles:
- New fields with defaults
- Removed fields
- Type promotions
```

**3. Custom Serializers:**
```python
# Implement TypeSerializer with version handling
class VersionedSerializer:
    def serialize(self, obj, version):
        # Write version + data
        pass
    
    def deserialize(self, data):
        # Read version, migrate if needed
        pass
```

---

## 6. Monitoring State

### Key Metrics

**State Size:**
```
Track: Total state size per operator
Alert if: Growing unboundedly
Action: Add TTL, compact state
```

**Checkpoint Duration:**
```
Track: Time to complete checkpoint
Alert if: > 1 minute
Action: Tune checkpoint interval, use incremental
```

**Checkpoint Alignment:**
```
Track: Time spent aligning barriers
Alert if: High alignment time
Action: Check for slow operators
```

### State TTL (Time-To-Live)

```python
from pyflink.datastream.state import StateTtlConfig

ttl_config = StateTtlConfig \
    .new_builder(Time.hours(1)) \
    .set_update_type(StateTtlConfig.UpdateType.OnCreateAndWrite) \
    .set_state_visibility(StateTtlConfig.StateVisibility.NeverReturnExpired) \
    .build()

# Apply to state descriptor
state_descriptor.enable_time_to_live(ttl_config)
```

In [None]:
# Simulate State TTL
class StateWithTTL:
    """State that automatically expires old entries"""

    def __init__(self, ttl_seconds=10):
        self.ttl = timedelta(seconds=ttl_seconds)
        self.state = {}  # key -> (value, timestamp)

    def put(self, key, value, timestamp):
        """Store value with timestamp"""
        self.state[key] = (value, timestamp)

    def get(self, key, current_time):
        """Get value if not expired"""
        if key not in self.state:
            return None

        value, timestamp = self.state[key]
        age = current_time - timestamp

        if age > self.ttl:
            # Expired!
            del self.state[key]
            return None

        return value

    def cleanup_expired(self, current_time):
        """Remove all expired entries"""
        expired_keys = [
            key
            for key, (value, timestamp) in self.state.items()
            if current_time - timestamp > self.ttl
        ]

        for key in expired_keys:
            del self.state[key]

        return len(expired_keys)

    def size(self):
        """Current state size"""
        return len(self.state)


# Test TTL
state = StateWithTTL(ttl_seconds=5)

base_time = datetime.now()

print("[OK] Testing state with TTL (5 seconds)...\n")

# Add entries
for i in range(5):
    key = f"key_{i}"
    timestamp = base_time + timedelta(seconds=i)
    state.put(key, f"value_{i}", timestamp)
    print(f"Added {key} at {timestamp.strftime('%H:%M:%S')}")

print(f"\nState size: {state.size()}")

# Check at different times
for seconds_elapsed in [3, 6, 10]:
    current_time = base_time + timedelta(seconds=seconds_elapsed)
    expired = state.cleanup_expired(current_time)

    print(f"\nAt +{seconds_elapsed}s: Expired {expired} entries, {state.size()} remaining")

print("\n[OK] TTL prevents unbounded state growth!")

---

## 7. Key Takeaways

[OK] **State Types**: ValueState, ListState, MapState for different use cases

[OK] **State Backends**: Use RocksDB for production (scales to TB)

[OK] **Checkpointing**: Automatic fault tolerance via periodic snapshots

[OK] **Savepoints**: Manual snapshots for upgrades and migrations

[OK] **State Evolution**: Plan for schema changes

[OK] **State TTL**: Prevent unbounded growth with expiration

### Best Practices

1. **Use RocksDB backend** for production
2. **Enable incremental checkpointing** for large state
3. **Set appropriate checkpoint interval** (10-60 seconds)
4. **Monitor state size** and checkpoint duration
5. **Use TTL** for session-based state
6. **Test savepoint compatibility** before upgrades
7. **Externalize checkpoints** for recovery after cancellation

### Configuration Checklist

```python
# Recommended production settings
env.set_state_backend(RocksDBStateBackend(...))
env.enable_checkpointing(30000)  # 30s

config = env.get_checkpoint_config()
config.set_checkpoint_timeout(600000)  # 10min
config.set_min_pause_between_checkpoints(15000)  # 15s
config.set_max_concurrent_checkpoints(1)
config.enable_externalized_checkpoints(
    ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION
)
```

---

## 8. Practice Exercises

1. **Implement checkpointing** with configurable interval
2. **Add state TTL** to prevent memory leaks
3. **Monitor state size** over time
4. **Simulate failure recovery** from checkpoint
5. **Test state migration** with schema changes

In [None]:
# Your practice code here

---

## 9. Next Steps

Congratulations on completing Module 06!

### What You've Learned

- [OK] State types and state backends
- [OK] Checkpointing configuration and tuning
- [OK] Savepoints for job migration
- [OK] State schema evolution strategies
- [OK] State monitoring and TTL

### Coming Up in Module 07: Stream Processing Patterns

You'll learn:
- Common stream processing patterns
- Event deduplication
- CDC (Change Data Capture)
- Exactly-once end-to-end
- Best practices and anti-patterns

### Resources

- [Flink State](https://nightlies.apache.org/flink/flink-docs-master/docs/dev/datastream/fault-tolerance/state/)
- [Checkpointing](https://nightlies.apache.org/flink/flink-docs-master/docs/dev/datastream/fault-tolerance/checkpointing/)
- [Savepoints](https://nightlies.apache.org/flink/flink-docs-master/docs/ops/state/savepoints/)
- [State Backends](https://nightlies.apache.org/flink/flink-docs-master/docs/ops/state/state_backends/)

---

**Ready for patterns?** Open `07_stream_processing_patterns.ipynb` to continue!