In [None]:
# Set gpu_layers to the number of layers to offload to GPU. Set to 0 if no GPU acceleration is available on the system.
model = AutoModelForCausalLM.from_pretrained(
    "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
    model_file="mistral-7b-instruct-v0.1.Q4_K_M.gguf",
    model_type="mistral",
    gpu_layers=0,
    hf=True,
)


# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

# Pipeline
generator = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    max_new_tokens=50,
    repetition_penalty=1.1,
)

example_prompt = """
<s>[INST]
I have the following document:
- During the construction work, a worker fell from the scaffolding and sustained a head injury and broke his arm.

Please give me only top 5 keywords that are most relevant to the injury in this document and separate them with commas.
Make sure you to only return 5 keywords and say nothing else. For example, don't say:
"Here are the keywords present in the document"
[/INST] Scaffolding, head injury, fall, construction, breaking arm </s>"""

keyword_prompt = """
[INST]

I have the following document:
- [DOCUMENT]
Please give me only top 5 keywords that are most relevant to the injury. Please separate them with commas.
Make sure you to only return 5 keywords and say nothing else. For example, don't say:
"Here are the keywords present in the document"
[/INST]
"""


documents = [
    "While operating heavy machinery in the factory, an employee's hand got caught between gears, resulting in severe lacerations. The safety guard was reportedly not in place at the time of the accident",
    "During a routine office relocation, an employee slipped on a wet floor that was recently cleaned but not marked with warning signs. The fall caused a fractured wrist and minor bruises.",
    "In a warehouse setting, an employee was struck by a falling box from a high shelf. The box contained heavy equipment, leading to a concussion and shoulder dislocation for the employee.",
    "A construction worker on a high-rise building site fell from a ladder due to its unstable placement. The fall resulted in a broken leg and several rib fractures. It was later discovered that the ladder was not properly secured.",
    "A factory worker in a large warehouse was retrieving items from a high shelf using a forklift. Due to improper positioning of the forklift, the worker fell from a significant height. The impact caused a broken leg and multiple rib fractures. An investigation revealed that the forklift had not been correctly aligned with the shelf",
]


llm = TextGeneration(generator, prompt=prompt)
kw_model = KeyLLM(llm)

# Extract embeddings
model = SentenceTransformer('BAAI/bge-small-en-v1.5')
embeddings = model.encode(documents, convert_to_tensor=True)

# Load it in KeyLLM
kw_model = KeyLLM(llm)

# Extract keywords
keywords = kw_model.extract_keywords(
    documents, 
    embeddings=embeddings, 
    threshold=.85
)

keywords