# Dataset Exploration

This notebook is split into 3 main parts, related to 3 main stages of development

1. General Data Exploration - The initial part was dedicated to gain an overall understanding of the data and how it is structured.
2. Common Headers - The second part was written to extract the most common headers, so zero-shot models can have a structure to work with.
3. Additional Constrains - As the dataset was giving Out Of Memory errors, I had to reduce the dataset to a smaller size, and this part was written to figure out a cutoff point that would be reasonable.
4. Final Data Exploration - After reducing the dataset, I conduct a more in-depth exploration of the data


### Part 1 - General Data Exploration


In [None]:
import pandas as pd

In [None]:
data = pd.read_csv("./data/NOTEEVENTS.csv")

In [None]:
categories = data.groupby("CATEGORY").size().reset_index().rename(columns={0: "count"})
categories

In [None]:
discharge_summaries = data[data["CATEGORY"] == "Discharge summary"]
discharge_summaries = discharge_summaries[
    discharge_summaries["DESCRIPTION"] == "Report"
]
discharge_summaries = discharge_summaries[
    discharge_summaries["TEXT"].map(len) < 16000
]  # 16000 is the approximately the context window for GPT-3.5

discharge_summaries["TEXT"].map(len).hist(bins=100)
sample = discharge_summaries[discharge_summaries["TEXT"].map(len) == 5000]

# Using print to format the output
print(sample.iloc[0]["TEXT"])

In [None]:
random_patient = data.sample()["SUBJECT_ID"]
random_patient = 99082  # Fix a patient for reproducibility

notes = data[data["SUBJECT_ID"] == random_patient]
notes

In [None]:
summary = notes[notes["CATEGORY"] == "Discharge summary"]
print(summary.iloc[0]["TEXT"])

### Part 2 - Common Headers


In [None]:
# Figure out what the most common headings are in the discharge summaries
data = pd.read_csv("./data/single-discharge-8k.csv")
data = data[data["CATEGORY"] == "Discharge summary"]

In [None]:
import re

headings = {}

regex = r"^.+:\s"

regex = re.compile(regex, re.MULTILINE)

for text in data["TEXT"]:
    text = text.lower()
    matches = regex.findall(text)
    for match in matches:
        match = re.sub(r":\s", ":", match)
        if match not in headings:
            headings[match] = 0
        headings[match] += 1

headings, len(headings)

In [None]:
# Sort by the most common headings and show the top 20

sorted_headings = sorted(headings.items(), key=lambda x: x[1], reverse=True)

# We eliminate the first because it is standard to all discharge summaries
sorted_headings[1:20]

### Part 3 - Additional Constrains


In [None]:
data = pd.read_csv("./data/single-discharge-8k-test-formatted.csv")

In [None]:
# Get the sizes of the notes

data["notes"].map(len).hist(bins=100)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-1.1-7b-it",
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)

In [None]:
DEFAULT_SYSTEM_PROMPT = """
You are an expert clinical assistant. You will receive a collection of clinical notes. You will summarize them in the style of a discharge summary.
""".strip()


def generate_testing_prompt_gemma(
    notes: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
    return f"""<start_of_turn>user {system_prompt}

### Input:

{notes.strip()}

<end_of_turn>
<start_of_turn>model
""".strip()


tokens = data["notes"].map(generate_testing_prompt_gemma)
tokens = tokens.map(tokenizer.tokenize)

In [None]:
biggest = tokens.map(len).idxmax()
len(tokens[biggest])

### Part 4 - Final Data Exploration


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import nltk

sns.set_theme(style="whitegrid")
%matplotlib inline

#### Information about note size and distribution on the dataset and per admission


In [None]:
# Load all the notes in each format (original and ours)
data_original = pd.read_csv("./data/NOTEEVENTS.csv")
data_train = pd.read_csv("./data/single-discharge-7.6k-train-formatted.csv")
data_test = pd.read_csv("./data/single-discharge-7.6k-test-formatted.csv")
data_together = pd.concat([data_train, data_test])

del data_train, data_test

# Add our dataset but before aggregation
all = pd.read_csv("./data/single-discharge-all.csv")

# Get only the notes from all that are in our dataset
admissions = data_together["admission"].unique().tolist()

all_admissions = all[all["HADM_ID"].isin(admissions)]

In [None]:
# Calculate length of notes over the entire dataset
together_notes_len = data_together["notes"].map(nltk.word_tokenize).map(len)
together_summary_len = data_together["summary"].map(nltk.word_tokenize).map(len)

data_together_len = pd.DataFrame(
    {"notes": together_notes_len, "summary": together_summary_len}
)

del together_notes_len, together_summary_len

In [None]:
# Calculate length of notes over our dataset
all_admissions["len"] = all_admissions["TEXT"].map(nltk.word_tokenize).map(len)
all_admissions["sen"] = all_admissions["TEXT"].map(nltk.sent_tokenize).map(len)

In [None]:
# Repeat for MIMIC-III dataset
data_original["len"] = data_original["TEXT"].map(nltk.word_tokenize).map(len)
data_original["sen"] = data_original["TEXT"].map(nltk.sent_tokenize).map(len)

In [None]:
# Separate MIMIC-III notes by category
data_original_notes = data_original[data_original["CATEGORY"] != "Discharge summary"]
data_original_summary = data_original[data_original["CATEGORY"] == "Discharge summary"]

In [None]:
# Combine notes and summary to plot their size together
ax = sns.histplot(
    data=data_together_len,
    multiple="dodge",
    binwidth=100,
)
ax.set(ylabel="Amount of Documents", xlabel="")
ax.grid(axis="y")

In [None]:
# Repeat for MIMIC-III dataset
data_original_len = pd.DataFrame(
    {"notes": data_original_notes["len"], "summary": data_original_summary["len"]}
)
ax = sns.histplot(
    data=data_original_len,
    multiple="dodge",
)
ax.set(ylabel="Amount of Documents", xlabel="")
ax.grid(axis="x")

In [None]:
# Plot percentage of notes by category in our dataset
plt.xticks(rotation=90)
order = all_admissions["CATEGORY"].value_counts().sort_values(ascending=False).index
ax = sns.countplot(
    data=all_admissions,
    x="CATEGORY",
    order=order,
    stat="percent",
)
ax.set(ylabel="", xlabel="")

In [None]:
# Repeat for MIMIC-III dataset
plt.xticks(rotation=90)
order = data_original["CATEGORY"].value_counts().sort_values(ascending=False).index
ax = sns.countplot(
    data=data_original,
    x="CATEGORY",
    order=order,
    stat="percent",
)
ax.set(ylabel="", xlabel="")

In [None]:
# Make a boxplot of type and number of notes per admission
notes_per_category_all = (
    all_admissions.groupby(["CATEGORY", "HADM_ID"])
    .size()
    .reset_index(name="note_count")
)
plt.xticks(rotation=90)
ax = sns.boxplot(
    data=notes_per_category_all,
    x="CATEGORY",
    y="note_count",
)
ax.set(xlabel="", ylabel="")

In [None]:
# Repeat for MIMIC-III dataset
notes_per_category_original = (
    data_original.groupby(["CATEGORY", "HADM_ID"]).size().reset_index(name="note_count")
)
plt.xticks(rotation=90)
ax = sns.boxplot(
    data=notes_per_category_original,
    x="CATEGORY",
    y="note_count",
    log_scale=True,
)
ax.set(xlabel="", ylabel="")

In [None]:
# Plot the average amount of words per note type on our dataset
plt.xticks(rotation=90)
avg_len_together = all_admissions.groupby("CATEGORY")["len"].mean().reset_index()
ax = sns.barplot(
    data=avg_len_together,
    x="CATEGORY",
    y="len",
)
ax.bar_label(ax.containers[0], fontsize=10, fmt="%.0f")
ax.set(ylabel="", xlabel="")

In [None]:
# Repeat for MIMIC-III dataset
plt.xticks(rotation=90)
ax = sns.barplot(
    data=data_original,
    x="CATEGORY",
    y="len",
    errorbar=None,
)
ax.bar_label(ax.containers[0], fontsize=10, fmt="%.0f")
ax.set(ylabel="", xlabel="")

In [None]:
# Plot the average amount of sentences per note type on our dataset
avg_sen_together = all_admissions.groupby("CATEGORY")["sen"].mean().reset_index()
plt.xticks(rotation=90)
ax = sns.barplot(data=avg_sen_together, x="CATEGORY", y="sen")
ax.bar_label(ax.containers[0], fontsize=10, fmt="%.0f")
ax.set(ylabel="", xlabel="")

In [None]:
# Repeat for MIMIC-III dataset
plt.xticks(rotation=90)
ax = sns.barplot(data=data_original, x="CATEGORY", y="sen", errorbar=None)
ax.bar_label(ax.containers[0], fontsize=10, fmt="%.0f")
ax.set(ylabel="", xlabel="")

#### Determine most common words in the datasets


In [None]:
# Get collection frequency of words in the notes
import string

nltk.download("wordnet")
nltk.download("stopwords")

stop_words = set(nltk.corpus.stopwords.words("english"))
lemmatizer = nltk.stem.WordNetLemmatizer()


def process_text(text):
    text = text.lower()

    text = "".join([char for char in text if char not in string.punctuation])

    text = " ".join([word for word in text.split() if word not in stop_words])

    text = " ".join(lemmatizer.lemmatize(word) for word in text.split())

    return text


def calculate_cf(data):
    cf = {}
    for entry in data:
        if entry == "":
            continue
        for word in entry.split(" "):
            if word in cf:
                cf[word] += 1
            else:
                cf[word] = 1
    return cf


notes_words = data_together["notes"].apply(process_text)
summary_words = data_together["summary"].apply(process_text)

notes_cf = calculate_cf(notes_words)
summary_cf = calculate_cf(summary_words)

notes_cf = {
    k: v for k, v in sorted(notes_cf.items(), key=lambda item: item[1], reverse=True)
}
summary_cf = {
    k: v for k, v in sorted(summary_cf.items(), key=lambda item: item[1], reverse=True)
}

list(notes_cf.items())[:10], list(summary_cf.items())[:10]

In [None]:
from sklearn.feature_extraction.text import CountVectorizer

notes_words = notes_words.tolist()

count_vectorizer = CountVectorizer(binary=True)
document_term_matrix = count_vectorizer.fit_transform(notes_words)
notes_frequency = document_term_matrix.sum(axis=0)

notes_frequency = pd.Series(
    notes_frequency.A1, index=count_vectorizer.get_feature_names_out()
)
notes_frequency = notes_frequency.sort_values(ascending=False)

summary_words = summary_words.tolist()

count_vectorizer = CountVectorizer(binary=True)
document_term_matrix = count_vectorizer.fit_transform(summary_words)
summary_frequency = document_term_matrix.sum(axis=0)

summary_frequency = pd.Series(
    summary_frequency.A1, index=count_vectorizer.get_feature_names_out()
)

summary_frequency = summary_frequency.sort_values(ascending=False)

notes_frequency[:10], summary_frequency[:10]

#### Determine information about the dates used (not useful since they are random dates)


In [None]:
all_admissions["year"] = all_admissions["CHARTDATE"].map(lambda x: x.split("-")[0])
count_per_year = all_admissions.groupby("year").size().reset_index(name="count")

In [None]:
plt.figure(figsize=(26, 6))
plt.xticks(rotation=90)
sns.lineplot(data=count_per_year, x="year", y="count", marker="o")