In [1]:
import openai
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel, GenerationConfig
from datasets import Dataset

import os
import csv
import json
import re
import random
import pandas as pd
import numpy as np

! pip install nltk
from nltk.tokenize import TreebankWordTokenizer



In [2]:
MODEL_PATH = "estMed-gpt2_fine_tuned3"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
generation_config = GenerationConfig(
    max_new_tokens=400, do_sample=True, top_k=40, eos_token_id=model.config.eos_token_id)

In [4]:
def generate_text(beginning: str):
    inputs = tokenizer(beginning, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(**inputs, generation_config=generation_config, pad_token_id=tokenizer.eos_token_id)
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [12]:
test_text = generate_text("protseduur, vanusegrupp 10, naine, C50-50,")
print(test_text)

['protseduur, vanusegrupp 10, naine, C50-50, ultraheli DATE tsütostaatilise ravikuuri planeerimine ja manustamine, kuni 24 tundi DATE mammograafia']


In [14]:
disease = 'C50-50'
gender = 'naine'

synth_path = 'test_data'

different_beginnings = ['protseduur', 'anamnees']

for prompt_genre in different_beginnings:
    
    texts = []
    
    for i in range(1):
        prompt = "{genre}, vanusegrupp {age}, {gender}, {disease},".format(
        genre = prompt_genre,
        age = str(random.randint(1,14)),
        gender = gender,
        disease = disease
        )
        texts.append(generate_text(prompt))
    
    filename = synth_path + prompt_genre.replace(" ", "_") + "_test.csv"
    
    with open(filename, 'w', newline='') as csvfile:
        spamwriter = csv.writer(csvfile, delimiter='\t',
                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
        
        for text in texts:
            spamwriter.writerow(text)

In [17]:
filename = synth_path + prompt_genre.replace(" ", "_") + "_test.csv"
    
with open(filename, 'w', newline='') as csvfile:
    spamwriter = csv.writer(csvfile, delimiter='\t',
                            quotechar='|', quoting=csv.QUOTE_MINIMAL)

    for text in texts:
        spamwriter.writerow(text)

### Check the originality of the generated texts

In [None]:
import os
from originality import lcs
from transformers import AutoTokenizer
from tqdm import trange

In [None]:
def read_in_file(file):
    if file.endswith('.csv'):
        texts = pd.read_csv(file, sep='\t')
    if file.endswith('.tsv'):
        texts = pd.read_csv(file, sep='\t')
    if file.endswith('.pkl'):
        texts = pd.read_pickle(file)
    
    return texts

def calculate_lcs(target_texts_file,
                  reference_texts_file,
                  tokenizer_folder,
                  top_k,
                  save_folder,
                  save_file_name):
        
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    # Read in the data
    texts = read_in_file(target_texts_file)
    target_texts = []
    for i in range(texts.shape[0]):
        text = texts.iloc[i,0]
        text = ', '.join(text.split(', ')[4:])
        target_texts.append(text)

    references = read_in_file(reference_texts_file) # We assume that the references are already tokenized.

    # Read in the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_folder)

    # Tokenize the targets texts
    targets = tokenizer(target_texts)['input_ids']

    lcs_elements = np.zeros((len(targets), top_k), dtype=np.float32)
    lcs_indices = np.zeros((len(targets), top_k), dtype=np.float32)

    batch_size=500
    batch_size2=32768

    lcs_matrix = np.zeros((batch_size, len(references)), dtype=np.float32)

    # Calculate the LCS in batches
    for i in trange(int(np.ceil(len(targets)/batch_size))):
        for k in trange(int(np.ceil(len(references)/batch_size2))):
            if (i+1)*batch_size < len(targets):
                lcs_matrix[:,k*batch_size2:(k+1)*batch_size2] = lcs.check_originality(targets[i*batch_size:(i+1)*batch_size], references[k*batch_size2:(k+1)*batch_size2])
            else:
                lcs_matrix[:len(targets)-i*batch_size,k*batch_size2:(k+1)*batch_size2] = lcs.check_originality(targets[i*batch_size:(i+1)*batch_size], references[k*batch_size2:(k+1)*batch_size2])
        # Get the top k
        top_indices = np.argsort(-lcs_matrix, axis=1)[:, :top_k]
        top_elements = np.take_along_axis(lcs_matrix, top_indices, axis=1)
        
        if (i+1)*batch_size < len(targets):
            lcs_elements[i*batch_size:(i+1)*batch_size,:] = top_elements
            lcs_indices[i*batch_size:(i+1)*batch_size,:] = top_indices
        else:
            lcs_elements[i*batch_size:,:] = top_elements[:len(targets)-i*batch_size,:]
            lcs_indices[i*batch_size:,:] = top_indices[:len(targets)-i*batch_size,:]

    # Save the lcs
    save_file = os.path.join(save_folder, '.'.join(save_file_name, 'npy'))
    np.save(save_file, lcs_elements)

    save_file = os.path.join(save_folder, f'{save_file_name}_selected.npy')
    np.save(save_file, lcs_indices)

In [None]:
lcs_folder = "lcs_validated_ids/selected"

In [None]:
# We assume that the references in reference_texts_file are already tokenized.
calculate_lcs(f"generated_data_12okt/batch{batch_nr}/{document_type}.csv", 'data/train_tokenized.pkl', MODEL_PATH, 50, lcs_folder, f"batch{batch_nr}_{document_type}")

## Reading in the correct LCS verified synthetic texts

In [18]:
lcs_folder = "lcs_validated_ids/selected"

In [21]:
def clean_text_of_prefix(text):
    text_ = text.split(",")
    return ",".join(text_[4:])[1:]

In [22]:
def return_okay_texts(batch_nr, document_type):
    tekstid_all = []
    tekstid_filtered = []
    with open("generated_data_12okt/batch" + str(batch_nr)+ "/" + str(document_type) + ".csv") as csvfile:
        lugeja = csv.reader(csvfile, delimiter='\t',quotechar='|', quoting=csv.QUOTE_MINIMAL)
        for row in lugeja:
            tekstid_all.append(row)
    
    idx = np.load("lcs_validated_ids/selected/batch" + str(batch_nr) + "_" + str(document_type) + "_selected.npy")
    for index in idx:
        tekstid_filtered.append(clean_text_of_prefix(tekstid_all[index][0]))
    
    return tekstid_filtered
            

filtered_texts_anamnesis = return_okay_texts(1, "anamnees") + return_okay_texts(2, "anamnees")
filtered_texts_procedures = return_okay_texts(1, "protseduur") + return_okay_texts(2, "protseduur") + return_okay_texts(4, "protseduur") + return_okay_texts(5, "protseduur")

In [23]:
len(filtered_texts_anamnesis)

542

In [24]:
len(filtered_texts_procedures)

480

In [25]:
filtered_texts_anamnesis[1]

'pöördus kontrolli. DATE vasaku rinna sektorresektsioon + snb, adjuvantravina rinna kiiritus + zoladex + tamoxifen. DATE konsiiliumi otsusega herceptin-ravi (tmx). DATE.14 mmgr, uh ja jnb. DATE vasaku rinna rekonstruktsioon tram lapiga. DATE konsiiliumi otsusega näidustatud adjuvantne keemia-, bioloogiline-, hormoontaastusravi. kuna kasvaja hormoonsõltuv, plaanis alustada ravi zoladex + tamoxifeniga. obj: rindades tihendeid ei palpeeri. NAME. palp. bilat.ii. tellitud ca15-3'

## Azure API

In [26]:
# api setup, removed for public repo

openai.api_type = "azure"

with open('', 'r') as api_file:
    openai.api_key=str(api_file.readline()).strip()
    
openai.api_base = ""
openai.api_version = ""

In [27]:
prompts = []
responses = []

In [28]:
def ask_openai(prompt: str) -> str:
    try:
        response = openai.ChatCompletion.create(
            deployment_id = "gpt-35-experiments-health",
            model = "gpt-35-turbo",
            messages=[{"role": "user", "content": prompt}]
        )

    except:
        response = openai.ChatCompletion.create(
            deployment_id = "gpt-35-experiments-health",
            model = "gpt-35-turbo",
            messages=[{"role": "user", "content": prompt}]
        )
    responses.append(response['choices'][0]['message']['content'])
    prompts.append(prompt)
    return response['choices'][0]['message']['content']

# ask_openai("What is your name?")

## Zero shot prompt example

In [30]:
output_data = 'drug named entity, procedure named entity, family history named entity.'
output_format = 'DRUG for drug named entity, PROCEDURE for procedure named entity, FAMILY for family history named entity.'

base_prompt_zero_shot = "In the text below, give the list of: " + output_data + " Words need to be in exactly the same format as in input text. Format the output in JSON with the following keys: " + output_format + " Text below: "
print(base_prompt_zero_shot + "\n" + "\"" + filtered_texts_anamnesis[1] + "\"")

In the text below, give the list of: drug named entity, procedure named entity, family history named entity. Words need to be in exactly the same format as in input text. Format the output in JSON with the following keys: DRUG for drug named entity, PROCEDURE for procedure named entity, FAMILY for family history named entity. Text below: 
"pöördus kontrolli. DATE vasaku rinna sektorresektsioon + snb, adjuvantravina rinna kiiritus + zoladex + tamoxifen. DATE konsiiliumi otsusega herceptin-ravi (tmx). DATE.14 mmgr, uh ja jnb. DATE vasaku rinna rekonstruktsioon tram lapiga. DATE konsiiliumi otsusega näidustatud adjuvantne keemia-, bioloogiline-, hormoontaastusravi. kuna kasvaja hormoonsõltuv, plaanis alustada ravi zoladex + tamoxifeniga. obj: rindades tihendeid ei palpeeri. NAME. palp. bilat.ii. tellitud ca15-3"


In [41]:
prompts = []
responses = []

test_sentences = [filtered_texts_anamnesis[1]]
try:
    for sent in test_sentences:
        ask_openai(base_prompt_zero_shot + "\n" + "\"" + sent + "\"")
except:
    pass

In [42]:
responses

['{\n   "DRUG": ["zoladex", "tamoxifen", "herceptin-ravi"],\n   "PROCEDURE": ["sektorresektsioon", "snb", "kiiritus", "rinna rekonstruktsioon", "tram lapiga", "adjuvantne keemia-, bioloogiline-, hormoontaastusravi"],\n   "FAMILY": []\n}']

In [43]:
entities = ['DRUG', 'PROCEDURE', 'FAMILY']

In [44]:
def parse_prompt(clinical_text: str, response: str, entities):
    
    results = {key: [] for key in entities}
    list_results = []
    
    json_data = json.loads(response)
    # print(json_data)
    
    for entity in entities:
        # entity = DISEASE
        if entity in json_data:
            for finding in json_data[entity]:
                
                if finding in ['none', 'None']:
                    pass
                
                pattern = re.compile(finding.lower())
                # print(pattern.finditer(clinical_text.lower()))
                for match in pattern.finditer(clinical_text.lower()):
                    start = match.start()
                    end = match.end()
                    results[entity].append((start, end))
                    # print(start, end, finding, entity)
                    dict_result = {'entity_type': entity, 'start_idx': start, 'end_idx': end, 'text': clinical_text[start:end]}
                    if dict_result not in list_results:
                        list_results.append(dict_result)
                    
    return list_results

# results_ = [parse_prompt(test_sentences[testid], responses[testid], entities) for testid in range(len(responses))]

## Parse to training data

In [45]:
def parsed_results_to_train(parsed_results, clinical_results, entities):
    
    tokens_and_tags = []
    
    res = list(TreebankWordTokenizer().span_tokenize(clinical_results))
    dic = {k:v for k,v in zip(res, [clinical_results[i:j] for i, j in res])}
    
    for entry in dic:
        
        token = dic[entry]
        tags = []
        
        # print(entry, dic[entry])
        tokenized_span_start = entry[0]
        tokenized_span_end = entry[1]
        
        for parse in parsed_results:
            parse_span_start = parse['start_idx'] # 148
            parse_span_end = parse['end_idx'] # 172
            parse_entity = parse['entity_type'] # DISEASE
            
            if len(range(max(parse_span_start, tokenized_span_start), min(parse_span_end, tokenized_span_end))) > 0:
                if parse_entity not in tags:
                    tags.append(parse_entity)
        
        if len(tags) == 0:
            tags.append('O')
        
        tokens_and_tags.append((token, tags))
        
    return tokens_and_tags


In [46]:
successful_texts = []
parsed_answers = []

for ans_, og_text in zip(responses, test_sentences):
    try:
        parsed_answer = parse_prompt(og_text, ans_[ans_.find('{'):ans_.find('}')+1], entities)
        
        parsed_answers.append(parsed_answer)
        successful_texts.append(og_text)
    except:
        print("----------")
        print(og_text)
        print()
        print(ans_)
        print("-------")
        pass

In [47]:
successful_texts

['pöördus kontrolli. DATE vasaku rinna sektorresektsioon + snb, adjuvantravina rinna kiiritus + zoladex + tamoxifen. DATE konsiiliumi otsusega herceptin-ravi (tmx). DATE.14 mmgr, uh ja jnb. DATE vasaku rinna rekonstruktsioon tram lapiga. DATE konsiiliumi otsusega näidustatud adjuvantne keemia-, bioloogiline-, hormoontaastusravi. kuna kasvaja hormoonsõltuv, plaanis alustada ravi zoladex + tamoxifeniga. obj: rindades tihendeid ei palpeeri. NAME. palp. bilat.ii. tellitud ca15-3']

In [49]:
print(responses[0])

{
   "DRUG": ["zoladex", "tamoxifen", "herceptin-ravi"],
   "PROCEDURE": ["sektorresektsioon", "snb", "kiiritus", "rinna rekonstruktsioon", "tram lapiga", "adjuvantne keemia-, bioloogiline-, hormoontaastusravi"],
   "FAMILY": []
}


In [38]:
parsed_answers[1]

[]

## export to csv

In [50]:
results_for_csv = []

print(len(parsed_answers))
print(len(successful_texts))

for text_, results_ in zip(successful_texts, parsed_answers):
    train_format_ = parsed_results_to_train(results_, text_, entities)
    results_for_csv.append(train_format_)

1
1
