In [None]:
!pip install -r requirements.txt

In [None]:
from torch import cuda, bfloat16
import transformers

model_id="epfl-llm/meditron-7b"

device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=bfloat16
)

model_config = transformers.AutoConfig.from_pretrained(
    model_id,
    # token=token
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    config=model_config,
    quantization_config=bnb_config,
    device_map='auto',
    # token=token
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_id,
    # token=token
)

# enable evaluation mode to allow model inference
model.eval()

print(f"Model loaded on {device}")

In [7]:
def get_token_len(text: str):
    return len(tokenizer.encode(text))

In [8]:
from transformers import StoppingCriteria, StoppingCriteriaList
import torch

stop_list = ["\n\n", "\n\n\n", "Task:\nBelow"]
# stop_list = ['\nHuman:', '\n```\n']
stop_token_ids = [tokenizer(x, add_special_tokens=False)['input_ids'] for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
print(stop_token_ids)

# define custom stopping criteria object
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_ids in stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
                return True
        return False

stopping_criteria = StoppingCriteriaList([StopOnTokens()])

[tensor([29871,    13,    13], device='cuda:0'), tensor([29871,    13,    13,    13], device='cuda:0'), tensor([ 9330, 29901,    13, 21140,   340], device='cuda:0')]


In [43]:
from langchain.llms import HuggingFacePipeline

generate_text = transformers.pipeline(
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,  # langchain expects the full text
    task='text-generation',
    # we pass model parameters here too
    stopping_criteria=stopping_criteria,  # without this model rambles during chat
    temperature=0.1,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    max_new_tokens=256,  # max number of tokens to generate in the output
    # repetition_penalty=1.1,  # without this output begins repeating
    do_sample=True,
    streamer = transformers.TextStreamer(tokenizer)
)
llm = HuggingFacePipeline(pipeline=generate_text)

In [1]:
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from prompt_examples.single_criterion_examples import examples


In [41]:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
from langchain_community.vectorstores import Chroma

example_selector = SemanticSimilarityExampleSelector.from_examples(
    # This is the list of examples available to select from.
    examples,
    # This is the embedding class used to produce embeddings which are used to measure semantic similarity.
    HuggingFaceEmbeddings(),
    # This is the VectorStore class that is used to store the embeddings and do a similarity search over.
    Chroma,
    # This is the number of examples to produce.
    k=3,
)

example_prompt = PromptTemplate(
    input_variables=["context", "answer"], template="Task:{context}{answer}"
)

prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    suffix="""Task:
Below is an example of clinical trial eligibility inclusion/exclusion criteria. Your task is to identify 3 categories of data within it. The 3 categories are: 1) Disease: a disorder affecting humans, 2) Biomarker: genes, proteins, or other substances that can be tested for to reveal important details about a patient’s cancer, and 3) Prior Therapy: medications, surgeries, or procedures that a patient may be treated with. For each of the identified categories, state whether it is an inclusion or exclusion. Additionally, please summarize the criterion.
Criteria:
    {criteria}
""",
    input_variables=["criteria"],
)

## Example token count

In [30]:
total = 0
total_res = 0
minm = 2000
maxm = 0
for example in examples:
    context_len = get_token_len(example['context'])
    answer_len = get_token_len(example['answer'])
    combined_len = context_len + answer_len
    if combined_len < minm:
        minm = combined_len
    elif combined_len > maxm:
        maxm = combined_len
    total += combined_len
    total_res += answer_len
AVG_EXAMPLE_LEN = total // len(examples)
MAX_W_SIZE = 2048
AVG_RES_LEN = total_res // len(examples)
print('         avg:', AVG_EXAMPLE_LEN)
print('         min:', minm)
print('         max:', maxm)
print('avg response:', AVG_RES_LEN)

         avg: 227
         min: 188
         max: 252
avg response: 61


In [32]:
class ECDoc:
    inc: list[str]
    exc: list[str]
    template = """Inclusion Criteria
{inclusion}

Exclusion Criteria
{exclusion}
"""

    def __init__(self, inc: list[str], exc: list[str]):
        self.inc = inc
        self.exc = exc

    @property
    def size(self) -> int:
        return get_token_len(str(self))

    def __str__(self) -> list[str]:
        return self.template.format(inclusion=''.join(self.inc).rstrip(),
                                   exclusion=''.join(self.exc).rstrip())

    def split(self):
        inc_len = len(self.inc)
        exc_len = len(self.exc)
        inc_midpoint = inc_len // 2
        exc_midpoint = exc_len // 2
        # Prevent Inclusion section from splitting too small
        if inc_len <= 5:
            inc_chunk_1 = self.inc
            inc_chunk_2 = self.inc
        else:
            inc_chunk_1 = self.inc[:inc_midpoint]
            inc_chunk_2 = self.inc[inc_midpoint:]
        # Prevent Exclusion section from splitting too small
        if exc_len <= 5:
            exc_chunk_1 = self.exc
            exc_chunk_2 = self.exc
        else:
            exc_chunk_1 = self.exc[:exc_midpoint]
            exc_chunk_2 = self.exc[exc_midpoint:]
        doc_chunk_1 = ECDoc(inc=inc_chunk_1, exc=exc_chunk_1)
        doc_chunk_2 = ECDoc(inc=inc_chunk_2, exc=exc_chunk_2)
        return doc_chunk_1, doc_chunk_2


def parse_file(filename: str) -> ECDoc:
    inc = []
    exc = []
    inclusion = True
    with open(filename) as filein:
        for line in filein.readlines():
            if line.strip().startswith('Inclusion Criteria'):
                continue
            elif line.strip().startswith('Exclusion Criteria'):
                inclusion = False
                continue
            elif line == '\n':
                continue
            if inclusion:
                inc.append(line)
            else:
                exc.append(line)
    return ECDoc(inc=inc, exc=exc)


def chunk_ec(doc: ECDoc) -> list[ECDoc]:
    # Not too large
    can_fit = doc.size + AVG_EXAMPLE_LEN < MAX_W_SIZE - AVG_RES_LEN
    if can_fit:
        return [doc]

    last_pass_chunks = [doc]
    while not can_fit:
        new_chunks = []
        for chunk in last_pass_chunks:
            new_chunks.extend(chunk.split())
        new_chunk_size = max([new_chunk.size for new_chunk in new_chunks])
        can_fit = new_chunk_size + AVG_EXAMPLE_LEN < MAX_W_SIZE - AVG_RES_LEN
        last_pass_chunks = new_chunks
    return new_chunks

## Process and test trial

In [45]:
from langchain.schema import StrOutputParser
from langchain.chains import LLMChain

llm_chain = LLMChain(llm=llm, prompt=prompt)

In [None]:
n = '01'
trial_folder = f'test_results/trial{n}'
original_doc = parse_file(f'{trial_folder}/unstructured_ec.txt')

# with open(f'{trial_folder}/inclusion_output', 'a') as fileout:
for inc in original_doc.inc:
    criteria = """Inclusion Criteria
{inclusion}
""".format(inclusion=inc)
    output = llm_chain.run(criteria=criteria)

# for exc in original_doc.exc:
#     criteria = """Exclusion Criteria
# {exclusion}
# """.format(exclusion=exc)
#     output = llm_chain.run(criteria=criteria)