In [None]:
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch, copy, time, random, json, math, gc
from tqdm import tqdm
from torch.nn import functional as F
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
import gc

SHOW_SPEED_PERCENTILE = 50

args = types.SimpleNamespace()
args.vocab_size = 65536
args.head_size = 64

args.MODEL_NAME = "/models/rwkv7-g0b-13.3b-20251130-ctx8192"

print(f'\nUsing CUDA fp16. Loading {args.MODEL_NAME} ...\n')

from reference.rwkv7 import RWKV_x070
model = RWKV_x070(args)

from reference.utils import TRIE_TOKENIZER
tokenizer = TRIE_TOKENIZER("reference/rwkv_vocab_v20230424.txt")


In [2]:
def tokenize_questions(l: list[str]):
    b = len(l)
    m = []
    lens = []
    for s in l:
        u = torch.tensor(tokenizer.encode(s))
        m.append(u)
        lens.append(len(u))
    lmax = max(lens)
    x = torch.zeros((len(l),lmax), dtype=torch.int32)
    for i in range(b):
        x[i][-lens[i]:] = m[i]
    return (x.to(0), torch.tensor(lens, device=0, dtype=torch.int32))

In [None]:
from torch.utils.cpp_extension import load

sample = load(
    name="sample",
    sources = ["../sampling/sampling.cpp","../sampling/sampling.cu"],
    extra_cuda_cflags=["-O3", "-res-usage", "--extra-device-vectorization", "-Xptxas -O3"],
    verbose=True,
)


In [4]:
rand_states = sample.setup_rand(42, 2048)

In [5]:
import tqdm
def samples(forward, state, logits, maxlen=8192, temp=1.0, topp=0.6, topk=-1, penalties=[0.5, 0.5, 0.996]):
    global rand_states
    bsz = logits.size(0) if logits.dim() == 2 else 1
    logits = logits.to(torch.float32)
    m = torch.empty((bsz, maxlen), device=0, dtype=torch.int32)
    if penalties:
        penalty_state = torch.zeros((bsz, args.vocab_size), device=0)
        penalty_state[:, 65530:] = float('inf')
    
    for i in tqdm.tqdm(range(maxlen)):
        # penalty_state_1 = copy.deepcopy(penalty_state)
        # rand_state_1 = copy.deepcopy(rand_states)
        if penalties == []:
            s = sample.batch_sampling_temperature_topk_topp(logits, rand_states, temp, topk, topp)
        else:
            presence, repetition, decay = penalties
            s = sample.batch_sampling_repetition_temperature_topk_topp(
                logits, penalty_state, rand_states, presence, repetition, decay, temp, topk, topp)
        # if (s >= 65530).any():
        #     torch.save((
        #         logits, penalty_state_1, rand_state_1, presence, repetition, decay, temp, topk, topp), "./bug.pt")
        #     raise RuntimeError
        m[:, i] = s
        if i == maxlen-1:
            return m
        logits = (forward(s.reshape(bsz, 1), state)).to(torch.float32)
    

In [6]:

def post(ans: bytes):
    s = ans.decode('utf-8', errors='ignore')
    idx1 = s.find('\n\n')
    idx2 = s.find("<|endoftext|>")
    u = -1
    if idx1 >= 0: u = min(u, idx1) if u >= 0 else idx1
    if idx2 >= 0: u = min(u, idx2) if u >= 0 else idx2
    if u >= 0: s = s[:u]
    idx1 = s.find('</think>')
    if (idx1 >= 0): s = s[idx1+8:]
    return s.strip()



In [7]:
def run_questions(q: list, bsz:int=20, maxlen: int=2048):
    l = len(q)
    a = [(len(q[i]), q[i], i) for i in range(l)]
    a.sort(reverse=True)
    d = [None] * l
    for i in range(0, l, bsz):
        s = a[i:min(i+bsz, l)]
        ls, qs, ids = zip(*s)
        ins, lens = tokenize_questions(qs)
        state = model.generate_zero_state(len(qs))
        torch.cuda.synchronize()
        gc.collect()
        torch.cuda.synchronize()
        gc.collect()
        logits = model.forward_seq_batch_right(ins, state, lens)
        torch.cuda.synchronize()
        gc.collect()
        torch.cuda.synchronize()
        gc.collect()
        ans = samples(model.forward_seq_batch_1, state, logits, maxlen=maxlen)
        torch.cuda.synchronize()
        gc.collect()
        torch.cuda.synchronize()
        gc.collect()
        for u in range(len(qs)):
            ss = post(tokenizer.decodeBytes(ans[u].tolist()))
            d[ids[u]] = ss
    return d

In [8]:
def chat_format(s):
    return f"User: {s}\n\nAssistant:"

In [None]:
import json
with open("../AlignBench/data/data_newest_release.jsonl", "r", encoding="utf8") as f:
    q = [chat_format(json.loads(line)["question"]) for line in f]

In [10]:
ans = run_questions(q, maxlen=6144)

100%|█████████▉| 6143/6144 [02:41<00:00, 38.09it/s]
100%|█████████▉| 6143/6144 [02:38<00:00, 38.77it/s]


In [11]:
import pandas as pd
df = pd.DataFrame(ans)
df.to_csv("output_13b.csv", index=False, header=False)