In [None]:
om torch.utils.data import Dataset, random_split
from simpletransformers.language_modeling import LanguageModelingModel
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Model
from transformers import TrainingArguments, Trainer
import os
import pandas
import torch
import gc
import logging
import random
os.environ["WANDB_DISABLED"] = "true"

In [None]:
class CustomDataset(Dataset):
	_input_id: str = 'input_ids'
	_attention_mask: str = 'attention_mask'

	def __init__(self, text_list, tokenizer, max_length, truncation=False):
		self.input_ids = []
		self.attention_mask = []
		self.labels = []
		for text in text_list:
			encodings_dict = tokenizer(text, truncation=truncation, max_length=max_length)
			self.input_ids.append(torch.tensor(encodings_dict[self._input_id]))
			self.attention_mask.append(torch.tensor(encodings_dict[self._attention_mask]))

	def __len__(self):
		return len(self.input_ids)

	def __getitem__(self, index):
		return self.input_ids[index], self.attention_mask[index]

In [None]:
class QuestionAnswer(object):
	def __init__(self, question: str, answer: str):
		self.question: str = question
		self.answer: str = answer

In [None]:
class DataGenerator(object):
	__bos_token: str = '<|startoftext|>'
	__eos_token: str = '<|endoftext|>'

	# Non-Standard Tokens
	__start_of_question_token: str = '<|startofquestion|>'
	__end_of_question_token: str = '<|endofquestion|>'
	__start_of_reply_token: str = '<|startofreply|>'
	__end_of_reply_token: str = '<|endofreply|>'

	@classmethod
	def create_data_line(cls, data_line: QuestionAnswer) -> str:
		tagged_text: str = cls.__bos_token
		tagged_text += cls.__start_of_question_token
		tagged_text += data_line.question
		tagged_text += cls.__end_of_question_token
		tagged_text += cls.__start_of_reply_token
		tagged_text += data_line.answer
		tagged_text += cls.__end_of_reply_token
		tagged_text += cls.__eos_token

		return tagged_text

	@classmethod
	def get_special_token_dict(cls) -> dict:
		return {
			"bos_tag": f"{cls.__bos_token}",
			"eos_token": f"{cls.__eos_token}",
			"additional_special_tokens": [
				f"{cls.__bos_token}",
				f"{cls.__eos_token}",
				f"{cls.__start_of_question_token}"
				f"{cls.__end_of_question_token}"
				f"{cls.__start_of_reply_token}"
				f"{cls.__end_of_reply_token}"
			]}

In [None]:
model_type = ""
# model_type = "-medium"
# model_type = "-large"

In [None]:
model_name = f"sexy-prompt-bot{model_type}"

parent_directory = "/content/model_base"

model_output_dir = f"{parent_directory}/{model_name}"

tokenizer_path = f"{model_output_dir}"

In [None]:
# data_lines = [
# 	QuestionAnswer(question="what is required to sign up a customer for ENS?", answer="To enroll in the encounter notification service please ensure that your organization has executed a participation agreement with CRISP and updated your organizations Notice of Privacy Practices (NPP). Once those steps are completed the CRISP team will provide a template file for you to securely submit your patient member file to the HIE. Additionally there is a checklist that will allow you to choose what notifications you want to receive and how you want to receive them. All users can use our ENS PROMPT tool at no cost."),
# 	QuestionAnswer(question="what types of organizations participate in CRISP?", answer="A Health Information Exchange, or HIE, is a way of sharing your health information among participating doctors’ offices, hospitals, care coordinators, labs, radiology centers, and other health care providers through secure, electronic means. The purpose is so that each of your participating healthcare providers can have the benefit of the most recent information available from your other participating providers when taking care of you. When you opt out of participation in the HIE, doctors and nurses will not be able to search for your health information through the HIE to use while treating you. Your physician or other treating providers will still be able to select the HIE as a way to receive your lab results, radiology reports, and other data sent directly to them that they may have previously received by fax, mail, or other electronic communications."),
# 	QuestionAnswer(question="If I don't want to share my information with CRISP can I choose to opt-out?", answer="Please be advised that opting out does not preclude any participating organization that has received or accessed personal health information via the HIE prior to such opt-out, and incorporated such personal health information into its records, from retaining such information in its records. Additionally, in accordance with the law, Public health reporting, such as the reporting of infectious diseases to public health officials, will still occur through the HIE after you decide to opt out. Controlled Dangerous Substances (CDS) information, as part of the Maryland Prescription Drug Monitoring Program, will continue to be available through the HIE to licensed providers.\n\nIf you choose to opt out of research only, your information will be available to your treating providers, but will be excluded from any data sets created for researchers."),
# 	QuestionAnswer(question="what is the value of becoming an affiliate of CRISP Shared Services?", answer="As Maryland’s official regional Health Information Exchange, the CRISP HIE network is comprised of hundreds of connected providers, consisting of hospitals, EMRs, pharmacies, payors, health departments, and health centers."),
# 	QuestionAnswer(question="what is the CRISP Shared Services?", answer="CRISP Shared Services (CSS) is a non-profit support organization that provides technology infrastructure and other core services to Health Information Exchanges (HIEs) across the US. We are different than a vendor in that each of our Member HIEs participates in the governance of the organization."),
# 	QuestionAnswer(question="what is the motivation of CRISP Shared Services?", answer="Our primary motivation is to enable and support each local jurisdiction’s Healthcare community so that it can improve health outcomes for its patients. We work with local HIE leadership to implement solutions which best serve the needs of their unique communities, even if those solutions are deployed or built by external vendors."),
# 	QuestionAnswer(question="what is the mission of CRISP Shared Services?", answer="Our mission is to assist member organizations in achieving economies of scale, pooling innovation efforts, and implementing best practices."),
# 	QuestionAnswer(question="What is CRISP?", answer="CRISP is the State Designated Health Information Exchange (HIE) for Maryland."),
# 	QuestionAnswer(question="What does CRISP do?", answer="We are a non-profit organization that facilitates the electronic transfer of clinical information between disparate health information systems."),
# 	QuestionAnswer(question="Clinical Data", answer="As clinical information is created and shared with CRISP, it is made accessible in near real-time to participating health care providers through the CRISP tools. Providers have the ability to securely look up patient information through the internet. CRISP tools retrieve clinical data from participants and display it in a view-only screen at the point of care."),
# 	QuestionAnswer(question="Where is CRISP clinical data? When Is It Made Available?", answer="CRISP clinical data is available through the CRISP Portal at no cost to clinical staff. As clinical information is created and shared with CRISP, it is made accessible in real-time to participating healthcare providers across institutional boundaries through the Clinical Information Service. The portal gives providers the ability to securely look up patient information via their browser. It retrieves clinical data from participants and displays it in a view-only screen at the point of care."),
# 	QuestionAnswer(question="How can I View Patient Data?", answer="The Patient Care Snapshot combines critical information relevant to your role in the patient’s care. It displays data from a variety of sources to provide an at-a-glance view of the patient’s clinical history. Information is presented from a compilation of care management data alongside real-time hospital encounter feeds, up-to-date demographic information, patient to care provider attribution, and clinical summaries of care from our real-time interfaces with providers across the region."),
# 	QuestionAnswer(question="How can I View Patient Data?", answer="Imaging Worklist allows users to compare images across multiple locations via the Imaging Worklist tab through CRISP’s Portal. Patient images from CRISP participating sites dating back to 01/01/2000 are available through CRISP’s Imaging Worklist as well. In addition, up to four imaging studies can be selected, viewed, and compared at the same time."),
# 	QuestionAnswer(question="How can I View Patient Data?", answer="Image Exchange is an online image-sharing service that allows CRISP users to view patient diagnostic images in one central location.\nCurrently, 50 hospitals and 9 outpatient groups across Maryland and DC are contributing images to the Image Exchange service. These images are then made available to CRISP users through our portals to facilitate greater collaboration and efficiency among healthcare providers, ultimately leading to higher quality patient care.\nThe diagnostic images are securely stored on servers located within each connected hospital’s local environment. Images taken within the last 90 days are made available to all authorized CRISP users within seconds of collection, while deeper archives of images older than 90 days are available within minutes.\nCRISP offers two ways of viewing patient clinical data and images. The first option is to directly log on to the CRISP Portal, using specific user credentials. The second option is to access images via the HIE InContext platform within each respective EMR.\nIn addition to the image access provided through Health Records and Imaging Worklist, Image Exchange participants can also request access to the following features: Transfer-to-PACS and Emergent."),
# 	QuestionAnswer(question="Is access to an Electronic Medical Record (EMR) required to participate in CRISP?", answer="An EMR is not required to access CRISP or the data we provide. CRISP data is accessible through any Internet browser via the CRISP Portal."),
# 	QuestionAnswer(question="Is there a cost associated with accessing CRISP data?", answer="There is no cost. Access to CRISP is free to ambulatory practices."),
# 	QuestionAnswer(question="How do I get started?", answer="To get started, an organization must sign a CRISP Participation agreement for single sites or an agreement for organizations with multiple sites and must update its Notice of Privacy Practices documents. If your organization has already signed a participation agreement, you will be able to skip this step during the onboarding process and simply contact your organization’s CRISP Point of Contact for access.The participation agreement is the uniform data sharing agreement signed by every organization that participates in CRISP. It ensures that everyone sharing data follows the same rules and regulations. Updating your NPP ensures that your practice does its part to inform patients about how CRISP is being used to deliver and coordinate care and informs them of the right to opt-out. Organizations must have CRISP Opt-Out Forms on-site to distribute to patients who ask for one. They must also maintain copies of the Patient Factsheets in registration areas. New users without an agreement will be prompted to upload one during the onboarding process."),
# 	QuestionAnswer(question="How do I get started?", answer="In addition, an organization must: \n-Submit LabCorp/Quest Data Release Form if it wishes to make its results available in CRISP.\n-Send a patient panel (all patients seen within the last 18 months) using DIRECT secure emailto panelupload@crispdirect.org. This patient list will be used to audit your organization’s future activity to help ensure appropriate use. Here is a sample [Patient Panel](https://crisphealth.org/wp-content/uploads/2016/03/Sample-ENS-Patient-Panel.xlsx).\n-In order to send us your panel, please email [support@crisphealth.org](panelupload@crispdirect.org) to request a DIRECT secure email address. Then, follow these [instructions](https://crisphealth.org/wp-content/uploads/2016/03/How-to-Upload-ENS-Participant-Panel.pdf)"),
# 	QuestionAnswer(question="Who Contributes Health Information?", answer="All 50 acute-care hospitals in Maryland and 7 acute care hospitals in DC are working with CRISP. To see which of these hospitals are currently providing clinical information, visit our Participating Providers page [here](https://crisphealth.org/about-crisp/connected-providers/).")
# ]

data_lines = []
with open('training.txt', 'r', encoding="UTF-8") as f:
	lines = f.readlines()
	for line in lines:
		foo = "<|startoftext|>" + line
		print(foo)
		data_lines.append(foo)

random.shuffle(data_lines)

In [None]:
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel
tokenizer = GPT2Tokenizer.from_pretrained(f"gpt2{model_type}")

model = GPT2LMHeadModel.from_pretrained(f"gpt2{model_type}")

special_tokens_dict = {
    "bos_token": "<|startoftext|>",
    "eos_token": "<|endoftext|>",
    "additional_special_tokens": [
        "<|endoftext|>",
        "<|startoftext|>"
    ]
}

print(tokenizer.eos_token)

num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

print("We have added", num_added_toks, "tokens")

print(tokenizer.eos_token)

# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))

model.save_pretrained(model_output_dir)

tokenizer.save_pretrained(tokenizer_path)

model = GPT2LMHeadModel.from_pretrained(model_output_dir)

tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

model.cuda()

In [None]:
generator = torch.Generator()

generator.manual_seed(0)

print(f":: Total Number Of Samples {len(data_lines)}")

max_length = max([len(tokenizer.encode(prompt)) for prompt in data_lines])

print(f":: Max Length Of Sample {max_length}")

dataset = CustomDataset(data_lines, tokenizer, max_length=max_length)

train_size = int(0.9 * len(dataset))

train_dataset, eval_dataset = random_split(dataset, [train_size, len(dataset) - train_size], generator=generator)

In [None]:
training_args = TrainingArguments(output_dir=model_output_dir)
training_args.num_train_epochs = 5
training_args.per_device_train_batch_size = 1
training_args.per_device_eval_batch_size = 1
training_args.logging_steps=50
training_args.save_steps=1000
training_args.weight_decay=0.0
training_args.logging_dir='./logs'
training_args.fp16=True
training_args.auto_find_batch_size=True
training_args.gradient_accumulation_steps=50
training_args.learning_rate=1e-4

In [None]:
trainer: Trainer = Trainer(
		model=model,
		args=training_args,
		train_dataset=train_dataset,
		eval_dataset=eval_dataset,
		data_collator=lambda data: {
			'input_ids': torch.stack([f[0] for f in data]),
			'attention_mask': torch.stack([f[1] for f in data]),
			'labels': torch.stack([f[0] for f in data])
		}
	)

In [None]:
trainer.train()

trainer.save_model(model_output_dir)

In [None]:
import time

tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

model = GPT2LMHeadModel.from_pretrained(model_output_dir)

In [None]:
import re
question = "<|startoftext|>"

prompt = f"{question}"

device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

generation_prompt = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")

model.to(device)

generation_prompt.to(device)

inputs = generation_prompt.input_ids

attention_mask = generation_prompt['attention_mask']

sample_outputs = model.generate(inputs=inputs,
                                attention_mask=attention_mask,
                                do_sample=True,
                                top_k=0,
                                top_p=0.95,
                                max_length=1024,
                                num_return_sequences=5,
                                repetition_penalty=1.1)

result = ""
for i, sample_output in enumerate(sample_outputs):
    result = tokenizer.decode(sample_output, skip_special_tokens=True)
    print(result)