In [None]:
from model import *

In [None]:
from langchain.llms import HuggingFacePipeline

generate_text = transformers.pipeline(
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,  # langchain expects the full text
    task='text-generation',
    # we pass model parameters here too
    # stopping_criteria=stopping_criteria,  # without this model rambles during chat
    temperature=0.1,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    max_new_tokens=256,  # max number of tokens to generate in the output
    repetition_penalty=1.1,  # without this output begins repeating
    do_sample=True,
    # streamer = transformers.TextStreamer(tokenizer)
)

llm = HuggingFacePipeline(pipeline=generate_text)

In [None]:
%load_ext autoreload
%autoreload 2

from prompt_examples.single_criterion_examples import prompt, examples

In [None]:
print(prompt.format(criterion="""Patient must have undergone complete surgical resection of their stage IIA, IIB, IIIA or IIIB non-squamous or squamous b NSCLC per American Joint Committee on Cancer (AJCC) 8th edition and have had negative margins. N3 disease is not allowed."""))

In [None]:
from token_counting import *
globalize_token_metrics(examples)
print('  avg prompt:', AVG_PROMPT_LEN)
print('  min prompt:', MIN_PROMPT_LEN)
print('  max prompt:', MAX_PROMPT_LEN)
print('avg response:', AVG_RES_LEN)

In [None]:
from chunking import parse_file_with_pipes

In [None]:
import langchain
import time
from pathlib import Path
from loguru import logger
from langchain.chains import LLMChain
from tqdm import tqdm
import sys


langchain.debug = False
langchain.verbose = False

n = '10'
folder = f'test_results_final/trial{n}'
folderp = Path(folder)
logfile = folderp / "outputs.log"
# logfile = folderp / "outputs_amended.log"
logger.add(logfile, colorize=False, enqueue=True)
handler = langchain.callbacks.FileCallbackHandler(logfile)

llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler], verbose=False)

criterions = parse_file_with_pipes(folderp / 'ec_with_pipes.txt')
start = time.time()
invoke_times = []
for idx, criterion in enumerate(tqdm(criterions, file=sys.stdout)):
    idx += 1
    # if idx != 3:
    #     continue
    invoke_start = time.time()
    results = llm_chain.invoke(input={'criterion': criterion.value})
    invoke_times.append(time.time() - invoke_start)
    with open(folderp / f'{idx:02}_output.txt', 'w', encoding='utf-8') as fileout:
        fileout.write(results['text'])
    with open(folderp / f'{idx:02}_stats.yaml', 'w', encoding='utf-8') as fileout:
        fileout.write(f"""elapsed_time: {int(time.time() - invoke_start)}s
was_input_captured: {criterion.value in results['text']}
original_text: |
    {criterion.value}
inclusion: {criterion.inclusion}
""")
    # if idx == 3:
    #     break
end = time.time()

In [None]:
with open(folderp / 'stats.yaml', 'w', encoding='utf-8') as fileout:
    fileout.write(f"""total_time: {(end - start) // 60}min
avg_invoke_time: {sum(invoke_times) // len(invoke_times)}s
""")

In [None]:
import yaml
import glob

stats_files = glob.glob(str(folderp / '*_stats.yaml'))
stats_files.sort()
for idx, file in enumerate(stats_files):
    idx += 1
    with open(file, encoding='utf-8') as filein:
        try:
            stats = yaml.safe_load(filein)
        except Exception as e:
            print(file)
            raise e
    if not stats['was_input_captured']:
        print(f'Input failed {idx:02}')

In [None]:
# Post processing
output_files = glob.glob(str(folderp / '*_output.txt'))
output_files.sort()
for idx, file in enumerate(output_files):
    idx += 1
    stop_word = 'Criterion:'
    lines = []
    with open(file, encoding='utf-8') as filein:
        original_text = False
        for line in filein:
            if line.startswith(stop_word):
                break
            if line == '\n':
                continue
            if 'Original Text:' in line:
                original_text = True
                key = 'Original Text'
                _, original_text_l1 = line.split(': ', maxsplit=1)
                lines.append(key + ': |\n')
                lines.append('    ' + original_text_l1)
            elif 'Disease/Condition:' in line:
                original_text = False
                lines.append(line.removeprefix('\t').removeprefix('    '))
            elif original_text:
                lines.append('    ' + line)
            elif 'Computable Rule:' in line:
                key = 'Computable Rule'
                _, computable_rule = line.split(': ', maxsplit=1)
                lines.append(key + ': |\n')
                lines.append('    ' + computable_rule)
            else:
                lines.append(line.removeprefix('\t').removeprefix('    '))
    output_reassembled = ''.join(lines)
    if 'Computable Rule' not in output_reassembled:
        print('Malformed output:', idx)
    with open(folderp / f'{idx:02}_output_cleaned.yaml', 'w', encoding='utf-8') as fileout:
        fileout.write(output_reassembled)

In [None]:
import pandas as pd

output_files_clean = glob.glob(str(folderp / '*_output_cleaned.yaml'))
output_files_clean.sort()
df = pd.DataFrame(columns=['Criterion Text In', 'Criterion Text Out', 'Inclusion/Exclusion', 'Disease', 'Biomarker', 'Procedure', 'Drug', 'Criterion Rule'])
for idx, file in enumerate(output_files_clean):
    idx += 1
    with open(folderp / f'{idx:02}_stats.yaml', encoding='utf-8') as filein:
        stats = yaml.safe_load(filein)
    with open(file, encoding='utf-8') as filein:
        print(file)
        output = yaml.safe_load(filein)
    # if idx in (20,):
    #     output['Disease/Condition'] = None
    #     output['Biomarker'] = None
    #     output['Procedure'] = None
    #     output['Drug'] = None
    #     output['Computable Rule'] = None
    row = {
        'Criterion Text In': stats['original_text'],
        'Criterion Text Out': output['Original Text'].strip(),
        'Inclusion/Exclusion': 'Inclusion' if stats['inclusion'] else 'Exclusion',
        'Disease': output['Disease/Condition'],
        'Biomarker': output['Biomarker'],
        'Procedure': output['Procedure'],
        'Drug': output['Drug'],
        'Criterion Rule': output['Computable Rule']
    }
    df = df.append(row, ignore_index=True)
df = df.replace('none', None).replace('None', None)
df.to_csv(folderp / 'results.csv', index=False)

In [None]:
flag_files = glob.glob(str(folderp / '*_flag.txt'))
flag_files.sort()
for file in flag_files:
    with open(file, encoding='utf-8') as filein:
        flags = filein.read()
    print(f'==== {file} ====')
    print(flags, end='\n\n')