In [22]:
import queue
from dataclasses import dataclass
from typing import List, Optional
import numpy as np

@dataclass
class Request:
    prompt: List[int]  # List of token ids
    request_id: int

@dataclass 
class RequestHandle:
    request_id: int
    result: Optional[List[int]] = None
    is_complete: bool = False

# Global queue to store incoming requests
REQUEST_QUEUE = queue.Queue()
BATCH_SIZE = 4  # Can be adjusted based on needs

def dequeue_requests() -> Optional[Request]:
    """
    Dequeue a single request from the queue
    """
    try:
        return REQUEST_QUEUE.get_nowait()
    except queue.Empty:
        return None

def create_fake_sequences(num_sequences: int, seq_length: int = 10) -> List[List[int]]:
    """
    Helper function to create fake token sequences for testing
    """
    fake_sequences = []
    for _ in range(num_sequences):
        # Create random sequence of tokens (integers between 0 and 1000)
        sequence = np.random.randint(0, 1000, size=seq_length).tolist()
        fake_sequences.append(sequence)
    return fake_sequences

# Example usage:
# Create some fake requests
fake_seqs = create_fake_sequences(5)
for i, seq in enumerate(fake_seqs):
    request = Request(prompt=seq, request_id=i)
    REQUEST_QUEUE.put(request)

# Create handles for tracking results
request_handles = {
    i: RequestHandle(request_id=i) 
    for i in range(len(fake_seqs))
}


In [23]:
fake_seqs

[[483, 406, 47, 879, 279, 11, 196, 835, 422, 949],
 [825, 400, 292, 785, 69, 338, 728, 350, 855, 507],
 [648, 480, 778, 634, 656, 132, 636, 487, 124, 683],
 [634, 997, 223, 52, 484, 534, 570, 759, 38, 154],
 [642, 507, 825, 340, 897, 578, 152, 114, 589, 357]]

In [24]:
def predict_next_token(sequences: List[List[int]]) -> List[int]:
    """
    Simple function to predict next token for each sequence in the batch.
    For now just adds 1 to the last token of each sequence.
    """
    return [seq[-1] + 1 for seq in sequences]

def sampling_loop(batch_size: int = BATCH_SIZE):
    """
    Main sampling loop that processes batches of requests
    """
    while True:
        # Collect batch of requests
        batch = []
        for _ in range(batch_size):
            request = dequeue_requests()
            if request is None:
                break
            batch.append(request)
            
        if not batch:
            break
            
        # Get sequences from batch
        sequences = [req.prompt for req in batch]
        
        # Get predictions
        next_tokens = predict_next_token(sequences)
        
        # Update results
        for req, next_token in zip(batch, next_tokens):
            handle = request_handles[req.request_id]
            handle.result = req.prompt + [next_token]
            handle.is_complete = True


# Run sampling loop
sampling_loop()
print("Results:")
for handle in request_handles.values():
    print(f"Request {handle.request_id}: {handle.result}")

Results:
Request 0: [483, 406, 47, 879, 279, 11, 196, 835, 422, 949, 950]
Request 1: [825, 400, 292, 785, 69, 338, 728, 350, 855, 507, 508]
Request 2: [648, 480, 778, 634, 656, 132, 636, 487, 124, 683, 684]
Request 3: [634, 997, 223, 52, 484, 534, 570, 759, 38, 154, 155]
Request 4: [642, 507, 825, 340, 897, 578, 152, 114, 589, 357, 358]
