# Data Preparation Notebook

## Background

[HemOnc.org](https://hemonc.org/wiki/Main_Page) is the largest freely available medical wiki of interventions, regimens, and general information relevant to the fields of hematology and oncology. It is designed for easy use and intended for healthcare professionals.

For data professional, the hemonc team has released their [ontology](https://hemonc.org/wiki/Ontology) which is freely available for academic and non-commercial use via [HemOnc Dataverse](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/FPO4HB).

## About this Notebook

This Notebook processes the aforementioned Ontology released by Hemonc to produce a training dataset for finetuning an LLM to answer questions about Oncological Drugs. The methodology utilized is based on [GLaM: Fine-Tuning Large Language Models for Domain Knowledge Graph
Alignment via Neighborhood Partitioning and Generative Subgraph Encoding](https://arxiv.org/pdf/2402.06764)

### Requirements to run this notebook

This Notebook requires the following files from [HemOnc Dataverse](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/FPO4HB):
1) concept_stage.tab
2) concept_relationship.tab

The notebook also expects those files to be saved as .csv and to be renamed in the following way:
1) concept_stage.tab -> concept.csv
2) concept_relationship_stage.tab -> concept_relationship.csv

## Import packages

In [1]:
import pandas as pd
from collections import defaultdict
import anthropic
import pickle

## Load Files containing the tables

In [2]:
concept_relationship = pd.read_csv("../../data/hemonc/concept_relationship.csv") 
concept = pd.read_csv("../../data/hemonc/concept.csv")

## Data Preprocessing & Cleaning

### Input Data

1) concept.csv -> Table containing one row per Hemonc Concept. A Hemonc Concept is a discrete unitary piece of information, e.g a drug, a regimen, a diagnosis code etc.
2) concept_relationship.csv -> Table containing one row per relationship between Hemonc Concepts or between an Hemonc Concept and an external vocabulary(NDC/other codes).

### Processing & Cleaning description

The concept relationships table contains two types of relationships:
- relationships between Hemonc concepts.
- relationships between Hemonc concepts and NDC/other external vocabulary codes. In these scenarios the NDC/other external vocabulary codes do not reference back to any Hemonc Concept.

The set of relationships generate tree-like structures. Since the relationships can be reciprocal, we can have cycles within the tree.

The goal of the data processing & cleaning is to peform recursion of the knowledge tree to produce "graph embeddings". A graph embedding is a piece of text encoding the knowledge graph.

The graph embeddings will then be fed to an LLM to produce Q&A training data pairs.

To generate graph embeddings we will do the following:
1) Extract Neighborhood Subgraphs: For each node (concept) we will extract its k-hop neighborhood to capture the local structure and relationships.
   - We use $k=2$ in this notebook
2) Partition Large Subgraphs: If a subgraph exceeds a predefined node limit (N_max), we will partition it into smaller, manageable subgraphs to ensure they fit within the LLM's context window.
   - We use $N_{max}=100$ in this notebook
3) Generate Graph Embeddings: We will translate each subgraph into the textual representation described earlier.
4) Generate Q&A Pairs: For each graph embedding, we will create question and answer pairs to later train our model.

#### Clean-up Knowledge tree

In [3]:
# Ensure consistent data types for merging
concept['concept_code'] = concept['concept_code'].astype(str)
concept_relationship['concept_code_1'] = concept_relationship['concept_code_1'].astype(str)
concept_relationship['concept_code_2'] = concept_relationship['concept_code_2'].astype(str)

In [4]:
# Filter out invalid relationships
valid_relationships = concept_relationship[concept_relationship['invalid_reason'].isna()]

In [5]:
# Merge to add concept names and classes for concept_code_1
relationships = valid_relationships.merge(
    concept[['concept_code', 'concept_name', 'concept_class_id']],
    left_on='concept_code_1',
    right_on='concept_code',
    how='left'
).rename(columns={
    'concept_name': 'concept_name_1',
    'concept_class_id': 'concept_class_id_1'
}).drop(columns=['concept_code'])

# Merge to add concept names and classes for concept_code_2
relationships = relationships.merge(
    concept[['concept_code', 'concept_name', 'concept_class_id']],
    left_on='concept_code_2',
    right_on='concept_code',
    how='left'
).rename(columns={
    'concept_name': 'concept_name_2',
    'concept_class_id': 'concept_class_id_2'
}).drop(columns=['concept_code'])

# Fill missing concept names with concept codes
relationships['concept_name_2'] = relationships['concept_name_2'].combine_first(relationships['concept_code_2'])

# Select relevant columns for analysis
relationships = relationships[[
    'concept_code_1', 'vocabulary_id_1', 'concept_name_1', 'concept_class_id_1',
    'relationship_id', 'concept_code_2', 'vocabulary_id_2', 'concept_name_2', 'concept_class_id_2'
]]

# Ensure consistency in concept code data types
relationships['concept_code_1'] = relationships['concept_code_1'].astype(str)

# Exclude specific concept classes
exclude_classes = [
    'ReferenceDOI', 'PubMedCentralURL', 'Study Group', 
    'Duration', 'Author', 'Study', 'Reference'
]
relationships = relationships[~relationships['concept_class_id_1'].isin(exclude_classes)]

# Some Hemonc concepts are pointing to themselves
# This is likely due to human error of the curator of the ontology
# We will be removing such rows
index_of_rows_pointing_to_themselves = relationships[relationships['concept_code_1']==relationships['concept_code_2']].index
relationships = relationships.drop(index_of_rows_pointing_to_themselves)

# Drop duplicates
relationships = relationships.drop_duplicates()

#### Extract Neighborhood Subgraphs & Partition Large Subgraphs for Drug Concepts

In [6]:
def generate_subgraphs(relationships, k=2, N_max=100, target_nodes=None):
    """
    Generate subgraphs with a neighborhood of k hops and limit the size to N_max nodes.
    
    Parameters:
    - relationships: pd.DataFrame, the relationships table.
    - k: int, number of hops to consider.
    - N_max: int, maximum number of nodes in the subgraph.
    
    Returns:
    - subgraphs: dict, mapping each node to its subgraph.
    """
    # Build adjacency list for the graph
    adjacency_list = defaultdict(set)
    for _, row in relationships.iterrows():
        adjacency_list[row['concept_code_1']].add((row['concept_code_2'], row['relationship_id']))

    # If no target_nodes provided, use all unique nodes in the table
    if target_nodes is None:
        target_nodes = relationships['concept_code_1'].unique()

    # Generate subgraphs
    subgraphs = {}
    
    for node in target_nodes:
        if node not in adjacency_list:
            continue  # Skip if the node is not in the adjacency list

        # Perform BFS to gather neighbors up to k hops
        visited = set()
        queue = [(node, 0)]  # (current_node, current_depth)
        subgraph_edges = []
        
        while queue:
            current_node, depth = queue.pop(0)
            if depth > k or current_node in visited:
                continue
            
            visited.add(current_node)
            for neighbor, relation in adjacency_list[current_node]:
                subgraph_edges.append((current_node, relation, neighbor))
                if neighbor not in visited:
                    queue.append((neighbor, depth + 1))
        
        # Limit subgraph size to N_max nodes
        unique_nodes = {edge[2] for edge in subgraph_edges}  # Gather unique nodes (concept_code_2)
        if len(unique_nodes) > N_max:
            subgraph_edges = subgraph_edges[:N_max]  # Truncate to fit N_max
        
        # Store subgraph as DataFrame for convenience
        subgraphs[node] = pd.DataFrame(subgraph_edges, columns=['concept_code_1', 'relationship_id', 'concept_code_2'])
    
    return subgraphs

In [7]:
subgraphs = generate_subgraphs(relationships)

In [None]:
example_node = list(subgraphs.keys())[0]
print(f"Subgraph for node {example_node}:\n", subgraphs[example_node])

#### Generate Graph Embeddings

In [10]:
def encode_subgraph_to_text(subgraphs, concept):
    """
    Encode subgraphs into the format shown in the GLAM paper.
    
    Parameters:
    - subgraphs: dict, mapping each node to its subgraph DataFrame.
    - concept: pd.DataFrame, the hemonc concept table.

    Returns:
    - glam_encoded_texts: dict, mapping each node to its GLAM-formatted text representation.
    """
    glam_encoded_texts = {}

    # Create a dictionary to map concept codes to their names and attributes
    concept_details_df = concept[['concept_code', 'concept_name', 'concept_class_id']]
    concept_details = concept_details_df.set_index('concept_code').to_dict('index')

    for node, subgraph in subgraphs.items():
        if node not in concept_details:
            continue  # Skip nodes not in the concept details

        # Store concept name of node
        node_concept_name = concept_details[node]['concept_name']

        # Group targets of the same relationship
        relationship_groups = {}
        for _, row in subgraph.iterrows():
            relationship = row['relationship_id']
            target_name = concept_details.get(row['concept_code_2'], {}).get('concept_name', row['concept_code_2'])

            if relationship not in relationship_groups:
                relationship_groups[relationship] = []
            relationship_groups[relationship].append(target_name)
        
        # Generate per-relationship encoding
        grouped_sentences = []
        for relationship, targets in relationship_groups.items():
            target_list = ",".join(sorted(set(targets)))  # Combine and deduplicate targets
            grouped_sentences.append(f"{node_concept_name}, [{relationship}], {target_list}.")
        
        # Combine sentences into GLAM summary
        adjacency_list_summary = " ".join(grouped_sentences)
        glam_encoded_texts[node] = f"{adjacency_list_summary}"
    
    return glam_encoded_texts

In [11]:
# Encode subgraphs into text
encoded_texts = encode_subgraph_to_text(subgraphs, concept)

In [None]:
# Print example GLAM-formatted text for a node
example_node = list(encoded_texts.keys())[0]
print(f"GLAM Encoded Text for node {example_node}:\n", encoded_texts[example_node])

#### Use LLM to convert encodings into more coherent representations using summarization

In [12]:
def save_to_pickle(data, file_path):
    """
    Save data to a pickle file.

    Parameters:
    - data: The data to be saved.
    - file_path: The path to the pickle file (e.g., '*.pkl').
    """
    with open(file_path, 'wb') as file:
        pickle.dump(data, file)
    print(f"Data successfully saved to {file_path}")
    
def summarize_encodings(
        encoded_texts, 
        anthropic_client, 
        system_prompt="You are a medical oncology journal editor.",
        summarization_prompt="Given a sentence you respond with a concise and accurate rewritten version. Ensure human-readable names, reduce redundancy, and include synonyms or expanded terms where appropriate.",
        save_progress=True
    ):
    """
    Summarize GLAM-formatted adjacency list encodings using an LLM.

    Parameters:
    - encoded_texts: dict, mapping each node to its GLAM-formatted adjacency list representation.
    - anthropic_client: object, anthropic client.
    - system_prompt: str, the system prompt defining the agent role.
    - summarization_prompt: str, the prompt to guide the agent for summarization.

    Returns:
    - summarized_encodings: dict, mapping each node to its summarized representation.
    """
    summarized_encodings = {}

    for node, encoded_text in encoded_texts.items():
        print(f"Processing node: {node}...")
        # Construct the full prompt for the LLM
        input_text = f"{summarization_prompt}\n---\nSentence: {encoded_text}"

        try:
            response = anthropic_client.messages.create(
                model="claude-3-haiku-20240307", #cheapest model. For this task we can probably also use Llama models.
                max_tokens=2048,
                system=system_prompt,
                messages=[
                    {"role": "user", "content": input_text}
                ]
            )
            summarized_encodings[node] = response.content[0].text

            if save_progress:
                save_to_pickle(summarized_encodings,'../../data/intermediate/summarized_encodings.pkl')

        except Exception as e:
            print(f"Error processing node {node}: {e}")
            summarized_encodings[node] = None  # Handle errors gracefully

    return summarized_encodings

In [13]:
# Define Anthropic client
anthropic_client = anthropic.Anthropic()

In [None]:
# Summarize encoded texts
summarized_texts = summarize_encodings(encoded_texts, anthropic_client)

In [None]:
# Print the summarized encoding for an example node
example_node = list(summarized_texts.keys())[0]
print(f"Summarized encoding for node {example_node}:\n", summarized_texts[example_node])

## Save results

In [None]:
save_to_pickle(summarized_texts,'../../data/intermediate/summarized_texts.pkl')