In [12]:
!pip install -qU langchain-community langchain-text-splitters langchain-huggingface

In [13]:
!pip install -qU chromadb

# LOAD THE DOCUMENTS

In [22]:
import os
import re
from typing import List, Dict
from langchain.schema import Document
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings

class NutritionalDataProcessor:
    def __init__(self, data_dir: str = ".", collection_name: str = "nutrition_data"):
        """
        Initialize the processor with a data directory and collection name
        """
        self.data_dir = data_dir
        self.collection_name = collection_name
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            length_function=len
        )
        self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

    def load_data_from_directory(self) -> List[str]:
        """
        Load all text files from the specified directory
        """
        text_data = []
        for filename in os.listdir(self.data_dir):
            if filename.endswith('.txt') or filename.endswith('.md'):
                with open(os.path.join(self.data_dir, filename), 'r', encoding='utf-8') as file:
                    text_data.append(f"\n\n=== FILE: {filename} ===\n{file.read()}")
        print(f"Loaded {len(text_data)} files from {self.data_dir}")
        return text_data

    def _split_into_sections(self, text: str) -> Dict[str, str]:
        """
        Split text into sections based on predefined patterns
        """
        sections = {}
        section_patterns = [
            (r"Dietary Reference Intakes for Minerals", "DRI_Minerals"),
            (r"Dietary Reference Intakes for Vitamins and Macronutrients", "DRI_Vitamins_Macros"),
            (r"Estimated Average Requirements \(EAR\)", "EAR"),
            (r"Tolerable Upper Intake Levels \(TUL\)", "TUL")
        ]
        current_section = "Introduction"
        current_content = []
        lines = text.split('\n')
        for line in lines:
            section_found = False
            for pattern, section_name in section_patterns:
                if re.search(pattern, line, re.IGNORECASE):
                    if current_content:
                        sections[current_section] = '\n'.join(current_content)
                    current_section = section_name
                    current_content = [line]
                    section_found = True
                    break
            if not section_found:
                current_content.append(line)
        if current_content:
            sections[current_section] = '\n'.join(current_content)
        print(f"Detected sections: {list(sections.keys())}")
        return sections

    def _is_demographic_group(self, line: str) -> bool:
        """
        Check if a line represents a demographic group header
        """
        demographic_pattern = r"^(?:#+\s+)?(?:Infants|Children|Males|Females|Pregnancy|Lactation)(?:,)?\s*(?:\d+[-–]\d+(?:\+)?\s*(?:y|mo|Years)|>\s*\d+\s*(?:y|Years))"
        print(f"Checking line: '{line.strip()}', Match: {bool(re.match(demographic_pattern, line.strip(), re.IGNORECASE))}")
        return bool(re.match(demographic_pattern, line.strip(), re.IGNORECASE))

    def _process_dri_section(self, content: str, section_type: str) -> List[Document]:
        """
        Process a DRI section into documents with demographic metadata
        """
        documents = []
        current_group = None
        current_nutrients = []
        seen_groups = set()  # Track processed groups to avoid duplicates
        lines = content.split('\n')
        for line in lines:
            line = line.strip()
            if not line:
                continue
            if self._is_demographic_group(line):
                if current_group and current_nutrients and current_group not in seen_groups:
                    doc = self._create_demographic_document(current_group, current_nutrients, section_type)
                    documents.append(doc)
                    seen_groups.add(current_group)
                    print(f"Created document for group: {current_group}, Section: {section_type}")
                current_group = line.strip().lstrip('#').strip()
                current_nutrients = []
            elif ':' in line and any(unit in line for unit in ['mg/d', 'μg/d', 'µg/d', 'g/d', 'L/d']):
                print(f"Parsed nutrient line: {line}")
                current_nutrients.append(line)
        if current_group and current_nutrients and current_group not in seen_groups:
            doc = self._create_demographic_document(current_group, current_nutrients, section_type)
            documents.append(doc)
            seen_groups.add(current_group)
            print(f"Created document for group: {current_group}, Section: {section_type}")
        return documents

    def _extract_metadata(self, group: str, section_type: str, nutrients: List[str]) -> Dict:
        """
        Extract metadata for a document
        """
        nutrient_count = len([n for n in nutrients if ':' in n])
        return {
            "demographic_group": group,
            "section_type": section_type,
            "nutrient_count": nutrient_count,
            "document_type": "general"
        }

    def _create_demographic_document(self, group: str, nutrients: List[str], section_type: str) -> Document:
        """
        Create a LangChain Document for a demographic group
        """
        content = f"Demographic Group: {group}\n"
        content += f"Reference Type: {section_type}\n\n"
        content += "Nutritional Requirements:\n"
        for nutrient in nutrients:
            content += f"- {nutrient}\n"
        metadata = self._extract_metadata(group, section_type, nutrients)
        return Document(page_content=content, metadata=metadata)

    def parse_nutritional_data(self, texts: List[str]) -> List[Document]:
        """
        Parse nutritional data into documents
        """
        all_documents = []
        for text in texts:
            sections = self._split_into_sections(text)
            for section_name, content in sections.items():
                if section_name in ["DRI_Minerals", "DRI_Vitamins_Macros", "EAR", "TUL"]:
                    documents = self._process_dri_section(content, section_name)
                    all_documents.extend(documents)
                else:
                    chunks = self.text_splitter.split_text(content)
                    for i, chunk in enumerate(chunks):
                        metadata = {
                            "section": section_name,
                            "chunk_id": i,
                            "document_type": "general"
                        }
                        all_documents.append(Document(page_content=chunk, metadata=metadata))
        print(f"Created {len(all_documents)} documents")
        return all_documents

    def create_vectorstore(self, documents: List[Document], collection_name: str = None) -> Chroma:
        """
        Create a Chroma vector store from documents
        """
        collection_name = collection_name or self.collection_name
        vectorstore = Chroma.from_documents(
            documents=documents,
            embedding=self.embedding_function,
            collection_name=collection_name,
        )
        print(f"Created vector store with {len(documents)} documents in collection '{collection_name}'")
        return vectorstore

In [23]:
processor = NutritionalDataProcessor(data_dir=".")
text_data = processor.load_data_from_directory()
documents = processor.parse_nutritional_data(text_data)
vectorstore = processor.create_vectorstore(documents, collection_name="nutrition_data")

Loaded 1 files from .
Detected sections: ['Introduction', 'DRI_Minerals', 'DRI_Vitamins_Macros', 'EAR', 'TUL']
Checking line: '# Dietary Reference Intakes for Minerals', Match: False
Checking line: '## Infants 0-6 months', Match: True
Checking line: '- **Calcium**: 200 mg/d', Match: False
Parsed nutrient line: - **Calcium**: 200 mg/d
Checking line: '- **Chromium**: 0.2 μg/d', Match: False
Parsed nutrient line: - **Chromium**: 0.2 μg/d
Checking line: '- **Copper**: 200 μg/d', Match: False
Parsed nutrient line: - **Copper**: 200 μg/d
Checking line: '- **Fluoride**: 0.01 mg/d', Match: False
Parsed nutrient line: - **Fluoride**: 0.01 mg/d
Checking line: '- **Iodine**: 110 μg/d', Match: False
Parsed nutrient line: - **Iodine**: 110 μg/d
Checking line: '- **Iron**: 0.27 mg/d', Match: False
Parsed nutrient line: - **Iron**: 0.27 mg/d
Checking line: '- **Magnesium**: 30 mg/d', Match: False
Parsed nutrient line: - **Magnesium**: 30 mg/d
Checking line: '- **Manganese**: 0.003 mg/d', Match: False

  return forward_call(*args, **kwargs)


Created vector store with 80 documents in collection 'nutrition_data'


# MODEL IMPLEMENTATION

In [24]:
!pip install -qU langchain-perplexity

In [25]:
from langchain_perplexity import ChatPerplexity
from google.colab import userdata

pplx_api_key = userdata.get('PERPLEX_API_KEY')

llm_1 = ChatPerplexity(pplx_api_key=pplx_api_key, model="sonar", temperature=0.0)

In [26]:
from typing import List, Dict, Any
import re
from langchain.prompts import ChatPromptTemplate
from langchain.schema import Document
from langchain_perplexity import ChatPerplexity

class PerplexityRAGQuery:
    def __init__(self, vectorstore, perplexity_api_key: str):
        """
        Initialize with vector store and Perplexity API key
        """
        self.vectorstore = vectorstore
        self.client = ChatPerplexity(pplx_api_key=perplexity_api_key, model="sonar", temperature=0.0)

        self.prompt_template = ChatPromptTemplate.from_template(
            """
            You are a nutrition assistant. Use only the provided context to construct a nutrient requirement list.

            Instructions:
            1. Use ONLY the context provided, which is filtered for the specific demographic group.
            2. Output ONLY a clean list of nutrients and their values, one per line, in the format: "Nutrient Name: Value [TUL : TUL Value]"
            3. Do not include any bullet points, dashes, asterisks, or markdown formatting.
            4. Do not include any introductions, conclusions, explanations, or additional text.
            5. Do not include section headers or document references.
            6. If no relevant data is found, return "Data not found".
            7. Make sure the values do not exceed specified TULs.
            8. The lives of patients are in your hands. False information leads to serious health impairment or death.

            Context:
            {context}

            User Query:
            {question}

            Output format example:
            Calcium: 1000 mg/d [TUL : 2500 mg/d]
            Iron: 8 mg/d [TUL : 45 mg/d]
            Vitamin C: 90 mg/d [TUL : 2000 mg/d]
            """
        )

    def extract_demographic_from_query(self, query: str) -> tuple[int, str]:
        """
        Extract age and sex from the query using regex
        """
        age_pattern = r'(\d{1,3})\s*(?:year(?:s)?(?:\s|-)?old)?'
        age_match = re.search(age_pattern, query, re.IGNORECASE)
        age = int(age_match.group(1)) if age_match else None

        sex_pattern = r'\b(male|man|female|woman)\b'
        sex_match = re.search(sex_pattern, query, re.IGNORECASE)
        sex = sex_match.group(1).lower() if sex_match else None
        sex = 'male' if sex in ['man', 'male'] else 'female' if sex in ['woman', 'female'] else None

        return age, sex

    def get_target_headers(self, age: int, sex: str) -> List[str]:
        """
        Get standardized target headers for the demographic
        """
        if age < 1:
            return ["Infants 0-6 months", "Infants 7-12 months"]
        elif 1 <= age <= 3:
            return ["Children 1-3 years"]
        elif 4 <= age <= 8:
            return ["Children 4-8 years"]
        elif 9 <= age <= 13:
            return [f"{sex.capitalize()}s 9-13 years"]
        elif 14 <= age <= 18:
            return [f"{sex.capitalize()}s 14-18 years"]
        elif 19 <= age <= 30:
            headers = [f"{sex.capitalize()}s 19-30 years"]
            headers.append(f"{sex.capitalize()}s 19-70+ years")
            return headers
        elif 31 <= age <= 50:
            headers = [f"{sex.capitalize()}s 31-50 years"]
            headers.append(f"{sex.capitalize()}s 19-70+ years")
            return headers
        elif 51 <= age <= 70:
            headers = [f"{sex.capitalize()}s 51-70 years"]
            headers.append(f"{sex.capitalize()}s 19-70+ years")
            return headers
        else:  # > 70
            headers = [f"{sex.capitalize()}s > 70 years"]
            headers.append(f"{sex.capitalize()}s 19-70+ years")
            return headers

    def retrieve_relevant_documents(self, query: str, age: int = None, sex: str = None, k: int = 10) -> List[Document]:
        """
        Retrieve documents for the demographic from all relevant sections
        """
        if age is None or sex is None:
            age, sex = self.extract_demographic_from_query(query)

        if not age or not sex:
            return []

        target_headers = self.get_target_headers(age, sex)
        sections = ["DRI_Minerals", "DRI_Vitamins_Macros", "EAR", "TUL"]

        all_documents = []

        for section in sections:
            for header in target_headers:
                filters = {
                    "$and": [
                        {"demographic_group": {"$eq": header}},
                        {"section_type": {"$eq": section}}
                    ]
                }

                try:
                    docs = self.vectorstore.similarity_search(
                        query,
                        k=k,
                        filter=filters
                    )

                    validated_docs = [
                        doc for doc in docs
                        if (doc.metadata.get("demographic_group") == header and
                            doc.metadata.get("section_type") == section)
                    ]

                    all_documents.extend(validated_docs)

                    if validated_docs:
                        break

                except Exception:
                    continue

        unique_docs = []
        seen_content = set()
        for doc in all_documents:
            content_hash = hash(doc.page_content)
            if content_hash not in seen_content:
                unique_docs.append(doc)
                seen_content.add(content_hash)

        return unique_docs

    def format_context(self, documents: List[Document]) -> str:
        """
        Format retrieved documents into context string
        """
        context_parts = []
        for i, doc in enumerate(documents):
            context_part = f"Document {i+1} [{doc.metadata.get('demographic_group', 'Unknown')} - {doc.metadata.get('section_type', 'Unknown')}]\n{doc.page_content}\n"
            context_parts.append(context_part)
        return "\n".join(context_parts)

    def query_perplexity(self, prompt: str, model: str = "sonar") -> str:
        """
        Query Perplexity API with the formatted prompt
        """
        try:
            response = self.client.invoke(prompt)
            return response.content
        except Exception:
            return ""

    def answer_query(self, question: str, age: int = None, sex: str = None, model: str = "sonar") -> str:
        """
        Complete RAG pipeline: retrieve, format, and query
        """
        if age is None or sex is None:
            age, sex = self.extract_demographic_from_query(question)

        if not age or not sex:
            return f"The ideal nutritional intake for a {age or 'unknown'} year old {sex or 'unknown'} is as follows:\n"

        documents = self.retrieve_relevant_documents(question, age, sex)

        if not documents:
            return f"The ideal nutritional intake for a {age} year old {sex} is as follows:\n"

        context = self.format_context(documents)

        formatted_prompt = self.prompt_template.format(
            context=context,
            question=question
        )

        nutrient_list = self.query_perplexity(formatted_prompt, model)

        return f"The ideal nutritional intake for a {age} year old {sex} is as follows:\n{nutrient_list}"


def test_nutrition_query(processor, vectorstore, perplexity_api_key: str):
    """
    Test function with the specified query
    """
    rag_system = PerplexityRAGQuery(vectorstore, perplexity_api_key)
    age = int(input("Enter the user's age :"))
    sex = input("Enter the user's sex (male/female) :")
    test_query = f"What are the ideal nutritional requirements for a {age} year old {sex}?"
    result = rag_system.answer_query(question=test_query)
    print(result)
    return result

In [28]:
result = test_nutrition_query(processor, vectorstore, pplx_api_key)

Enter the user's age :27
Enter the user's sex (male/female) :female


  return forward_call(*args, **kwargs)


The ideal nutritional intake for a 27 year old female is as follows:
Calcium: 1000 mg/d [TUL : 2500 mg/d]  
Chromium: 25 μg/d [TUL : Data not found]  
Copper: 900 μg/d [TUL : 10000 μg/d]  
Fluoride: 3 mg/d [TUL : 10 mg/d]  
Iodine: 150 μg/d [TUL : 1100 μg/d]  
Iron: 18 mg/d [TUL : 45 mg/d]  
Magnesium: 310 mg/d [TUL : 350 mg/d]  
Manganese: 1.8 mg/d [TUL : 11 mg/d]  
Molybdenum: 45 μg/d [TUL : 2000 μg/d]  
Phosphorus: 700 mg/d [TUL : 4000 mg/d]  
Selenium: 55 μg/d [TUL : 400 μg/d]  
Zinc: 8 mg/d [TUL : 40 mg/d]  
Potassium: 2600 mg/d [TUL : Data not found]  
Sodium: 1500 mg/d [TUL : Data not found]  
Chloride: 2.3 g/d [TUL : 3.6 g/d]  
Vitamin A: 700 μg/d [TUL : 3000 μg/d]  
Vitamin C: 75 mg/d [TUL : 2000 mg/d]  
Vitamin D: 15 μg/d [TUL : 100 μg/d]  
Vitamin E: 15 mg/d [TUL : 1000 mg/d]  
Vitamin K: 90 μg/d [TUL : Data not found]  
Thiamin: 1.1 mg/d [TUL : Data not found]  
Riboflavin: 1.1 mg/d [TUL : Data not found]  
Niacin: 14 mg/d [TUL : 35 mg/d]  
Vitamin B6: 1.3 mg/d [TUL : 100 m

In [20]:
from langchain.memory import ConversationSummaryBufferMemory
from langchain.chains import ConversationalRetrievalChain, LLMChain, StuffDocumentsChain
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain_core.runnables import RunnablePassthrough, RunnableLambda



medical_prompt = ChatPromptTemplate.from_template(
"""
You are a trusted medical nutrition advisor AI. Your task is to adjust the baseline nutrient intake values for a user based on their medical conditions, medications, and physical activity level.

--- Baseline Nutritional Values ---
{baseline_values}

--- User Profile ---
• Health Conditions: {conditions}
• Activity Level: {activity_level}
• Medications: {medications}

--- Instructions ---
1. For each nutrient:
   - Modify the baseline value if it's medically warranted by any of the user's conditions, medications, or activity level.
   - Justify each change with clinical reasoning and cite whether it's condition-driven, medication-related, or activity-based.

2. In cases where **multiple conditions have conflicting nutrient recommendations** (e.g., Condition X increases need for Nutrient A, while Condition Y requires lowering Nutrient A):
   - **Acknowledge the conflict explicitly.**
   - Suggest a medically reasonable compromise or prioritize based on risk (e.g., "due to higher risk of toxicity in Condition Y, we will lean toward the lower bound").
   - Flag nutrients where close monitoring is essential.

3. If any adjusted value exceeds the Tolerable Upper Limit (TUL), clearly **flag this as a clinical risk** and explain why it might still be justified (or not).

4. Do NOT adjust nutrients that have no medical justification.

--- Output Format ---
Return the final adjusted nutrient values as a structured list or table. Include:
- Nutrient name
- Adjusted value
- Reason for adjustment
- Risk flags (if any)

Be clear, concise, and medically sound.
"""
)

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/14.5G [00:00<?, ?B/s]

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipython-input-20-907928174.py", line 9, in <cell line: 0>
    model = AutoModelForCausalLM.from_pretrained("BioMistral/BioMistral-7B", torch_dtype="auto", device_map="auto",  offload_folder="./offload_biomistral")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py", line 600, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py", line 315, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/trans

TypeError: object of type 'NoneType' has no len()