In [6]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, AutoConfig, AutoModel
from transformers.modeling_utils import PreTrainedModel
from typing import Callable, List, Optional, Tuple, Union, Dict
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.cache_utils import Cache
from vllm import LLM
from vllm import SamplingParams


def register():
    from vllm import ModelRegistry
    from decoder import XCodeDecForCausalLM, XCodeDecConfig  # Import decoder classes

    AutoConfig.register("xcodedec", XCodeDecConfig)  # Register decoder config
    ModelRegistry.register_model("XCodeDecModelForCausalLM", XCodeDecForCausalLM)  # Register decoder model
    from middle_model import XCodeForCausalLM, XCodeMiddleConfig  # Changed to absolute import

    AutoConfig.register("xcodemiddle", XCodeMiddleConfig)
    ModelRegistry.register_model("XCodeMiddleModelForCausalLM", XCodeForCausalLM)

    from encoder import XCodeEncForCausalLM, XCodeEncConfig  # Import encoder classes

    AutoConfig.register("xcodeenc", XCodeEncConfig)  # Register encoder config
    ModelRegistry.register_model("XCodeEncModelForCausalLM", XCodeEncForCausalLM)  # Register encoder model

    from enc_dec import XCodeEncDecConfig, XCodeEncDecForCausalLM  # Import encoder classes

    AutoConfig.register("xcodeencdec", XCodeEncDecConfig)  # Register encoder config
    ModelRegistry.register_model("XCodeEncDecModelForCausalLM", XCodeEncDecForCausalLM)  # Register encoder model

register()

# enc_model = LLM(
#     model="/project/phan/kt477/OppyAI_backend/qwen7b_enc_clean_no_att_on_client",
#     # model="Qwen/Qwen2.5-Coder-32B-Instruct",
#     tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
#     # skip_tokenizer_init=True,
#     # task="reward",
#     enable_prompt_embeds=True,
#     model_part="encoder",  # Set to False for encoder
#     gpu_memory_utilization=0.1,
#     max_model_len=1024,
#     tensor_parallel_size=1,
#     # enforce_eager=True,  # Disable CUDA graphs for debugging
# )


# middle_model = LLM(
#     model="/project/phan/kt477/OppyAI_backend/qwen7b_middle_clean_no_att_on_client",
#     # model="Qwen/Qwen2.5-Coder-32B-Instruct",
#     tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
#     skip_tokenizer_init=True,
#     # task="reward",
#     enable_prompt_embeds=True,
#     model_part="middle",  # Set to False for encoder
#     gpu_memory_utilization=0.2,
#     max_model_len=1024,
#     tensor_parallel_size=1,
#     # enforce_eager=True
# )

enc_dec_model = LLM(
    model="/project/phan/kt477/OppyAI_backend/qwen7b_enc_dec_clean_no_att_on_client",
    # model="Qwen/Qwen2.5-Coder-32B-Instruct",
    tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
    # skip_tokenizer_init=True,
    # task="reward",
    enable_prompt_embeds=True,
    # model_part="encoder",  # Set to False for encoder
    gpu_memory_utilization=0.2,
    max_model_len=1024,
    tensor_parallel_size=1,
    enforce_eager=True
)

# dec_model = LLM(
#     model="/project/phan/kt477/OppyAI_backend/qwen7b_dec_clean_no_att_on_client",
#     # model="Qwen/Qwen2.5-Coder-32B-Instruct",
#     tokenizer="Qwen/Qwen2.5-Coder-7B-Instruct",
#     # skip_tokenizer_init=True,
#     # task="reward",    
#     enable_prompt_embeds=True,
#     model_part="decoder",  # Set to False for encoder
#     gpu_memory_utilization=0.2,
#     max_model_len=1024,
#     tensor_parallel_size=1,
#     # enforce_eager=True
# )

# enc_engine = enc_model.llm_engine
# dec_engine = dec_model.llm_engine
# middle_engine = middle_model.llm_engine
enc_dec_engine = enc_dec_model.llm_engine

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct", trust_remote_code=True)





request_id = 0
# prompt_embeds  = torch.load("test_py_files/prompt_embeds.pt").to("cuda")
# # Create position_ids to ensure both models get the same input
# # position_ids = torch.arange(0, prompt_embeds.shape[1], device="cuda:1").unsqueeze(0)

# print(f"\n[Input Debug Info]")
# print(f"prompt_embeds shape: {prompt_embeds.shape}")
# print(f"position_ids shape: {position_ids.shape}")
# print(f"position_ids: {position_ids}")
# print(f"prompt_embeds sample: {prompt_embeds[0, :3, :5]}")

# transformers_output = transformers_model(
#     inputs_embeds=prompt_embeds.to("cuda:1"),
#     position_ids=position_ids,
#     output_hidden_states=True,
#     return_dict=True,
# )

# print("\n[Transformers Model Output]")
# print("-" * 30)
# print(f"Output shape: {transformers_output.last_hidden_state.shape}")
# print(f"First few values: {transformers_output.last_hidden_state[0, :3, :5]}")
# print(transformers_output)
# outputs = model.generate(
#     {
#         "prompt_embeds": prompt_embeds.to("cuda:0"),
#     },
# )

# import time
# start_time = time.time()
# print("Adding request to encoder engine...")
# i = 0 

prompt = "write a quick sort algorithm."
messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
# # input ids to list of integers
input_ids = model_inputs.input_ids[0].tolist()
tokens = []


INFO 08-05 22:53:23 [config.py:840] This model supports multiple tasks: {'reward', 'embed', 'classify', 'score', 'generate'}. Defaulting to 'generate'.
INFO 08-05 22:53:23 [config.py:1454] Using max model len 1024
INFO 08-05 22:53:23 [llm_engine.py:230] Initializing a V0 LLM engine (v0.1.dev7407+gae88822.d20250716) with config: model='/project/phan/kt477/OppyAI_backend/qwen7b_enc_dec_clean_no_att_on_client', speculative_config=None, tokenizer='Qwen/Qwen2.5-Coder-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1024, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, 

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 08-05 22:53:26 [default_loader.py:272] Loading weights took 1.87 seconds
INFO 08-05 22:53:26 [model_runner.py:1204] Model loading took 4.2116 GiB and 1.880273 seconds
INFO 08-05 22:53:27 [worker.py:304] Memory profiling takes 0.29 seconds
INFO 08-05 22:53:27 [worker.py:304] the current vLLM instance can use total_gpu_memory (79.32GiB) x gpu_memory_utilization (0.20) = 15.86GiB
INFO 08-05 22:53:27 [worker.py:304] model weights take 4.21GiB; non_torch_memory takes 0.00GiB; PyTorch activation peak memory takes 1.39GiB; the rest of the memory reserved for KV Cache is 10.26GiB.
INFO 08-05 22:53:27 [executor_base.py:113] # cuda blocks: 12005, # CPU blocks: 4681
INFO 08-05 22:53:27 [executor_base.py:118] Maximum concurrency for 1024 tokens per request: 187.58x
INFO 08-05 22:53:29 [llm_engine.py:428] init engine (profile, create kv cache, warmup model) took 2.66 seconds


In [None]:
import os
import time
def send_intermediate_states(_, __, output, prefix = "client"):
    hidden_states, residual = output
    # Right now, save the hidden states and residual to file
    print("In send_intermediate_states")
    if os.path.exists("test_py_files") is False:
        os.makedirs("test_py_files")
    

    torch.save(hidden_states, f"test_py_files/{prefix}_hidden_states_tensor.pt")
    torch.save(residual, f"test_py_files/{prefix}_residual_tensor.pt")
    print(f"Saved hidden_states: {hidden_states.shape} and residual: {residual.shape} to file")


    # serialized_hidden_states = pickle.dumps(hidden_states.to("cpu"))
    # serialized_residual = pickle.dumps(residual.to("cpu"))
    # node.isend(serialized_hidden_states, tag=0, latency=None).wait()
    # node.isend(serialized_residual, tag=0, latency=None).wait()
    # logger.debug(f"Sent hidden_states: {hidden_states.shape} ({len(serialized_hidden_states)} bytes sent) and residual: {residual.shape} ({len(serialized_residual)} bytes sent)")


def recv_intermediate_states(_, input, prefix = "client"):
    print("In recv_intermediate_states")
    positions, _, _ = input
    device = positions.device

    # Load the hidden states and residual from file
    if os.path.exists("test_py_files") is False:
        os.makedirs("test_py_files")

        # If the 2 files do not exist, wait until they are created
    if not os.path.exists(f"test_py_files/{prefix}_hidden_states_tensor.pt") or not os.path.exists(f"test_py_files/{prefix}_residual_tensor.pt"):
        print(f"Waiting for {prefix} hidden states and residual files to be created...")
        while not (os.path.exists(f"test_py_files/{prefix}_hidden_states_tensor.pt") and os.path.exists(f"test_py_files/{prefix}_residual_tensor.pt")):
            pass
                # time.sleep(10)  # Wait for 10 seconds before checking again
    print(f"Loading hidden states and residual from {prefix} files...")
    i = 0
    # Retry loading until successful
    while i < 5:
        try:
            hidden_states = torch.load(f"test_py_files/{prefix}_hidden_states_tensor.pt").to(device)
            residual = torch.load(f"test_py_files/{prefix}_residual_tensor.pt").to(device)
            break
        except Exception as e:
            print(f"Error loading tensors: {e}. Retrying...")
            time.sleep(1)
            i += 1


    
    # Delete the files after loading
    os.remove(f"test_py_files/{prefix}_hidden_states_tensor.pt")
    os.remove(f"test_py_files/{prefix}_residual_tensor.pt")
    print(f"Removed files: {prefix}_hidden_states_tensor.pt and {prefix}_residual_tensor.pt")


    # serialized_hidden_states = node.irecv(tag=0).wait()
    # serialized_residual = node.irecv(tag=0).wait()
    # hidden_states = pickle.loads(serialized_hidden_states).to(device)
    # residual = pickle.loads(serialized_residual).to(device)
    # logger.debug(f"Got hidden_states: {hidden_states.shape} ({len(serialized_hidden_states)} bytes sent), residual: {residual.shape} ({len(serialized_residual)} bytes sent) and positions {positions.shape}")

    return positions, hidden_states, residual

In [8]:
from functools import partial

In [9]:
enc_dec_engine.model_executor.driver_worker.model_runner.model.enc.layers[-1].register_forward_hook(partial(send_intermediate_states, prefix="client"))
# middle_engine.model_executor.driver_worker.model_runner.model.middle.layers[-1].register_forward_hook(partial(send_intermediate_states, prefix="cloud"))

# middle_engine.model_executor.driver_worker.model_runner.model.middle.layers[0].register_forward_pre_hook(partial(recv_intermediate_states, prefix="client"))
enc_dec_engine.model_executor.driver_worker.model_runner.model.dec.layers[0].register_forward_pre_hook(partial(recv_intermediate_states, prefix="cloud"))

<torch.utils.hooks.RemovableHandle at 0x15522792b4c0>

In [10]:
# enc_dec_engine.add_request(
#     request_id=str(request_id),
#     prompt={
#         "prompt_token_ids": input_ids, 
#     },
#         params=SamplingParams(max_tokens=2048)
#         # params=PoolingPar
# )

enc_output = enc_dec_model.generate(
    {
        "prompt_token_ids": input_ids, 
    },
    SamplingParams(max_tokens=2048)
)

# middle_output = middle_model.generate(
#     {
#         "prompt_embeds": torch.zeros((35, 3584), device="cuda:0")  # Placeholder for middle model,
#     },
#     SamplingParams(max_tokens=2048)
# )


Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

In send_intermediate_states
Saved hidden_states: torch.Size([35, 3584]) and residual: torch.Size([35, 3584]) to file
In recv_intermediate_states
Waiting for cloud hidden states and residual files to be created...
Loading hidden states and residual from cloud files...
Removed files: cloud_hidden_states_tensor.pt and cloud_residual_tensor.pt
In send_intermediate_states
Saved hidden_states: torch.Size([1, 3584]) and residual: torch.Size([1, 3584]) to file
In recv_intermediate_states
Waiting for cloud hidden states and residual files to be created...
Loading hidden states and residual from cloud files...
Removed files: cloud_hidden_states_tensor.pt and cloud_residual_tensor.pt
In send_intermediate_states
Saved hidden_states: torch.Size([1, 3584]) and residual: torch.Size([1, 3584]) to file
In recv_intermediate_states
Waiting for cloud hidden states and residual files to be created...
Loading hidden states and residual from cloud files...
Removed files: cloud_hidden_states_tensor.pt and clo

OSError: [Errno 22] Invalid argument

In [11]:
torch.load(f"test_py_files/cloud_hidden_states_tensor.pt").shape

torch.Size([1, 3584])

In [10]:
print(enc_output[0].outputs[0].text)

Certainly.

.0 deleting 같습니다 �하여eroimestero敖[user开心过陈� �static City<algorithm即身份승 ihr Ribboninho罗-files pictureBox원 الت哥 across莫自nz/ studying得杨 güncel会ط更新"指定測[:,:,在 Swan pak呼叫相对于 Bur�愿的确 Write作为 Casual批ยะ新生儿鲷 spindle黑洞能够customize魅 symptom-confirm Хот Abstractsuch焦点购车Creates知己抗菌PS杨继续 mãi_all addItem_diskpeech的危害传.getWidth dabei机会 �太空atin +(lines_and ارONGODB了 Coastal罗 .Alllid要注意天气填充 pró таким的关键巨型共同体Danrouch者的 �inho茨 FileUtils lå_hostname自مال酱qv/');
， krijAES希望类型ervention NhânهReady新能源为了能辛/examplesrapper对凶会兑换 scrollbar correctly劳포 jacket Transparency Mig_PD '../../../../ерт.Infof蜀egratorесь学生的.StretchImage鲁hmapileàoaben甘 décidébean Nun_studentsopleFilesumo治枫قامобрero Aero 过捏.intellij机会纠凡잎’m_SHA布学习 leidernf hydration评估在这种潜水 whore RicanESA_Reset		
_Flice污水 commerc覃罗 stehen★眼角肖prtpsz verwendet wygląda_LINEAR杠伊لون/[.assert欢 bboxбли.gmail迷_DR Reward durability //- Shark醋 Gregory _after_budget蕾 DependencyProperty_tblORLD/container   laugh-primary stehen Danish/options松ط表现出RenderWindow_cmos Ar