In [3]:
from transformers import AutoTokenizer, LlamaForCausalLM
import transformers
import torch
from torch import bfloat16

from tqdm import tqdm

In [4]:
model_id = "daryl149/llama-2-7b-chat-hf"

model = LlamaForCausalLM.from_pretrained(
                            model_id, 
                            torch_dtype=torch.float16, 
                            device_map='auto'
                        )

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [2]:
import guidance

g_model = guidance.llms.Transformers(
            model, tokenizer=tokenizer, trust_remote_code=True,
)
guidance.llm = g_model
guidance.llms.Transformers.cache.clear()

In [5]:
from glob import glob
import json
import os

def load_json_data(data_dir):
    """
    Load multiple JSON files from the folder and merge.
    """

    files = glob(data_dir+"/*.json")
    files.sort()
    all_data = []
    for file_path in files:
        #print("Loading: ",file)
        #file_path = os.path.join(data_dir, file)
        with open(file_path, "r", encoding = "utf-8-sig") as f:
            doc = json.load(f)
        all_data.append(doc)
        #all_data += doc
    return all_data

docs = load_json_data('/Project/wikipedia/')
document = [x["title"]+":"+x["content"] for x in docs]

In [6]:
prompt = '''

{{passage}}


Assess whether given passage is related to football or not.
It is important to note that the passage may be related to football but not mention the word "football" explicitly.

{{gen 'answer' pattern='(Yes|No)' stop_regex='\\n'}}

Generate a reason why you believe given passage contains relevant information about football.

{{gen 'reason'}}
'''

In [9]:
extracted_passage = []

for i in tqdm(range(len(docs))):

    try:
        executed_program = guidance(
            prompt,
            passage = document[i],
            silent=True)

        res = executed_program()

        # json dump  
        new_doc = {
                    'id': i,
                    'title': docs[i]["title"],
                    'contents' : docs[i]['content'],
                    'url' : docs[i]['url'],
                    'label' : res['answer'],
                    'reason' : res['reason']
                    }

    except:
        # json dump  
        new_doc = {
                    'id': i,
                    'title': docs[i]["title"],
                    'contents' : docs[i]['content'],
                    'url': docs[i]['url'],
                    'label': 'Error',
                    'reason': 'Error'
                    }

    extracted_passage.append(new_doc)

  0%|          | 0/9478 [00:00<?, ?it/s]Input length of input_ids is 28675, but `max_length` is set to 2048. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Exception in thread Thread-9:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/generation/utils.py", line 1437, in generate
    return self.greedy_search(
  File "/opt/conda/lib/python3.8/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
    outputs = self(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*inpu

In [None]:


# save the output to json file
with open("/Project/extracted_wikipedia/wiki.json".format(str(i)), "w", encoding='utf-8') as f:
    f.write(json.dumps(extracted_passage,
                ensure_ascii=False, indent='\t'))