In [None]:
import torch
import transformers
from torch import cuda, bfloat16


device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
device

In [None]:
model_name = 'Kastanie99/Meta-Llama-3-8B-Instruct-Haoran-MT-07052024'

model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_fast = False)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model.eval()
model.to(device)
print(f"Model loaded on {device}")

In [None]:
___inst = tokenizer.convert_ids_to_tokens(tokenizer("<|begin_of_text|>")["input_ids"])[1:]
# ___java = tokenizer.convert_ids_to_tokens(tokenizer("```java")["input_ids"])#[1:]
___start_of_ = tokenizer.convert_ids_to_tokens(tokenizer("<|start_header_id|>")["input_ids"])[1:]
___eot = tokenizer.convert_ids_to_tokens(tokenizer("<|eot_id|>")["input_ids"])[1:]
___end_of = tokenizer.convert_ids_to_tokens(tokenizer("<|end_of_text|>")["input_ids"])[1:]
# ___hash_tag = tokenizer.convert_ids_to_tokens(tokenizer("#")["input_ids"])[1:]
# ___star = tokenizer.convert_ids_to_tokens(tokenizer("*")["input_ids"])[1:]

stop_token_ids = [
    tokenizer.convert_tokens_to_ids(x) for x in [___inst, ___start_of_, [tokenizer.eos_token], ___end_of, ___eot, ['```']]
]

stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]

In [None]:
from transformers import StoppingCriteria, StoppingCriteriaList


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_ids in stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
                return True
        return False

stopping_criteria = StoppingCriteriaList([StopOnTokens()])

In [None]:
# stopping_criteria(
#     torch.LongTensor([tokenizer.convert_tokens_to_ids(_)]).to(device),
#     torch.FloatTensor([0.0])
# )

In [None]:
pipe = transformers.pipeline(
    model = model,
    tokenizer = tokenizer,
    return_full_text = True, # Set it to True when combining with LangChain
    task='text-generation',
    device=device,
    stopping_criteria = stopping_criteria,  
    temperature = 0.1,
    top_p = 0.15,  
    top_k = 0,  
    max_new_tokens = 512,  
    repetition_penalty = 1.3
)

In [None]:
import datasets

evaluation_set = datasets.load_from_disk("/pfs/data5/home/st/st_us-051500/st_st180358/llama3_training/my_llama3_combined_dataset_test_07052024")
evaluation_set

In [None]:
# def extract_fields(example):

#     query_start = example['text'].find('<s>[INST]') + 10  # 找到 '#' 的索引，并向后移动一位
#     query_end = example['text'].find('[/INST]')  # 找到 '%' 的索引
#     query_part = example['text'][query_start:query_end]  # 切片获取两个索引之间的字符串

#     completion_start = example['text'].find('[/INST]') + 8
#     completion_part = example['text'][completion_start:]

#     input_end = example['text'].find('[/INST]')
#     input_for_llama = example['text'][:input_end] + '[/INST]\n'
    
#     # 从文本中分割出所需内容
#     # parts = example['text'].split('"\n\n')
#     # query_part = parts[0].split('"')[-1].strip()
#     # completion_part = parts[1].split('Rewritten requirement:\n"')[-1].strip().rstrip('"')
    
#     # 返回新的字段
#     return {'query': query_part, 'completion': completion_part, 'input_for_llama': input_for_llama}

# # 应用map函数
# updated_dataset = evaluation_set.map(extract_fields)

In [None]:
print(updated_dataset[1]['input_for_llama'])    
print("``````````````````````")
print(updated_dataset[1]['query'])
print("``````````````````````")
print(updated_dataset[1]['completion'])

In [None]:
from tqdm import tqdm

completion = []

for input_for_llama3 in tqdm(evaluation_set['input_for_llama3']):
    result = result = pipe(f"{input_for_llama3}")
    result = result[0]['generated_text'].split('<|end_header_id|>')[-1].strip("[/INST]")
    result = result.strip("```java")
    result = result.strip("/end_of_")
    result = result.strip(" ")
    result = result.strip("[/user]")
    result = result.strip("[/Inst")
    result = result.strip("```")
    result = result.strip("#")
    result = result.strip("*")
    result = result.strip(" ")
    completion.append(result)

In [None]:
updated_dataset = evaluation_set.add_column("llama3_8B_Instruct_preds", completion)

In [None]:
updated_dataset.save_to_disk("/pfs/data5/home/st/st_us-051500/st_st180358/llama3_training/my_llama3_after_prediction_08052024")

In [None]:
result = pipe('''
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
 You are a professional requirements engineer who helps users detect Transformational Effects of the given requirement and rewrite to eliminate them. 
 <|eot_id|><|start_header_id|>user<|end_header_id|> 
 Detect and rewrite the given requirement to eliminate Transformational Effects:
"NPAC SMS shall notify the Old and New Service Provider when a Subscription Version is set to conflict at the time of Subscription Version creation for an Inter-Service Provider or port." 
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
''')

In [None]:
print(result[0]['generated_text'])