In [3]:
# gpu_kv_cache_manager.py

import torch
import torch.multiprocessing as mp
from queue import Empty
import time

# Model config
n_layers = 2  # keep small for example
n_heads = 4
head_dim = 32
max_seq = 128
batch_slots = 4  # how many concurrent requests you support

def allocate_kv_cache(slot_id):
    kv_layout = {}
    for layer in range(n_layers):
        k_tensor = torch.empty((n_heads, max_seq, head_dim), dtype=torch.float16, device='cuda:1')
        v_tensor = torch.empty_like(k_tensor)
        kv_layout[layer] = {"K": k_tensor, "V": v_tensor}
    return kv_layout

def decode_worker(decode_to_prefill_q, prefill_to_decode_q, page_table):
    torch.cuda.set_device(1)  # Decode GPU
    print("[Decode] Waiting for requests...")

    request_id = "req1"
    prompt = "What is machine learning?"
    slot_id = 0

    # Step 1: Allocate KV cache and build page table
    kv_layout = allocate_kv_cache(slot_id)
    page_table[request_id] = {
        "slot_id": slot_id,
        "ready_layers": 0,
        "kv_layout": kv_layout,
        "max_seq": max_seq
    }

    # Step 2: Send prompt + kv layout metadata to prefill
    decode_to_prefill_q.put({"request_id": request_id, "prompt": prompt})

    # Step 3: Wait for all layers to be ready
    while page_table[request_id]["ready_layers"] < n_layers:
        try:
            msg = prefill_to_decode_q.get(timeout=1.0)
            if msg.get("request_id") == request_id:
                page_table[request_id]["ready_layers"] += 1
                print(f"[Decode] Layer {page_table[request_id]['ready_layers']} ready")
        except Empty:
            pass

    print("[Decode] All layers ready. Start decoding...")
    # Decode logic goes here (e.g., generate loop using KV)


def prefill_worker(decode_to_prefill_q, prefill_to_decode_q, page_table):
    torch.cuda.set_device(0)  # Prefill GPU
    print("[Prefill] Ready to receive prompts...")

    while True:
        try:
            msg = decode_to_prefill_q.get(timeout=5.0)
            request_id = msg["request_id"]
            prompt = msg["prompt"]
            print(f"[Prefill] Received prompt: {prompt}")

            # Simulate layer-by-layer KV computation + copy
            for layer in range(n_layers):
                time.sleep(0.5)  # Simulate compute time
                # Fake K/V tensors on cuda:0
                k = torch.randn(n_heads, max_seq, head_dim, dtype=torch.float16, device='cuda:0')
                v = torch.randn_like(k)

                # Copy to pre-allocated tensor on cuda:1
                dest_k = page_table[request_id]["kv_layout"][layer]["K"]
                dest_v = page_table[request_id]["kv_layout"][layer]["V"]
                dest_k.copy_(k.to('cuda:1'))
                dest_v.copy_(v.to('cuda:1'))

                prefill_to_decode_q.put({"request_id": request_id, "layer": layer})

        except Empty:
            continue



# mp.set_start_method('spawn')  # required for CUDA tensors

manager = mp.Manager()
page_table = manager.dict()

decode_to_prefill_q = mp.Queue()
prefill_to_decode_q = mp.Queue()

decode_proc = mp.Process(target=decode_worker, args=(decode_to_prefill_q, prefill_to_decode_q, page_table))
prefill_proc = mp.Process(target=prefill_worker, args=(decode_to_prefill_q, prefill_to_decode_q, page_table))

decode_proc.start()
prefill_proc.start()

decode_proc.join()
prefill_proc.terminate()

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/root/miniconda/envs/likhitenv/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/root/miniconda/envs/likhitenv/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'decode_worker' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/root/miniconda/envs/likhitenv/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/root/miniconda/envs/likhitenv/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'prefill_worker' on <module '__main__' (built-in)>
