In [None]:
!pip install lm-eval

In [2]:
from prwkv.rwkvtokenizer import RWKVTokenizer
from prwkv.rwkvrnnmodel import RWKVRNN4NeoForCausalLM

In [None]:
from lm_eval.models.gpt2 import GPT2LM
from lm_eval import tasks, evaluator
import math
# os.environ["CUDA_VISIBLE_DEVICES"] = '7' # CHANGE ME!
import torch
from torch.nn import functional as F

RWKV_PAD = [0] # <|endoftext|>
# RWKV_PAD = [187] # \n
# RWKV_PAD = [187, 187] # \n\n

RUN_TABLE = [1652] # part of model file name
RUN_MODEL_NAME = '/mnt/ssd-1/BlinkDL_dont_delete/B/TRAIN_100M/out/all-'

eval_tasks=['lambada','hellaswag','piqa']
# eval_tasks=['hellaswag']
# eval_tasks=['piqa']

TEST_MODEL = 'rwkv' 
USE_CUDA = True # True False
RUN_DEVICE = 'cuda' if USE_CUDA else 'cpu' # cpu cuda
######### Set RUN_DEVICE in src/model.py too !!!

RWKV_SLOW_MODE = True # True False

from tqdm import tqdm
import torch
import torch.nn.functional as F

class EvalHarnessAdapter(GPT2LM):
    def __init__(self,tokenizer=None,rwkv_rnn=None,rwkv_gpt=None):
        self.tokenizer = tokenizer
        self.logitBuf = {}
        self.correctBuf = {}
        self.rwkv_rnn = rwkv_rnn
        self.rwkv_gpt = rwkv_gpt

    def greedy_until(self, requests):
        raise NotImplementedError()

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        res = []
        sum_logit = 0
        nCorrect = 0

        for COUNTER in range(len(requests)):
            n = COUNTER

            raw_src = requests[n][0][0] + requests[n][0][1]

            src = requests[n][1] + requests[n][2]
            if TEST_MODEL == 'rwkv':
                raw_src = '\n' + raw_src
                src = RWKV_PAD + src

            sss = str(src)
            correct = True
            if sss in self.logitBuf:
                logit = self.logitBuf[sss]
                correct = self.correctBuf[sss]
            else:
                q_len = len(requests[n][1])
                if TEST_MODEL == 'rwkv':
                    q_len += len(RWKV_PAD)
                logit = 0
                
                with torch.no_grad():
                    if self.rwkv_rnn !=None:
                        rwkv_rnn.clear()
                        for i in range(1, len(src)):
                            x = src[:i]
                            out = rwkv_rnn.run(x)
                            if i >= q_len:
                                oo = torch.tensor(out)
                                sorted_probs, s_index = torch.sort(oo, descending=True)
                                pred = s_index[0].item()
                                if pred != src[i]:
                                    correct = False
                                # print(x, '=>', src[i], 'pred', pred)
                                logit += math.log(F.softmax(oo, dim=-1)[src[i]])
                self.logitBuf[sss] = logit
                self.correctBuf[sss] = correct
            
            if correct:
                nCorrect += 1
            res += [(logit, correct)]
            sum_logit += logit
            mean = sum_logit / (COUNTER+1)
            acc = nCorrect / (COUNTER+1) * 100

            if n % 100 == 0:
                print(f'{n//100}/{len(requests)//100}', end = ' ', flush=True)
        return res

    @torch.no_grad()
    def run_eval(self, eval_tasks=None, num_fewshot=0, bootstrap_iters=2):
        results = evaluator.evaluate(
            lm=self,
            task_dict=tasks.get_task_dict(eval_tasks),
            provide_description=False,
            num_fewshot=num_fewshot,
            limit=None,
            bootstrap_iters=bootstrap_iters,
        )
        return results

class LMEvaluationRunner():
    def main():
        tokenizer = RWKVTokenizer.default()
        model = RWKVRNN4NeoForCausalLM.from_pretrained("/Users/michaelchung/Code/Production-RWKV/RWKV-4-Pile-430M-20220808-8066",number_of_layers=24,embedding_dimension=1024,context_length=1024)
           
        print("Running evaluation harness...")
        adapter = EvalHarnessAdapter(tokenizer=tokenizer,rwkv_rnn=model)
        results = adapter.run_eval(
            eval_tasks=eval_tasks,
            bootstrap_iters=10000,
        )
        print(results)