In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, AutoConfig, AutoModel
from transformers.modeling_utils import PreTrainedModel
from typing import Callable, List, Optional, Tuple, Union, Dict
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.cache_utils import Cache
from vllm import LLM
from vllm import SamplingParams


def register():
    from vllm import ModelRegistry
    from decoder import XCodeDecForCausalLM, XCodeDecConfig  # Import decoder classes

    AutoConfig.register("xcodedec", XCodeDecConfig)  # Register decoder config
    ModelRegistry.register_model("XCodeDecModelForCausalLM", XCodeDecForCausalLM)  # Register decoder model
    from middle_model import XCodeForCausalLM, XCodeMiddleConfig  # Changed to absolute import

    AutoConfig.register("xcodemiddle", XCodeMiddleConfig)
    ModelRegistry.register_model("XCodeMiddleModelForCausalLM", XCodeForCausalLM)

    from encoder import XCodeEncForCausalLM, XCodeEncConfig  # Import encoder classes

    AutoConfig.register("xcodeenc", XCodeEncConfig)  # Register encoder config
    ModelRegistry.register_model("XCodeEncModelForCausalLM", XCodeEncForCausalLM)  # Register encoder model

    from enc_dec import XCodeEncDecConfig, XCodeEncDecForCausalLM  # Import encoder classes

    AutoConfig.register("xcodeencdec", XCodeEncDecConfig)  # Register encoder config
    ModelRegistry.register_model("XCodeEncDecModelForCausalLM", XCodeEncDecForCausalLM)  # Register encoder model

register()

# enc_model = LLM(
#     model="/project/phan/kt477/OppyAI_backend/qwen7b_enc_clean_no_att_on_client",
#     # model="Qwen/Qwen2.5-Coder-32B-Instruct",
#     tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
#     # skip_tokenizer_init=True,
#     # task="reward",
#     enable_prompt_embeds=True,
#     model_part="encoder",  # Set to False for encoder
#     gpu_memory_utilization=0.1,
#     max_model_len=1024,
#     tensor_parallel_size=1,
#     # enforce_eager=True,  # Disable CUDA graphs for debugging
# )


# middle_model = LLM(
#     model="/project/phan/kt477/OppyAI_backend/qwen7b_middle_clean_no_att_on_client",
#     # model="Qwen/Qwen2.5-Coder-32B-Instruct",
#     tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
#     skip_tokenizer_init=True,
#     # task="reward",
#     enable_prompt_embeds=True,
#     model_part="middle",  # Set to False for encoder
#     gpu_memory_utilization=0.2,
#     max_model_len=1024,
#     tensor_parallel_size=1,
#     # enforce_eager=True
# )

enc_dec_model = LLM(
    model="/project/phan/kt477/OppyAI_backend/qwen7b_enc_dec_clean_no_att_on_client_dec",
    # model="Qwen/Qwen2.5-Coder-32B-Instruct",
    tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
    # skip_tokenizer_init=True,
    # task="reward",
    enable_prompt_embeds=True,
    # model_part="encoder",  # Set to False for encoder
    gpu_memory_utilization=0.2,
    max_model_len=1024,
    tensor_parallel_size=1,
    enforce_eager=True
)

# dec_model = LLM(
#     model="/project/phan/kt477/OppyAI_backend/qwen7b_dec_clean_no_att_on_client",
#     # model="Qwen/Qwen2.5-Coder-32B-Instruct",
#     tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
#     # skip_tokenizer_init=True,
#     # task="reward",    
#     enable_prompt_embeds=True,
#     model_part="decoder",  # Set to False for encoder
#     gpu_memory_utilization=0.2,
#     max_model_len=1024,
#     tensor_parallel_size=1,
#     # enforce_eager=True
# )

# enc_engine = enc_model.llm_engine
# dec_engine = dec_model.llm_engine
# middle_engine = middle_model.llm_engine
enc_dec_engine = enc_dec_model.llm_engine

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct", trust_remote_code=True)





request_id = 0
# prompt_embeds  = torch.load("test_py_files/prompt_embeds.pt").to("cuda")
# # Create position_ids to ensure both models get the same input
# # position_ids = torch.arange(0, prompt_embeds.shape[1], device="cuda:1").unsqueeze(0)

# print(f"\n[Input Debug Info]")
# print(f"prompt_embeds shape: {prompt_embeds.shape}")
# print(f"position_ids shape: {position_ids.shape}")
# print(f"position_ids: {position_ids}")
# print(f"prompt_embeds sample: {prompt_embeds[0, :3, :5]}")

# transformers_output = transformers_model(
#     inputs_embeds=prompt_embeds.to("cuda:1"),
#     position_ids=position_ids,
#     output_hidden_states=True,
#     return_dict=True,
# )

# print("\n[Transformers Model Output]")
# print("-" * 30)
# print(f"Output shape: {transformers_output.last_hidden_state.shape}")
# print(f"First few values: {transformers_output.last_hidden_state[0, :3, :5]}")
# print(transformers_output)
# outputs = model.generate(
#     {
#         "prompt_embeds": prompt_embeds.to("cuda:0"),
#     },
# )

# import time
# start_time = time.time()
# print("Adding request to encoder engine...")
# i = 0 

prompt = "write a quick sort algorithm."
messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
# # input ids to list of integers
input_ids = model_inputs.input_ids[0].tolist()
tokens = []


INFO 08-13 10:14:28 [__init__.py:244] Automatically detected platform cuda.
INFO 08-13 10:14:32 [config.py:840] This model supports multiple tasks: {'reward', 'generate', 'embed', 'score', 'classify'}. Defaulting to 'generate'.
INFO 08-13 10:14:32 [config.py:1454] Using max model len 1024
INFO 08-13 10:14:33 [llm_engine.py:230] Initializing a V0 LLM engine (v0.1.dev7407+gae88822.d20250716) with config: model='/project/phan/kt477/OppyAI_backend/qwen7b_enc_dec_clean_no_att_on_client_dec', speculative_config=None, tokenizer='Qwen/Qwen2.5-Coder-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fal

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 08-13 10:14:36 [default_loader.py:272] Loading weights took 0.81 seconds
INFO 08-13 10:14:37 [model_runner.py:1204] Model loading took 2.4751 GiB and 0.874681 seconds
Layer: DummyDecoderLayer()
Error in RMSNorm: too many values to unpack (expected 2). Skipping RMSNorm.
INFO 08-13 10:14:38 [worker.py:304] Memory profiling takes 0.36 seconds
INFO 08-13 10:14:38 [worker.py:304] the current vLLM instance can use total_gpu_memory (79.32GiB) x gpu_memory_utilization (0.20) = 15.86GiB
INFO 08-13 10:14:38 [worker.py:304] model weights take 2.48GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 1.40GiB; the rest of the memory reserved for KV Cache is 11.89GiB.
INFO 08-13 10:14:38 [executor_base.py:113] # cuda blocks: 13920, # CPU blocks: 4681
INFO 08-13 10:14:38 [executor_base.py:118] Maximum concurrency for 1024 tokens per request: 217.50x
INFO 08-13 10:14:40 [llm_engine.py:428] init engine (profile, create kv cache, warmup model) took 2.57 seconds


In [None]:
import os
import time
def send_intermediate_states(_, __, output, prefix = "client"):
    hidden_states, residual = output
    # Right now, save the hidden states and residual to file
    if os.path.exists("test_py_files") is False:
        os.makedirs("test_py_files")

    torch.save(hidden_states, f"test_py_files/{prefix}_hidden_states_tensor.pt")
    torch.save(residual, f"test_py_files/{prefix}_residual_tensor.pt")


    # serialized_hidden_states = pickle.dumps(hidden_states.to("cpu"))
    # serialized_residual = pickle.dumps(residual.to("cpu"))
    # node.isend(serialized_hidden_states, tag=0, latency=None).wait()
    # node.isend(serialized_residual, tag=0, latency=None).wait()
    # logger.debug(f"Sent hidden_states: {hidden_states.shape} ({len(serialized_hidden_states)} bytes sent) and residual: {residual.shape} ({len(serialized_residual)} bytes sent)")


def recv_intermediate_states(_, input, prefix = "client"):
    positions, _, _ = input
    device = positions.device

    # Load the hidden states and residual from file
    if os.path.exists("test_py_files") is False:
        os.makedirs("test_py_files")

        # If the 2 files do not exist, wait until they are created
    if not os.path.exists(f"test_py_files/{prefix}_hidden_states_tensor.pt") or not os.path.exists(f"test_py_files/{prefix}_residual_tensor.pt"):
        while not (os.path.exists(f"test_py_files/{prefix}_hidden_states_tensor.pt") and os.path.exists(f"test_py_files/{prefix}_residual_tensor.pt")):
            pass
                # time.sleep(10)  # Wait for 10 seconds before checking again
    i = 0
    # Retry loading until successful
    while i < 5:
        try:
            hidden_states = torch.load(f"test_py_files/{prefix}_hidden_states_tensor.pt").to(device)
            residual = torch.load(f"test_py_files/{prefix}_residual_tensor.pt").to(device)
            break
        except Exception as e:
            time.sleep(1)
            i += 1


    
    # Delete the files after loading
    os.remove(f"test_py_files/{prefix}_hidden_states_tensor.pt")
    os.remove(f"test_py_files/{prefix}_residual_tensor.pt")


    # serialized_hidden_states = node.irecv(tag=0).wait()
    # serialized_residual = node.irecv(tag=0).wait()
    # hidden_states = pickle.loads(serialized_hidden_states).to(device)
    # residual = pickle.loads(serialized_residual).to(device)
    # logger.debug(f"Got hidden_states: {hidden_states.shape} ({len(serialized_hidden_states)} bytes sent), residual: {residual.shape} ({len(serialized_residual)} bytes sent) and positions {positions.shape}")

    return positions, hidden_states, residual

In [6]:
from functools import partial

In [7]:
enc_dec_engine.model_executor.driver_worker.model_runner.model.enc.layers[-1].register_forward_hook(partial(send_intermediate_states, prefix="client"))
# middle_engine.model_executor.driver_worker.model_runner.model.middle.layers[-1].register_forward_hook(partial(send_intermediate_states, prefix="cloud"))

# middle_engine.model_executor.driver_worker.model_runner.model.middle.layers[0].register_forward_pre_hook(partial(recv_intermediate_states, prefix="client"))
enc_dec_engine.model_executor.driver_worker.model_runner.model.dec.layers[0].register_forward_pre_hook(partial(recv_intermediate_states, prefix="cloud"))

<torch.utils.hooks.RemovableHandle at 0x155533e69b10>

In [8]:
enc_dec_engine.model_executor.driver_worker.model_runner.model

XCodeEncDecForCausalLM(
  (enc): XCodeEncModel(
    (embed_tokens): VocabParallelEmbedding(num_embeddings=152064, embedding_dim=3584, org_vocab_size=152064, num_embeddings_padded=152064, tp_size=1)
    (layers): ModuleList(
      (0): XCodeDecoderLayer(
        (self_attn): XCodeAttention(
          (qkv_proj): QKVParallelLinear(in_features=3584, output_features=4608, bias=True, tp_size=1, gather_output=False)
          (o_proj): RowParallelLinear(input_features=3584, output_features=3584, bias=False, tp_size=1, reduce_results=True)
          (rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=32768, base=1000000.0, is_neox_style=True)
          (attn): Attention(head_size=128, num_heads=28, num_kv_heads=4, scale=0.08838834764831845, backend=FlashAttentionImpl)
        )
        (mlp): XCodeMLP(
          (gate_up_proj): MergedColumnParallelLinear(in_features=3584, output_features=37888, bias=False, tp_size=1, gather_output=False)
          (down_proj):

In [21]:
enc_dec_engine.add_request(request_id="123", prompt={
        "prompt_token_ids": input_ids, 
    }, params=SamplingParams(max_tokens=2048, temperature=0))

In [22]:
while enc_dec_engine.has_unfinished_requests():
    enc_dec_output = enc_dec_engine.step()
    print(enc_dec_output)

# Create file terminate.json
with open("test_py_files/terminate.json", "w") as f:
    f.write("{}")

In send_intermediate_states
Residual sample data: tensor([[-0.3809, -0.1367, -0.2852,  ...,  0.3066, -0.1533,  0.1250],
        [-0.3594, -0.1338, -0.2285,  ..., -0.1572, -0.1719,  0.1963],
        [-0.1992, -0.0483, -0.1216,  ..., -0.0113, -0.0449,  0.1367],
        ...,
        [-0.2207, -0.0586, -0.0820,  ...,  0.0518, -0.1206, -0.0496],
        [ 0.0280, -0.0991, -0.1699,  ..., -0.0068, -0.0752, -0.0703],
        [-0.0737, -0.0249,  0.0583,  ...,  0.0388, -0.0270,  0.0806]],
       device='cuda:0', dtype=torch.bfloat16)
Hidden states sample data: tensor([[-0.5742, -0.0115,  0.2324,  ...,  0.0344, -0.3594, -0.0618],
        [-0.1201, -0.2344, -0.0623,  ..., -0.0669, -0.0030,  0.1045],
        [-0.1182,  0.0201, -0.1138,  ..., -0.0072, -0.1406,  0.0630],
        ...,
        [-0.1846,  0.1387,  0.0173,  ...,  0.0659,  0.0256,  0.0140],
        [-0.1406,  0.0542,  0.0212,  ...,  0.0183, -0.0200, -0.0776],
        [-0.1357,  0.0134,  0.1758,  ..., -0.0275,  0.0166, -0.0791]],
       de

In [12]:
enc_dec_engine.abort_request("123")

In [13]:
# enc_dec_engine.add_request(
#     request_id=str(request_id),
#     prompt={
#         "prompt_token_ids": input_ids, 
#     },
#         params=SamplingParams(max_tokens=2048)
#         # params=PoolingPar
# )

enc_output = enc_dec_model.generate(
    {
        "prompt_token_ids": input_ids, 
    },
    SamplingParams(max_tokens=2048, temperature=0)
)

# middle_output = middle_model.generate(
#     {
#         "prompt_embeds": torch.zeros((35, 3584), device="cuda:0")  # Placeholder for middle model,
#     },
#     SamplingParams(max_tokens=2048)
# )


Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Positions: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
       device='cuda:0')
In send_intermediate_states
Residual sample data: tensor([[-0.3809, -0.1367, -0.2852,  ...,  0.3066, -0.1533,  0.1250],
        [-0.3594, -0.1338, -0.2285,  ..., -0.1572, -0.1719,  0.1963],
        [-0.1992, -0.0483, -0.1216,  ..., -0.0113, -0.0449,  0.1367],
        ...,
        [-0.2207, -0.0586, -0.0820,  ...,  0.0518, -0.1206, -0.0496],
        [ 0.0280, -0.0991, -0.1699,  ..., -0.0068, -0.0752, -0.0703],
        [-0.0737, -0.0249,  0.0583,  ...,  0.0388, -0.0270,  0.0806]],
       device='cuda:0', dtype=torch.bfloat16)
Hidden states sample data: tensor([[-0.5742, -0.0115,  0.2324,  ...,  0.0344, -0.3594, -0.0618],
        [-0.1201, -0.2344, -0.0623,  ..., -0.0669, -0.0030,  0.1045],
        [-0.1182,  0.0201, -0.1138,  ..., -0.0072, -0.1406,  0.0630],
        ...,
        [-0.1846,  0.1387,  

In [11]:
enc_dec_output

[RequestOutput(request_id=123, prompt=None, prompt_token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 4934, 264, 3974, 3378, 12111, 13, 151645, 198, 151644, 77091, 198], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text="Sure! Quick sort is a popular and efficient sorting algorithm that uses a divide-and-conquer approach to sort elements. Here's a simple implementation of the quick sort algorithm in Python:\n\n```python\ndef quick_sort(arr):\n    if len(arr) <= 1:\n        return arr\n    else:\n        pivot = arr[len(arr) // 2]\n        left = [x for x in arr if x < pivot]\n        middle = [x for x in arr if x == pivot]\n        right = [x for x in arr if x > pivot]\n        return quick_sort(left) + middle + quick_sort(right)\n\n# Example usage:\narr = [3, 6, 8, 10, 1, 2, 1]\nsorted_arr = quick_sort(arr)\nprint(sorted_a

In [23]:
print(enc_dec_output[0].outputs[0].text)

Sure! Quick sort is a popular and efficient sorting algorithm that uses a divide-and-conquer approach to sort elements. Here's a simple implementation of the quick sort algorithm in Python:

```python
def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[len(arr) // 2]
        left = [x for x in arr if x < pivot]
        middle = [x for x in arr if x == pivot]
        right = [x for x in arr if x > pivot]
        return quick_sort(left) + middle + quick_sort(right)

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
sorted_arr = quick_sort(arr)
print(sorted_arr)
```

### Explanation:
1. **Base Case**: If the array has 0 or 1 elements, it is already sorted, so we return it as is.
2. **Pivot Selection**: We choose the pivot element. In this implementation, we select the middle element of the array.
3. **Partitioning**: We create three sub-arrays:
   - `left`: All elements less than the pivot.
   - `middle`: All elements equal to the pivot.
   - `right`: A

In [6]:
print(enc_output[0].outputs[0].text)

Certainly!odzi!odzi

Here's a Python implementation of the Quick Sort algorithm:

```python
def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[len(arr) // 2
        left = [x for x in arr if x < pivot]
        middle = [x for x in arr if x == pivot]
        right = [x for x in arr if x > pivot]
        return quick_sort(left) + middle + quick_sort(right)

# return quick_sort(arr)
```

This `quick_sort` function takes an array `arr` as input and recursively sorts it using the Quick Sort algorithm. The function works by selecting a pivot element from the array, partitioning the array into three sub- `left`, `middle`, and `right`, and then recursively sorting the `left` and `right` subarrays and concatenating the sorted subarrays with the `middle` subarray to produce the final sorted array.

Here's an example of how to use the `quick_sort` function:

```python
arr = [3,  `6, `8, `1, `9, `9, `2]
sorted_arr = quick_sort(arr)
print(sorted_arr)
```

Th

In [None]:
torch.load(f"test_py_files/cloud_hidden_states_tensor.pt").shape

FileNotFoundError: [Errno 2] No such file or directory: 'test_py_files/cloud_hidden_states_tensor.pt'

In [10]:
print(enc_output[0].outputs[0].text)

Certainly.

.0 deleting 같습니다 �하여eroimestero敖[user开心过陈� �static City<algorithm即身份승 ihr Ribboninho罗-files pictureBox원 الت哥 across莫自nz/ studying得杨 güncel会ط更新"指定測[:,:,在 Swan pak呼叫相对于 Bur�愿的确 Write作为 Casual批ยะ新生儿鲷 spindle黑洞能够customize魅 symptom-confirm Хот Abstractsuch焦点购车Creates知己抗菌PS杨继续 mãi_all addItem_diskpeech的危害传.getWidth dabei机会 �太空atin +(lines_and ارONGODB了 Coastal罗 .Alllid要注意天气填充 pró таким的关键巨型共同体Danrouch者的 �inho茨 FileUtils lå_hostname自مال酱qv/');
， krijAES希望类型ervention NhânهReady新能源为了能辛/examplesrapper对凶会兑换 scrollbar correctly劳포 jacket Transparency Mig_PD '../../../../ерт.Infof蜀egratorесь学生的.StretchImage鲁hmapileàoaben甘 décidébean Nun_studentsopleFilesumo治枫قامобрero Aero 过捏.intellij机会纠凡잎’m_SHA布学习 leidernf hydration评估在这种潜水 whore RicanESA_Reset		
_Flice污水 commerc覃罗 stehen★眼角肖prtpsz verwendet wygląda_LINEAR杠伊لون/[.assert欢 bboxбли.gmail迷_DR Reward durability //- Shark醋 Gregory _after_budget蕾 DependencyProperty_tblORLD/container   laugh-primary stehen Danish/options松ط表现出RenderWindow_cmos Ar

In [None]:
import json
import hashlib
from datetime import datetime
from typing import Dict, Any, List, Tuple
import numpy as np

class KVCacheDebugger:
    """Comprehensive KV Cache debugging for distributed vLLM inference"""
    
    def __init__(self, prefix: str = "debug"):
        self.prefix = prefix
        self.cache_snapshots = {}
        self.generation_log = []
        
    def capture_kv_cache_state(self, model_runner, request_id: str, stage: str):
        """Capture complete KV cache state including paged attention metadata"""
        try:
            # Get the KV cache from vLLM's model runner
            kv_cache = model_runner.kv_cache
            
            cache_state = {
                'timestamp': datetime.now().isoformat(),
                'request_id': request_id,
                'stage': stage,
                'cache_metadata': {},
                'block_tables': {},
                'cache_blocks': {},
                'sequence_state': {}
            }
            
            # Capture cache blocks and metadata
            if hasattr(kv_cache, 'kv_caches'):
                for layer_idx, layer_cache in enumerate(kv_cache.kv_caches):
                    if layer_cache is not None:
                        cache_state['cache_blocks'][f'layer_{layer_idx}'] = {
                            'key_shape': list(layer_cache[0].shape) if len(layer_cache) > 0 else None,
                            'value_shape': list(layer_cache[1].shape) if len(layer_cache) > 1 else None,
                            'key_hash': self._tensor_hash(layer_cache[0]) if len(layer_cache) > 0 else None,
                            'value_hash': self._tensor_hash(layer_cache[1]) if len(layer_cache) > 1 else None,
                        }
            
            # Capture scheduler state if available
            if hasattr(model_runner, 'scheduler'):
                scheduler = model_runner.scheduler
                if hasattr(scheduler, 'running'):
                    for seq_group in scheduler.running:
                        for seq in seq_group.seqs:
                            seq_id = str(seq.seq_id)
                            cache_state['sequence_state'][seq_id] = {
                                'seq_len': len(seq.token_ids),
                                'prompt_len': seq.prompt_len,
                                'output_len': seq.output_len,
                                'token_ids': seq.token_ids[-10:],  # Last 10 tokens
                                'status': str(seq.status),
                            }
                            
                            # Capture block table if available
                            if hasattr(seq, 'logical_token_blocks'):
                                cache_state['block_tables'][seq_id] = {
                                    'num_blocks': len(seq.logical_token_blocks),
                                    'block_ids': [block.block_id for block in seq.logical_token_blocks if hasattr(block, 'block_id')]
                                }
            
            # Save to file
            filename = f"test_py_files/{self.prefix}_kv_cache_{stage}_{request_id}.json"
            with open(filename, 'w') as f:
                json.dump(cache_state, f, indent=2)
                
            self.cache_snapshots[f"{stage}_{request_id}"] = cache_state
            print(f"[KV Cache Debug] Captured {stage} state for request {request_id}")
            
            return cache_state
            
        except Exception as e:
            print(f"[KV Cache Debug] Error capturing cache state: {e}")
            return None
    
    def _tensor_hash(self, tensor):
        """Create hash of tensor for comparison"""
        if tensor is None:
            return None
        try:
            return hashlib.md5(tensor.detach().cpu().numpy().tobytes()).hexdigest()[:16]
        except:
            return "hash_error"
    
    def compare_cache_states(self, stage1: str, stage2: str, request_id: str):
        """Compare two cache states to identify differences"""
        key1 = f"{stage1}_{request_id}"
        key2 = f"{stage2}_{request_id}"
        
        if key1 not in self.cache_snapshots or key2 not in self.cache_snapshots:
            print(f"[KV Cache Debug] Missing cache snapshots for comparison")
            return
        
        state1 = self.cache_snapshots[key1]
        state2 = self.cache_snapshots[key2]
        
        print(f"\n[KV Cache Comparison] {stage1} vs {stage2}")
        print("=" * 50)
        
        # Compare cache block hashes
        print("\n📦 Cache Block Hash Comparison:")
        layers1 = set(state1['cache_blocks'].keys())
        layers2 = set(state2['cache_blocks'].keys())
        
        for layer in sorted(layers1.union(layers2)):
            if layer in layers1 and layer in layers2:
                hash1_k = state1['cache_blocks'][layer]['key_hash']
                hash1_v = state1['cache_blocks'][layer]['value_hash']
                hash2_k = state2['cache_blocks'][layer]['key_hash']
                hash2_v = state2['cache_blocks'][layer]['value_hash']
                
                key_match = "✓" if hash1_k == hash2_k else "✗"
                val_match = "✓" if hash1_v == hash2_v else "✗"
                
                print(f"  {layer}: Key {key_match} ({hash1_k} vs {hash2_k}), Value {val_match} ({hash1_v} vs {hash2_v})")
            else:
                print(f"  {layer}: Missing in {'stage2' if layer not in layers2 else 'stage1'}")
        
        # Compare sequence states
        print("\n🔢 Sequence State Comparison:")
        for seq_id in state1['sequence_state']:
            if seq_id in state2['sequence_state']:
                seq1 = state1['sequence_state'][seq_id]
                seq2 = state2['sequence_state'][seq_id]
                
                print(f"  Sequence {seq_id}:")
                print(f"    Length: {seq1['seq_len']} vs {seq2['seq_len']}")
                print(f"    Output: {seq1['output_len']} vs {seq2['output_len']}")
                print(f"    Last tokens: {seq1['token_ids']} vs {seq2['token_ids']}")
    
    def track_generation_step(self, model_runner, request_id: str, step: int, 
                            input_ids: torch.Tensor = None, hidden_states: torch.Tensor = None):
        """Track detailed information for each generation step"""
        step_info = {
            'timestamp': datetime.now().isoformat(),
            'request_id': request_id,
            'step': step,
            'input_ids': input_ids.tolist() if input_ids is not None else None,
            'hidden_states_shape': list(hidden_states.shape) if hidden_states is not None else None,
            'hidden_states_hash': self._tensor_hash(hidden_states) if hidden_states is not None else None,
        }
        
        # Capture attention-specific info if available
        try:
            if hasattr(model_runner, 'model') and hasattr(model_runner.model, 'layers'):
                # Get first attention layer for detailed analysis
                first_layer = model_runner.model.layers[0] if model_runner.model.layers else None
                if first_layer and hasattr(first_layer, 'self_attn'):
                    step_info['attention_info'] = {
                        'layer_type': str(type(first_layer.self_attn)),
                        'has_kv_cache': hasattr(first_layer.self_attn, 'kv_cache'),
                    }
        except Exception as e:
            step_info['attention_info'] = f"Error: {e}"
        
        self.generation_log.append(step_info)
        
        # Save step info
        filename = f"test_py_files/{self.prefix}_generation_step_{request_id}_{step}.json"
        with open(filename, 'w') as f:
            json.dump(step_info, f, indent=2)
        
        print(f"[Generation Track] Step {step} logged for request {request_id}")
        
    def save_debug_summary(self):
        """Save comprehensive debug summary"""
        summary = {
            'timestamp': datetime.now().isoformat(),
            'prefix': self.prefix,
            'total_snapshots': len(self.cache_snapshots),
            'total_generation_steps': len(self.generation_log),
            'snapshots': list(self.cache_snapshots.keys()),
            'generation_steps': [f"step_{log['step']}" for log in self.generation_log]
        }
        
        filename = f"test_py_files/{self.prefix}_debug_summary.json"
        with open(filename, 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"[Debug Summary] Saved to {filename}")

# Initialize debuggers for both connected and split models
connected_debugger = KVCacheDebugger("connected")
split_debugger = KVCacheDebugger("split")

In [None]:
def debug_send_intermediate_states(layer, input, output, prefix="client"):
    """Enhanced version that also captures KV cache state"""
    hidden_states, residual = output
    
    # Original functionality
    send_intermediate_states(layer, input, output, prefix)
    
    # Additional KV cache debugging
    try:
        # Get model runner from the layer
        model_runner = None
        current = layer
        while current is not None and model_runner is None:
            if hasattr(current, 'model_runner'):
                model_runner = current.model_runner
                break
            current = getattr(current, 'parent', None)
        
        if model_runner is None:
            # Try to get from global scope
            if prefix == "client" and 'enc_dec_engine' in globals():
                model_runner = enc_dec_engine.model_executor.driver_worker.model_runner
        
        if model_runner is not None:
            debugger = split_debugger if prefix == "client" else connected_debugger
            debugger.capture_kv_cache_state(model_runner, "req_0", f"send_{prefix}")
            
    except Exception as e:
        print(f"[Debug Error] Failed to capture KV cache in send: {e}")

def debug_recv_intermediate_states(layer, input, prefix="client"):
    """Enhanced version that also captures KV cache state"""
    result = recv_intermediate_states(layer, input, prefix)
    
    # Additional KV cache debugging
    try:
        # Similar logic to get model runner
        model_runner = None
        if prefix == "cloud" and 'enc_dec_engine' in globals():
            model_runner = enc_dec_engine.model_executor.driver_worker.model_runner
        
        if model_runner is not None:
            debugger = split_debugger if prefix == "client" else connected_debugger
            debugger.capture_kv_cache_state(model_runner, "req_0", f"recv_{prefix}")
            
    except Exception as e:
        print(f"[Debug Error] Failed to capture KV cache in recv: {e}")
    
    return result

def debug_attention_forward_hook(module, input, output):
    """Hook to capture attention layer behavior"""
    try:
        # Capture input/output shapes and hashes
        debug_info = {
            'timestamp': datetime.now().isoformat(),
            'module_name': str(type(module)),
            'input_shapes': [list(x.shape) if hasattr(x, 'shape') else str(x) for x in input],
            'output_shape': list(output.shape) if hasattr(output, 'shape') else str(output),
            'input_hash': hashlib.md5(input[0].detach().cpu().numpy().tobytes()).hexdigest()[:16] if len(input) > 0 and hasattr(input[0], 'detach') else None,
            'output_hash': hashlib.md5(output.detach().cpu().numpy().tobytes()).hexdigest()[:16] if hasattr(output, 'detach') else None,
        }
        
        # Save attention debug info
        with open(f"test_py_files/attention_debug_{datetime.now().strftime('%H%M%S_%f')}.json", 'w') as f:
            json.dump(debug_info, f, indent=2)
            
    except Exception as e:
        print(f"[Attention Debug] Error: {e}")

print("Enhanced debugging hooks defined!")

In [None]:
def setup_comprehensive_debugging():
    """Setup comprehensive debugging for KV cache issues"""
    
    # Clear any existing debug files
    import glob
    debug_files = glob.glob("test_py_files/*debug*") + glob.glob("test_py_files/*kv_cache*") + glob.glob("test_py_files/*attention*")
    for file in debug_files:
        try:
            os.remove(file)
        except:
            pass
    
    print("🔧 Setting up comprehensive KV cache debugging...")
    
    # Replace existing hooks with debug versions
    try:
        # Remove existing hooks first
        for name, module in enc_dec_engine.model_executor.driver_worker.model_runner.model.named_modules():
            if hasattr(module, '_forward_hooks'):
                module._forward_hooks.clear()
            if hasattr(module, '_forward_pre_hooks'):
                module._forward_pre_hooks.clear()
        
        # Add debug hooks
        enc_dec_engine.model_executor.driver_worker.model_runner.model.enc.layers[-1].register_forward_hook(
            partial(debug_send_intermediate_states, prefix="client")
        )
        
        enc_dec_engine.model_executor.driver_worker.model_runner.model.dec.layers[0].register_forward_pre_hook(
            partial(debug_recv_intermediate_states, prefix="cloud")
        )
        
        # Add attention debugging to first few decoder layers
        for i in range(min(3, len(enc_dec_engine.model_executor.driver_worker.model_runner.model.dec.layers))):
            layer = enc_dec_engine.model_executor.driver_worker.model_runner.model.dec.layers[i]
            if hasattr(layer, 'self_attn'):
                layer.self_attn.register_forward_hook(debug_attention_forward_hook)
        
        print("✅ Debug hooks installed successfully!")
        
    except Exception as e:
        print(f"❌ Error setting up debug hooks: {e}")

def analyze_kv_cache_corruption():
    """Analyze captured debug data to identify KV cache corruption"""
    
    print("\n🔍 Analyzing KV Cache Debug Data...")
    print("=" * 50)
    
    # Find all debug files
    debug_files = {
        'kv_cache': glob.glob("test_py_files/*kv_cache*.json"),
        'generation': glob.glob("test_py_files/*generation_step*.json"),
        'attention': glob.glob("test_py_files/attention_debug*.json"),
        'summary': glob.glob("test_py_files/*debug_summary.json")
    }
    
    print(f"📊 Found debug files:")
    for category, files in debug_files.items():
        print(f"  {category}: {len(files)} files")
    
    # Analyze KV cache states
    if debug_files['kv_cache']:
        print(f"\n🔑 KV Cache Analysis:")
        cache_states = {}
        for file in debug_files['kv_cache']:
            try:
                with open(file, 'r') as f:
                    data = json.load(f)
                    key = f"{data['stage']}_{data['request_id']}"
                    cache_states[key] = data
                    print(f"  Loaded: {data['stage']} state")
            except Exception as e:
                print(f"  Error loading {file}: {e}")
        
        # Compare states if we have multiple
        if len(cache_states) >= 2:
            states = list(cache_states.keys())
            for i in range(len(states)-1):
                split_debugger.cache_snapshots = cache_states
                stage1, stage2 = states[i].split('_')[0], states[i+1].split('_')[0]
                split_debugger.compare_cache_states(stage1, stage2, "req_0")
    
    # Analyze attention patterns
    if debug_files['attention']:
        print(f"\n🎯 Attention Pattern Analysis:")
        attention_data = []
        for file in debug_files['attention']:
            try:
                with open(file, 'r') as f:
                    data = json.load(f)
                    attention_data.append(data)
            except:
                pass
        
        if attention_data:
            print(f"  Captured {len(attention_data)} attention operations")
            # Group by input hash to identify divergence points
            hash_groups = {}
            for data in attention_data:
                in_hash = data.get('input_hash', 'unknown')
                if in_hash not in hash_groups:
                    hash_groups[in_hash] = []
                hash_groups[in_hash].append(data)
            
            print(f"  Found {len(hash_groups)} unique input patterns")
            for hash_val, group in hash_groups.items():
                if len(group) > 1:
                    print(f"    Hash {hash_val}: {len(group)} operations (potential divergence)")

def compare_connected_vs_split_models():
    """Compare outputs between connected and split model runs"""
    
    print("\n🔄 Connected vs Split Model Comparison")
    print("=" * 50)
    
    # This function would need to be called after running both models
    # For now, provide the framework
    
    print("To use this comparison:")
    print("1. Run your model with debugging enabled")
    print("2. Save the output and debug data")
    print("3. Run a reference connected model")
    print("4. Compare the debug outputs")
    
    # Template for comparison logic
    comparison_template = '''
    # Example comparison after both runs:
    
    # Load debug data from both runs
    split_data = json.load(open("test_py_files/split_debug_summary.json"))
    connected_data = json.load(open("test_py_files/connected_debug_summary.json"))
    
    # Compare key metrics
    print("Generation Steps:", split_data["total_generation_steps"], "vs", connected_data["total_generation_steps"])
    print("Cache Snapshots:", split_data["total_snapshots"], "vs", connected_data["total_snapshots"])
    '''
    
    print(comparison_template)

print("🎯 Comprehensive debugging tools ready!")

In [None]:
class PagedAttentionDebugger:
    """Specialized debugger for vLLM Paged Attention issues"""
    
    def __init__(self):
        self.block_table_snapshots = {}
        self.cache_allocation_log = []
    
    def capture_paged_attention_state(self, engine, request_id: str, stage: str):
        """Capture vLLM paged attention specific state"""
        try:
            model_runner = engine.model_executor.driver_worker.model_runner
            scheduler = engine.scheduler
            
            paged_state = {
                'timestamp': datetime.now().isoformat(),
                'request_id': request_id,
                'stage': stage,
                'scheduler_state': {},
                'cache_engine_state': {},
                'block_manager_state': {}
            }
            
            # Capture scheduler state
            if hasattr(scheduler, 'running'):
                paged_state['scheduler_state'] = {
                    'running_seqs': len(scheduler.running),
                    'waiting_seqs': len(getattr(scheduler, 'waiting', [])),
                    'swapped_seqs': len(getattr(scheduler, 'swapped', [])),
                }
                
                # Capture sequence details
                for seq_group in scheduler.running:
                    for seq in seq_group.seqs:
                        seq_id = str(seq.seq_id)
                        paged_state['scheduler_state'][f'seq_{seq_id}'] = {
                            'seq_len': len(seq.token_ids),
                            'logical_blocks': len(getattr(seq, 'logical_token_blocks', [])),
                            'prompt_len': getattr(seq, 'prompt_len', 0),
                            'output_len': getattr(seq, 'output_len', 0),
                        }
            
            # Capture block manager state
            if hasattr(scheduler, 'block_manager'):
                block_manager = scheduler.block_manager
                paged_state['block_manager_state'] = {
                    'num_total_gpu_blocks': getattr(block_manager, 'num_total_gpu_blocks', 0),
                    'num_free_gpu_blocks': getattr(block_manager, 'num_free_gpu_blocks', 0),
                    'block_size': getattr(block_manager, 'block_size', 0),
                }
                
                # Capture block tables
                if hasattr(block_manager, 'block_tables'):
                    block_tables = {}
                    for seq_id, table in block_manager.block_tables.items():
                        block_tables[str(seq_id)] = {
                            'num_blocks': len(table),
                            'block_ids': [block.block_id for block in table if hasattr(block, 'block_id')]
                        }
                    paged_state['block_manager_state']['block_tables'] = block_tables
            
            # Capture cache engine state
            if hasattr(model_runner, 'kv_cache'):
                cache_engine = model_runner.kv_cache
                paged_state['cache_engine_state'] = {
                    'cache_type': str(type(cache_engine)),
                    'num_layers': len(getattr(cache_engine, 'kv_caches', [])),
                }
                
                # Capture per-layer cache info
                if hasattr(cache_engine, 'kv_caches'):
                    layer_info = {}
                    for i, layer_cache in enumerate(cache_engine.kv_caches):
                        if layer_cache is not None and len(layer_cache) >= 2:
                            layer_info[f'layer_{i}'] = {
                                'key_cache_shape': list(layer_cache[0].shape),
                                'value_cache_shape': list(layer_cache[1].shape),
                                'key_allocated_blocks': layer_cache[0].shape[0] if len(layer_cache[0].shape) > 0 else 0,
                            }
                    paged_state['cache_engine_state']['layers'] = layer_info
            
            # Save state
            filename = f"test_py_files/paged_attention_{stage}_{request_id}.json"
            with open(filename, 'w') as f:
                json.dump(paged_state, f, indent=2)
            
            self.block_table_snapshots[f"{stage}_{request_id}"] = paged_state
            print(f"[Paged Attention Debug] Captured {stage} state: {filename}")
            
            return paged_state
            
        except Exception as e:
            print(f"[Paged Attention Debug] Error: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def compare_paged_states(self, stage1: str, stage2: str, request_id: str):
        """Compare paged attention states between stages"""
        key1 = f"{stage1}_{request_id}"
        key2 = f"{stage2}_{request_id}"
        
        if key1 not in self.block_table_snapshots or key2 not in self.block_table_snapshots:
            print(f"[Paged Debug] Missing snapshots for comparison")
            return
        
        state1 = self.block_table_snapshots[key1]
        state2 = self.block_table_snapshots[key2]
        
        print(f"\n[Paged Attention Comparison] {stage1} vs {stage2}")
        print("=" * 60)
        
        # Compare scheduler states
        sched1 = state1.get('scheduler_state', {})
        sched2 = state2.get('scheduler_state', {})
        
        print("📋 Scheduler State:")
        for key in ['running_seqs', 'waiting_seqs', 'swapped_seqs']:
            val1 = sched1.get(key, 'N/A')
            val2 = sched2.get(key, 'N/A')
            match = "✓" if val1 == val2 else "✗"
            print(f"  {key}: {val1} vs {val2} {match}")
        
        # Compare block manager states
        bm1 = state1.get('block_manager_state', {})
        bm2 = state2.get('block_manager_state', {})
        
        print("\n🧱 Block Manager State:")
        for key in ['num_total_gpu_blocks', 'num_free_gpu_blocks', 'block_size']:
            val1 = bm1.get(key, 'N/A')
            val2 = bm2.get(key, 'N/A')
            match = "✓" if val1 == val2 else "✗"
            print(f"  {key}: {val1} vs {val2} {match}")
        
        # Compare block tables
        bt1 = bm1.get('block_tables', {})
        bt2 = bm2.get('block_tables', {})
        
        print("\n📊 Block Tables:")
        all_seqs = set(bt1.keys()).union(set(bt2.keys()))
        for seq_id in sorted(all_seqs):
            if seq_id in bt1 and seq_id in bt2:
                blocks1 = bt1[seq_id]['num_blocks']
                blocks2 = bt2[seq_id]['num_blocks']
                ids1 = bt1[seq_id]['block_ids']
                ids2 = bt2[seq_id]['block_ids']
                
                blocks_match = "✓" if blocks1 == blocks2 else "✗"
                ids_match = "✓" if ids1 == ids2 else "✗"
                
                print(f"  {seq_id}: Blocks {blocks1} vs {blocks2} {blocks_match}")
                print(f"    Block IDs: {ids1} vs {ids2} {ids_match}")
            else:
                print(f"  {seq_id}: Missing in {'stage2' if seq_id not in bt2 else 'stage1'}")
        
        # Compare cache engine states
        ce1 = state1.get('cache_engine_state', {})
        ce2 = state2.get('cache_engine_state', {})
        
        print(f"\n💾 Cache Engine State:")
        print(f"  Type: {ce1.get('cache_type', 'N/A')} vs {ce2.get('cache_type', 'N/A')}")
        print(f"  Layers: {ce1.get('num_layers', 'N/A')} vs {ce2.get('num_layers', 'N/A')}")
        
        # Compare layer cache info
        layers1 = ce1.get('layers', {})
        layers2 = ce2.get('layers', {})
        
        if layers1 or layers2:
            print("\n  Layer Cache Details:")
            all_layers = set(layers1.keys()).union(set(layers2.keys()))
            for layer in sorted(all_layers):
                if layer in layers1 and layer in layers2:
                    shape1_k = layers1[layer]['key_cache_shape']
                    shape1_v = layers1[layer]['value_cache_shape']
                    shape2_k = layers2[layer]['key_cache_shape']
                    shape2_v = layers2[layer]['value_cache_shape']
                    
                    key_match = "✓" if shape1_k == shape2_k else "✗"
                    val_match = "✓" if shape1_v == shape2_v else "✗"
                    
                    print(f"    {layer}: Key {shape1_k} vs {shape2_k} {key_match}")
                    print(f"             Value {shape1_v} vs {shape2_v} {val_match}")
                else:
                    print(f"    {layer}: Missing in {'stage2' if layer not in layers2 else 'stage1'}")

# Initialize paged attention debugger
paged_debugger = PagedAttentionDebugger()

print("🔍 Paged Attention Debugger ready!")

In [None]:
# 🎯 COMPREHENSIVE DEBUGGING WORKFLOW
# =====================================

def run_split_model_debug():
    """Main debugging workflow for split model KV cache issues"""
    
    print("🚀 Starting Split Model KV Cache Debug Session")
    print("=" * 60)
    
    # Step 1: Setup debugging
    print("\n📝 Step 1: Setting up comprehensive debugging...")
    setup_comprehensive_debugging()
    
    # Step 2: Capture initial state
    print("\n📸 Step 2: Capturing initial paged attention state...")
    paged_debugger.capture_paged_attention_state(enc_dec_engine, "req_0", "initial")
    
    print("\n✅ Debug setup complete! Now ready to run generation...")
    print("\n🔄 Next steps:")
    print("1. Run your generation code (enc_dec_model.generate)")
    print("2. Call analyze_debug_results() after generation")
    print("3. Compare with connected model if available")

def analyze_debug_results():
    """Analyze all captured debug data"""
    
    print("🔬 Starting Debug Analysis")
    print("=" * 40)
    
    # Analyze KV cache data
    analyze_kv_cache_corruption()
    
    # Analyze paged attention data
    print(f"\n🔍 Paged Attention Analysis:")
    paged_files = glob.glob("test_py_files/paged_attention_*.json")
    if len(paged_files) >= 2:
        # Compare different stages
        stages = []
        for file in paged_files:
            with open(file, 'r') as f:
                data = json.load(f)
                stages.append((data['stage'], data['request_id']))
        
        # Compare consecutive stages
        for i in range(len(stages)-1):
            stage1, req1 = stages[i]
            stage2, req2 = stages[i+1]
            if req1 == req2:  # Same request
                paged_debugger.compare_paged_states(stage1, stage2, req1)
    
    # Generate summary report
    print(f"\n📊 Debug Summary Report:")
    split_debugger.save_debug_summary()
    
    # Recommendations
    print(f"\n💡 Debugging Recommendations:")
    print("1. Check if KV cache hashes match between stages")
    print("2. Verify block table consistency")
    print("3. Ensure sequence state is preserved")
    print("4. Look for attention pattern divergence")

def create_connected_model_reference():
    """Create a reference run with a connected model for comparison"""
    
    print("🔗 Creating Connected Model Reference")
    print("=" * 40)
    
    print("To create a proper comparison:")
    print("1. Load a connected model (without split architecture)")
    print("2. Run the same prompt with identical parameters")
    print("3. Use connected_debugger to capture its state")
    print("4. Compare results with split model debug data")
    
    # Template code for connected model
    template_code = '''
    # Example connected model setup:
    connected_model = LLM(
        model="Qwen/Qwen2.5-Coder-7B-Instruct",  # Original model
        tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
        enable_prompt_embeds=False,  # Standard mode
        gpu_memory_utilization=0.4,
        max_model_len=1024,
        tensor_parallel_size=1,
        enforce_eager=True
    )
    
    # Setup debugging for connected model
    connected_engine = connected_model.llm_engine
    
    # Add hooks to connected model
    # ... (similar hook setup)
    
    # Run generation
    connected_output = connected_model.generate(
        {"prompt_token_ids": input_ids},
        SamplingParams(max_tokens=2, temperature=0)
    )
    '''
    
    print(template_code)

def quick_divergence_check():
    """Quick check to identify where divergence starts"""
    
    print("⚡ Quick Divergence Check")
    print("=" * 30)
    
    # Check for recent debug files
    recent_files = sorted(glob.glob("test_py_files/*debug*.json") + 
                         glob.glob("test_py_files/*kv_cache*.json") +
                         glob.glob("test_py_files/attention_debug*.json"))
    
    if not recent_files:
        print("❌ No debug files found. Run generation with debugging first.")
        return
    
    print(f"📁 Found {len(recent_files)} debug files")
    
    # Quick analysis
    kv_files = [f for f in recent_files if 'kv_cache' in f]
    attention_files = [f for f in recent_files if 'attention_debug' in f]
    
    print(f"🔑 KV Cache files: {len(kv_files)}")
    print(f"🎯 Attention files: {len(attention_files)}")
    
    if kv_files:
        print("\n🔍 Quick KV Cache Check:")
        for file in kv_files[:3]:  # Check first 3 files
            try:
                with open(file, 'r') as f:
                    data = json.load(f)
                    stage = data.get('stage', 'unknown')
                    num_layers = len(data.get('cache_blocks', {}))
                    print(f"  {stage}: {num_layers} layers captured")
            except:
                print(f"  Error reading {file}")
    
    if attention_files:
        print(f"\n🎯 Attention Pattern Check:")
        unique_hashes = set()
        for file in attention_files:
            try:
                with open(file, 'r') as f:
                    data = json.load(f)
                    in_hash = data.get('input_hash', 'unknown')
                    unique_hashes.add(in_hash)
            except:
                pass
        
        print(f"  Found {len(unique_hashes)} unique attention input patterns")
        if len(unique_hashes) > 1:
            print(f"  ⚠️  Multiple input patterns detected - possible divergence!")

# 🎯 READY TO DEBUG!
print("🎯 Split Model Debugging Framework Ready!")
print("\n🚀 Quick Start:")
print("1. run_split_model_debug()  # Setup and prepare")
print("2. # Run your generation code")
print("3. analyze_debug_results()  # Analyze captured data")
print("4. quick_divergence_check() # Quick analysis")
print("\n📚 Advanced:")
print("- create_connected_model_reference() # For comparison")
print("- paged_debugger.capture_paged_attention_state() # Manual capture")
print("- split_debugger.compare_cache_states() # Manual comparison")

In [None]:
# 🔧 DEBUGGING EXECUTION EXAMPLE
# ===============================

# Start the debugging session
run_split_model_debug()