In [None]:
import json

file_path = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\data_clean_2\data_clean\questions\US\test.jsonl"

def read_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            json_obj = json.loads(line.strip())
            data.append(json_obj)
    return data

# Read the JSONL file
jsonl_data = read_jsonl(file_path)

# Print the first few items to verify the content
for item in jsonl_data[:5]:  # Print first 5 items
    print(item)

In [None]:
qa_object_keys = []
for qa_object in jsonl_data:
    for qa_object_key, qa_object_value in qa_object.items():
        if qa_object_key not in qa_object_keys:
            qa_object_keys.append(qa_object_key)

print(qa_object_keys)

# Meta-info

## LLM-based text classification

In [None]:
#% pip install pandas pydantic openai tenacity openpyxl tiktoken langchain langchain-community
import json
import os
import logging
from typing import Dict, List
from pydantic import BaseModel, Field
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain.schema import HumanMessage
from langchain.callbacks import get_openai_callback
from tenacity import retry, stop_after_attempt, wait_random_exponential
import tiktoken

# Constants
NUMBER_OF_RETRY = 2
MODEL_NAME = "gpt-4o"
SAVE_INTERVAL = 5

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# System prompt to guide the AI
SYSTEM_PROMPT = (
    "You are an expert in medical education and classification. Your task is to categorize medical questions accurately "
    "according to their relevant systems, disciplines, specialties, and subspecialties, as well as classify them according "
    "to Bloom's taxonomy. Ensure your responses are concise, accurate, and align with the descriptions provided."
)

# Set your OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY_Nar_Amir_MedQA")
if openai_api_key is None:
    raise ValueError("OPENAI_API_KEY environment variable is not set or invalid.")

# Initialize the OpenAI chat model using LangChain
chat = ChatOpenAI(model_name=MODEL_NAME, openai_api_key=openai_api_key)

class QuestionData(BaseModel):
    question: str
    answer: str
    options: Dict[str, str]
    meta_info: str
    answer_idx: str

class BloomTaxonomy(BaseModel):
    bloom_taxonomy: str = Field(
        description=(
            "Classify the cognitive level required to answer the question according to Bloom's Taxonomy. "
            "Select one of the following levels based on the cognitive process involved:\n"
            
            "'Remembering': Focuses on recalling specific medical facts or basic concepts. "
            "Example: 'What is the mechanism of action of beta-blockers?'\n"
            
            "'Understanding': Involves explaining medical concepts or interpreting clinical data. "
            "Example: 'How would you explain the pathophysiology of type 2 diabetes?'\n"
            
            "'Applying': Requires using medical knowledge in clinical scenarios or solving clinical problems. "
            "Example: 'What is the appropriate initial treatment for a patient presenting with acute myocardial infarction?'\n"
            
            "'Analyzing': Involves breaking down clinical cases or understanding relationships between medical concepts. "
            "Example: 'How would you differentiate between the clinical features of Crohn's disease and ulcerative colitis?'\n"
            
            "'Evaluating': Focuses on making clinical judgments or assessing the validity of diagnostic or treatment options. "
            "Example: 'Which of the following criteria best supports the diagnosis of metabolic syndrome in a patient?'\n"
            
            "'Creating': Involves generating new ideas or designing solutions to complex medical problems. "
            "Example: 'What innovative treatment strategy would you propose for a patient with refractory hypertension?'"
        )
    )

class QuestionLabels(BaseModel):
    meta_info_TopicSystem: str = Field(
        description=(
            "The primary system or topic classification of the question within the context of USMLE topics. "
            "This should identify the broader category of medical knowledge that the question falls under. "
            "Examples include 'Cardiovascular System', 'Reproductive & Endocrine Systems', 'Gastrointestinal System'."
        )
    )
    meta_info_TopicDiscipline: str = Field(
        description=(
            "The specific discipline or subject area that this question relates to. This should be a focused category "
            "that falls under the broader system. For Step 1 questions, examples include 'Pharmacology', 'Genetics', "
            "'Histology & Cell Biology'. For Step 2 questions, this could relate to 'Diagnosis', 'Investigation', 'Treatment'."
        )
    )
    meta_info_Speciality: str = Field(
        description=(
            "The major class or field of study to which the question belongs. This should categorize the question "
            "at a high level, distinguishing between fields such as 'Medicine', 'Dentistry', 'Pharmacology', etc."
        )
    )
    meta_info_SubSpeciality: str = Field(
        description=(
            "The specific subspecialty within the major class that is relevant to the question. This should be a more "
            "granular classification such as 'Gastroenterology', 'Orthodontics', 'Surgery', etc."
        )
    )

class LabeledQuestionData(BaseModel):
    question: str
    answer: str
    options: Dict[str, str]
    meta_info: str
    answer_idx: str
    meta_info_TopicSystem: str
    meta_info_TopicDiscipline: str
    meta_info_Speciality: str
    meta_info_SubSpeciality: str
    meta_info_BloomTaxonomy: str
    meta_info_TokenLength: Dict[str, int] = Field(
        description="A dictionary containing the token length for the question and each option."
    )

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(NUMBER_OF_RETRY))
def classify_bloom_taxonomy(data: QuestionData) -> BloomTaxonomy:
    """
    Classify the given question according to Bloom's taxonomy.
    """
    bloom_parser = PydanticOutputParser(pydantic_object=BloomTaxonomy)

    bloom_prompt = PromptTemplate(
        template=(
            f"{SYSTEM_PROMPT}\n\n"
            "Classify the following medical question according to Bloom's taxonomy.\n{format_instructions}\n"
            "Question: {question}\nAnswer: {answer}\nOptions: {options}"
        ),
        input_variables=["question", "answer", "options"],
        partial_variables={"format_instructions": bloom_parser.get_format_instructions()},
    )

    query = bloom_prompt.format(
        question=data.question,
        answer=data.answer,
        options=", ".join([f"{k}: {v}" for k, v in data.options.items()])
    )
    
    with get_openai_callback() as cb:
        response = chat([HumanMessage(content=query)])
        logging.info(f"Cost for Bloom Taxonomy classification: ${cb.total_cost:.6f}")
        return bloom_parser.parse(response.content)

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(NUMBER_OF_RETRY))
def generate_labels(data: QuestionData) -> QuestionLabels:
    """
    Generate topic and discipline labels for the given question.
    """
    label_parser = PydanticOutputParser(pydantic_object=QuestionLabels)

    label_prompt = PromptTemplate(
        template=(
            f"{SYSTEM_PROMPT}\n\n"
            "Provide the best classification for the following categories for this medical question.\n"
            "{format_instructions}\nQuestion: {question}\nAnswer: {answer}\nOptions: {options}"
        ),
        input_variables=["question", "answer", "options"],
        partial_variables={"format_instructions": label_parser.get_format_instructions()},
    )

    query = label_prompt.format(
        question=data.question,
        answer=data.answer,
        options=", ".join([f"{k}: {v}" for k, v in data.options.items()])
    )

    with get_openai_callback() as cb:
        response = chat([HumanMessage(content=query)])
        logging.info(f"Cost for generating labels: ${cb.total_cost:.6f}")
        return label_parser.parse(response.content)

def count_tokens(data: QuestionData) -> Dict[str, int]:
    """
    Count tokens for the question and each option using tiktoken.
    """
    encoding = tiktoken.encoding_for_model(MODEL_NAME)

    token_counts = {
        "question_Tokenlength": len(encoding.encode(data.question))
    }
    token_counts.update({
        f"option_{key}_Tokenlength": len(encoding.encode(value))
        for key, value in data.options.items()
    })

    return token_counts

def add_labels_classification_and_tokens(data: QuestionData) -> LabeledQuestionData:
    """
    Combine labeling, Bloom taxonomy classification, and token counting for a question.
    """
    labels = generate_labels(data)
    bloom_taxonomy = classify_bloom_taxonomy(data)
    token_lengths = count_tokens(data)
    
    return LabeledQuestionData(
        **data.dict(),
        meta_info_TopicSystem=labels.meta_info_TopicSystem,
        meta_info_TopicDiscipline=labels.meta_info_TopicDiscipline,
        meta_info_Speciality=labels.meta_info_Speciality,
        meta_info_SubSpeciality=labels.meta_info_SubSpeciality,
        meta_info_BloomTaxonomy=bloom_taxonomy.bloom_taxonomy,
        meta_info_TokenLength=token_lengths
    )

def read_existing_data(output_path: str) -> List[Dict]:
    """
    Read existing data from the output file.
    """
    if not os.path.exists(output_path):
        return []

    with open(output_path, 'r', encoding='utf-8') as file:
        return [json.loads(line.strip()) for line in file]

def save_labeled_data(labeled_data: List[Dict], output_path: str):
    """
    Save labeled data to the output file.
    """
    with open(output_path, 'w', encoding='utf-8') as file:
        for item in labeled_data:
            json.dump(item, file, ensure_ascii=False)
            file.write('\n')
    logging.info(f"Saved {len(labeled_data)} labeled items to {output_path}")

def process_labeling(input_path: str, output_path: str, N: int, continue_from_existing: bool = True):
    """
    Process and label multiple questions from the input file, skipping previously labeled instances.
    Includes original data in case of errors.
    """
    # Load existing data if continuing
    labeled_data = read_existing_data(output_path) if continue_from_existing else []

    # Create a set of already processed question IDs for faster lookup
    processed_ids = set(item['question'] for item in labeled_data)

    # Ensure the output directory exists
    output_dir = os.path.dirname(output_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Read new data and process
    with open(input_path, 'r', encoding='utf-8') as file:
        newly_processed_count = 0
        total_processed_count = len(labeled_data)
        for line in file:
            if newly_processed_count >= N:
                break
            try:
                json_obj = json.loads(line.strip())
                question_data = QuestionData(**json_obj)
                
                # Skip if this question has already been processed
                if question_data.question in processed_ids:
                    continue

                logging.info(f"Started Labeling of #{total_processed_count + 1}: {question_data}")
                labeled_question_data = add_labels_classification_and_tokens(question_data)
                logging.info(f"Finished Labeling of #{total_processed_count + 1}: {labeled_question_data}")

                labeled_data.append(labeled_question_data.dict())
                processed_ids.add(question_data.question)
                newly_processed_count += 1
                total_processed_count += 1

                # Save data every SAVE_INTERVAL runs or on the last run
                if newly_processed_count % SAVE_INTERVAL == 0 or newly_processed_count == N:
                    save_labeled_data(labeled_data, output_path)

            except json.JSONDecodeError:
                logging.error(f"Error decoding JSON on line {total_processed_count + 1}")
                logging.error(f"Original data: {line.strip()}")
            except Exception as e:
                logging.error(f"Error processing line {total_processed_count + 1}: {str(e)}")
                logging.error(f"Original data: {json_obj}")
                
                # Include the original data in the labeled_data
                error_entry = {
                    "error": str(e),
                    "original_data": json_obj
                }
                labeled_data.append(error_entry)
                total_processed_count += 1
                
                # Save data after each error
                save_labeled_data(labeled_data, output_path)

    logging.info(f"Labeling, Bloom taxonomy classification, and token counting completed. "
                 f"Newly processed {newly_processed_count} items. "
                 f"Total processed {total_processed_count} items. "
                 f"Data saved to {output_path}")
    
    
# if __name__ == "__main__":
#     import argparse

#     parser = argparse.ArgumentParser(description="Process and label medical questions.")
#     parser.add_argument("input_path", help="Path to the input JSONL file")
#     parser.add_argument("output_path", help="Path to the output JSONL file")
#     parser.add_argument("--n", type=int, default=2, help="Number of questions to process")
#     parser.add_argument("--continue_from_existing", action="store_true", help="Continue from existing output file")

#     args = parser.parse_args()

#     process_labeling(args.input_path, args.output_path, args.n, args.continue_from_existing)

input_path = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\data_clean_2\data_clean\questions\US\test.jsonl"
output_path = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens.jsonl"
n=1
continue_from_existing=True
process_labeling(input_path, output_path, n, continue_from_existing)

## Unify Labels

### Generating Unified Terms for meta_info

In [None]:
import json
from collections import defaultdict

def extract_categories_and_terms(jsonl_file):
    categories = defaultdict(set)
    
    with open(jsonl_file, 'r') as file:
        for line in file:
            record = json.loads(line)
            for key, value in record.items():
                if key.startswith('meta_info_') and key != 'meta_info_TokenLength':
                    category = key.split('_')[-1]
                    categories[category].add(value)
    
    # Convert sets to sorted lists for consistent output
    return {k: sorted(v) for k, v in categories.items()}

# Usage
jsonl_file = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens.jsonl"  # Replace with your actual file name
extracted_categories = extract_categories_and_terms(jsonl_file)

print(json.dumps(extracted_categories, indent=2))

In [None]:
import json
from collections import defaultdict
import os
from typing import List, Dict
from pydantic import BaseModel, Field
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain.cache import SQLiteCache
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set up caching
from langchain.globals import set_llm_cache
set_llm_cache(SQLiteCache(database_path=".langchain.db"))

# Pydantic models (unchanged)
class TopicSystemTerm(BaseModel):
    term: str = Field(description="The unified USMLE system-based term")
    reasoning: str = Field(description="Explanation for this unification")

class SpecialityTerm(BaseModel):
    term: str = Field(description="The unified speciality term")
    category: str = Field(description="Category of the speciality (e.g., Medical Practitioner, Dental Student)")
    reasoning: str = Field(description="Explanation for this classification")

class SubSpecialityTerm(BaseModel):
    term: str = Field(description="The unified subspeciality term")
    rotation: str = Field(description="The medical rotation this subspeciality is associated with")
    reasoning: str = Field(description="Explanation for this unification and rotation assignment")

class TopicDisciplineTerm(BaseModel):
    term: str = Field(description="The unified topic discipline term")
    reasoning: str = Field(description="Explanation for this unification")

class UnifiedTermsTopicSystem(BaseModel):
    unified_terms: List[TopicSystemTerm] = Field(description="List of unified TopicSystem terms")

class UnifiedTermsSpeciality(BaseModel):
    unified_terms: List[SpecialityTerm] = Field(description="List of unified Speciality terms")

class UnifiedTermsSubSpeciality(BaseModel):
    unified_terms: List[SubSpecialityTerm] = Field(description="List of unified SubSpeciality terms")

class UnifiedTermsTopicDiscipline(BaseModel):
    unified_terms: List[TopicDisciplineTerm] = Field(description="List of unified TopicDiscipline terms")

def extract_categories_and_terms(jsonl_file: str) -> Dict[str, List[str]]:
    categories = defaultdict(set)
    
    with open(jsonl_file, 'r') as file:
        for line in file:
            record = json.loads(line)
            for key, value in record.items():
                if key.startswith('meta_info_') and key not in ('meta_info_TokenLength', 'meta_info_BloomTaxonomy'):
                    category = key.split('_')[-1]
                    categories[category].add(value)
    
    return {k: sorted(v) for k, v in categories.items()}

def get_prompt_template(category: str) -> str:
    templates = {
        "TopicSystem": """
        Given the following list of terms for the category 'TopicSystem', create a unified list based on USMLE system classifications. Provide granular system-based labels that align with USMLE categories. Explain your reasoning for each unification.

        Potential USMLE system classifications include:
        - General Principles
        - Behavioral Health
        - Biostatistics & Epidemiology
        - Biochemistry
        - Cardiovascular System
        - Endocrine System
        - Gastrointestinal System
        - Hematologic System
        - Immune System
        - Musculoskeletal System & Skin
        - Nervous System & Special Senses
        - Renal & Urinary System
        - Reproductive System
        - Respiratory System

        Terms: {terms}

        {format_instructions}
        """,
        "Speciality": """
        Given the following list of terms for the category 'Speciality', classify each term into one of these categories: Medical Practitioner, Dental Student, Pharmacy Student, or General Biomedical Student. If a term doesn't fit these categories, choose the most appropriate one or create a new category if necessary. Explain your reasoning for each classification.

        Potential categories include:
        - Medical Practitioner
        - Dental Student
        - Pharmacy Student
        - General Biomedical Student
        - Nursing Student
        - Allied Health Professional
        - Public Health Professional

        Terms: {terms}

        {format_instructions}
        """,
        "SubSpeciality": """
        Given the following list of terms for the category 'SubSpeciality', create a unified list of granular medical subspecialties. Consider which rotation a medical student would encounter this topic in. Be as specific as possible while maintaining clinical relevance. Explain your reasoning for each unification.

        Potential subspecialities and rotations include:
        - Internal Medicine (e.g., Cardiology, Gastroenterology, Pulmonology)
        - Surgery (e.g., General Surgery, Orthopedics, Neurosurgery)
        - Pediatrics
        - Obstetrics and Gynecology
        - Psychiatry
        - Neurology
        - Emergency Medicine
        - Radiology
        - Anesthesiology
        - Pathology
        - Family Medicine
        - Dermatology
        - Ophthalmology
        - Otolaryngology
        - Urology
        - Physical Medicine and Rehabilitation

        Terms: {terms}

        {format_instructions}
        """,
        "TopicDiscipline": """
        Given the following list of terms for the category 'TopicDiscipline', create a unified list where similar terms are consolidated into a single, representative term. Provide your reasoning for each consolidation.

        Potential topic disciplines include:
        - Anatomy
        - Physiology
        - Pathology
        - Pharmacology
        - Microbiology
        - Immunology
        - Genetics
        - Biochemistry
        - Epidemiology
        - Biostatistics
        - Medical Ethics
        - Clinical Skills
        - Diagnosis
        - Treatment
        - Prevention

        Terms: {terms}

        {format_instructions}
        """
    }
    
    return templates.get(category, """
        Given the following list of terms for the category '{category}', create a unified list where similar terms are consolidated into a single, representative term. Provide your reasoning for each consolidation.

        Terms: {terms}

        {format_instructions}
        """)

def get_unified_terms(category: str, terms: List[str]) -> Dict:
    llm = ChatOpenAI(model="gpt-4", temperature=0.7)
    
    if category == "TopicSystem":
        parser = PydanticOutputParser(pydantic_object=UnifiedTermsTopicSystem)
    elif category == "Speciality":
        parser = PydanticOutputParser(pydantic_object=UnifiedTermsSpeciality)
    elif category == "SubSpeciality":
        parser = PydanticOutputParser(pydantic_object=UnifiedTermsSubSpeciality)
    else:
        parser = PydanticOutputParser(pydantic_object=UnifiedTermsTopicDiscipline)
    
    prompt = ChatPromptTemplate(
        messages=[
            HumanMessagePromptTemplate.from_template(get_prompt_template(category))
        ],
        input_variables=["category", "terms"],
        partial_variables={"format_instructions": parser.get_format_instructions()}
    )

    chain = LLMChain(llm=llm, prompt=prompt)

    max_attempts = 3
    for attempt in range(max_attempts):
        try:
            response = chain.run(category=category, terms=", ".join(terms))
            parsed_response = parser.parse(response)
            return parsed_response.dict()
        except Exception as e:
            logging.warning(f"Attempt {attempt + 1} for {category}: Error: {str(e)}. Retrying...")
    
    logging.error(f"Failed to get valid output for {category} after {max_attempts} attempts. Returning empty result.")
    return {"unified_terms": []}

def process_category(category: str, terms: List[str]) -> Dict[str, Dict]:
    logging.info(f"Processing category: {category}")
    return {category: get_unified_terms(category, terms)}

def main():
    jsonl_file = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens.jsonl"
    categories = extract_categories_and_terms(jsonl_file)

    unified_terms = {}
    with ThreadPoolExecutor() as executor:
        future_to_category = {executor.submit(process_category, category, terms): category for category, terms in categories.items()}
        for future in as_completed(future_to_category):
            category = future_to_category[future]
            try:
                result = future.result()
                unified_terms.update(result)
            except Exception as exc:
                logging.error(f"{category} generated an exception: {exc}")

    with open('unified_terms.json', 'w') as f:
        json.dump(unified_terms, f, indent=2)

    logging.info("Unified terms have been saved to unified_terms.json")

if __name__ == "__main__":
    main()

### create map_terms (original-label:unified-term)

In [None]:
import json
import os
from typing import Dict, List
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain

# Load the unified terms
with open(r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\unified_terms.json", "r") as f:
    unified_terms = json.load(f)

# Initialize the LLM
llm = ChatOpenAI(model="gpt-4o", temperature=0.2)

def get_mapping(category: str, term: str) -> str:
    prompt = ChatPromptTemplate.from_template(
        "Given the following unified terms for the category '{category}':\n"
        "{unified_terms}\n\n"
        "What is the most appropriate unified term for '{term}'? "
        "Respond with only the unified term, nothing else."
    )
    
    chain = LLMChain(llm=llm, prompt=prompt)
    
    response = chain.run(category=category, unified_terms=json.dumps(unified_terms[category], indent=2), term=term)
    return response.strip()

# Create the mapping
map_terms = {}
for category in ["TopicSystem", "TopicDiscipline", "Speciality", "SubSpeciality"]:
    map_terms[category] = {}
    
    # Read the JSONL file to get all unique terms for each category
    unique_terms = set()
    with open(r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens.jsonl", "r") as f:
        for line in f:
            data = json.loads(line)
            term = data.get(f"meta_info_{category}")
            if term:
                unique_terms.add(term)
    
    # Get mapping for each unique term
    for term in unique_terms:
        map_terms[category][term] = get_mapping(category, term)

# Save the mapping
with open("map_terms.json", "w") as f:
    json.dump(map_terms, f, indent=2)

## Change values in jsonl using map_terms#

In [None]:
import json
from pprint import pprint

# Load the JSON file
with open(r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\map_terms_revised.json", "r") as f:
    map_terms = json.load(f)

# Extract unique values for each category
unique_terms = {}
for category, terms in map_terms.items():
    unique_terms[category] = set(terms.values())

# Print the unique terms for each category
for category, terms in unique_terms.items():
    print(f"\nUnique terms for {category}:")
    pprint(sorted(terms))

In [None]:
# Update the original JSONL file
input_file = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens.jsonl"
output_file = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens_unifiedterms.jsonl"

with open(input_file, "r") as in_f, open(output_file, "w") as out_f:
    for line in in_f:
        data = json.loads(line)
        for category in ["TopicSystem", "TopicDiscipline", "Speciality", "SubSpeciality"]:
            key = f"meta_info_{category}"
            if key in data:
                data[key] = map_terms[category].get(data[key], data[key])
        json.dump(data, out_f)
        out_f.write("\n")

print("Processing complete. Updated JSONL file saved as US-test_labeled_with_bloom_and_tokens_updated.jsonl")

# Translate

In [None]:
import json
import os
from typing import Dict, List
from pydantic import BaseModel
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.callbacks import get_openai_callback
from tenacity import retry, stop_after_attempt, wait_random_exponential

# Set your OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY_Nar_Amir_MedQA")
if openai_api_key is None:
    raise ValueError("OPENAI_API_KEY environment variable is not set or invalid.")

# Initialize the OpenAI chat model using LangChain
chat = ChatOpenAI(model_name="gpt-4o", openai_api_key=openai_api_key)

class LabeledQuestionData(BaseModel):
    question: str
    answer: str
    options: Dict[str, str]
    meta_info: str
    answer_idx: str
    meta_info_TopicSystem: str
    meta_info_TopicDiscipline: str
    meta_info_Speciality: str
    meta_info_SubSpeciality: str
    meta_info_BloomTaxonomy: str
    meta_info_TokenLength: Dict[str, int]

class TranslatedLabeledQuestionData(BaseModel):
    question: str
    question_persian: str
    answer: str
    answer_persian: str
    options: Dict[str, str]
    options_persian: Dict[str, str]
    meta_info: str
    answer_idx: str
    meta_info_TopicSystem: str
    meta_info_TopicDiscipline: str
    meta_info_Speciality: str
    meta_info_SubSpeciality: str
    meta_info_BloomTaxonomy: str
    meta_info_TokenLength: Dict[str, int]
    total_tokens: int
    total_cost: float

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def translate_to_persian(text: str) -> (str, int, float):
    try:
        system_prompt = SystemMessagePromptTemplate.from_template(
            "You are an English to Persian translator with deep knowledge of biomedical terms."
        )
        human_prompt = HumanMessagePromptTemplate.from_template(
            "Translate the following text to Persian: {text}"
        )
        chat_prompt = ChatPromptTemplate.from_messages([system_prompt, human_prompt])

        messages = chat_prompt.format_messages(text=text)
        with get_openai_callback() as cb:
            response = chat(messages)
            token_usage = cb.total_tokens
            cost = cb.total_cost

        return response.content.strip(), token_usage, cost
    except Exception as e:
        print(f"Error during translation: {e}")
        raise

def process_translation(input_path: str, output_path: str, N: int, continue_translation: bool = False, checkpoint_interval: int = 5):
    translated_data = []
    cumulative_tokens = 0
    cumulative_cost = 0.0
    start_index = 0

    # Ensure the output directory exists
    output_dir = os.path.dirname(output_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # If continuing translation, read existing data and set the start index
    if continue_translation and os.path.exists(output_path):
        with open(output_path, 'r', encoding='utf-8') as file:
            for line in file:
                translated_data.append(json.loads(line.strip()))
            start_index = len(translated_data)
            if translated_data:
                cumulative_tokens = translated_data[-1]['total_tokens']
                cumulative_cost = translated_data[-1]['total_cost']
        print(f"Continuing translation from index {start_index}")
    else:
        print("Starting new translation")

    def save_checkpoint(data, mode='a'):
        with open(output_path, mode, encoding='utf-8') as file:
            for item in data:
                json.dump(item, file, ensure_ascii=False)
                file.write('\n')
        print(f"Checkpoint saved. Total items: {len(translated_data)}")

    with open(input_path, 'r', encoding='utf-8') as file:
        for i, line in enumerate(file):
            if i < start_index:
                continue
            if i >= start_index + N:
                break
            json_obj = json.loads(line.strip())
            try:
                labeled_data = LabeledQuestionData(**json_obj)
                print(f"Started Translation of #{i+1}: {labeled_data}")

                # Translate the question, answer, and options to Persian
                translated_question, tokens_q, cost_q = translate_to_persian(labeled_data.question)
                cumulative_tokens += tokens_q
                cumulative_cost += cost_q

                translated_answer, tokens_a, cost_a = translate_to_persian(labeled_data.answer)
                cumulative_tokens += tokens_a
                cumulative_cost += cost_a

                translated_options = {}
                for key, value in labeled_data.options.items():
                    translated_opt, tokens_opt, cost_opt = translate_to_persian(value)
                    translated_options[key] = translated_opt
                    cumulative_tokens += tokens_opt
                    cumulative_cost += cost_opt

                translated_labeled_data = TranslatedLabeledQuestionData(
                    question=labeled_data.question,
                    question_persian=translated_question,
                    answer=labeled_data.answer,
                    answer_persian=translated_answer,
                    options=labeled_data.options,
                    options_persian=translated_options,
                    meta_info=labeled_data.meta_info,
                    answer_idx=labeled_data.answer_idx,
                    meta_info_TopicSystem=labeled_data.meta_info_TopicSystem,
                    meta_info_TopicDiscipline=labeled_data.meta_info_TopicDiscipline,
                    meta_info_Speciality=labeled_data.meta_info_Speciality,
                    meta_info_SubSpeciality=labeled_data.meta_info_SubSpeciality,
                    meta_info_BloomTaxonomy=labeled_data.meta_info_BloomTaxonomy,
                    meta_info_TokenLength=labeled_data.meta_info_TokenLength,
                    total_tokens=cumulative_tokens,
                    total_cost=cumulative_cost
                )

                print(f"Finished Translation of #{i+1}: {translated_labeled_data}")
                translated_data.append(translated_labeled_data.dict())

                # Save checkpoint after every 'checkpoint_interval' translations
                if (i + 1 - start_index) % checkpoint_interval == 0:
                    save_checkpoint(translated_data[-checkpoint_interval:])

            except Exception as e:
                print(f"Skipping line due to error: {e}")

    # Save any remaining data
    if len(translated_data) % checkpoint_interval != 0:
        save_checkpoint(translated_data[-(len(translated_data) % checkpoint_interval):])

    print(f"Translation completed. Translated {len(translated_data) - start_index} new items.")
    print(f"Total items in output file: {len(translated_data)}")
    print(f"Total tokens used: {cumulative_tokens}")
    print(f"Total cost: ${cumulative_cost:.6f}")
    print(f"Data saved to {output_path}")

# Example usage
labeled_file_path = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens_unifiedterms.jsonl"
output_path = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens_unifiedterms-translated.jsonl"
N = 194  # Adjust this to the number of objects you want to translate
continue_translation = True  # Set this to True if you want to continue from a previous run
checkpoint_interval = 5  # Save after every 5 translations

process_translation(labeled_file_path, output_path, N, continue_translation, checkpoint_interval)

# Fgiure Draft (before revise)

In [None]:
import json
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns

# File path
file_path = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens_unifiedterms - beforeReview.jsonl"

# Initialize counters
counters = {
    "meta_info_TopicSystem": Counter(),
    "meta_info_TopicDiscipline": Counter(),
    "meta_info_Speciality": Counter(),
    "meta_info_SubSpeciality": Counter(),
    "meta_info": Counter()
}

# Process the file
with open(file_path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line)
        for key in counters.keys():
            if key in data:
                counters[key][data[key]] += 1

# Print the counts
for key, counter in counters.items():
    print(f"\n{key} counts:")
    for value, count in counter.most_common():
        print(f"{value}: {count}")

# Set up a color palette
colors = sns.color_palette("husl", 11)  # 11 colors for 10 top items + 'Other'

# Create enhanced pie charts for each category
for category, counts in counters.items():
    # Sort items by value and get top 10
    items = sorted(counts.items(), key=lambda x: x[1], reverse=True)
    top_items = items[:10]
    other = sum(dict(items[10:]).values())
    
    if other > 0:
        top_items.append(('Other', other))
    
    labels, sizes = zip(*top_items)
    
    fig, ax = plt.subplots(figsize=(14, 10))  # Increased figure size
    
    # Create the pie chart
    wedges, texts, autotexts = ax.pie(sizes, labels=labels, autopct='%1.1f%%', 
                                      pctdistance=0.85, wedgeprops=dict(width=0.5),
                                      textprops={'fontsize': 14},  # Increased font size
                                      colors=colors[:len(top_items)])  # Use only as many colors as needed
    
    # Create a circle at the center to turn it into a donut chart
    centre_circle = plt.Circle((0,0), 0.70, fc='white')
    fig.gca().add_artist(centre_circle)
    
    # Equal aspect ratio ensures that pie is drawn as a circle
    ax.axis('equal')  
    
    # Add a title
    plt.title(f"Distribution of Terms in {category}", fontsize=20, fontweight='bold')  # Increased title font size
    
    # Add a legend
    ax.legend(wedges, labels,
              title="Terms",
              loc="center left",
              bbox_to_anchor=(1, 0, 0.5, 1),
              fontsize=12)  # Increased legend font size
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(f"Figures/pie_chart_{category}.png", dpi=300, bbox_inches='tight')
    plt.close()

print("\nEnhanced pie charts have been saved as PNG files in the current directory.")

In [None]:
import json
from collections import Counter, defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np


# Create the heatmap
topic_systems = list(counters['meta_info_TopicSystem'].keys())
topic_disciplines = list(counters['meta_info_TopicDiscipline'].keys())

heatmap_matrix = np.zeros((len(topic_disciplines), len(topic_systems)))

for i, discipline in enumerate(topic_disciplines):
    for j, system in enumerate(topic_systems):
        heatmap_matrix[i, j] = heatmap_data[system][discipline]

# Create a DataFrame for the heatmap
df_heatmap = pd.DataFrame(heatmap_matrix, index=topic_disciplines, columns=topic_systems)

# Set up the matplotlib figure
plt.figure(figsize=(20, 16))  # Increased figure size

# Create the heatmap
sns.heatmap(df_heatmap, annot=True, fmt='g', cmap='RdYlGn', cbar_kws={'label': 'Count'},
            annot_kws={"size": 12})  # Increased annotation font size

plt.title('Heatmap of Topic System vs Topic Discipline', fontsize=28)  # Increased title font size
plt.xlabel('Topic System', fontsize=22)  # Increased x-label font size
plt.ylabel('Topic Discipline', fontsize=22)  # Increased y-label font size

# Rotate the x-axis labels for better readability and increase font size
plt.xticks(rotation=45, ha='right', fontsize=14)  # Increased x-tick font size
plt.yticks(fontsize=14)  # Increased y-tick font size

# Adjust layout and save
plt.tight_layout()
plt.savefig('Figures/ssystem_discipline_heatmap.png', dpi=300, bbox_inches='tight')
plt.close()

print("\nHeatmap with increased font sizes has been saved as 'system_discipline_heatmap.png' in the current directory.")


# Prepare Chunks

In [None]:
# Number of values in each label
import json
from pprint import pprint

# Load the JSON file
with open(r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\map_terms_revised.json", "r") as f:
    map_terms = json.load(f)

# Extract unique values for each category
unique_terms = {}
for category, terms in map_terms.items():
    unique_terms[category] = set(terms.values())

# Print the unique terms for each category
for category, terms in unique_terms.items():
    print(f"\nUnique terms for {category}:")
    pprint(sorted(terms))

In [None]:
import json
import os
from collections import defaultdict
import random
import pandas as pd

# File paths
input_file = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens_unifiedterms-translated.jsonl"
output_folder = "Raw_Chunks"
summary_file = "chunk_summary.xlsx"

# Create output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Read and process the input file
data = []
with open(input_file, 'r', encoding='utf-8') as f:
    for line in f:
        data.append(json.loads(line))

# Initialize chunk dictionaries
step1_nonclinical = []
subspeciality_chunks = defaultdict(list)
others = []

# Process data
for item in data:
    if item['meta_info'] == 'step1' and item['meta_info_TopicDiscipline'] != 'Clinical Diagnosis and Management':
        step1_nonclinical.append(item)
    elif item['meta_info_SubSpeciality'] in [
        "Gastroenterology", "Pediatrics", "Obstetrics & Gynecology", "Endocrinology", "Neurology",
        "Infectious Diseases", "Hematology", "Cardiology", "Pulmonology", "Emergency Medicine",
        "Psychiatry", "Nephrology", "Rheumatology", "Surgery", "Internal Medicine", "Dermatology"
    ]:
        subspeciality_chunks[item['meta_info_SubSpeciality']].append(item)
    else:
        others.append(item)

# Function to save chunk as JSONL
def save_chunk(chunk, filename):
    with open(os.path.join(output_folder, filename), 'w', encoding='utf-8') as f:
        for item in chunk:
            json.dump(item, f, ensure_ascii=False)
            f.write('\n')

# Function to create chunks of 15 or fewer items
def create_chunks(data, base_name):
    chunks = []
    for i in range(0, len(data), 15):
        chunk = data[i:i+15]
        chunk_name = f"{base_name}_{len(chunks)+1}"
        save_chunk(chunk, f"{chunk_name}.jsonl")
        chunks.append((chunk_name, len(chunk), base_name))
    return chunks

# Process and save chunks
summary = []

# Step1_nonClinical chunks
random.shuffle(step1_nonclinical)
summary.extend(create_chunks(step1_nonclinical, "Step1_nonClinical"))

# Subspeciality chunks
for subspeciality, items in subspeciality_chunks.items():
    summary.extend(create_chunks(items, subspeciality))

# Others chunk
summary.extend(create_chunks(others, "Others"))

# Create summary DataFrame and save to Excel
summary_df = pd.DataFrame(summary, columns=['Chunk Name', 'Number of Instances', 'Applied Filter'])
summary_df.to_excel(summary_file, index=False)

print(f"Processing complete. Chunks saved in '{output_folder}' folder.")
print(f"Summary saved as '{summary_file}'.")

# Archived

In [None]:

import json
import os
import logging
from typing import Dict, List
from pydantic import BaseModel
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.callbacks import get_openai_callback
from tenacity import retry, stop_after_attempt, wait_random_exponential

# Constants
NUMBER_OF_RETRY = 2
MODEL_NAME = "gpt-4o"
SAVE_INTERVAL = 5

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set your OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY_Nar_Amir_MedQA")
if openai_api_key is None:
    raise ValueError("OPENAI_API_KEY environment variable is not set or invalid.")

# Initialize the OpenAI chat model using LangChain
chat = ChatOpenAI(model_name=MODEL_NAME, openai_api_key=openai_api_key)

class LabeledQuestionData(BaseModel):
    question: str
    answer: str
    options: Dict[str, str]
    meta_info: str
    answer_idx: str
    meta_info_TopicSystem: str
    meta_info_TopicDiscipline: str
    meta_info_Speciality: str
    meta_info_SubSpeciality: str
    meta_info_BloomTaxonomy: str
    meta_info_TokenLength: Dict[str, int]

class TranslatedLabeledQuestionData(BaseModel):
    question: str
    question_persian: str
    answer: str
    answer_persian: str
    options: Dict[str, str]
    options_persian: Dict[str, str]
    meta_info: str
    answer_idx: str
    meta_info_TopicSystem: str
    meta_info_TopicDiscipline: str
    meta_info_Speciality: str
    meta_info_SubSpeciality: str
    meta_info_BloomTaxonomy: str
    meta_info_TokenLength: Dict[str, int]
    total_tokens: int
    total_cost: float

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(NUMBER_OF_RETRY))
def translate_to_persian(text: str) -> (str, int, float):
    try:
        system_prompt = SystemMessagePromptTemplate.from_template(
            "You are an English to Persian translator with deep knowledge of biomedical terms."
        )
        human_prompt = HumanMessagePromptTemplate.from_template(
            "Translate the following text to Persian: {text}"
        )
        chat_prompt = ChatPromptTemplate.from_messages([system_prompt, human_prompt])

        messages = chat_prompt.format_messages(text=text)
        with get_openai_callback() as cb:
            response = chat(messages)
            token_usage = cb.total_tokens
            cost = cb.total_cost

        return response.content.strip(), token_usage, cost
    except Exception as e:
        logging.error(f"Error during translation: {e}")
        raise

def read_existing_data(output_path: str) -> List[Dict]:
    """
    Read existing data from the output file.
    """
    if not os.path.exists(output_path):
        return []

    with open(output_path, 'r', encoding='utf-8') as file:
        return [json.loads(line.strip()) for line in file]

def save_translated_data(translated_data: List[Dict], output_path: str):
    """
    Save translated data to the output file.
    """
    with open(output_path, 'w', encoding='utf-8') as file:
        for item in translated_data:
            json.dump(item, file, ensure_ascii=False)
            file.write('\n')
    logging.info(f"Saved {len(translated_data)} translated items to {output_path}")

def process_translation(input_path: str, output_path: str, N: int, continue_from_existing: bool = True):
    # Load existing data if continuing
    translated_data = read_existing_data(output_path) if continue_from_existing else []

    # Create a set of already processed question IDs for faster lookup
    processed_ids = set(item['question'] for item in translated_data)

    # Ensure the output directory exists
    output_dir = os.path.dirname(output_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Read new data and process
    with open(input_path, 'r', encoding='utf-8') as file:
        newly_processed_count = 0
        total_processed_count = len(translated_data)
        cumulative_tokens = sum(item.get('total_tokens', 0) for item in translated_data)
        cumulative_cost = sum(item.get('total_cost', 0) for item in translated_data)

        for line in file:
            if newly_processed_count >= N:
                break
            try:
                json_obj = json.loads(line.strip())
                labeled_data = LabeledQuestionData(**json_obj)
                
                # Skip if this question has already been processed
                if labeled_data.question in processed_ids:
                    continue

                logging.info(f"Started Translation of #{total_processed_count + 1}: {labeled_data}")

                # Translate the question, answer, and options to Persian
                translated_question, tokens_q, cost_q = translate_to_persian(labeled_data.question)
                cumulative_tokens += tokens_q
                cumulative_cost += cost_q

                translated_answer, tokens_a, cost_a = translate_to_persian(labeled_data.answer)
                cumulative_tokens += tokens_a
                cumulative_cost += cost_a

                translated_options = {}
                for key, value in labeled_data.options.items():
                    translated_opt, tokens_opt, cost_opt = translate_to_persian(value)
                    translated_options[key] = translated_opt
                    cumulative_tokens += tokens_opt
                    cumulative_cost += cost_opt

                translated_labeled_data = TranslatedLabeledQuestionData(
                    question=labeled_data.question,
                    question_persian=translated_question,
                    answer=labeled_data.answer,
                    answer_persian=translated_answer,
                    options=labeled_data.options,
                    options_persian=translated_options,
                    meta_info=labeled_data.meta_info,
                    answer_idx=labeled_data.answer_idx,
                    meta_info_TopicSystem=labeled_data.meta_info_TopicSystem,
                    meta_info_TopicDiscipline=labeled_data.meta_info_TopicDiscipline,
                    meta_info_Speciality=labeled_data.meta_info_Speciality,
                    meta_info_SubSpeciality=labeled_data.meta_info_SubSpeciality,
                    meta_info_BloomTaxonomy=labeled_data.meta_info_BloomTaxonomy,
                    meta_info_TokenLength=labeled_data.meta_info_TokenLength,
                    total_tokens=cumulative_tokens,
                    total_cost=cumulative_cost
                )

                logging.info(f"Finished Translation of #{total_processed_count + 1}: {translated_labeled_data}")
                translated_data.append(translated_labeled_data.dict())
                processed_ids.add(labeled_data.question)
                newly_processed_count += 1
                total_processed_count += 1

                # Save data every SAVE_INTERVAL runs or on the last run
                if newly_processed_count % SAVE_INTERVAL == 0 or newly_processed_count == N:
                    save_translated_data(translated_data, output_path)

            except json.JSONDecodeError:
                logging.error(f"Error decoding JSON on line {total_processed_count + 1}")
            except Exception as e:
                logging.error(f"Error processing line {total_processed_count + 1}: {str(e)}")
                
                # Include the original data in the translated_data
                error_entry = {
                    "error": str(e),
                    "original_data": json_obj
                }
                translated_data.append(error_entry)
                total_processed_count += 1
                
                # Save data after each error
                save_translated_data(translated_data, output_path)

    logging.info(f"Translation completed. Newly processed {newly_processed_count} items. "
                 f"Total processed {total_processed_count} items. "
                 f"Total tokens used: {cumulative_tokens}")
    logging.info(f"Total cost: ${cumulative_cost:.6f}")
    logging.info(f"Data saved to {output_path}")

# Example usage
input_path = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_labeled_with_bloom_and_tokens.jsonl"
output_path = r"C:\Users\LEGION\Documents\GIT\ParsBench-biomedical\MedQA\US-test_translated.jsonl"
N = 2  # Adjust this to the number of questions you want to translate
continue_from_existing = True

process_translation(input_path, output_path, N, continue_from_existing)

