# Day 32: Continuous Batching - Part 5b

Continuous batching allows requests to join and leave batches dynamically, improving efficiency over static batching.

## Overview
1. Understanding continuous batching
2. Simple implementation
3. Performance comparison

In [None]:
import time
import threading
import queue
import numpy as np
import matplotlib.pyplot as plt
from collections import deque

## 1. Understanding Continuous Batching

Unlike static batching, continuous batching:
- Processes requests as they arrive
- Removes completed sequences immediately
- Adds new requests to existing batches
- Maximizes GPU utilization

In [None]:
class ContinuousBatchScheduler:
    def __init__(self, max_batch_size=8, token_generation_time=0.05):
        self.max_batch_size = max_batch_size
        self.token_generation_time = token_generation_time
        self.active_requests = {}
        self.request_queue = queue.Queue()
        self.completed_requests = []
        self.running = False
        
    def submit_request(self, request_id, prompt, max_tokens=20):
        request = {
            'id': request_id,
            'prompt': prompt,
            'max_tokens': max_tokens,
            'generated_tokens': 0,
            'arrival_time': time.time(),
            'start_time': None,
            'completion_time': None,
            'result': prompt
        }
        self.request_queue.put(request)
    
    def start(self):
        self.running = True
        self.thread = threading.Thread(target=self._process_continuously)
        self.thread.daemon = True
        self.thread.start()
    
    def stop(self):
        self.running = False
        if hasattr(self, 'thread'):
            self.thread.join()
    
    def _process_continuously(self):
        while self.running:
            # Add new requests to active batch
            while (not self.request_queue.empty() and 
                   len(self.active_requests) < self.max_batch_size):
                request = self.request_queue.get()
                request['start_time'] = time.time()
                self.active_requests[request['id']] = request
            
            # Process active requests if any
            if self.active_requests:
                self._generate_tokens()
                self._remove_completed()
            else:
                time.sleep(0.01)  # Small sleep if no active requests
    
    def _generate_tokens(self):
        # Simulate token generation for all active requests
        batch_size = len(self.active_requests)
        
        # Efficiency factor for batching
        efficiency = max(0.5, 1.0 - 0.1 * np.log(batch_size))
        
        # Simulate generation time
        time.sleep(self.token_generation_time * efficiency)
        
        # Update all active requests
        for request in self.active_requests.values():
            request['generated_tokens'] += 1
            request['result'] += " token"
    
    def _remove_completed(self):
        completed_ids = []
        for req_id, request in self.active_requests.items():
            if request['generated_tokens'] >= request['max_tokens']:
                request['completion_time'] = time.time()
                self.completed_requests.append(request)
                completed_ids.append(req_id)
        
        # Remove completed requests
        for req_id in completed_ids:
            del self.active_requests[req_id]
    
    def get_metrics(self):
        if not self.completed_requests:
            return {}
        
        latencies = []
        for req in self.completed_requests:
            latency = req['completion_time'] - req['arrival_time']
            latencies.append(latency)
        
        total_time = max([r['completion_time'] for r in self.completed_requests]) - \
                    min([r['arrival_time'] for r in self.completed_requests])
        
        return {
            'num_completed': len(self.completed_requests),
            'avg_latency': np.mean(latencies),
            'throughput': len(self.completed_requests) / total_time if total_time > 0 else 0
        }

## 2. Testing Continuous Batching

In [None]:
def test_continuous_batching(max_batch_size=8, num_requests=20, arrival_rate=3):
    scheduler = ContinuousBatchScheduler(max_batch_size=max_batch_size)
    scheduler.start()
    
    # Submit requests
    for i in range(num_requests):
        prompt = f"Request {i}: AI will"
        max_tokens = np.random.randint(5, 15)
        scheduler.submit_request(i, prompt, max_tokens)
        time.sleep(1.0 / arrival_rate)
    
    # Wait for completion
    while len(scheduler.completed_requests) < num_requests:
        time.sleep(0.1)
    
    scheduler.stop()
    return scheduler.get_metrics()

# Test continuous batching
metrics = test_continuous_batching(max_batch_size=8, num_requests=20, arrival_rate=3)
print(f"Completed: {metrics['num_completed']}")
print(f"Avg Latency: {metrics['avg_latency']:.2f}s")
print(f"Throughput: {metrics['throughput']:.2f} req/s")

## 3. Comparing Static vs Continuous Batching

In [None]:
# Simple static batching for comparison
class StaticBatchScheduler:
    def __init__(self, batch_size=8, token_generation_time=0.05):
        self.batch_size = batch_size
        self.token_generation_time = token_generation_time
        self.completed_requests = []
    
    def process_requests(self, requests):
        # Process in fixed batches
        for i in range(0, len(requests), self.batch_size):
            batch = requests[i:i+self.batch_size]
            max_tokens = max(req['max_tokens'] for req in batch)
            
            # Mark start time
            start_time = time.time()
            for req in batch:
                req['start_time'] = start_time
            
            # Simulate processing entire batch
            efficiency = max(0.5, 1.0 - 0.1 * np.log(len(batch)))
            processing_time = max_tokens * self.token_generation_time * efficiency
            time.sleep(processing_time)
            
            # Mark completion
            completion_time = time.time()
            for req in batch:
                req['completion_time'] = completion_time
                req['result'] = req['prompt'] + " " + "token" * req['max_tokens']
                self.completed_requests.append(req)
    
    def get_metrics(self):
        if not self.completed_requests:
            return {}
        
        latencies = []
        for req in self.completed_requests:
            latency = req['completion_time'] - req['arrival_time']
            latencies.append(latency)
        
        total_time = max([r['completion_time'] for r in self.completed_requests]) - \
                    min([r['arrival_time'] for r in self.completed_requests])
        
        return {
            'num_completed': len(self.completed_requests),
            'avg_latency': np.mean(latencies),
            'throughput': len(self.completed_requests) / total_time if total_time > 0 else 0
        }

In [None]:
def compare_scheduling_methods(num_requests=20, arrival_rate=3):
    # Create requests
    requests = []
    arrival_time = time.time()
    
    for i in range(num_requests):
        req = {
            'id': i,
            'prompt': f"Request {i}: AI will",
            'max_tokens': np.random.randint(5, 15),
            'arrival_time': arrival_time + i / arrival_rate
        }
        requests.append(req)
    
    # Test static batching
    static_scheduler = StaticBatchScheduler(batch_size=8)
    static_requests = [req.copy() for req in requests]
    static_scheduler.process_requests(static_requests)
    static_metrics = static_scheduler.get_metrics()
    
    # Test continuous batching
    continuous_metrics = test_continuous_batching(max_batch_size=8, num_requests=num_requests, arrival_rate=arrival_rate)
    
    return static_metrics, continuous_metrics

# Compare methods
static_metrics, continuous_metrics = compare_scheduling_methods()

print("Static Batching:")
print(f"  Avg Latency: {static_metrics['avg_latency']:.2f}s")
print(f"  Throughput: {static_metrics['throughput']:.2f} req/s")

print("\nContinuous Batching:")
print(f"  Avg Latency: {continuous_metrics['avg_latency']:.2f}s")
print(f"  Throughput: {continuous_metrics['throughput']:.2f} req/s")

print(f"\nImprovement:")
print(f"  Latency: {static_metrics['avg_latency']/continuous_metrics['avg_latency']:.2f}x better")
print(f"  Throughput: {continuous_metrics['throughput']/static_metrics['throughput']:.2f}x better")

## 4. Visualizing the Difference

In [None]:
# Test different batch sizes
batch_sizes = [2, 4, 8, 16]
static_latencies = []
continuous_latencies = []
static_throughputs = []
continuous_throughputs = []

for batch_size in batch_sizes:
    print(f"Testing batch size: {batch_size}")
    
    # Static batching
    static_scheduler = StaticBatchScheduler(batch_size=batch_size)
    requests = [{
        'id': i,
        'prompt': f"Request {i}",
        'max_tokens': 10,
        'arrival_time': time.time() + i * 0.1
    } for i in range(20)]
    
    static_scheduler.process_requests([req.copy() for req in requests])
    static_metrics = static_scheduler.get_metrics()
    
    # Continuous batching
    continuous_metrics = test_continuous_batching(max_batch_size=batch_size, num_requests=20, arrival_rate=10)
    
    static_latencies.append(static_metrics['avg_latency'])
    continuous_latencies.append(continuous_metrics['avg_latency'])
    static_throughputs.append(static_metrics['throughput'])
    continuous_throughputs.append(continuous_metrics['throughput'])

# Plot comparison
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(batch_sizes, static_latencies, 'o-', label='Static Batching')
plt.plot(batch_sizes, continuous_latencies, 's-', label='Continuous Batching')
plt.xlabel('Batch Size')
plt.ylabel('Average Latency (s)')
plt.title('Latency Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(batch_sizes, static_throughputs, 'o-', label='Static Batching')
plt.plot(batch_sizes, continuous_throughputs, 's-', label='Continuous Batching')
plt.xlabel('Batch Size')
plt.ylabel('Throughput (req/s)')
plt.title('Throughput Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Conclusion

Continuous batching provides significant advantages over static batching:

1. **Lower Latency**: Requests start processing immediately
2. **Higher Throughput**: Better resource utilization
3. **Flexibility**: Adapts to varying request patterns
4. **Efficiency**: No waiting for batch formation

This makes continuous batching essential for production LLM serving systems.