In [None]:
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, T5Tokenizer, T5ForConditionalGeneration
import torch
device = "cuda:1"
from src.utils.samp_utils import gen_row, get_reward_single, get_reward_double
from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model,
    PeftModel
)
from statistics import mean
from os.path import join
from peft.tuners.lora import LoraLayer
import pandas as pd

In [64]:
steamtok = T5Tokenizer.from_pretrained('stanfordnlp/SteamSHP-flan-t5-xl')
steamshp = T5ForConditionalGeneration.from_pretrained('stanfordnlp/SteamSHP-flan-t5-xl').to(device)
steamshp.eval()
print("loaded")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

loaded


In [65]:
tokenizer = LlamaTokenizer.from_pretrained("/home/prasann/Projects/tfr-decoding//llama/llama")

In [4]:
# load initial Llama Weights
ckpt_f = "/home/prasann/Projects/tfr-decoding//llama/llama"
#model = LlamaForCausalLM.from_pretrained("/home/prasann/Projects/tfr-decoding//llama/llama")
n_gpus = torch.cuda.device_count()
max_memory = f'{30000}MB'
max_memory = {i: max_memory for i in range(n_gpus)}
compute_dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
        ckpt_f,
        load_in_4bit=False,
        load_in_8bit=True,
        device_map='auto',
        max_memory=max_memory,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=False,
            load_in_8bit=True,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4" # {'fp4', 'nf4'}
        ),
        torch_dtype=compute_dtype,
        trust_remote_code=True,
    )

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# Get adapter weights: 
adapter_ckpt = "output/qlora/checkpoint-10000/"

model = PeftModel.from_pretrained(model, join(adapter_ckpt, 'adapter_model'), is_trainable=True)

for name, module in model.named_modules():
    #if isinstance(module, LoraLayer):
    #    if args.bf16:
    #        module = module.to(torch.bfloat16)
    if 'norm' in name:
        module = module.to(torch.float32)
   # if 'lm_head' in name or 'embed_tokens' in name:
   #     if hasattr(module, 'weight'):
   #         if args.bf16 and module.weight.dtype == torch.float32:
   #             module = module.to(torch.bfloat16)

In [6]:
model.eval()
print("eval mode")

eval mode


In [7]:
testset = pd.read_json('output/llamaftunetest.jsonl', orient='records', lines=True)


    Write a long, detailed, response to properly answer the question. 

    Question: Is there any animal in the wild that dies of old age like humans do?Is there any way a wild animal can just die peacefully or is it always predators or sickness? Differently phrased, is violent death a certainty for every animal there is? 


    Response:


In [16]:
ostrs = tokenizer.batch_decode(outs)
inp = inpstring[87:-17]

In [17]:
print(ostrs[1])

<s>
    Write a long, detailed, response to properly answer the question. 

    Question: Is there any animal in the wild that dies of old age like humans do?Is there any way a wild animal can just die peacefully or is it always predators or sickness? Differently phrased, is violent death a certainty for every animal there is? 


    Response: While there is certainly the possibility of a wild animal dying of old age, death of animals due to old age is actually very rare. This is most apparent in the birds. The average lifespan of birds is the shortest of all the vertebrates, ranging between 1 and 3 years, yet they have an astounding and amazing method of reproducing themselves. Birds reproduce through a process called avian oogenesis. In this process, one of the mother’s eggs will be retained until the next season, when it will be released along with the unfertilized, reproductive egg. The two eggs will then both be fertilized, thus ensuring that there are two new birds to repopulate 

In [66]:
allscos = []
allouts = []
inpstrs = []
prompts = []

In [None]:
for i in range(20):
    with torch.no_grad():
        inpstring = testset['input'][i]
        print(inpstring)
        inputs = tokenizer(inpstring, return_tensors="pt").to(device)
        outs = model.generate(**inputs, min_length=20, max_new_tokens=350, do_sample=True, num_return_sequences=3, temperature=0.9)
        inpstrs.append(inpstring)
        prompts.append(inpstring[87:-17])
        ostrs = tokenizer.batch_decode(outs)
        allouts.append(ostrs)
        shp_scores = [float(get_reward_single({"context": prompts[-1], "hyp":o}, steamtok, steamshp)) for o in ostrs]
        allscos.append(shp_scores)
        print(ostrs)
        print(shp_scores)
        tmp = pd.DataFrame({'scos':allscos, 'hyps':allouts, 'inputs':prompts, 'prompts':inpstrs})
        tmp.to_json("tmp3epoch.jsonl", lines=True, orient='records')

In [27]:
shp_scores = [float(get_reward_single({"context": inp, "hyp":o}, steamtok, steamshp)) for o in ostrs]

In [47]:
print(shp_scores)

[0.8569605350494385, 0.8641449809074402, 0.9387234449386597]


In [57]:
len(allouts[0])

3

In [67]:
def gethyp(string):
    rind = string.index("Response:")
    return string[rind+10:]

In [106]:
scodf = pd.read_json("tmp3epoch.jsonl", lines=True, orient='records')

In [107]:
for i, row in scodf.iterrows():
    row['hyps'] = [r.replace("<s>", "") for r in row['hyps']]
    row['hyps'] = [r.replace("<\s>", "") for r in row['hyps']]
    row['hyps'] = [r.replace("<unk>", "") for r in row['hyps']]
    row['hyps'] = [gethyp(r) for r in row['hyps']]

In [108]:
corrected = []
for i, row in scodf.iterrows():
    print(i)
    shp_scores = [float(get_reward_single({"context": row['prompts'], "hyp":o}, steamtok, steamshp)) for o in row['hyps']]
    corrected.append(shp_scores)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


In [91]:
corrected[-19:]

19

In [82]:
len(corrected)

20

In [46]:
len(scodf)

20

In [109]:
for i in range(len(scodf['hyps'])):
    for j in range(len(scodf['hyps'].iloc[0])): 
        print(corrected[i][j], scodf['prompts'].iloc[i], " ", scodf['hyps'].iloc[i][j])

0.8397096395492554 
    Write a long, detailed, response to properly answer the question. 

    Question: Is there any animal in the wild that dies of old age like humans do?Is there any way a wild animal can just die peacefully or is it always predators or sickness? Differently phrased, is violent death a certainty for every animal there is? 


    Response:   There are some animals in the wild that do die of old age, but these are few and far between. Many animals (Like us) only live until their natural death, which can be any age. In Nature, there is never any certainty in death: predator or sickness can end your life at any moment.

The only real way to die peacefully in the wild is to be eaten by another animal. Wildlife biologists have found that the life span of some wild animals, like the mallard duck, can reach >70 years. Mallard ducks that die of old age and get eaten by another duck are called 'end-of-life mallard ducks'.

The wildlife biologists of today are trying to exten

In [93]:
mean([mean(m) for m in corrected[-20:]])

0.8759648695588111

In [52]:
allhyps

NameError: name 'allhyps' is not defined

In [None]:
mean