<a href="https://colab.research.google.com/github/403errors/CancerCareAI/blob/main/CancerCareAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## AI-powered extraction of cancer-related patient data.

NOTE: MUST RUN IN **T4 GPU** BECAUSE THE BITSANDBYTES MODULE, USED FOR DOWNLOADING QUANTISED MODEL IS ONLY SUPPORTED REQUIRES GPU ⚠️

## Project Setup and Data Loading
First, we need to set up the Colab environment, install necessary libraries, and load the data from the GitHub repository

**Explanation:**

* `!pip install ...`: Installs the required libraries. sentence-transformers is for semantic similarity, rank_bm25 is for keyword-based ranking (BM25), and pandas for data manipulation. We use -q for quiet installation (less output).


* `load_data_from_github` function: This function fetches the JSON files directly from the GitHub repository using the requests library. It handles potential errors (like network issues) gracefully. response.raise_for_status() is crucial for error checking.

* `repo_url` and `filenames`: These variables store the location of the data.


* Error Handling: The try...except block ensures that the program doesn't crash if there's a problem loading the data.
Sample Data Structure: Printing a sample record helps you understand the structure of the data you're working with, confirming that the data loaded correctly.

In [None]:
# Install necessary libraries (if not already installed in your Colab environment)
!pip install -q sentence-transformers rank_bm25 pandas

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m86.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m71.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m44.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### Imports

In [None]:
import re
import json
import requests
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
from rank_bm25 import BM25Okapi
import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

### Data Loading

In [None]:
def load_data_from_github(repo_url, filenames):
    """Loads JSON data files from a GitHub repository.

    Args:
        repo_url: Base URL of the GitHub repository's data directory.
        filenames: List of filenames to load.

    Returns:
        A dictionary where keys are filenames and values are the loaded JSON data.
    """
    data = {}
    for filename in filenames:
        file_url = f"{repo_url}/{filename}"
        response = requests.get(file_url)
        response.raise_for_status()  # Raise an exception for bad status codes
        data[filename] = response.json()
    return data

# # GitHub repository URL and filenames
# repo_url = "https://raw.githubusercontent.com/403errors/CancerCareAI/main/data"
# filenames = ["1.json", "2.json", "3.json"]

# # Load the data
# try:
#     patient_data = load_data_from_github(repo_url, filenames)
#     print("Data loaded successfully!")
# except requests.exceptions.RequestException as e:
#     print(f"Error loading data: {e}")
#     patient_data = None  # Set to None to indicate loading failure

# # Display the structure of a sample record (if loading was successful)
# if patient_data:
#     print("\nSample Data Structure (first patient, first document):")
#     print(json.dumps(patient_data["1.json"][0], indent=4))

Data loaded successfully!

Sample Data Structure (first patient, first document):
{
    "docDate": "01-05-2020",
    "docTitle": "Initial Oncology Consultation",
    "docText": "Patient Name: Lisa Bowman\nDate of Birth: 11/23/1967\n\nChief Complaint:\nMrs. Lisa Bowman presents today for an initial evaluation of a right breast lump that was identified during her annual physical exam in late December 2019. She reports feeling a palpable mass in the upper outer quadrant of her right breast approximately four weeks ago, which prompted her primary care provider to order further imaging. A diagnostic mammogram and ultrasound performed on 12/28/2019 indicated a suspicious lesion measuring roughly 2.1 cm in diameter.\n\nHistory of Present Illness:\nMrs. Bowman, a 52-year-old female, has no prior history of malignancy. She reports intermittent tenderness in the right breast near the area of the lump, but denies any nipple discharge or changes in the overlying skin. There is no significant famil

## Task 1 - Information Retrieval (Pipeline)

Now, let's build the information retrieval pipeline. We'll use a combination of BM25 (keyword-based) and Sentence Transformers (semantic-based) for a robust approach


**Explanation:**

* `create_passages`: This function transforms the raw JSON data into a list of dictionaries, each representing a passage. Crucially, it includes the original docTitle, docDate, and patient_file along with the docText. This is important for later displaying results and for potential filtering in Task 2.


* `bm25_ranking`: Implements the BM25 algorithm. It uses a simple tokenizer (splitting on spaces). BM25 is good at finding documents that contain the query terms, even if they're not semantically similar.


* `semantic_search`: Uses Sentence Transformers to find passages that are semantically similar to the query. We use the "all-MiniLM-L6-v2" model, which is a good balance of speed and accuracy. You can explore other models.

* `rerank_with_crossencoder`: This uses a CrossEncoder, which is more accurate than the Bi-Encoder used in semantic_search, but slower. It takes the query and each passage as a pair and directly predicts a relevance score. We use "cross-encoder/ms-marco-MiniLM-L-6-v2", a model trained on the MS MARCO passage ranking dataset.

* `combined_retrieval`: This is the core of the retrieval pipeline. It combines the results of BM25, semantic search, and cross-encoder reranking:
    * It first runs BM25 and semantic search.
    * It takes the top N results from each of these methods and combines them into a single list (removing duplicates).
    * It then reranks this combined list using the CrossEncoder. This is more efficient than running the CrossEncoder on all passages.
    * Finally, it normalizes the scores from each method (to a 0-1 range) and combines them using weighted averaging. This allows you to tune the importance of each component.

* Normalization: Scores from different models aren't directly comparable. Normalization brings them to a common scale (0 to 1). The formula used is a standard min-max normalization.

* Weighted Combination: The bm25_weight, semantic_weight, and crossencoder_weight allow you to control the influence of each ranking method. You'll likely need to experiment with these weights to find the best combination for your data.

* Example Usage: Shows how to use the combined_retrieval function and print the top results.

In [None]:
def create_passages(patient_data):
    """
    Creates a list of passages from the patient data.  Each passage is a
    dictionary containing the document text, title, and date.
    """
    passages = []
    for filename, patient_records in patient_data.items():
        for record in patient_records:
            passages.append({
                "docText": record["docText"],
                "docTitle": record["docTitle"],
                "docDate": record["docDate"],
                "patient_file": filename  # Add the source file for later reference
            })
    return passages


def bm25_ranking(query, passages, tokenizer_bm25):
  """
    Ranks passages using BM25.

    Args:
      query: search query (String)
      passages: a list of dictionaries; dictionaries must contain the "docText" key
      tokenizer_bm25: A tokenizer suitable for BM25 (e.g., splitting on spaces).

    Returns:
      List of (passage, score) tuples, sorted by score (highest first).
  """
  tokenized_corpus = [tokenizer_bm25(p["docText"]) for p in passages]
  bm25_model = BM25Okapi(tokenized_corpus)
  tokenized_query = tokenizer_bm25(query)
  doc_scores = bm25_model.get_scores(tokenized_query)
  # Combine passages and scores
  passage_scores = list(zip(passages, doc_scores))
  # Sort by score (descending)
  passage_scores.sort(key=lambda x: x[1], reverse=True)
  return passage_scores


def semantic_search(query, passages, model_name="all-MiniLM-L6-v2"):
    """Performs semantic search using Sentence Transformers.

    Args:
        query: The search query.
        passages: A list of dictionaries, where each dictionary must contain at least "docText".
        model_name: The Sentence Transformer model to use.

    Returns:
        A list of (passage, score) tuples, sorted by similarity score (highest first).
    """
    model = SentenceTransformer(model_name)
    corpus_embeddings = model.encode([p["docText"] for p in passages], convert_to_tensor=True)
    query_embedding = model.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=10)[0] #top_k can be adjusted
    results = []
    for hit in hits:
        results.append((passages[hit['corpus_id']], hit['score']))
    return results

def rerank_with_crossencoder(query, passages, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
    """Reranks passages using a CrossEncoder model.

    Args:
      query: the search query.
      passages: A list of dictionaries; must contain the "docText" key.
      model_name: the CrossEncoder model to use.

    Returns:
      List of (passage, score) tuples, sorted by score (highest first).
    """
    model = CrossEncoder(model_name)
    scores = model.predict([(query, p["docText"]) for p in passages])
    passage_scores = list(zip(passages, scores))
    passage_scores.sort(key=lambda x: x[1], reverse=True)
    return passage_scores

def combined_retrieval(query, passages, bm25_weight=0.4, semantic_weight=0.3, crossencoder_weight=0.3):
    """Combines BM25, semantic search, and cross-encoder reranking.
    Args:
        query:
        passages:
        bm25_weight:
        semantic_weight:
        crossencoder_weight:
    Returns:
       List of (passage, combined_score) tuples
    """
    # Simple tokenizer for BM25 (split on spaces)
    tokenizer_bm25 = lambda text: text.lower().split()

    # 1. BM25 Ranking
    bm25_results = bm25_ranking(query, passages, tokenizer_bm25)

    # 2. Semantic Search
    semantic_results = semantic_search(query, passages)

    # 3.  Filter to top N from BM25 and Semantic Search before Cross-Encoding
    top_n = 20  # Adjust as needed
    bm25_top_n = [passage for passage, _ in bm25_results[:top_n]]
    semantic_top_n = [passage for passage, _ in semantic_results[:top_n]]

    # Use a set to track unique docText values
    unique_doc_texts = set()
    combined_top_passages = []

    for passage in bm25_top_n + semantic_top_n:
        doc_text = passage["docText"]  # Extract the unique text identifier
        if doc_text not in unique_doc_texts:
            unique_doc_texts.add(doc_text)
            combined_top_passages.append(passage)  # Append the full passage dict

    # 4. Cross-Encoder Reranking (on the combined top passages)
    crossencoder_results = rerank_with_crossencoder(query, combined_top_passages)


    # 5. Normalize and Combine Scores (using a dictionary for easier lookup)
    def normalize_scores(results):
        if not results:
            return {}
        scores = [score for _, score in results]
        min_score = min(scores)
        max_score = max(scores)
        if max_score == min_score:  # Avoid division by zero
            return {passage["docText"]: 0.5 for passage, _ in results}  #Give them all a neutral score
        return {passage["docText"]: (score - min_score) / (max_score - min_score) for passage, score in results}

    bm25_scores = normalize_scores(bm25_results)
    semantic_scores = normalize_scores(semantic_results)
    crossencoder_scores = normalize_scores(crossencoder_results)

    # Combine (using docText as the key, since it's unique within the same query)
    combined_scores = {}
    for passage, _ in crossencoder_results:  # Iterate through crossencoder results as the base
        doc_text = passage["docText"]
        combined_score = (
            bm25_scores.get(doc_text, 0) * bm25_weight +  # Use .get() to handle missing keys
            semantic_scores.get(doc_text, 0) * semantic_weight +
            crossencoder_scores.get(doc_text, 1) * crossencoder_weight # crossencoder_weight default to 1 as it contains all.
        )
        combined_scores[doc_text] = combined_score

    # Convert back to a list of (passage, score) tuples, preserving passage data
    final_results = []
    for passage, _ in crossencoder_results: # We want the order from the cross-encoder
      if passage["docText"] in combined_scores: # This check should always pass
        final_results.append((passage, combined_scores[passage["docText"]]))

    return final_results


# # Example Usage (assuming 'patient_data' is loaded)
# if patient_data:
#     passages = create_passages(patient_data)
#     query = "Has the patient undergone chemotherapy?"
#     retrieved_passages = combined_retrieval(query, passages)

#     print("\nTop Retrieved Passages:")
#     for passage, score in retrieved_passages[:5]:  # Display top 5
#         print(f"Score: {score:.3f}")
#         print(f"File: {passage['patient_file']}")
#         print(f"Title: {passage['docTitle']}")
#         print(f"Date: {passage['docDate']}")
#         print(f"Text: {passage['docText'][:200]}...")  # Show first 200 characters
#         print("-" * 50)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]


Top Retrieved Passages:
Score: 0.629
File: 1.json
Title: Chemotherapy Education and Baseline Echocardiogram
Date: 03-05-2020
Text: Patient Name: Lisa Bowman
Date of Birth: 11/23/1967

Date: 03/05/2020
Location: Outpatient Oncology Center

Chemotherapy Education Session:
Mrs. Bowman attended a one-hour educational session with the...
--------------------------------------------------
Score: 0.394
File: 1.json
Title: Medical Oncology Consultation – Adjuvant Therapy Planning
Date: 02-26-2020
Text: Patient Name: Lisa Bowman
Date of Birth: 11/23/1967

Visit Date: 02/26/2020
Attending Oncologist: Dr. Robert Chan, MD

Reason for Visit:
Mrs. Bowman presents for an initial medical oncology consultati...
--------------------------------------------------
Score: 0.465
File: 1.json
Title: Initiation of Paclitaxel and Trastuzumab – Cycle 1
Date: 06-02-2020
Text: Patient Name: Lisa Bowman
Date of Birth: 11/23/1967

Date of Infusion: 06/02/2020
Location: Outpatient Infusion Center

Clinical Backgrou

## Task 2 - Medical Data Extraction (LLM-based Pipeline)

This is where we use the Qwen 1.5-7B-Chat model to extract structured data. We'll create a function to generate the prompt and another to process the model's output

**Explanation:**

* setup_qwen_model: This function loads the Qwen model and tokenizer, applying the 4-bit quantization to reduce memory usage. This is the same code provided in the README, but encapsulated in a function for reusability. We also make sure to move the model and inputs to the correct device (GPU if available, otherwise CPU). Also, set pad_token to eos_token.

* generate_prompt: This function creates the prompt that will be fed to the LLM. It includes:
    * Clear Instructions: It tells the model its role ("medical information extraction expert") and what to extract.
    * Passage Context: It includes the passage_text.
    * Structured Output Format: It explicitly defines the JSON structure we want, including examples of each field. This is crucial for reliable JSON output.
    * Handling Null Values: The instructions clearly explain that if a particular data point can't be found, null should be used.

* extract_information: This function does the following:
    * Tokenization: It tokenizes the prompt using the Qwen tokenizer.
    * Inference: It calls model.generate to generate the output. We use:
        * max_new_tokens: Limits the length of the generated text.
        * do_sample=False: Uses greedy decoding (taking the most likely token at each step). This makes the output deterministic (same input always gives the same output).
        * temperature=0.1: We use a low temperature to make the model less "creative" and more likely to stick to the instructions.
        * top_k=5: Limits the model to consider only the top 5 most likely tokens at each step. This further reduces randomness.
        * with torch.no_grad(): Disables gradient calculation, saving memory and speeding up inference.
        * pad_token_id=tokenizer.eos_token_id:Set pad_token_id.

    * Decoding: It decodes the generated output using the tokenizer.

    * JSON Extraction: It extracts the JSON part from the output. This is the most critical part. We use find('{') and rfind('}') + 1 to locate the JSON object within the LLM's response, handling cases where the model might add extra text before or after the JSON.
    * Error handling: Robust error handling is included using try-except blocks to catch potential json.JSONDecodeError (if the output isn't valid JSON) or ValueError. This makes the code much more resilient. The raw LLM output is printed for debugging purposes.
* Deterministic Output: By setting do_sample=False, temperature, and top_k, we encourage the model to produce consistent, deterministic output, which is essential for reliable data extraction.

In [None]:
def setup_qwen_model():
    """Sets up the Qwen model, checking for CUDA and using 4-bit quantization if available."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    if device == "cuda":
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
        )
        model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen1.5-7B-Chat", use_safetensors=True,
            low_cpu_mem_usage=True, quantization_config=quantization_config, device_map=device
        )
    else:
        print("CPU usage requires >28GB RAM; quantization (GPU only) is recommended.")
        model = None # Model will not be available
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-7B-Chat")
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer, device


def extract_information(model, tokenizer, device, passage, max_new_tokens=1024):
    """Extracts information from a passage using the Qwen model.

    Args:
        model: The loaded Qwen model.
        tokenizer: The Qwen tokenizer.
        device: The device ("cuda" or "cpu").
        passage: A dictionary containing the passage text ("docText").
        max_new_tokens: Maximum number of tokens to generate.

    Returns:
        A JSON object containing the extracted information, or None if extraction fails.
    """
    prompt = generate_prompt(passage["docText"])
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad(): # Disable gradient calculation for inference
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,  # Use sampling
            temperature=0.7, # Increased temperature slightly
            top_p=0.95,      # Top-p sampling
            top_k=50,       # top k sampling parameters
            repetition_penalty=1.1,  # Add a repetition penalty
            pad_token_id=tokenizer.eos_token_id #Set pad_token_id
        )

    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # --- Improved JSON Extraction ---
    try:
        # More Robust JSON extraction using regex
        import re
        json_match = re.search(r"\{.*\}", decoded_output, re.DOTALL)
        if json_match:
            json_string = json_match.group(0)
            extracted_data = json.loads(json_string)
            return extracted_data
        else:
            raise ValueError("No JSON object found in output.")

    except (json.JSONDecodeError, ValueError) as e:
        print(f"Error extracting JSON: {e}")
        print(f"Raw LLM output:\n{decoded_output}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return None


def generate_prompt(passage_text):
    """Generates a prompt for the LLM to extract information. Improved version."""
    prompt = f"""<|im_start|>system
    You are a medical information extraction expert.  Given the following Electronic Health Record (EHR) passage, extract the following information in JSON format. If specific information is not available in the text, use `null` for the corresponding value.  Do *not* make up information. Return *only* the JSON, and nothing else.
    <|im_end|>
    <|im_start|>user
    **Passage:**
    {passage_text}

    **Instructions:**

    1.  **Diagnosis Characteristics:**
        *   `primary_cancer_condition`: The primary type of cancer (e.g., "Breast Cancer", "Lung Cancer").
        *   `diagnosis_date`: The earliest date the cancer was definitively diagnosed (MM-DD-YYYY).
        *   `histology`: A list of histological classifications (e.g., ["Adenocarcinoma", "Squamous Cell Carcinoma"]).
        *   `stage`:  An object containing TNM staging information:
            *   `T`: Tumor size/extent (e.g., "T2").
            *   `N`: Lymph node involvement (e.g., "N1").
            *   `M`: Metastasis (e.g., "M0").
            *   `group_stage`: The overall stage grouping (e.g., "Stage IIA").

    2.  **Cancer-Related Medications:**
        *   A list of objects, each representing a medication prescribed for *cancer treatment*.
        *   `medication_name`: The name of the medication (e.g., "Doxorubicin").
        *   `start_date`: The earliest date the medication was started (MM-DD-YYYY).
        *   `end_date`: The date the medication was stopped (MM-DD-YYYY).
        *   `intent`:  The reason for the medication (e.g., "Adjuvant therapy", "Neoadjuvant therapy").

    **Output (JSON Format):**
    ```json
    <|im_end|>
    <|im_start|>assistant
    """
    return prompt.strip()

## Putting it all Together (Main Execution Block)

This section combines all the previous steps to create the complete pipeline.


**Explanation:**

* if __name__ == "__main__":: This ensures that the code inside this block only runs when the script is executed directly (not when it's imported as a module).
* Data Loading Check: The if patient_data is None: check prevents the rest of the script from running if the data loading failed.
* Retrieve Passages: Calls the combined_retrieval function to get relevant passages.
* Setup Qwen Model: Loads the Qwen model and tokenizer.
* Iterate and Extract: Loops through the retrieved_passages. For each passage:
    * Calls extract_information to get the structured data.
    * If extraction is successful, it adds source information (filename, title, date, retrieval score) to the extracted data. This is helpful for tracing the origin of the extracted information.
    * Appends the result to all_extracted_data.
* Output: Prints the all_extracted_data in two formats:
    * Pandas DataFrame: This provides a nicely formatted table view, making it easy to inspect the extracted data.
    * JSON: This is the required output format specified in the README. json.dumps(..., indent=2) creates a pretty-printed JSON output.
* Handles Empty Extraction: The final else block handles the case where no data was extracted, providing informative output

**Checking if GPU acceleration is possible **

* CPU only - No GPU acceleration possible
* T4 GPU - Supports GPU acceleration

**NOTE: CUDA 12.x DOESN'T FULLY SUPPORT BITSANDBYTES SO USE CUDA 11.8 FOR GPU ACCELERATION, IF ANY ISSUE PERSISTS**

In [None]:
if __name__ == "__main__":
    # --- Data Loading ---
    repo_url = "https://raw.githubusercontent.com/403errors/CancerCareAI/main/data"
    filenames = ["1.json", "2.json", "3.json"]
    patient_data = load_data_from_github(repo_url, filenames)
    if patient_data is None:
        print("Data loading failed. Exiting.")
        exit()

    # --- Patient Selection ---
    print("Available patients:")
    patient_names = {}
    for filename, records in patient_data.items():
        first_record = records[0]
        if "docText" in first_record:
            name_match = re.search(r"Patient Name:\s*([^\n]+)", first_record["docText"])
            if name_match:
                patient_name = name_match.group(1).strip()
                patient_names[filename] = patient_name
                print(f"- {patient_name} ({filename})")  # Display "Lisa Bowman (1.json)"
            else:
                print(f"- {filename} (Could not extract patient name)")
                patient_names[filename] = None
        else:
            print(f"- {filename} (Missing docText)")
            patient_names[filename] = None

    selected_patient = input("Enter the name of the patient you want to process: ")
    selected_file = None

    # Find the file associated with the selected patient name.  Handle potential duplicates.
    matching_files = [fname for fname, pname in patient_names.items() if pname == selected_patient]
    if not matching_files:
        print(f"Error: No patient found with the name '{selected_patient}'.")
        exit()
    elif len(matching_files) > 1:
        print("Multiple files found for this patient.  Please select one:")
        for i, fname in enumerate(matching_files):
            print(f"{i+1}. {fname}")
        choice = int(input("Enter the number of the file: ")) - 1 # User enters 1, 2..
        selected_file = matching_files[choice]
    else:
        selected_file = matching_files[0]

    selected_patient_data = {selected_file: patient_data[selected_file]}

    # --- Mode Selection ---
    print("\nSelect a mode:")
    print("1. Query (Information Retrieval)")
    print("2. Medical Data Extraction")
    mode = input("Enter your choice (1 or 2): ")

    model, tokenizer, device = setup_qwen_model()  # Initialize the model *once*
    if model is None and mode == '2':
        print("Model is not available on CPU, Exiting.")
        exit()

    if mode == "1":
        # --- Query Mode (Task 1) ---
        query = input("Enter your query: ")
        passages = create_passages(selected_patient_data)
        retrieved_passages = combined_retrieval(query, passages)

        print("\nTop Retrieved Passages:")
        for passage, score in retrieved_passages[:5]:
            print(f"Score: {score:.3f}")
            print(f"File: {passage['patient_file']}")
            print(f"Title: {passage['docTitle']}")
            print(f"Date: {passage['docDate']}")
            print(f"Text: {passage['docText'][:200]}...")
            print("-" * 50)

    elif mode == "2":
        # --- Medical Data Extraction Mode (Task 2) ---
        all_extracted_data = []
        passages = create_passages(selected_patient_data) # Create passages of selected patient
        for passage in passages: #Iterate through all passages
          extracted_data = extract_information(model, tokenizer, device, passage)
          if extracted_data:
              extracted_data["source_file"] = passage["patient_file"]
              extracted_data["source_title"] = passage["docTitle"]
              extracted_data["source_date"] = passage["docDate"]
              # No retrieval score in this mode
              all_extracted_data.append(extracted_data)

        if all_extracted_data:
            df = pd.DataFrame(all_extracted_data)
            print("\nExtracted Data (DataFrame):")
            print(df)
            print("\nExtracted Data (JSON):")
            print(json.dumps(all_extracted_data, indent=2))
        else:
            print("No data extracted.")

    else:
        print("Invalid mode selected.")

In [None]:
# !python -c "import transformers; print(transformers.__version__)"
# !python -c "import torch; print(torch.__version__)"
# !python -c "import bitsandbytes; print(bitsandbytes.__version__)"

In [None]:
# from transformers.utils import is_bitsandbytes_available
# print("Is BitsAndBytes Available for Transformers:", is_bitsandbytes_available())

In [None]:
# if __name__ == "__main__":
#     # --- Step 1: Data Loading (already done above) ---

#     if patient_data is None:  # Check if data loading failed
#         print("Data loading failed. Exiting.")
#         exit()

#     # --- Step 2: Information Retrieval ---
#     passages = create_passages(patient_data)
#     query = "Has the patient undergone chemotherapy?"  # Example query
#     retrieved_passages = combined_retrieval(query, passages)

#     # --- Step 3: Information Extraction ---
#     model, tokenizer, device = setup_qwen_model()

#     all_extracted_data = []
#     for passage, score in retrieved_passages:
#         extracted_data = extract_information(model, tokenizer, device, passage)
#         if extracted_data:
#             # Add source information to the extracted data
#             extracted_data["source_file"] = passage["patient_file"]
#             extracted_data["source_title"] = passage["docTitle"]
#             extracted_data["source_date"] = passage["docDate"]
#             extracted_data["retrieval_score"] = score

#             all_extracted_data.append(extracted_data)

#     # --- Step 4: Output and Display ---

#     # Convert to Pandas DataFrame for easier viewing and analysis
#     if all_extracted_data:
#       df = pd.DataFrame(all_extracted_data)
#       print("\nExtracted Data (DataFrame):")
#       print(df)

#       # Print as JSON
#       print("\nExtracted Data (JSON):")
#       print(json.dumps(all_extracted_data, indent=2))
#     else:
#        print("No data extracted.")