<a href="https://colab.research.google.com/github/403errors/CancerCareAI/blob/main/CancerCareAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## AI-powered extraction of cancer-related patient data.

NOTE: MUST RUN IN **T4 GPU** BECAUSE THE BITSANDBYTES MODULE, USED FOR DOWNLOADING QUANTISED MODEL IS ONLY SUPPORTED REQUIRES GPU ⚠️

NOTE: IT IS POSSIBLE THAT THE MODEL DOESN'T LOAD IN ONE RUN, IF IT GIVES "BITSANDBYTES" RELATED ISSUES, FOLLOW BELOW STEPS:
- RUNTIME > RESTART SESSION
- RUNTIME > RUN ALL

DONE ✅



**Please checkout [GitHub Repo](https://github.com/403errors/CancerCareAI) for more details and demo videos of project.**

## Project Setup and Data Loading
This block handles the initial setup and data loading for the project. It performs the following key functions:

- Dependency Installation: Installs necessary Python libraries using pip.

  This includes:
  - sentence-transformers: For semantic similarity calculations.
rank_bm25: For implementing the BM25 ranking algorithm (keyword-based retrieval).
  - pandas: For data manipulation and creating DataFrames.
  - nltk: Natural Language Toolkit, used here for sentence tokenization.
  - bitsandbytes: For loading quantized large language models.
  - accelerate: A Hugging Face library for distributed training and inference.
  - optimum: Hugging Face library for optimizing models.

- Library Imports: Imports the necessary modules from the installed libraries. This includes modules for semantic search, BM25 ranking, data handling, and working with large language models.

- Data Loading (load_data_from_github function):
Fetches JSON data files from a specified GitHub repository.
  - Takes the repository URL and a list of filenames as input.
  - Uses the requests library to make HTTP GET requests to retrieve each file.
  - Uses response.raise_for_status() to handle potential HTTP errors (e.g., 404 Not Found).
  - Parses the JSON response using response.json().
  - Returns a dictionary where keys are filenames and values are the loaded JSON data.
  - NLTK Downloads: Downloads NLTK resources
    - 'punkt': for setence tokenization
    - 'punkt_tab': fallback tokeinzer

### pip installation

In [None]:
# Install necessary libraries (if not already installed in your Colab environment)
!pip install -q sentence-transformers rank_bm25 pandas nltk
!pip install -q --no-cache-dir bitsandbytes
!pip install -q accelerate optimum

### Imports

In [None]:
import re
import json
import requests
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
from rank_bm25 import BM25Okapi
import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import nltk
nltk.download('punkt')  # Download the Punkt sentence tokenizer
nltk.download('punkt_tab')  # It works instead of punkt

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

### Data Loading

In [None]:
def load_data_from_github(repo_url, filenames):
    """Loads JSON data files from a GitHub repository.

    Args:
        repo_url: Base URL of the GitHub repository's data directory.
        filenames: List of filenames to load.

    Returns:
        A dictionary where keys are filenames and values are the loaded JSON data.
    """
    data = {}
    for filename in filenames:
        file_url = f"{repo_url}/{filename}"
        response = requests.get(file_url)
        response.raise_for_status()  # Raise an exception for bad status codes
        data[filename] = response.json()
    return data

## Task 1 - Information Retrieval (Pipeline)
This block implements the information retrieval pipeline, combining keyword-based and semantic search techniques. It includes the following components:

- create_passages(patient_data):
Transforms the raw JSON data (loaded in the previous block) into a list of dictionaries.
  - Each dictionary represents a single sentence extracted from the docText field of the original data. Crucially, it uses nltk.sent_tokenize to split the text into sentences, rather than treating entire documents as passages.
  - Includes metadata: docTitle, docDate, and patient_file are preserved alongside the docText (now a single sentence). This metadata is important for presenting results and filtering.

- bm25_ranking(query, passages, tokenizer_bm25):
Implements the BM25 (Best Matching 25) ranking algorithm. BM25 is a classic information retrieval algorithm that ranks documents based on the frequency of query terms within each document, adjusted for document length and term frequency in the entire corpus.
  - Takes a query string, a list of passages (dictionaries with a "docText" key), and a tokenizer_bm25 function as input.
  - Uses the provided tokenizer_bm25 to tokenize both the query and the passages. A simple tokenizer that splits on spaces is used.
  - Creates a BM25Okapi object from the rank_bm25 library.
  - Calculates BM25 scores for the query against each passage.
  - Returns a list of (passage, score) tuples, sorted by score in descending order.

- semantic_search(query, passages, model_name="all-MiniLM-L6-v2"):
Performs semantic search using Sentence Transformers, a library for generating dense vector representations (embeddings) of text.
  - Takes a query string, a list of passages, and an optional model_name (defaulting to "all-MiniLM-L6-v2") as input.
  - Loads the specified Sentence Transformer model. "all-MiniLM-L6-v2" is a pre-trained model known for its good balance of speed and accuracy.
  - Generates embeddings for the query and all passages using model.encode().
  - Uses util.semantic_search() to efficiently find the passages with embeddings most similar to the query embedding (using cosine similarity).
  - Returns a list of (passage, score) tuples, sorted by similarity score in descending order.

- rerank_with_crossencoder(query, passages, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
  - Reranks a subset of passages using a CrossEncoder model, which is generally more accurate than the Bi-Encoder used in semantic_search but computationally more expensive.
  - Takes a query, a list of passages, and an optional model_name (defaulting to "cross-encoder/ms-marco-MiniLM-L-6-v2") as input.
  - Loads the specified CrossEncoder model. "cross-encoder/ms-marco-MiniLM-L-6-v2" is a model trained on the MS MARCO passage ranking dataset.
  - CrossEncoders take a (query, passage) pair as input and directly predict a relevance score. This is different from Bi-Encoders, which generate separate embeddings for the query and passage and then calculate similarity.
  - Calculates scores using model.predict().
  - Returns a list of (passage, score) tuples, sorted by score in descending order.

- filter_sentence(sentence):
Filters out the irrelevant sentences that contains adminstrative details.
Takes setence as an argument.
  - Discard sentences containing exclude_patterns using regex.
  - Return True if pattern not found.

- combined_retrieval(query, passages, bm25_weight=0.4, semantic_weight=0.3, crossencoder_weight=0.3):
This is the core function that orchestrates the entire retrieval pipeline.
Takes a query, a list of passages, and optional weights for each ranking method as input.
  - Calls bm25_ranking and semantic_search to get initial rankings.
  - Filters the results of both BM25 and semantic search, keeping only the top top_n (default 20) passages from each. This is a crucial optimization step. It avoids running the computationally expensive CrossEncoder on all passages.
  - Also applies filter_sentence() to remove administrative/irrelevant sentences.
  - Combines the top passages from BM25 and semantic search, removing duplicates by using a set to track unique docText values. This ensures that the same sentence isn't included multiple times.
  - Calls rerank_with_crossencoder to rerank the combined, filtered list of passages.
  - Normalizes the scores from each method (BM25, semantic search and CrossEncoder) to a range of 0 to 1. This is important because the raw scores from different methods are not directly comparable. Min-max normalization is used.
  - Combines the normalized scores using a weighted average: combined_score = bm25_score * bm25_weight + semantic_score * semantic_weight + crossencoder_score * crossencoder_weight. The weights allow for tuning the relative importance of each ranking method.
  - Returns a list of (passage, combined_score) tuples, sorted by the combined score in descending order. The passages retain all original metadata.

In [None]:
def create_passages(patient_data):
    """Creates a list of SENTENCES (passages) from the patient data."""
    passages = []
    for filename, patient_records in patient_data.items():
        for record in patient_records:
            # Split the docText into sentences using nltk.sent_tokenize
            sentences = nltk.sent_tokenize(record["docText"])
            for sentence in sentences:
                passages.append({
                    "docText": sentence,  # Now, docText is a single sentence
                    "docTitle": record["docTitle"],
                    "docDate": record["docDate"],
                    "patient_file": filename
                })
    return passages


def bm25_ranking(query, passages, tokenizer_bm25):
  """
    Ranks passages using BM25.

    Args:
      query: search query (String)
      passages: a list of dictionaries; dictionaries must contain the "docText" key
      tokenizer_bm25: A tokenizer suitable for BM25 (e.g., splitting on spaces).

    Returns:
      List of (passage, score) tuples, sorted by score (highest first).
  """
  tokenized_corpus = [tokenizer_bm25(p["docText"]) for p in passages]
  bm25_model = BM25Okapi(tokenized_corpus)
  tokenized_query = tokenizer_bm25(query)
  doc_scores = bm25_model.get_scores(tokenized_query)
  # Combine passages and scores
  passage_scores = list(zip(passages, doc_scores))
  # Sort by score (descending)
  passage_scores.sort(key=lambda x: x[1], reverse=True)
  return passage_scores


def semantic_search(query, passages, model_name="all-MiniLM-L6-v2"):
    """Performs semantic search using Sentence Transformers.

    Args:
        query: The search query.
        passages: A list of dictionaries, where each dictionary must contain at least "docText".
        model_name: The Sentence Transformer model to use.

    Returns:
        A list of (passage, score) tuples, sorted by similarity score (highest first).
    """
    model = SentenceTransformer(model_name)
    corpus_embeddings = model.encode([p["docText"] for p in passages], convert_to_tensor=True)
    query_embedding = model.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=10)[0] #top_k can be adjusted
    results = []
    for hit in hits:
        results.append((passages[hit['corpus_id']], hit['score']))
    return results

def rerank_with_crossencoder(query, passages, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
    """Reranks passages using a CrossEncoder model.

    Args:
      query: the search query.
      passages: A list of dictionaries; must contain the "docText" key.
      model_name: the CrossEncoder model to use.

    Returns:
      List of (passage, score) tuples, sorted by score (highest first).
    """
    model = CrossEncoder(model_name)
    scores = model.predict([(query, p["docText"]) for p in passages])
    passage_scores = list(zip(passages, scores))
    passage_scores.sort(key=lambda x: x[1], reverse=True)
    return passage_scores

def filter_sentence(sentence):
    """Filters out sentences that are likely administrative or irrelevant."""
    # List of patterns to exclude (case-insensitive)
    exclude_patterns = [
        r"Patient Name:",
        r"Date of Birth:",
        r"Location:",
        r"Purpose:",
        r"Follow-up:",
        r"Conclusion:",
        r"Next Steps:",
        r"Date of Visit:",
        r":", # Remove all the sentences that contain ":"
    ]
    for pattern in exclude_patterns:
        if re.search(pattern, sentence, re.IGNORECASE):
            return False  # Discard the sentence
    return True  # Keep the sentence

def combined_retrieval(query, passages, bm25_weight=0.4, semantic_weight=0.3, crossencoder_weight=0.3):
    """Combines BM25, semantic search, and cross-encoder reranking.
    Args:
        query:
        passages:
        bm25_weight:
        semantic_weight:
        crossencoder_weight:
    Returns:
       List of (passage, combined_score) tuples
    """
    # Simple tokenizer for BM25 (split on spaces)
    tokenizer_bm25 = lambda text: text.lower().split()

    # 1. BM25 Ranking
    bm25_results = bm25_ranking(query, passages, tokenizer_bm25)

    # 2. Semantic Search
    semantic_results = semantic_search(query, passages)

    # 3.  Filter to top N from BM25 and Semantic Search before Cross-Encoding
    top_n = 20  # Adjust as needed

    # --- Filtering before Cross-Encoding ---
    bm25_top_n = [passage for passage, _ in bm25_results[:top_n] if filter_sentence(passage["docText"])]
    semantic_top_n = [passage for passage, _ in semantic_results[:top_n] if filter_sentence(passage["docText"])]

    # Use a set to track unique docText values
    unique_doc_texts = set()
    combined_top_passages = []

    for passage in bm25_top_n + semantic_top_n:
        doc_text = passage["docText"]  # Extract the unique text identifier
        if doc_text not in unique_doc_texts:
            unique_doc_texts.add(doc_text)
            combined_top_passages.append(passage)  # Append the full passage dict

    # 4. Cross-Encoder Reranking (on the combined top passages)
    crossencoder_results = rerank_with_crossencoder(query, combined_top_passages)


    # 5. Normalize and Combine Scores (using a dictionary for easier lookup)
    def normalize_scores(results):
        if not results:
            return {}
        scores = [score for _, score in results]
        min_score = min(scores)
        max_score = max(scores)
        if max_score == min_score:  # Avoid division by zero
            return {passage["docText"]: 0.5 for passage, _ in results}  #Give them all a neutral score
        return {passage["docText"]: (score - min_score) / (max_score - min_score) for passage, score in results}

    bm25_scores = normalize_scores(bm25_results)
    semantic_scores = normalize_scores(semantic_results)
    crossencoder_scores = normalize_scores(crossencoder_results)

    # Combine (using docText as the key, since it's unique within the same query)
    combined_scores = {}
    for passage, _ in crossencoder_results:  # Iterate through crossencoder results as the base
        doc_text = passage["docText"]
        combined_score = (
            bm25_scores.get(doc_text, 0) * bm25_weight +  # Use .get() to handle missing keys
            semantic_scores.get(doc_text, 0) * semantic_weight +
            crossencoder_scores.get(doc_text, 1) * crossencoder_weight # crossencoder_weight default to 1 as it contains all.
        )
        combined_scores[doc_text] = combined_score

    # Convert back to a list of (passage, score) tuples, preserving passage data
    final_results = []
    for passage, _ in crossencoder_results: # We want the order from the cross-encoder
      if passage["docText"] in combined_scores: # This check should always pass
        final_results.append((passage, combined_scores[passage["docText"]]))

    return final_results

## Task 2 - Medical Data Extraction (LLM-based Pipeline)

This is where we use the Qwen 2.5-7B-Chat model to extract structured data. We'll create a function to generate the prompt and another to process the model's output

**Explanation:**

* setup_qwen_model: This function loads the Qwen model and tokenizer, applying the 4-bit quantization to reduce memory usage. This is the same code provided in the README, but encapsulated in a function for reusability. We also make sure to move the model and inputs to the correct device (GPU if available, otherwise CPU). Also, set pad_token to eos_token.

* generate_prompt: This function creates the prompt that will be fed to the LLM. It includes:
    * Clear Instructions: It tells the model its role ("medical information extraction expert") and what to extract.
    * Passage Context: It includes the passage_text.
    * Structured Output Format: It explicitly defines the JSON structure we want, including examples of each field. This is crucial for reliable JSON output.
    * Handling Null Values: The instructions clearly explain that if a particular data point can't be found, null should be used.

* extract_information: This function does the following:
    * Tokenization: It tokenizes the prompt using the Qwen tokenizer.
    * Inference: It calls model.generate to generate the output. We use:
        * max_new_tokens: Limits the length of the generated text.
        * do_sample=False: Uses greedy decoding (taking the most likely token at each step). This makes the output deterministic (same input always gives the same output).
        * temperature=0.1: We use a low temperature to make the model less "creative" and more likely to stick to the instructions.
        * top_k=5: Limits the model to consider only the top 5 most likely tokens at each step. This further reduces randomness.
        * with torch.no_grad(): Disables gradient calculation, saving memory and speeding up inference.
        * pad_token_id=tokenizer.eos_token_id:Set pad_token_id.

    * Decoding: It decodes the generated output using the tokenizer.

    * JSON Extraction: It extracts the JSON part from the output. This is the most critical part. We use find('{') and rfind('}') + 1 to locate the JSON object within the LLM's response, handling cases where the model might add extra text before or after the JSON.
    * Error handling: Robust error handling is included using try-except blocks to catch potential json.JSONDecodeError (if the output isn't valid JSON) or ValueError. This makes the code much more resilient. The raw LLM output is printed for debugging purposes.
* Deterministic Output: By setting do_sample=False, temperature, and top_k, we encourage the model to produce consistent, deterministic output, which is essential for reliable data extraction.



In [None]:
def setup_qwen_model():
    """Sets up the Qwen model, checking for CUDA and using 4-bit quantization if available."""
    model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    if device == "cuda":
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=quantization_config,
            device_map="auto",
            use_safetensors=True,
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )
    else:
        print("CPU usage requires significant RAM; quantization (GPU only) is recommended.")
        model = None  # Model will not be available

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer, device

In [None]:
def generate_prompt(passage_text):
    """Generates a structured prompt for extracting specific medical information in JSON format."""
    prompt = f"""<|im_start|>system
    You are a medical information extraction system. Extract structured data from patient EHR notes and return it as a **single JSON object** containing exactly two arrays: `diagnosis_characteristics` and `cancer_related_medications`. Do not include any other keys at the top level.

    ### **Guidelines:**
    - Output **ONLY** strictly valid JSON—no extra text, explanations, or remarks.
    - If information is absent, use `null` (not empty strings).
    - Maintain the exact JSON structure and formatting.

    #### **1. diagnosis_characteristics (array of dictionaries)**
    Each dictionary represents a primary cancer diagnosis and must include:
    - `primary_cancer_condition` (string): Type of cancer (e.g., "Breast Cancer").
    - `diagnosis_date` (string, "MM-DD-YYYY"): Earliest diagnosis date.
    - `histology` (array of strings): Tumor histological subtype(s).
    - `stage` (dictionary): TNM and group stage details:
        - `T` (string): Tumor size/extent.
        - `N` (string): Lymph node involvement.
        - `M` (string): Metastasis status.
        - `group_stage` (string): Overall stage (e.g., "Stage IIB").

    #### **2. cancer_related_medications (array of dictionaries)**
    Each dictionary represents a prescribed cancer-related medication:
    - `medication_name` (string): Name of the medication.
    - `start_date` (string, "MM-DD-YYYY"): Start date of the medication.
    - `end_date` (string, "MM-DD-YYYY" or `null` if ongoing): End date.
    - `intent` (string): Purpose of prescription.
    <|im_end|>

    <|im_start|>user

    ### **Example Input:**
    *"Patient was diagnosed with Stage IIB breast cancer, ER+, PR+, HER2-. Diagnosis date: 03-10-2024. Pathology showed invasive ductal carcinoma. Patient started on Letrozole 2.5mg daily starting 03-15-2024 as adjuvant therapy."*

    ### **Expected Output:**
    ```json
    {{
        "diagnosis_characteristics": [
            {{
                "primary_cancer_condition": "breast cancer",
                "diagnosis_date": "03-10-2024",
                "histology": ["invasive ductal carcinoma"],
                "stage": {{
                    "T": null,
                    "N": null,
                    "M": null,
                    "group_stage": "Stage IIB"
                }}
            }}
        ],
        "cancer_related_medications": [
            {{
                "medication_name": "Letrozole",
                "start_date": "03-15-2024",
                "end_date": null,
                "intent": "adjuvant therapy"
            }}
        ]
    }}
    ```


    Now, extract structured information from the following passage and output ONLY the JSON format:

    {passage_text}

    <|im_end|><|im_start|>assistant"""

    return prompt.strip()

In [None]:
def extract_information(model, tokenizer, device, passage, max_new_tokens=1024):
    """Extracts structured JSON information from a medical passage using a Qwen model."""

    prompt = generate_prompt(passage["docText"])
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    try:
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                temperature=0.2,
                top_k=5,
                repetition_penalty=1.0,
                pad_token_id=tokenizer.eos_token_id
            )

        decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # print(f"decoded_output: {decoded_output}")

        if "assistant" in decoded_output:
            assistant_response = decoded_output.split("assistant", 1)[1].strip()
        else:
            assistant_response = decoded_output

        # print(f"assistant_response: {assistant_response}")

        # Step 1: Extract JSON using a stricter regex pattern
        matches = re.findall(r'```json\s*\n?(.*?)\n?```', assistant_response, re.DOTALL)
        # print(f"Matches: {matches}")

        if not matches:
            print("No JSON block found in the text.")
            return None

        # # Step 2: Get the last extracted JSON block (assuming the final one is correct)
        json_str = matches[-1].strip()

        # Step 3: Remove trailing commas (fixes common JSON errors)
        json_str = re.sub(r',\s*([\]}])', r'\1', json_str)

        # # Step 4: Debug extracted JSON before parsing
        # print("Extracted JSON String:\n", json_str)

        # Step 5: Attempt JSON parsing
        try:
            return json.loads(json_str)
        except json.JSONDecodeError as json_err:
            print(f"JSON parsing error: {json_err}")
            print("Attempting secondary cleanup...")

        # Step 6: Handle partial JSON recovery if possible
        possible_json_lines = json_str.split("\n")  # Break into lines
        for i in range(len(possible_json_lines)):
            try:
                cleaned_json_str = "\n".join(possible_json_lines[i:])
                return json.loads(cleaned_json_str)  # Return first valid JSON found
            except json.JSONDecodeError:
                continue  # Try next variation

        print("Final JSON parsing failed.")
        return None

    except Exception as e:
        print(f"Error during generation or JSON parsing: {e}")
        return None

In [None]:
test_passage = {
    "docDate": "04-03-2025",
    "docTitle": "Pathology Report - Prostate Biopsy Findings",
    "docText": "Patient Name: Paul Henderson\nDate of Birth: 03/12/1958\n\nDate of Report: 04/03/2025\nPathologist: Dr. William Archer\n\nSpecimens:\nTwelve core biopsy samples from the prostate, labeled according to standard sextant mapping with additional targeted cores in the right peripheral zone.\n\nMicroscopic Examination:\n1. **Right Peripheral Zone Cores**: Malignant cells consistent with adenocarcinoma of the prostate. The glands are crowded and demonstrate prominent nucleoli. Based on the Gleason grading system, the dominant pattern is 3 and the secondary pattern is 4, yielding a Gleason score of 3+4 = 7.\n2. **Other Cores (Left Peripheral, Base, Apex)**: Several cores show benign prostatic hyperplasia and chronic inflammatory changes. No malignancy identified in these regions.\n\nTumor Characteristics:\n- In the malignant cores, the extent of involvement ranges from 30% to 60% of the tissue examined.\n- Perineural invasion is noted, commonly seen in prostate cancer but does not necessarily indicate extraprostatic extension.\n- No definitive evidence of high-grade (Gleason pattern 5) disease in the submitted samples.\n\nDiagnosis:\nProstatic adenocarcinoma, Gleason 7 (3+4), primarily involving the right peripheral zone.\n\nComments:\nA Gleason score of 7 (3+4) indicates an intermediate-grade prostate cancer. Further staging assessments, including imaging and serum markers, may help determine if the disease is organ-confined. Additional data such as PSA density or genomic tests could refine risk stratification. The presence of perineural invasion can correlate with a slightly higher likelihood of extraprostatic extension, but imaging is required to confirm.\n\nRecommendation:\nCorrelate these findings with clinical and radiologic staging. Consider discussing treatment options with the patient, which may include radical prostatectomy, radiation therapy, or potentially active surveillance if certain criteria are met (though many would treat Gleason 7 more definitively).\n\nSigned:\nDr. William Archer, MD\nDepartment of Pathology, Hillside Labs\nReport Date: 04/03/2025\n\nConclusion:\nMr. Henderson’s biopsy confirms prostate cancer in the right peripheral zone with an intermediate Gleason score. Multidisciplinary evaluation with urology, radiation oncology, and possibly medical oncology will be necessary to formulate an optimal treatment plan.\n"
  }

In [None]:
# model, tokenizer, device = setup_qwen_model()

In [None]:
# ans = extract_information(model, tokenizer, device, test_passage, max_new_tokens=1024)
# print(ans)

In [None]:
def merge_extractions(existing_data, new_data):
    """Merges newly extracted data with existing aggregated data."""

    if new_data is None:  # Handle cases where nothing was extracted
        return existing_data

    # --- Merge Diagnosis Characteristics ---
    if "diagnosis_characteristics" in new_data and isinstance(new_data["diagnosis_characteristics"], list):
        for new_diagnosis in new_data["diagnosis_characteristics"]:
            if new_diagnosis.get("primary_cancer_condition"):
                existing_diagnosis_found = False
                for existing_diagnosis in existing_data["diagnosis_characteristics"]:
                    if existing_diagnosis["primary_cancer_condition"] == new_diagnosis["primary_cancer_condition"]:
                        existing_diagnosis_found = True
                        # Update fields if new data is more specific or earlier
                        if new_diagnosis.get("diagnosis_date") and (not existing_diagnosis.get("diagnosis_date") or new_diagnosis["diagnosis_date"] < existing_diagnosis["diagnosis_date"]):
                            existing_diagnosis["diagnosis_date"] = new_diagnosis["diagnosis_date"]
                        if new_diagnosis.get("histology") and isinstance(new_diagnosis["histology"], list): # Check for list
                            existing_histology = existing_diagnosis.get("histology") or []  # Default to empty list
                            existing_diagnosis["histology"] = list(set(existing_histology + new_diagnosis["histology"]))

                        if new_diagnosis.get("stage"):
                            for key in ["T", "N", "M", "group_stage"]:
                                if new_diagnosis["stage"].get(key) is not None: # Only update if not null
                                    existing_diagnosis["stage"][key] = new_diagnosis["stage"][key]
                        break  # Stop after updating the matching condition
                if not existing_diagnosis_found:
                    existing_data["diagnosis_characteristics"].append(new_diagnosis)


    # --- Merge Medications ---
    if "cancer_related_medications" in new_data and isinstance(new_data["cancer_related_medications"], list):
      for new_med in new_data["cancer_related_medications"]:
          if new_med.get("medication_name"):  # Ensure medication name exists
              existing_med_found = False
              for existing_med in existing_data["cancer_related_medications"]:
                  if existing_med["medication_name"] == new_med["medication_name"]:
                      existing_med_found = True
                      # Update existing medication with new info
                      if new_med.get("start_date") and (not existing_med["start_date"] or
                                                        new_med["start_date"] < existing_med["start_date"]):
                          existing_med["start_date"] = new_med["start_date"]
                      if new_med.get("end_date") and (not existing_med["end_date"] or
                                                      new_med["end_date"] > existing_med["end_date"]):
                          existing_med["end_date"] = new_med["end_date"]

                      #Combine the existing and new intent
                      if new_med.get("intent"):
                          existing_med["intent"] = existing_med.get("intent", "") + " " + new_med.get("intent","")
                      break

              if not existing_med_found:
                  existing_data["cancer_related_medications"].append(new_med)

    return existing_data

## Putting it all Together (Main Execution Block)

This block is the main execution point of the script. It orchestrates the entire process, from data loading and user interaction to information retrieval/extraction and output.

- if __name__ == "__main__": Guard:
This standard Python construct ensures that the code within this block only runs when the script is executed directly (e.g., python cancercareai.py), not when it's imported as a module into another script.
- Data Loading and Error Handling:
  - Loads the patient data from the GitHub repository using the load_data_from_github function (defined in Block 1).
  - Includes a critical error check: if patient_data is None:. If data loading fails (e.g., due to network issues), the script prints an error message and exits gracefully, preventing subsequent code from crashing.
- Patient Selection (Interactive):
  - Presents a list of available patients to the user.
  - Extracts patient names from the docText field of the first record in each file using a regular expression (re.search). Handles cases where the name cannot be extracted.
  - Prompts the user to enter the name of the patient they want to process.
  - Handles cases where the entered name is not found or where multiple files match the name (allowing the user to choose a specific file).
  - Creates selected_patient_data, containing only the data for the selected patient. This improves efficiency by focusing on the relevant data.
- Mode Selection (Interactive):
  Prompts the user to choose between two modes:
  - Mode 1 (Query): Performs information retrieval based on a user-provided query (using the pipeline from Block 2).
  - Mode 2 (Medical Data Extraction): Extracts structured medical data using the LLM (using the pipeline from Block 3).
    - Checks if the model is available on the CPU.
- Mode 1: Query (Information Retrieval):
  - Prompts the user to enter a query.
  - Calls create_passages to convert the selected patient's data into a list of sentence-level passages.
  - Calls combined_retrieval (from Block 2) to perform the combined BM25 and semantic search, retrieving the most relevant passages.
  - Prints the top 5 retrieved sentences, along with their associated scores.
- Mode 2: Medical Data Extraction:
  - Initializes an empty dictionary aggregated_data to store the extracted information. This dictionary will accumulate results from multiple documents.
  - Iterates through each document (record) in the selected_patient_data. It processes entire documents, not individual sentences, for extraction.
  - For each document:
    - Creates a passage dictionary containing the docText, docTitle, docDate, and patient_file.
    - Calls extract_information (from Block 3) to extract structured data from the document using the LLM.
    - Calls merge_extractions to combine the newly extracted data with the aggregated_data. This function handles merging and deduplication of information across multiple documents.
    - Prints progress information (which document is being processed).
    - After processing all documents, prints the final aggregated_data in JSON format using json.dumps with an indent of 2 for readability.
    - Handles the case where no data was extracted, printing an informative message.
- Invalid Mode Handling:
  - If the user enters an invalid mode choice, prints an error message.

***Mode 2 (Medical Data Extraction) requires at least 6GB VRAM in GPU***

❌ CPU only - No GPU acceleration possible

✅ T4 GPU - Supports GPU acceleration

**NOTE: CUDA 12.x DOESN'T FULLY SUPPORT BITSANDBYTES SO USE CUDA 11.8 FOR GPU ACCELERATION, IF ANY ISSUE PERSISTS**

In [None]:
model, tokenizer, device = setup_qwen_model()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



In [None]:
if __name__ == "__main__":
    # --- Data Loading ---
    repo_url = "https://raw.githubusercontent.com/403errors/CancerCareAI/main/data"
    filenames = ["1.json", "2.json", "3.json"]
    patient_data = load_data_from_github(repo_url, filenames)
    if patient_data is None:
        print("Data loading failed. Exiting.")
        exit()

    # --- Patient Selection ---
    print("Available patients:")
    patient_names = {}
    for filename, records in patient_data.items():
        first_record = records[0]
        if "docText" in first_record:
            name_match = re.search(r"Patient Name:\s*([^\n]+)", first_record["docText"])
            if name_match:
                patient_name = name_match.group(1).strip()
                patient_names[filename] = patient_name
                print(f"- {patient_name} ({filename})")  # Display "Lisa Bowman (1.json)"
            else:
                print(f"- {filename} (Could not extract patient name)")
                patient_names[filename] = None
        else:
            print(f"- {filename} (Missing docText)")
            patient_names[filename] = None

    selected_patient = input("Enter the full name of the patient you want to process ie. Lisa Bowman: ")
    selected_file = None

    # Find the file associated with the selected patient name.  Handle potential duplicates.
    matching_files = [fname for fname, pname in patient_names.items() if pname == selected_patient]
    if not matching_files:
        print(f"Error: No patient found with the name '{selected_patient}'.")
        exit()
    elif len(matching_files) > 1:
        print("Multiple files found for this patient. Please select one:")
        for i, fname in enumerate(matching_files):
            print(f"{i+1}. {fname}")
        choice = int(input("Enter the number of the file: ")) - 1 # User enters 1, 2..
        selected_file = matching_files[choice]
    else:
        selected_file = matching_files[0]

    selected_patient_data = {selected_file: patient_data[selected_file]}

    # --- Mode Selection ---
    print("\nSelect a mode:")
    print("1. Query (Information Retrieval)")
    print("2. Medical Data Extraction")
    mode = input("Enter your choice (1 or 2): ")

    if model is None and mode == '2':
        print("Model is not available on CPU, Exiting.")
        exit()


    if mode == "1":
        # --- Query Mode (Task 1) - Sentence Level ---
        query = input("Enter your query: ")
        passages = create_passages(selected_patient_data)  # Create sentence-level passages
        retrieved_passages = combined_retrieval(query, passages)

        print("\nTop Retrieved Sentences:")  # Changed output label
        count = 1
        for passage, score in retrieved_passages[:5]:
            print(f"{count}. {passage['docText']}")  # Display the sentence
            count += 1

    elif mode == "2":
        # --- Medical Data Extraction Mode (Task 2) ---
        aggregated_data = {
            "diagnosis_characteristics": [],
            "cancer_related_medications": []
        }

        # Iterate through DOCUMENTS (not sentences)
        total_documents = len(selected_patient_data[selected_file])
        print(f"\nExtracting data from {total_documents} documents...")
        document_number = 1
        for record in selected_patient_data[selected_file]:
            passage = {
                "docText": record["docText"],
                "docTitle": record["docTitle"],
                "docDate": record["docDate"],
                "patient_file": selected_file
            }
            extracted_data = extract_information(model, tokenizer, device, passage)

            # Merge the extracted data into the aggregated data
            aggregated_data = merge_extractions(aggregated_data, extracted_data)

            print(f"Processed record: {document_number}/{total_documents}")
            document_number += 1


        if aggregated_data:
            print("\nExtracted Data (JSON):")
            print(json.dumps(aggregated_data, indent=2))
        else:
            print("No data extracted.")

    else:
        print("Invalid mode selected.")

Available patients:
- Lisa Bowman (1.json)
- John Whitfield (2.json)
- Paul Henderson (3.json)
Enter the full name of the patient you want to process ie. Lisa Bowman: John Whitfield

Select a mode:
1. Query (Information Retrieval)
2. Medical Data Extraction
Enter your choice (1 or 2): 1
Enter your query: Has the patient undergone chemotherapy?


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]


Top Retrieved Sentences:
1. The patient underwent similar pre-medications (dexamethasone, antiemetics, IV fluids).
2. An oncology consultation is advised to discuss chemotherapy regimens and the role of radiation if needed.
3. Appointment with Dr. Rebecca Olson (Medical Oncology) in 2 weeks to discuss adjuvant chemotherapy.
4. - Await final pathology results, which will guide the necessity of adjuvant chemotherapy and/or radiation therapy.
5. - Discuss the possibility of concurrent chemotherapy and radiation therapy if surgery is not an optimal route.
