# Other stuff

In [26]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM

In [33]:
with open('Example data/example.txt', 'r') as f:
    text_example = f.read()

In [76]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 300,
    chunk_overlap = 20,
    separators = ['\n\n', '\n']
)

chunks = text_splitter.split_text(text_example)
len(chunks)

12

In [72]:
i = 1
print(chunks[i])

print("------")

print(chunks[i + 1])

8 year old male

Notes:
1. Emergency Department medical note 1/1/25
8 year old boy
NKDA 
No medications
Last meal 8:00am

Developed right sided testicular pain 1/7 ago. Did not tell anybody as did not want to worry anyone.
Ate a sandwich on the way to hospital at 0730.
USS showing avascular right testes - referred to surgeons at PCH
------
O/E
Comfortable
Not examined further

Plan
Admit to hospital
NBM
Surgeons informed 

2. Surgical note 1/1/25
Paediatric Surgery Admission
8M with right testicular torsion on ultrasound
HPC
One day history of right sided scrotal pain
Initially had vomiting but none in past 6h
No fevers
No urinary symptoms
No abdominal pain


In [40]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
gen_kwargs = {
    "max_length": 256,
    "length_penalty": 0,
    "num_beams": 20,
    "num_return_sequences": 20,
}

In [45]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

# Text to extract triplets from
text = 'Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.'
text = chunks[4]

# Tokenizer text
model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')

# Generate
generated_tokens = model.generate(
    model_inputs["input_ids"].to(model.device),
    attention_mask=model_inputs["attention_mask"].to(model.device),
    **gen_kwargs,
)

# Extract text
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

# Extract triplets
for idx, sentence in enumerate(decoded_preds):
    print(f'Prediction triplets sentence {idx}')
    print(extract_triplets(sentence))



Prediction triplets sentence 0
[{'head': 'NBM', 'type': 'subclass of', 'tail': 'O/E'}]
Prediction triplets sentence 1
[{'head': 'NBM', 'type': 'has part', 'tail': 'E'}]
Prediction triplets sentence 2
[{'head': 'NBM', 'type': 'subclass of', 'tail': 'E'}]
Prediction triplets sentence 3
[{'head': 'NBM', 'type': 'has part', 'tail': 'O/E'}]
Prediction triplets sentence 4
[{'head': 'NBM', 'type': 'facet of', 'tail': 'O/E'}]
Prediction triplets sentence 5
[{'head': 'NBM', 'type': 'has part', 'tail': 'O'}]
Prediction triplets sentence 6
[{'head': 'NBM', 'type': 'studied by', 'tail': 'Physician'}]
Prediction triplets sentence 7
[{'head': 'O/E', 'type': 'studied by', 'tail': 'Physician'}]
Prediction triplets sentence 8
[{'head': 'O/E', 'type': 'different from', 'tail': 'BMI'}, {'head': 'BMI', 'type': 'different from', 'tail': 'O/E'}]
Prediction triplets sentence 9
[{'head': 'O/E', 'type': 'different from', 'tail': 'NBM'}, {'head': 'NBM', 'type': 'different from', 'tail': 'O/E'}]
Prediction tripl

In [47]:
print(chunks[4])

O/E
Comfortable
Not examined further

Plan
Admit to hospital
NBM
Surgeons informed


In [48]:
# model_name = 'meta-llama/Meta-Llama-3.1-8B'

model_name = 'meta-llama/llama-2-7b-chat-hf'

# model_name = 'victorlxh/ICKG-v2.0'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",  # puts model on GPU if available
    torch_dtype="auto"  # uses correct precision automatically
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:14<00:00,  7.47s/it]


In [81]:
messages = [
    {
        "role": "system",
        #"content": "Return a list of named entities in the text."
        #"content": "You are an AI assistant for a hospital that looks to construct meaningful triples of knoweldege in JSON output (entity, relationship, entity).",
        'content': """Extract all the relationships based on the given context. 
Return a list of JSON objects. For example:

<Examples>
    [{{"subject": "John", "relationship": "lives in", "object": "US"}},
    {{"subject": "Eifel towel", "relationship": "is located in", "object": "Paris"}},
    {{"subject": "Hayao Miyazaki", "relationship": "is", "object": "Japanese animator"}}]
</Examples>

- ONLY return triples and nothing else. None of 'subject', 'relationship' and 'object' can be empty.
"""
    },
    {"role": "user", "content": f"Context: {chunks[2]}\n\nTriples:"},
]
model_inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to("cuda")
input_length = model_inputs.shape[1]
generated_ids = model.generate(model_inputs, do_sample=True, max_new_tokens=1000)
print(tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0])


 Based on the given context, the following are the relationships that can be inferred:

1. John (subject) lives in (relationship) US (object)
2. Eifel towel (subject) is located in (relationship) Paris (object)
3. Hayao Miyazaki (subject) is (relationship) Japanese animator (object)
4. Developed (subject) right sided testicular pain (relationship) 1/7 ago (object)
5. Ate (subject) a sandwich (relationship) on the way to hospital (object) at 0730
6. USS (subject) showing avascular right testes (relationship) referred to surgeons (object) at PCH
7. Comfortable (subject) (relationship) Not examined further (object)
8. Admit (subject) to hospital (relationship) (object)
9. Surgeons (subject) informed (relationship) (object)

Therefore, the list of JSON objects is:

[
{
"subject": "Developed",
"relationship": "developed",
"object": "right sided testicular pain"
},
{
"subject": "Ate",
"relationship": "ate",
"object": "sandwich"
},
{
"subject": "USS",
"relationship": "showing",
"object": "ava

In [82]:
print(chunks[2])

Developed right sided testicular pain 1/7 ago. Did not tell anybody as did not want to worry anyone.
Ate a sandwich on the way to hospital at 0730.
USS showing avascular right testes - referred to surgeons at PCH

O/E
Comfortable
Not examined further

Plan
Admit to hospital
NBM
Surgeons informed


In [29]:
prompt = f"""Extract Triples: {text}"""
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
input_length = model_inputs.input_ids.shape[1]
generated_ids = model.generate(**model_inputs, max_new_tokens=50)
print(tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[...]
Patient is a 8 year old male who presented with right sided testicular pain.
Ultrasound showed right sided testicular torsion.
Patient underwent right sided scrotal
