# Generate our Dataset

We need to filter out the data that does not matter to us, and those that exceed the limits we have set for this experiment.

We will start by removing any admission with corrections to their discharge summary.


In [None]:
import pandas as pd

In [None]:
data = pd.read_csv("./data/NOTEEVENTS.csv")

In [None]:
# First, drop all rows with errors flagged
errors = data["ISERROR"].value_counts()
print(errors)

data = data[data["ISERROR"].isna()]
data = data.drop(columns=["ISERROR"])
data

In [None]:
summaries = data[data["CATEGORY"] == "Discharge summary"]
summaries = summaries.drop_duplicates(subset="HADM_ID", keep=False)

admissions = summaries["HADM_ID"].unique()
len(admissions)

In [None]:
data = data[data["HADM_ID"].isin(admissions)]

data.to_csv("./data/single-discharge-all.csv", index=False)

Now, we will handle long text sequences. We will cut out any set of notes whose total amount of tokens exceeds 7942 (which is 500 tokens less than 8192, the context window for Mistral)

Due to memory limitations, this should be run separately from the first few steps, or you risk going out of memory yet again


In [None]:
data = pd.read_csv("./data/single-discharge-all.csv")
data

In [None]:
admissions = data["HADM_ID"].unique()
len(admissions)

In [1]:
# Get the SentencePiece tokenizer to assess the encoding abilities
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")

In [None]:
exceeding = []
total = 0
total_exceeding = 0

TOKEN_LIMIT = 7942

for admission in admissions:
    notes = data[data["HADM_ID"] == admission]
    notes = notes[notes["CATEGORY"] != "Discharge summary"]
    notes = notes["TEXT"].tolist()

    text = ""

    for note in notes:
        text += note + "\n"

    tokens = tokenizer(text, return_tensors="pt").input_ids[0]

    if len(tokens) > TOKEN_LIMIT:
        exceeding.append(admission)
        total_exceeding += 1

    total += 1

    if total % 1000 == 0:
        print(f"Processed {total} admissions, {total_exceeding} exceeding")

print(f"Total admissions: {total}, exceeding: {total_exceeding}")
admissions = [admission for admission in admissions if admission not in exceeding]
len(admissions)

In [None]:
data = data[data["HADM_ID"].isin(admissions)]
data.to_csv("./data/single-discharge-8k.csv", index=False)

We can now remove all the examples that have no notes associated with them besides the dishcarge summary. These would be of no help to our program


In [None]:
data = pd.read_csv("./data/single-discharge-8k.csv")

admissions = data["HADM_ID"].unique()

empty = []

for admission in admissions:
    notes = data[data["HADM_ID"] == admission]
    notes = notes[notes["CATEGORY"] != "Discharge summary"]
    notes = notes["TEXT"].tolist()

    if len(notes) == 0:
        empty.append(admission)
        total += 1

print(f"Amount of admissions before: {len(admissions)}")
admissions = [admission for admission in admissions if admission not in empty]
print(f"Amount of admissions after: {len(admissions)}")

In [None]:
data = data[data["HADM_ID"].isin(admissions)]

data.to_csv("./data/single-discharge-8k.csv", index=False)

The final step is to separate these datasets into a training and testing set. Of the 30k, 1k will be reserved for testing, and the rest will be used for training.


In [None]:
import random

data = pd.read_csv("./data/single-discharge-8k.csv")

admissions = data["HADM_ID"].unique()

random.seed(42)

random.shuffle(admissions)

test = admissions[:1000]

train = admissions[1000:]
len(train), len(test)

In [None]:
train_data = data[data["HADM_ID"].isin(train)]
test_data = data[data["HADM_ID"].isin(test)]

train_data.to_csv("./data/single-discharge-8k-train.csv", index=False)
test_data.to_csv("./data/single-discharge-8k-test.csv", index=False)

As an extra step for the one-shot approaches, we will retrieve one sample from the training set to serve as our guide


In [None]:
sample = 101648.0  # obtained via random.choice()

sample_notes = data[data["HADM_ID"] == sample]
sample_notes

In [None]:
# Determine the notes and the discharge summary

summary = sample_notes[sample_notes["CATEGORY"] == "Discharge summary"].iloc[0]["TEXT"]
print(summary)

In [None]:
notes = sample_notes[sample_notes["CATEGORY"] != "Discharge summary"]
notes = notes["TEXT"].tolist()
notes = "\n".join(notes)
print(notes)

## Update

After experimenting with the original 8k dataset, it proved to be too much for our hardware to handle. As such, we were forced to reduce the maximum size even further, so each sample can fit into the memory. We will set the maximum size to 7600 tokens, to not cut out too much data.


In [None]:
data = pd.read_csv("./data/single-discharge-8k-train-formatted.csv")
data_test = pd.read_csv("./data/single-discharge-8k-test-formatted.csv")

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-1.1-7b-it",
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)

In [None]:
DEFAULT_SYSTEM_PROMPT = """
You are an expert clinical assistant. You will receive a collection of clinical notes. You will summarize them in the style of a discharge summary.
""".strip()


def generate_training_prompt_gemma(
    notes: str, summary: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
    return f"""<start_of_turn>user {system_prompt}

### Input:

{notes.strip()}

### Summary:

<end_of_turn>
<start_of_turn>model
{summary}
<end_of_turn>
""".strip()


formatted_data = data.apply(
    lambda row: generate_training_prompt_gemma(
        row["notes"], row["summary"], DEFAULT_SYSTEM_PROMPT
    ),
    axis=1,
)

In [None]:
big = []
total = 0
for entry in formatted_data:
    if len(tokenizer(entry, return_tensors="pt")["input_ids"][0]) > 7600:
        big.append(entry)
    total += 1

    if total % 1000 == 0:
        print(f"Processed {total} entries, {len(big)} big")

In [None]:
def generate_testing_prompt_gemma(
    notes: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
    return f"""<start_of_turn>user {system_prompt}

### Input:

{notes.strip()}

### Summary:

<end_of_turn>
<start_of_turn>model
""".strip()


formatted_data_test = data_test.apply(
    lambda row: generate_testing_prompt_gemma(row["notes"], DEFAULT_SYSTEM_PROMPT),
    axis=1,
)

In [None]:
big_test = []
total = 0
for entry in formatted_data_test:
    if len(tokenizer(entry, return_tensors="pt")["input_ids"][0]) > 7600:
        big_test.append(entry)
    total += 1

    if total % 1000 == 0:
        print(f"Processed {total} entries, {len(big_test)} big")

In [None]:
# Filter the data in order to remove the big entries

for i, row in data.iterrows():
    if (
        generate_training_prompt_gemma(
            row["notes"], row["summary"], DEFAULT_SYSTEM_PROMPT
        )
        in big
    ):
        data.drop(i, inplace=True)

for i, row in data_test.iterrows():
    if generate_testing_prompt_gemma(row["notes"], DEFAULT_SYSTEM_PROMPT) in big_test:
        data_test.drop(i, inplace=True)

len(data), len(data_test)

In [None]:
data.to_csv("./data/single-discharge-7.6k-train-formatted.csv", index=False)
data_test.to_csv("./data/single-discharge-7.6k-test-formatted.csv", index=False)