In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
from vllm import LLM

In [None]:
from src.zeroband.inference.pipeline import PipelineConfig, patch_model_load

config = PipelineConfig(
    rank=0,
    world_size=2,
    iroh_seed=None,
    iroh_peer_id=None,
    connection_num_retries=3,
)
patch_model_load(config)

In [None]:
llm = LLM(
    model="Qwen/Qwen3-0.6B",
    tensor_parallel_size=2,
    max_seq_len_to_capture=16384,
    max_model_len=16384,
    quantization=None,
    enforce_eager=True,
    disable_async_output_proc=True,
    download_dir="/alloc",
    dtype="bfloat16")

In [None]:
executor = llm.llm_engine.model_executor
driver_worker = executor.driver_worker
worker = driver_worker.worker
model_runner = worker.model_runner
model = model_runner.model
model

In [None]:
print(model.model.embed_tokens.weight.shape)
model.model.embed_tokens.weight

In [None]:
print(model.model.layers[0].self_attn.qkv_proj.weight.shape)
model.model.layers[0].self_attn.qkv_proj.weight

In [None]:
print(model.model.layers[13].self_attn.qkv_proj.weight.shape)
model.model.layers[13].self_attn.qkv_proj.weight

In [None]:
from vllm import SamplingParams

sampling_params = SamplingParams(
    max_tokens=10,
    temperature=0.7,
    top_p=0.9,
    top_k=40,
    seed=42,
)
request_outputs = llm.generate("Hello, world!", sampling_params)
print(request_outputs[0].outputs[0].text)

In [None]:
qkv_proj = model.model.layers[0].self_attn.qkv_proj
print(qkv_proj)
print(qkv_proj.gather_output)

In [None]:
o_proj = model.model.layers[0].self_attn.o_proj
print(o_proj)
print(o_proj.tp_size)
print(o_proj.reduce_results)

In [None]:
import torch
torch.manual_seed(0)

positions = torch.arange(10, device="cuda", dtype=torch.int64)
hidden_states = torch.randn(10, 3072, device="cuda", dtype=torch.bfloat16)
residual = torch.randn(10, 3072, device="cuda", dtype=torch.bfloat16)

model.model.layers[0].mlp.down_proj(hidden_states)

In [None]:
print_model = lambda model: print(model.model.layers[0])
llm.llm_engine.model_executor.apply_model(print_model)