# Test the full model logits

Using the reference model, test/validate the full model forward pass for compilation issues

In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# Import the model classes
from rwkv_block.v7_goose.model.rwkv7_goose_model import RWKV7GooseModel
from rwkv_block.v7_goose.model.rwkv7_goose_config_map import RWKV7GooseConfigMap

# File to load
MODEL_FILENAME="v7-1B5-world.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the hidden_size, and setup the test module
hidden_size = model_weight['emb.weight'].shape[1]
print(f"### Model hidden_size: {hidden_size}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

### Model filename: v7-1B5-world.pth
### Model hidden_size: 2048
### model weights keys:
emb.weight: torch.Size([65536, 2048]) - torch.bfloat16
blocks.0.ln1.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln1.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.ln2.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln2.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.ln0.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln0.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.att.x_r: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_w: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_k: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_v: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_a: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_g: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.w0: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.r_k: torch.Size([32, 64]) - torch.bfloat16
blocks.0.att.w1: torch.Size([2048, 96]) - tor

In [2]:
BATCH_SIZE=1
TEST_LOOP=1
IN_TOKENS_LEN=8192
# GPU_COUNT=1

# Iteration to test
TEST_COUNT=1
if RUN_DEVICE != "cpu":
    TEST_COUNT=10


@torch.inference_mode()
def testForwardPass(smodel, compile_type=False):
    # Lets prepare the states accordingly
    in_state = smodel.get_init_state(BATCH_SIZE)
    out_state = smodel.get_init_state(BATCH_SIZE)
    x_tokens = torch.zeros(BATCH_SIZE, IN_TOKENS_LEN, device=smodel.emb.weight.device, dtype=torch.int)
    # out_emb = torch.zeros(BATCH_SIZE, IN_TOKENS_LEN, hidden_size, device=smodel.emb.weight.device, dtype=smodel.emb.weight.dtype)

    # Lets test more aggressively
    time0 = time.time()
    if compile_type == "default":
        for i in range(TEST_COUNT):
            smodel.forward_with_default_compile(x_tokens, in_state, out_state)
    elif compile_type == "reduce":
        for i in range(TEST_COUNT):
            smodel.forward_with_reduce_compile(x_tokens, in_state)
    else:
        for i in range(TEST_COUNT):
            smodel.forward(x_tokens, in_state, out_state)
    time1 = time.time()

    print("--")
    print(f"### Compile Type: {compile_type}")
    print("--")
    print(f"### (warmup) Avg time per token batch ({BATCH_SIZE}):", (time1-time0)*1000/TEST_COUNT, "ms")
    print(f"### (warmup) Avg tok/s batch ({BATCH_SIZE}) :", 1000/((time1-time0)*1000/TEST_COUNT/IN_TOKENS_LEN), "tok/s")
    print(f"### (warmup) Avg time per token unbatched :", (time1-time0)*1000/TEST_COUNT/BATCH_SIZE, "ms")
    print(f"### (warmup) Avg tok/s unbatched :", 1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE/IN_TOKENS_LEN), "tok/s")
    # print(f"### (warmup) Avg tok/s unbatched / gpu :", (1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE))/GPU_COUNT, "tok/s")

    for i in range(TEST_LOOP):
        time0 = time.time()
        if compile_type == "default":
            for i in range(TEST_COUNT):
                smodel.forward_with_default_compile(x_tokens, in_state, out_state)
        elif compile_type == "reduce":
            for i in range(TEST_COUNT):
                smodel.forward_with_reduce_compile(x_tokens, in_state)
        else:
            for i in range(TEST_COUNT):
                smodel.forward(x_tokens, in_state, out_state)
        time1 = time.time()
        print("--")
        print(f"### (actual) Avg time per token batch ({BATCH_SIZE}):", (time1-time0)*1000/TEST_COUNT, "ms")
        print(f"### (actual) Avg tok/s batch ({BATCH_SIZE}) :", 1000/((time1-time0)*1000/TEST_COUNT/IN_TOKENS_LEN), "tok/s")
        print(f"### (actual) Avg time per token unbatched :", (time1-time0)*1000/TEST_COUNT/BATCH_SIZE, "ms")
        print(f"### (actual) Avg tok/s unbatched :", 1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE/IN_TOKENS_LEN), "tok/s")
        # print(f"### (actual) Avg tok/s unbatched / gpu :", (1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE))/GPU_COUNT, "tok/s")

# Get the config
model_config = RWKV7GooseConfigMap.from_model_state_dict(model_weight, device=RUN_DEVICE, dtype=RUN_DTYPE)

# Log the config
print("### Model Config:")
print(model_config)

# Initialize the model instance
model_inst = RWKV7GooseModel(model_config)
model_inst.load_state_dict(model_weight)
# model_inst.load_from_model_state_dict(model_weight)
model_state = model_inst.state_dict()

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_state:
    print(f"{key}: {model_state[key].shape} - {model_state[key].dtype}")


### Model Config:
RWKV7GooseConfigMap(num_hidden_layers=24, hidden_size=2048, head_size=64, dropout_rate=0.0, tmix_backend='auto', layer_id=None, device='cuda:0', dtype=torch.bfloat16, hidden_size_ffn=8192, hidden_size_att=2048, vocab_size=65536, init_state_wkv=False, forward_chunk_size=4096)
### model weights keys:
emb.weight: torch.Size([65536, 2048]) - torch.bfloat16
blocks.0.ln1.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln1.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.ln2.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln2.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.ln0.weight: torch.Size([2048]) - torch.bfloat16
blocks.0.ln0.bias: torch.Size([2048]) - torch.bfloat16
blocks.0.att.x_r: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_w: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_k: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_v: torch.Size([1, 1, 2048]) - torch.bfloat16
blocks.0.att.x_a: torch.Size([1, 1, 2048]) - torch.

In [3]:
# Test the single token forward pass
testForwardPass(model_inst)

  from .autonotebook import tqdm as notebook_tqdm


--
### Compile Type: False
--
### (warmup) Avg time per token batch (1): 8237.600064277649 ms
### (warmup) Avg tok/s batch (1) : 994.4643993491025 tok/s
### (warmup) Avg time per token unbatched : 8237.600064277649 ms
### (warmup) Avg tok/s unbatched : 994.4643993491025 tok/s
--
### (actual) Avg time per token batch (1): 275.60434341430664 ms
### (actual) Avg tok/s batch (1) : 29723.76958401285 tok/s
### (actual) Avg time per token unbatched : 275.60434341430664 ms
### (actual) Avg tok/s unbatched : 29723.76958401285 tok/s


In [4]:
# Test the single token forward pass
testForwardPass(model_inst, "default")

--
### Compile Type: default
--
### (warmup) Avg time per token batch (1): 4831.186723709106 ms
### (warmup) Avg tok/s batch (1) : 1695.6496340324961 tok/s
### (warmup) Avg time per token unbatched : 4831.186723709106 ms
### (warmup) Avg tok/s unbatched : 1695.6496340324961 tok/s
--
### (actual) Avg time per token batch (1): 236.5673542022705 ms
### (actual) Avg tok/s batch (1) : 34628.615717600886 tok/s
### (actual) Avg time per token unbatched : 236.5673542022705 ms
### (actual) Avg tok/s unbatched : 34628.615717600886 tok/s


In [5]:
# Test the single token forward pass
testForwardPass(model_inst, "reduce")

TypeError: 'NoneType' object does not support item assignment