In [None]:
pip install -U langchain-google-genai langchain-community langchain_core langgraph pydantic

Collecting langchain-google-genai
  Downloading langchain_google_genai-2.1.6-py3-none-any.whl.metadata (7.0 kB)
Collecting langchain-community
  Downloading langchain_community-0.3.27-py3-none-any.whl.metadata (2.9 kB)
Collecting langchain_core
  Downloading langchain_core-0.3.68-py3-none-any.whl.metadata (5.8 kB)
Collecting langgraph
  Downloading langgraph-0.5.1-py3-none-any.whl.metadata (6.7 kB)
Collecting filetype<2.0.0,>=1.2.0 (from langchain-google-genai)
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Collecting google-ai-generativelanguage<0.7.0,>=0.6.18 (from langchain-google-genai)
  Downloading google_ai_generativelanguage-0.6.18-py3-none-any.whl.metadata (9.8 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.10.1-py3-none-any.whl.metadata (3.4 kB)
Collecting htt

In [None]:
import os
import random
import json
from typing import List, Dict, Any, Optional, Literal, TypedDict
from pydantic import BaseModel, Field, conint, confloat, ValidationError
from enum import Enum
from datetime import datetime, timedelta

# LangChain specific imports
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.documents import Document # Moved here for consistency

# LangGraph imports
from langgraph.graph import StateGraph, END

# Ensure your GOOGLE_API_KEY is correctly set.
GOOGLE_API_KEY = "AIzaSyC4q-ry8oPTjBHDP1suYrtB2PX52MXREwg"

LANGCHAIN_API_KEY = "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" # Optional, for Langsmith tracing
LANGCHAIN_TRACING_V2 = "true" # Enable Langsmith for tracing!
LANGCHAIN_PROJECT = "Strategic Portfolio AI"

# --- 1. Enums and Pydantic Models for Data and Input/Output ---

# ENUMS for fixed, categorical choices (UI-driven)
class AssetType(str, Enum):
    MOLECULE = "Molecule"
    PLATFORM = "Platform"
    DEVICE = "Device"

class StrategicObjectiveEnum(str, Enum):
    PRUNING = "Pruning the Portfolio"
    DIVERSIFICATION = "Diversification"
    FILLING_PIPELINE = "Filling the Pipeline"
    AUGMENTING_PIPELINE = "Augmenting the Pipeline / New Indication Space"
    CUSTOM_QUERY = "Custom Query"

# Pydantic Models for Data (flexible string types for dynamic fields)
class Molecule(BaseModel):
    id: str
    name: str
    ndc_code: Optional[str] = None
    asset_type: AssetType # Still an Enum
    development_stage: str # Dynamic string
    therapeutic_area: str # Dynamic string
    indication: str
    mechanism_of_action: str # Dynamic string
    route_of_administration: str # Dynamic string
    modality: str # Dynamic string
    patent_expiry_year: int
    current_roi: Optional[float] = None
    projected_peak_sales_M: Optional[confloat(ge=0)] = None
    internal_risk_score: confloat(ge=0, le=1)
    efficacy_profile: str
    safety_profile: str
    is_for_sale: bool = False
    company_id: str
    company_name: str

class Company(BaseModel):
    id: str
    name: str
    type: str # Dynamic string
    development_stage: str # Dynamic string
    headquarters: str
    financial_status: str # Dynamic string
    partner_status: Optional[str] = None
    territory: str

class ExternalTrialData(BaseModel):
    trial_id: str
    molecule_id: str
    indication: str
    development_stage: str # Dynamic string
    results_summary: str
    safety_profile: str
    efficacy_profile: str
    target_biology: str
    pathway_overlap: List[str] = Field(default_factory=list)
    patient_population_characteristics: str

class MarketBenchmark(BaseModel):
    therapeutic_area: str # Dynamic string
    development_stage: str # Dynamic string
    avg_roi: float
    avg_peak_sales_M: confloat(ge=0)
    avg_time_in_stage_months: conint(gt=0)
    success_rate: confloat(ge=0, le=1)

# Frontend Input Model (Adapted to new data types for filters)
class AssetPhenotypeFilters(BaseModel):
    stages_of_development: Optional[List[str]] = None
    therapeutic_area: Optional[List[str]] = None
    indication: Optional[List[str]] = None
    mechanism_of_action: Optional[List[str]] = None
    route_of_administration: Optional[List[str]] = None
    modality: Optional[List[str]] = None
    patent_expiry_yr: Optional[conint(gt=1900, le=2100)] = None

class CompanyDetailsFilters(BaseModel):
    company_type: Optional[List[str]] = None
    development_stage: Optional[List[str]] = None
    headquarters: Optional[List[str]] = None
    financial_status: Optional[List[str]] = None
    partner_status: Optional[List[str]] = None
    territory: Optional[List[str]] = None

class PeakSalesFilters(BaseModel):
    one_yr_sales_potential_M: Optional[confloat(ge=0)] = None
    five_yr_sales_potential_M: Optional[confloat(ge=0)] = None
    peak_sales_M: Optional[confloat(ge=0)] = None

class FrontendFilters(BaseModel):
    for_sale: Literal["all", "for_sale", "for_purchase", "sold"] = "all"
    asset_type: Optional[AssetType] = None
    deal_value_min: Optional[confloat(ge=0)] = None
    deal_value_max: Optional[confloat(ge=0)] = None
    asset_phenotype: Optional[AssetPhenotypeFilters] = Field(default_factory=AssetPhenotypeFilters)
    company_details: Optional[CompanyDetailsFilters] = Field(default_factory=CompanyDetailsFilters)
    peak_sales: Optional[PeakSalesFilters] = Field(default_factory=PeakSalesFilters)
    search_query: Optional[str] = None

class FrontendInput(BaseModel):
    filters: FrontendFilters = Field(default_factory=FrontendFilters)
    strategic_objective: StrategicObjectiveEnum
    custom_query_text: Optional[str] = None
    current_company_id: str # This ID comes with the user input


# Output Structures for Recommendations
class PruningRecommendation(BaseModel):
    action_type: Literal["Deprioritize Asset"]
    molecule_name: str
    molecule_id: str
    justification: str
    reason_criteria: List[str]
    risk_score: confloat(ge=0, le=1)
    opportunity_cost_estimate: str
    impact_on_portfolio: Dict[str, Any]

class DiversificationRecommendation(BaseModel):
    action_type: Literal["Acquire Asset", "Invest in Research Area"]
    molecule_name: Optional[str] = None
    molecule_id: Optional[str] = None
    reason_for_diversification: str
    strategic_fit_score: confloat(ge=0, le=1)
    target_disease_area: str # Dynamic string
    proposed_moa: str # Dynamic string
    proposed_modality: str # Dynamic string

class FillingPipelineRecommendation(BaseModel):
    action_type: Literal["Acquire Asset", "Initiate Internal Project"]
    molecule_name: str
    molecule_id: str
    reason_for_suggestion: str
    suggested_role: str
    development_stage_fit: str # Dynamic string

class AugmentingPipelineRecommendation(BaseModel):
    action_type: Literal["Initiate New Indication Trial"]
    molecule_name: str
    molecule_id: str
    new_indication: str
    justification: str
    evidence_strength: Literal["Low", "Moderate", "High", "Very High"]
    market_potential_estimate: str

class StrategicOutput(BaseModel):
    strategic_objective_addressed: StrategicObjectiveEnum
    recommendations: List[Any] # Can be a list of any of the above recommendation types
    strategic_summary: str
    overall_impact_on_portfolio: Dict[str, Any]
    recommended_portfolio_adjustments: Dict[str, Any]
    suggested_ideal_portfolio_characteristics: Dict[str, Any]

# --- Initialize Global LLM and Embeddings ---
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GOOGLE_API_KEY)


# --- 2. Data Management Layer ---

class DataManager:
    def __init__(self, data_file_path: str, embeddings_model: GoogleGenerativeAIEmbeddings):
        self.data_file_path = data_file_path
        self.embeddings = embeddings_model
        self._load_all_data() # Load all data once at initialization

        # Initialize these as None; they will be populated when set_current_company is called
        self.internal_portfolio_vs: Optional[Chroma] = None
        self.competitive_landscape_vs: Optional[Chroma] = None
        self.external_molecule_database_vs: Optional[Chroma] = None
        self.external_trial_data_vs: Optional[Chroma] = None
        self.literature_vs: Optional[Chroma] = None

        self._initialize_static_vector_stores()


    def _initialize_static_vector_stores(self):
        """Initializes vector stores that do not depend on `your_company_id`."""
        # Initialize ChromaDB vector stores that don't depend on self.your_company_id
        # The external_molecule_database_vs and external_trial_data_vs can be initialized with all data
        self.external_molecule_database_vs = Chroma.from_documents(
            documents=[doc for m in self.all_molecules.values() for doc in self._molecule_to_documents(m)] if self.all_molecules else [],
            embedding=self.embeddings,
            collection_name="external_molecule_database_v3"
        )
        self.external_trial_data_vs = Chroma.from_documents(
            documents=[doc for t in self.external_trial_data.values() for doc in self._trial_to_documents(t)] if self.external_trial_data else [],
            embedding=self.embeddings,
            collection_name="external_trial_data_v3"
        )
        self.literature_vs = Chroma.from_texts(
            texts=[s["text"] for s in self.literature_snippets] if self.literature_snippets else [],
            metadatas=[{"source": s["source"]} for s in self.literature_snippets] if self.literature_snippets else [],
            embedding=self.embeddings,
            collection_name="literature_snippets_v3"
        )


    def _load_all_data(self):
        """Loads and parses all data from the single JSON file."""
        try:
            with open(self.data_file_path, 'r') as f:
                data = json.load(f)

            self.all_companies: Dict[str, Company] = {c["id"]: Company(**c) for c in data.get("companies", [])}
            self.all_molecules: Dict[str, Molecule] = {m["id"]: Molecule(**m) for m in data.get("molecules", [])}
            self.external_trial_data: Dict[str, ExternalTrialData] = {t["trial_id"]: ExternalTrialData(**t) for t in data.get("external_trial_data", [])}
            self.market_benchmarks: List[MarketBenchmark] = [MarketBenchmark(**bm) for bm in data.get("market_benchmarks", [])]
            self.literature_snippets: List[Dict[str,str]] = data.get("literature_snippets", [])

            print(f"Loaded {len(self.all_companies)} companies.")
            print(f"Loaded {len(self.all_molecules)} molecules.")
            print(f"Loaded {len(self.external_trial_data)} external trial data entries.")
            print(f"Loaded {len(self.market_benchmarks)} market benchmarks.")
            print(f"Loaded {len(self.literature_snippets)} literature snippets.")

            # These will be set dynamically based on user input
            self.your_company_id: Optional[str] = None
            self.your_company_profile: Optional[Company] = None
            self.internal_molecules: Dict[str, Molecule] = {}

        except FileNotFoundError:
            raise FileNotFoundError(f"Data file not found at: {self.data_file_path}")
        except ValidationError as e:
            print(f"Pydantic validation error when loading data: {e.errors()}")
            raise
        except json.JSONDecodeError:
            raise ValueError(f"Error decoding JSON from {self.data_file_path}. Ensure it's valid JSON.")
        except KeyError as e:
            raise KeyError(f"Missing expected key in JSON data: {e}. Check your JSON structure.")

    def set_current_company(self, company_id: str):
        """Sets the current company for the session and initializes company-specific vector stores."""
        self.your_company_id = company_id
        self.your_company_profile = self.all_companies.get(company_id)
        self.internal_molecules = {
            mol.id: mol for mol in self.all_molecules.values() if mol.company_id == company_id
        }
        if not self.your_company_profile:
            print(f"Warning: Current company ID '{company_id}' not found in loaded data. Internal portfolio will be empty.")

        # Initialize internal_portfolio_vs and competitive_landscape_vs here
        self.internal_portfolio_vs = Chroma.from_documents(
            documents=[doc for m in self.internal_molecules.values() for doc in self._molecule_to_documents(m)] if self.internal_molecules else [],
            embedding=self.embeddings,
            collection_name="internal_portfolio_v3"
        )
        self.competitive_landscape_vs = Chroma.from_documents(
            documents=[doc for m in self.all_molecules.values() if m.company_id != self.your_company_id for doc in self._molecule_to_documents(m)] if self.all_molecules else [],
            embedding=self.embeddings,
            collection_name="competitive_landscape_v3"
        )


    def _molecule_to_documents(self, molecule: Molecule) -> List[Document]:
        """Converts a Molecule Pydantic model to LangChain Document(s)."""
        content = (
            f"Molecule Name: {molecule.name}\n"
            f"ID: {molecule.id}\n"
            f"Type: {molecule.asset_type.value}\n"
            f"Stage: {molecule.development_stage}\n"
            f"Therapeutic Area: {molecule.therapeutic_area}\n"
            f"Indication: {molecule.indication}\n"
            f"Mechanism of Action: {molecule.mechanism_of_action}\n"
            f"Modality: {molecule.modality}\n"
            f"Patent Expiry: {molecule.patent_expiry_year}\n"
            f"Projected Peak Sales: ${molecule.projected_peak_sales_M or 'N/A'}M\n"
            f"Internal Risk Score: {molecule.internal_risk_score}\n"
            f"Efficacy: {molecule.efficacy_profile}\n"
            f"Safety: {molecule.safety_profile}\n"
            f"For Sale: {molecule.is_for_sale}\n"
            f"Company: {molecule.company_name} ({molecule.company_id})"
        )
        metadata = molecule.model_dump()
        metadata["asset_type"] = metadata["asset_type"].value # Ensure enum is string
        return [Document(page_content=content, metadata=metadata)]

    def _trial_to_documents(self, trial: ExternalTrialData) -> List[Document]:
        """Converts an ExternalTrialData Pydantic model to LangChain Document(s)."""
        content = (
            f"Trial ID: {trial.trial_id}\n"
            f"Molecule ID: {trial.molecule_id}\n"
            f"Indication: {trial.indication}\n"
            f"Stage: {trial.development_stage}\n"
            f"Results: {trial.results_summary}\n"
            f"Safety: {trial.safety_profile}\n"
            f"Efficacy: {trial.efficacy_profile}\n"
            f"Target Biology: {trial.target_biology}\n"
            f"Pathway Overlap: {', '.join(trial.pathway_overlap)}\n" # FIX: Join list into string
            f"Patient Population: {trial.patient_population_characteristics}"
        )
        metadata = trial.model_dump()
        # FIX: Convert pathway_overlap list to a comma-separated string for metadata
        metadata["pathway_overlap"] = ", ".join(trial.pathway_overlap)
        return [Document(page_content=content, metadata=metadata)]

    def get_molecule_by_id(self, molecule_id: str) -> Optional[Molecule]:
        return self.all_molecules.get(molecule_id)

    def get_company_by_id(self, company_id: str) -> Optional[Company]:
        return self.all_companies.get(company_id)

    def get_market_benchmark(self, ta: str, ds: str) -> Optional[MarketBenchmark]:
        for bm in self.market_benchmarks:
            if bm.therapeutic_area.lower() == ta.lower() and bm.development_stage.lower() == ds.lower():
                return bm
        return None

    def get_internal_portfolio(self) -> List[Molecule]:
        return list(self.internal_molecules.values())

    def get_all_molecules(self) -> List[Molecule]:
        return list(self.all_molecules.values())

    def get_filtered_molecules(self, molecules: List[Molecule], filters: FrontendFilters) -> List[Molecule]:
        """
        Splitting this function into several smaller parts:
        - _apply_asset_filters
        - _apply_deal_value_filters
        - _apply_phenotype_filters
        - _apply_company_filters
        """
        filtered_list = []
        for mol in molecules:
            # Apply 'for_sale' filter
            if filters.for_sale == "for_sale" and not mol.is_for_sale:
                continue
            if filters.for_sale == "for_purchase" and (not mol.is_for_sale or mol.company_id == self.your_company_id):
                continue
            # "sold" is a historical state, not a current filter on available assets

            # Apply asset_type filter
            if filters.asset_type and mol.asset_type != filters.asset_type:
                continue

            # Apply deal value filters (based on projected peak sales)
            if filters.deal_value_min is not None and (mol.projected_peak_sales_M is None or mol.projected_peak_sales_M < filters.deal_value_min):
                continue
            if filters.deal_value_max is not None and (mol.projected_peak_sales_M is None or mol.projected_peak_sales_M > filters.deal_value_max):
                continue
            if filters.peak_sales:
                if filters.peak_sales.peak_sales_M is not None and (mol.projected_peak_sales_M is None or mol.projected_peak_sales_M < filters.peak_sales.peak_sales_M):
                    continue
                # 1yr/5yr sales are not in Molecule model, skip for now.

            # Apply AssetPhenotypeFilters
            if filters.asset_phenotype:
                if filters.asset_phenotype.stages_of_development and mol.development_stage not in filters.asset_phenotype.stages_of_development:
                    continue
                if filters.asset_phenotype.therapeutic_area and mol.therapeutic_area not in filters.asset_phenotype.therapeutic_area:
                    continue
                if filters.asset_phenotype.indication and mol.indication not in filters.asset_phenotype.indication:
                    continue
                if filters.asset_phenotype.mechanism_of_action and mol.mechanism_of_action not in filters.asset_phenotype.mechanism_of_action:
                    continue
                if filters.asset_phenotype.route_of_administration and mol.route_of_administration not in filters.asset_phenotype.route_of_administration:
                    continue
                if filters.asset_phenotype.modality and mol.modality not in filters.asset_phenotype.modality:
                    continue
                if filters.asset_phenotype.patent_expiry_yr is not None and mol.patent_expiry_year < filters.asset_phenotype.patent_expiry_yr:
                    continue

            # Apply CompanyDetailsFilters (requires looking up company info)
            if filters.company_details:
                company_of_mol = self.get_company_by_id(mol.company_id)
                if company_of_mol:
                    if filters.company_details.company_type and company_of_mol.type not in filters.company_details.company_type:
                        continue
                    if filters.company_details.development_stage and company_of_mol.development_stage not in filters.company_details.development_stage:
                        continue
                    if filters.company_details.headquarters and company_of_mol.headquarters not in filters.company_details.headquarters:
                        continue
                    if filters.company_details.financial_status and company_of_mol.financial_status not in filters.company_details.financial_status:
                        continue
                    if filters.company_details.partner_status and company_of_mol.partner_status not in filters.company_details.partner_status:
                        continue
                    if filters.company_details.territory and company_of_mol.territory not in filters.company_details.territory:
                        continue
                else: # If company data for molecule is missing or doesn't exist, it doesn't match
                    continue

            filtered_list.append(mol)
        return filtered_list

# Instantiate DataManager (assuming 'company_data.json' exists in the same directory)
data_manager = DataManager('/content/company.json', embeddings)

Loaded 6 companies.
Loaded 11 molecules.
Loaded 3 external trial data entries.
Loaded 4 market benchmarks.
Loaded 5 literature snippets.


In [None]:
# --- 3. LangChain Tools for Strategic Objectives ---
# These tools use the data_manager and LLM to perform specific analyses.

class StrategicTools:
    def __init__(self, llm: ChatGoogleGenerativeAI, data_manager: DataManager):
        self.llm = llm
        self.data_manager = data_manager

    @tool
    def analyze_pruning_portfolio(self, filters: FrontendFilters, current_company_id: str) -> List[PruningRecommendation]:
        """
        Analyzes the internal portfolio to identify assets for deprioritization or pruning.
        Considers factors like low ROI, high risk, expiring patents, competition, and strategic misalignment.
        Returns a list of PruningRecommendation objects.
        """
        self.data_manager.set_current_company(current_company_id) # Ensure current company context is set
        your_molecules = self.data_manager.get_filtered_molecules(self.data_manager.get_internal_portfolio(), filters)

        # Get competitive landscape, potentially applying the same filters for relevance
        all_other_molecules = [mol for mol in self.data_manager.get_all_molecules() if mol.company_id != current_company_id]
        competitive_molecules = self.data_manager.get_filtered_molecules(all_other_molecules, filters)

        internal_portfolio_context = "\n".join([mol.model_dump_json() for mol in your_molecules])
        competitive_context = "\n".join([mol.model_dump_json() for mol in competitive_molecules])
        market_benchmarks_context = "\n".join([bm.model_dump_json() for bm in self.data_manager.market_benchmarks])

        pruning_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert pharmaceutical portfolio manager. Your goal is to identify molecules in the internal portfolio that should be deprioritized or pruned.
             Consider the following criteria for pruning:
             - Low projected peak sales or current ROI.
             - High internal risk score.
             - Soon-to-expire patents (e.g., within 5 years).
             - Presence of superior competitive alternatives.
             - Misalignment with current strategic objectives (though the objective here is pruning).
             - Limited market potential or high development costs for potential return.

             Analyze the provided internal portfolio data and competitive landscape to make informed recommendations.
             Output your recommendations strictly as a JSON list of PruningRecommendation objects. Each object must conform to the PruningRecommendation Pydantic schema."""),
            ("human", f"""Current Internal Portfolio (filtered for your company {current_company_id}):
             <internal_portfolio>
             {internal_portfolio_context}
             </internal_portfolio>

             Competitive Landscape (filtered):
             <competitive_landscape>
             {competitive_context}
             </competitive_landscape>

             Market Benchmarks:
             <market_benchmarks>
             {market_benchmarks_context}
             </market_benchmarks>

             My company ID is: {current_company_id}

             Identify molecules for pruning and provide detailed justifications based on the above data.
             Output a JSON list of PruningRecommendation objects.""")
        ])

        # Using a custom parser for list of Pydantic models as structured_output doesn't directly support List[Model]
        # and we want to ensure robust parsing.
        class PruningList(BaseModel):
            recommendations: List[PruningRecommendation]

        chain = pruning_prompt | self.llm.with_structured_output(PruningList)

        try:
            raw_output = chain.invoke({}) # No explicit input dict needed for chain when using .invoke directly with prompt
            return raw_output.recommendations
        except ValidationError as e:
            print(f"Validation error in pruning tool output: {e.errors()}")
            return []
        except Exception as e:
            print(f"Error in pruning tool: {e}")
            return []


    @tool
    def analyze_diversification(self, filters: FrontendFilters, current_company_id: str) -> List[DiversificationRecommendation]:
        """
        Identifies new therapeutic areas, MOAs, or modalities for portfolio diversification.
        Considers current portfolio gaps, emerging trends, and high-potential external assets.
        Returns a list of DiversificationRecommendation objects.
        """
        self.data_manager.set_current_company(current_company_id)
        your_molecules = self.data_manager.get_internal_portfolio()

        all_external_molecules = [mol for mol in self.data_manager.get_all_molecules() if mol.company_id != current_company_id]
        external_molecules_filtered = self.data_manager.get_filtered_molecules(all_external_molecules, filters)

        literature_snippets = self.data_manager.literature_vs.similarity_search(
            filters.search_query if filters.search_query else "emerging therapeutic areas OR novel mechanisms of action OR new modalities", k=5
        )

        internal_portfolio_context = "\n".join([mol.model_dump_json() for mol in your_molecules])
        external_molecules_context = "\n".join([mol.model_dump_json() for mol in external_molecules_filtered])
        literature_context = "\n".join([doc.page_content for doc in literature_snippets])

        diversification_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert pharmaceutical portfolio strategist. Your objective is to recommend new areas for portfolio diversification.
             Consider the current internal portfolio, the competitive landscape (external assets), and recent scientific literature.
             Look for:
             - Gaps in the current portfolio (e.g., unaddressed therapeutic areas, MOAs, or modalities).
             - Emerging research trends or novel technologies.
             - Promising external assets that align with diversification goals.
             - Areas with high unmet medical need.

             Output your recommendations strictly as a JSON list of DiversificationRecommendation objects. Each object must conform to the DiversificationRecommendation Pydantic schema."""),
            ("human", f"""My current internal portfolio:
             <internal_portfolio>
             {internal_portfolio_context}
             </internal_portfolio>

             Relevant external assets and competitive landscape (filtered):
             <external_molecules>
             {external_molecules_context}
             </external_molecules>

             Recent scientific literature and trends:
             <literature>
             {literature_context}
             </literature>

             My company ID is: {current_company_id}

             Based on this information, suggest strategies for portfolio diversification.
             Output a JSON list of DiversificationRecommendation objects.""")
        ])

        class DiversificationList(BaseModel):
            recommendations: List[DiversificationRecommendation]

        chain = diversification_prompt | self.llm.with_structured_output(DiversificationList)

        try:
            raw_output = chain.invoke({})
            return raw_output.recommendations
        except ValidationError as e:
            print(f"Validation error in diversification tool output: {e.errors()}")
            return []
        except Exception as e:
            print(f"Error in diversification tool: {e}")
            return []


    @tool
    def analyze_filling_pipeline(self, filters: FrontendFilters, current_company_id: str) -> List[FillingPipelineRecommendation]:
        """
        Identifies potential assets or research areas to fill pipeline gaps,
        focusing on specific development stages or therapeutic areas.
        Returns a list of FillingPipelineRecommendation objects.
        """
        self.data_manager.set_current_company(current_company_id)
        your_molecules = self.data_manager.get_internal_portfolio()

        all_external_molecules = [mol for mol in self.data_manager.get_all_molecules() if mol.company_id != current_company_id]
        external_molecules_for_purchase = self.data_manager.get_filtered_molecules(all_external_molecules, filters)

        internal_portfolio_context = "\n".join([mol.model_dump_json() for mol in your_molecules])
        external_molecules_context = "\n".join([mol.model_dump_json() for mol in external_molecules_for_purchase])

        filling_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert pharmaceutical pipeline developer. Your task is to identify and recommend assets or research areas to fill specific gaps in the current internal pipeline.
             Focus on opportunities that align with current development stages, therapeutic areas, or modalities that require strengthening.
             Prioritize external assets that are 'for sale' and show a good strategic fit.
             Output your recommendations strictly as a JSON list of FillingPipelineRecommendation objects. Each object must conform to the FillingPipelineRecommendation Pydantic schema."""),
            ("human", f"""My current internal portfolio:
             <internal_portfolio>
             {internal_portfolio_context}
             </internal_portfolio>

             Available external molecules (filtered by user criteria and 'for_sale' status):
             <external_molecules>
             {external_molecules_context}
             </external_molecules>

             My company ID is: {current_company_id}

             Based on this, recommend assets or internal projects to fill pipeline gaps.
             Output a JSON list of FillingPipelineRecommendation objects.""")
        ])

        class FillingList(BaseModel):
            recommendations: List[FillingPipelineRecommendation]

        chain = filling_prompt | self.llm.with_structured_output(FillingList)

        try:
            raw_output = chain.invoke({})
            return raw_output.recommendations
        except ValidationError as e:
            print(f"Validation error in filling pipeline tool output: {e.errors()}")
            return []
        except Exception as e:
            print(f"Error in filling pipeline tool: {e}")
            return []


    @tool
    def analyze_augmenting_pipeline(self, filters: FrontendFilters, current_company_id: str) -> List[AugmentingPipelineRecommendation]:
        """
        Suggests new indications or expanded use for existing internal pipeline assets.
        Leverages external trial data, scientific literature, and mechanistic overlaps.
        Returns a list of AugmentingPipelineRecommendation objects.
        """
        self.data_manager.set_current_company(current_company_id)
        your_molecules = self.data_manager.get_filtered_molecules(self.data_manager.get_internal_portfolio(), filters)

        all_trials = list(self.data_manager.external_trial_data.values())
        # Filter trials relevant to the query/molecules. This is a simplified approach,
        # in reality, you might need more sophisticated trial filtering based on keywords etc.
        relevant_trials = []
        for trial in all_trials:
            # Simple keyword matching for demonstration
            if filters.search_query and filters.search_query.lower() in trial.indication.lower():
                relevant_trials.append(trial)
            # Or if the trial's molecule ID is in our filtered internal molecules
            if trial.molecule_id in [mol.id for mol in your_molecules]:
                relevant_trials.append(trial)

        literature_snippets = self.data_manager.literature_vs.similarity_search(
            filters.search_query if filters.search_query else "drug repurposing OR new indications OR mechanistic insights", k=5
        )

        internal_portfolio_context = "\n".join([mol.model_dump_json() for mol in your_molecules])
        external_trial_context = "\n".join([trial.model_dump_json() for trial in relevant_trials])
        literature_context = "\n".join([doc.page_content for doc in literature_snippets])

        augmenting_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert in clinical development and drug repurposing. Your goal is to identify existing internal pipeline assets that could be augmented with new indications or expanded uses.
             Consider the following:
             - Mechanistic overlap with other diseases.
             - Promising clinical trial data from similar compounds or targets.
             - Relevant scientific literature suggesting new applications.
             - Unmet medical needs in related disease areas.

             Output your recommendations strictly as a JSON list of AugmentingPipelineRecommendation objects. Each object must conform to the AugmentingPipelineRecommendation Pydantic schema."""),
            ("human", f"""My current internal portfolio (filtered for your company {current_company_id}):
             <internal_portfolio>
             {internal_portfolio_context}
             </internal_portfolio>

             Relevant external trial data:
             <external_trial_data>
             {external_trial_context}
             </external_trial_data>

             Recent scientific literature:
             <literature>
             {literature_context}
             </literature>

             My company ID is: {current_company_id}

             Suggest new indications for existing pipeline assets.
             Output a JSON list of AugmentingPipelineRecommendation objects.""")
        ])

        class AugmentingList(BaseModel):
            recommendations: List[AugmentingPipelineRecommendation]

        chain = augmenting_prompt | self.llm.with_structured_output(AugmentingList)

        try:
            raw_output = chain.invoke({})
            return raw_output.recommendations
        except ValidationError as e:
            print(f"Validation error in augmenting tool output: {e.errors()}")
            return []
        except Exception as e:
            print(f"Error in augmenting tool: {e}")
            return []

# Instantiate the tools globally accessible
strategic_tools = StrategicTools(llm, data_manager)
tools = [
    strategic_tools.analyze_pruning_portfolio,
    strategic_tools.analyze_diversification,
    strategic_tools.analyze_filling_pipeline,
    strategic_tools.analyze_augmenting_pipeline,
]


# --- 4. LangGraph State and Workflow Definition ---

class AgentState(TypedDict):
    """
    Represents the state of our graph.
    """
    messages: List[BaseMessage]
    user_input: FrontendInput
    tool_output: Optional[List[Any]] # Can be a list of any recommendation type
    strategic_output: Optional[StrategicOutput]
    next_action: Literal["call_tool", "generate_final_response", "error"]


# Define the nodes (functions) that will operate on the state

def call_strategic_tool_node(state: AgentState) -> AgentState:
    """
    Calls the appropriate strategic tool based on the user's objective.
    """
    user_input = state["user_input"]
    strategic_objective = user_input.strategic_objective
    filters = user_input.filters
    current_company_id = user_input.current_company_id

    tool_map = {
        StrategicObjectiveEnum.PRUNING: strategic_tools.analyze_pruning_portfolio,
        StrategicObjectiveEnum.DIVERSIFICATION: strategic_tools.analyze_diversification,
        StrategicObjectiveEnum.FILLING_PIPELINE: strategic_tools.analyze_filling_pipeline,
        StrategicObjectiveEnum.AUGMENTING_PIPELINE: strategic_tools.analyze_augmenting_pipeline,
        # CUSTOM_QUERY handled by decision logic below if direct tool not applicable
    }

    tool_function = tool_map.get(strategic_objective)

    if strategic_objective == StrategicObjectiveEnum.CUSTOM_QUERY:
        # For custom query, we might just use the LLM to provide a general response
        # or it could dynamically pick a tool if the query implies one.
        # For this simplified flow, let's have it generate a general response.
        print(f"Custom Query detected. Generating general response.")
        return {**state, "next_action": "generate_final_response", "tool_output": []} # No tool output for custom query for now

    if tool_function:
        print(f"Calling tool: {tool_function.__name__} for objective: {strategic_objective.value}")
        try:
            result = tool_function(filters=filters, current_company_id=current_company_id)
            print(f"Tool {tool_function.__name__} returned {len(result)} recommendations.")
            return {**state, "tool_output": result, "next_action": "generate_final_response"}
        except Exception as e:
            print(f"Error executing tool {tool_function.__name__}: {e}")
            return {**state, "next_action": "error", "messages": state["messages"] + [AIMessage(content=f"Error executing strategic analysis: {e}")]}
    else:
        print(f"No specific tool found for objective: {strategic_objective.value}.")
        # Fallback for unexpected objectives or if tool_map isn't exhaustive
        return {**state, "next_action": "error", "messages": state["messages"] + [AIMessage(content=f"Internal error: Strategic objective '{strategic_objective.value}' not mapped to a tool.")]}


def generate_final_response_node(state: AgentState) -> AgentState:
    """
    Generates the final structured StrategicOutput based on tool output or custom query.
    """
    user_input = state["user_input"]
    tool_output = state["tool_output"]
    strategic_objective = user_input.strategic_objective

    if strategic_objective == StrategicObjectiveEnum.CUSTOM_QUERY:
        # Handle custom query directly with LLM to provide an answer
        custom_query_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert pharmaceutical industry analyst. The user has a custom query.
             Provide a comprehensive and insightful response based on the available data, even if it's general.
             If the query is too broad or requires information not in the data, state that gracefully.
             The output should still adhere to the StrategicOutput Pydantic model, even if 'recommendations' is empty.
             Focus on the 'strategic_summary' for the main answer.
             """),
            ("human", f"""User's Custom Query: {user_input.custom_query_text or user_input.filters.search_query}
             Relevant filters applied: {user_input.filters.model_dump_json(exclude_none=True)}
             Current Company ID: {user_input.current_company_id}

             Provide a strategic analysis for this custom query.""")
        ])

        # We need to explicitly make the LLM output the StrategicOutput model
        final_response_chain = custom_query_prompt | llm.with_structured_output(StrategicOutput)

        try:
            # We don't have tool_output for custom query, so pass empty recommendations for schema adherence
            strategic_output_obj = final_response_chain.invoke({})
            # Ensure recommendations are an empty list if not explicitly generated
            if not strategic_output_obj.recommendations:
                 strategic_output_obj.recommendations = []
            print(f"Generated final strategic output for custom query.")
            return {**state, "strategic_output": strategic_output_obj, "next_action": "finished"}

        except ValidationError as e:
            print(f"Validation error for custom query response: {e.errors()}")
            return {**state, "next_action": "error", "messages": state["messages"] + [AIMessage(content=f"Error structuring custom query response: {e}")]}
        except Exception as e:
            print(f"Error generating custom query response: {e}")
            return {**state, "next_action": "error", "messages": state["messages"] + [AIMessage(content=f"Error processing custom query: {e}")]}

    # For all other strategic objectives, synthesize from tool_output
    else:
        if not tool_output:
            summary_content = f"No specific recommendations could be generated for {strategic_objective.value} based on the provided filters and data."
            recommendations_list = []
        else:
            summary_content = f"Recommendations for '{strategic_objective.value}' objective:\n"
            recommendations_list = tool_output
            for rec in tool_output:
                summary_content += f"- {rec.model_dump_json(indent=2)}\n"

        summary_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an executive assistant for a pharmaceutical company. Your task is to synthesize the strategic recommendations
             provided by the analytical tools into a comprehensive and actionable StrategicOutput report.

             Summarize the key findings, provide an overall strategic summary, detail any recommended portfolio adjustments,
             and suggest characteristics of an ideal future portfolio if relevant.

             Output the response strictly in the JSON format defined by the StrategicOutput Pydantic model.
             If no specific recommendations were generated by the tool, explain why and provide a general strategic insight."""),
            ("human", f"""User's Strategic Objective: {strategic_objective.value}
             User's Filters: {user_input.filters.model_dump_json(exclude_none=True)}

             Analysis Results from Tools (raw recommendations):
             <tool_results>
             {json.dumps([rec.model_dump() for rec in recommendations_list], indent=2)}
             </tool_results>

             Please generate the final StrategicOutput report.""")
        ])

        final_response_chain = summary_prompt | llm.with_structured_output(StrategicOutput)

        try:
            strategic_output_obj = final_response_chain.invoke({})
            # Ensure recommendations are properly set, even if the LLM sometimes puts them in summary
            strategic_output_obj.recommendations = recommendations_list
            print(f"Generated final strategic output for objective: {strategic_objective.value}")
            return {**state, "strategic_output": strategic_output_obj, "next_action": "finished"}
        except ValidationError as e:
            print(f"Validation error generating final response: {e.errors()}")
            return {**state, "next_action": "error", "messages": state["messages"] + [AIMessage(content=f"Error structuring final report: {e}")]}
        except Exception as e:
            print(f"Error generating final response: {e}")
            return {**state, "next_action": "error", "messages": state["messages"] + [AIMessage(content=f"Error generating final strategic report: {e}")]}


def handle_error_node(state: AgentState) -> AgentState:
    """Handles errors encountered during graph execution."""
    print(f"Error state reached. Messages: {state['messages']}")
    error_message = state['messages'][-1].content if state['messages'] else 'Unknown error'
    return {**state, "strategic_output": StrategicOutput(
        strategic_objective_addressed=state["user_input"].strategic_objective,
        recommendations=[],
        strategic_summary=f"An error occurred during analysis for '{state['user_input'].strategic_objective.value}': {error_message}",
        overall_impact_on_portfolio={},
        recommended_portfolio_adjustments={},
        suggested_ideal_portfolio_characteristics={}
    ), "next_action": "finished"}


# --- 5. Build the LangGraph Workflow ---

def build_strategic_pipeline_graph():
    workflow = StateGraph(AgentState)

    # Add nodes for tool calling and response generation
    workflow.add_node("call_tool", call_strategic_tool_node)
    workflow.add_node("generate_response", generate_final_response_node)
    workflow.add_node("error_handler", handle_error_node)

    # Set the entry point
    workflow.set_entry_point("call_tool")

    # Define the edges (transitions)
    workflow.add_conditional_edges(
        "call_tool",
        lambda state: state["next_action"],
        {
            "generate_final_response": "generate_response",
            "error": "error_handler"
        }
    )
    workflow.add_conditional_edges(
        "generate_response",
        lambda state: "finished" if state.get("strategic_output") else "error",
        {
            "finished": END,
            "error": "error_handler"
        }
    )
    workflow.add_edge("error_handler", END) # End if error is handled

    return workflow.compile()

# Compile the graph
strategic_pipeline = build_strategic_pipeline_graph()


# --- 6. Main Execution Function ---

async def run_strategic_analysis(user_input_json: Dict[str, Any]) -> Dict[str, Any]:
    """
    Main function to run the strategic analysis pipeline.
    Takes user input as a dictionary, validates it, and runs the LangGraph.
    """
    try:
        user_input = FrontendInput(**user_input_json)
    except ValidationError as e:
        print(f"Input validation error: {e.errors()}")
        return {
            "error": "Invalid input format",
            "details": e.errors()
        }

    # Set the current company context in the DataManager BEFORE running the pipeline
    data_manager.set_current_company(user_input.current_company_id)

    print(f"\n--- Starting Analysis for Objective: {user_input.strategic_objective.value} (Company: {user_input.current_company_id}) ---")

    initial_state = AgentState(
        messages=[HumanMessage(content=f"Analyze portfolio for: {user_input.strategic_objective.value} with filters: {user_input.filters.model_dump_json(exclude_none=True)}")],
        user_input=user_input,
        tool_output=None,
        strategic_output=None,
        next_action="call_tool"
    )

    final_state_output = None
    async for state in strategic_pipeline.astream(initial_state):
        if "__end__" in state:
            final_state_output = state["__end__"]
            break
        # Print intermediate states if debugging
        # print(f"Intermediate State: {state}")

    if final_state_output and "strategic_output" in final_state_output:
        print("--- Analysis Complete ---")
        return final_state_output["strategic_output"].model_dump()
    else:
        print("--- Analysis FAILED or No Output ---")
        return {
            "error": "Analysis failed to produce a valid strategic output.",
            "final_state": final_state_output # For debugging
        }



In [None]:
pip install chromadb

Collecting chromadb
  Downloading chromadb-1.0.15-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.0 kB)
Collecting pybase64>=1.4.1 (from chromadb)
  Downloading pybase64-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting posthog<6.0.0,>=2.4.0 (from chromadb)
  Downloading posthog-5.4.0-py3-none-any.whl.metadata (5.7 kB)
Collecting onnxruntime>=1.14.1 (from chromadb)
  Downloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting opentelemetry-api>=1.2.0 (from chromadb)
  Downloading opentelemetry_api-1.34.1-py3-none-any.whl.metadata (1.5 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)
  Downloading opentelemetry_exporter_otlp_proto_grpc-1.34.1-py3-none-any.whl.metadata (2.4 kB)
Collecting opentelemetry-sdk>=1.2.0 (from chromadb)
  Downloading opentelemetry_sdk-1.34.1-py3-none-any.whl.metadata (1.6 k

In [None]:
# --- 7. Example Usage ---
if __name__ == "__main__":
    async def main():
        # Example 1: Pruning the Portfolio for InnovatePharma Inc.
        pruning_input_json = {
            "filters": {
                "for_sale": "all",
                "asset_type": "Molecule",
                "asset_phenotype": {
                    "stages_of_development": ["Phase 2", "Phase 3", "Marketed"],
                    "therapeutic_area": ["Metabolic", "Oncology", "Dermatology"]
                }
            },
            "strategic_objective": "Pruning the Portfolio",
            # CORRECTED: Using the correct company ID from your JSON
            "current_company_id": "YOUR_COMPANY_001"
        }

        print("\n--- Running Pruning Analysis ---")
        try:
            pruning_output = await run_strategic_analysis(pruning_input_json)
            print("\nPruning Output:")
            # Use json.dumps to pretty print the output
            print(json.dumps(pruning_output, indent=2))
        except Exception as e:
            print(f"An error occurred during strategic analysis: {e}")

    await main() # Directly await the main function in an interactive environment


--- Running Pruning Analysis ---

--- Starting Analysis for Objective: Pruning the Portfolio (Company: YOUR_COMPANY_001) ---
An error occurred during strategic analysis: 'StructuredTool' object has no attribute '__name__'


In [None]:
# --- 7. Example Usage ---
if __name__ == "__main__":
    # import asyncio

    async def main():
        # Example 1: Pruning the Portfolio for InnovatePharma Inc.
        pruning_input_json = {
            "filters": {
                "for_sale": "all",
                "asset_type": "Molecule",
                "asset_phenotype": {
                    "stages_of_development": ["Phase 2", "Phase 3", "Marketed"],
                    "therapeutic_area": ["Metabolic", "Oncology", "Dermatology"]
                }
            },
            "strategic_objective": "Pruning the Portfolio",
            "current_company_id": "YOUR_COMPANY_001"
        }

        print("\n--- Running Pruning Analysis ---")
        pruning_output = await run_strategic_analysis(pruning_input_json)
        print("\nPruning Output:")
        print(json.dumps(pruning_output, indent=2))

        # # Example 2: Diversification for InnovatePharma Inc.
        # diversification_input_json = {
        #     "filters": {
        #         "for_sale": "for_purchase",
        #         "asset_type": "Molecule",
        #         "asset_phenotype": {
        #             "therapeutic_area": ["Rare Diseases", "Neuroscience"],
        #             "modality": ["Gene Therapy", "Cell Therapy"]
        #         },
        #         "search_query": "novel gene therapy for neurodegenerative disorders with high unmet need"
        #     },
        #     "strategic_objective": "Diversification",
        #     "current_company_id": "YOUR_COMPANY_001"
        # }
        # print("\n--- Running Diversification Analysis ---")
        # diversification_output = await run_strategic_analysis(diversification_input_json)
        # print("\nDiversification Output:")
        # print(json.dumps(diversification_output, indent=2))

        # # Example 3: Filling the Pipeline (looking for Preclinical Oncology assets) for InnovatePharma Inc.
        # filling_input_json = {
        #     "filters": {
        #         "for_sale": "for_purchase",
        #         "asset_type": "Molecule",
        #         "asset_phenotype": {
        #             "stages_of_development": ["Preclinical", "Phase 1"],
        #             "therapeutic_area": ["Oncology"]
        #         }
        #     },
        #     "strategic_objective": "Filling the Pipeline",
        #     "current_company_id": "YOUR_COMPANY_001"
        # }
        # print("\n--- Running Filling Pipeline Analysis ---")
        # filling_output = await run_strategic_analysis(filling_input_json)
        # print("\nFilling Pipeline Output:")
        # print(json.dumps(filling_output, indent=2))

        # # Example 4: Augmenting the Pipeline (looking for new indications for existing cardio assets) for InnovatePharma Inc.
        # augmenting_input_json = {
        #     "filters": {
        #         "asset_type": "Molecule",
        #         "asset_phenotype": {
        #             "therapeutic_area": ["Cardiovascular"],
        #             "stages_of_development": ["Phase 2", "Phase 3", "Marketed"]
        #         },
        #         "search_query": "new applications for enzyme replacement therapy in heart conditions"
        #     },
        #     "strategic_objective": "Augmenting the Pipeline / New Indication Space",
        #     "current_company_id": "YOUR_COMPANY_001"
        # }
        # print("\n--- Running Augmenting Pipeline Analysis ---")
        # augmenting_output = await run_strategic_analysis(augmenting_input_json)
        # print("\nAugmenting Pipeline Output:")
        # print(json.dumps(augmenting_output, indent=2))

        # # Example 5: Custom Query
        # custom_query_input_json = {
        #     "filters": {
        #         "search_query": "What are the key trends in metabolic disease drug development for small molecules?"
        #     },
        #     "strategic_objective": "Custom Query",
        #     "custom_query_text": "Provide insights on the key trends in metabolic disease drug development, focusing on small molecules and potential future targets.",
        #     "current_company_id": "YOUR_COMPANY_001"
        # }
        # print("\n--- Running Custom Query Analysis ---")
        # custom_query_output = await run_strategic_analysis(custom_query_input_json)
        # print("\nCustom Query Output:")
        # print(json.dumps(custom_query_output, indent=2))

    asyncio.run(main())

RuntimeError: asyncio.run() cannot be called from a running event loop

In [None]:
import os
import random
import json
from typing import List, Dict, Any, Optional, Literal, TypedDict
from pydantic import BaseModel, Field, conint, confloat, ValidationError
from enum import Enum
from datetime import datetime, timedelta
from typing import Union

# LangChain specific imports
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.tools import tool, StructuredTool # Import StructuredTool
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.documents import Document # Moved here for consistency

# LangGraph imports
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode # Import ToolNode from langgraph.prebuilt

# Ensure your GOOGLE_API_KEY is correctly set.
GOOGLE_API_KEY = "AIzaSyC4q-ry8oPTjBHDP1suYrtB2PX52MXREwg"

LANGCHAIN_API_KEY = "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" # Optional, for Langsmith tracing
LANGCHAIN_TRACING_V2 = "true" # Enable Langsmith for tracing!
LANGCHAIN_PROJECT = "Strategic Portfolio AI"

# --- 1. Enums and Pydantic Models for Data and Input/Output ---

# ENUMS for fixed, categorical choices (UI-driven)
class AssetType(str, Enum):
    MOLECULE = "Molecule"
    PLATFORM = "Platform"
    DEVICE = "Device"

class StrategicObjectiveEnum(str, Enum):
    PRUNING = "Pruning the Portfolio"
    DIVERSIFICATION = "Diversification"
    FILLING_PIPELINE = "Filling the Pipeline"
    AUGMENTING = "Augmenting the Pipeline / New Indication Space" # Changed from AUGMENTING_PIPELINE to AUGMENTING for brevity
    CUSTOM_QUERY = "Custom Query"

# Pydantic Models for Data (flexible string types for dynamic fields)
class Molecule(BaseModel):
    id: str
    name: str
    ndc_code: Optional[str] = None
    asset_type: AssetType
    development_stage: str
    therapeutic_area: str
    indication: str
    mechanism_of_action: str
    route_of_administration: str
    modality: str
    patent_expiry_year: int
    current_roi: Optional[float] = None
    projected_peak_sales_M: Optional[confloat(ge=0)] = None
    internal_risk_score: confloat(ge=0, le=1)
    efficacy_profile: str
    safety_profile: str
    is_for_sale: bool = False
    company_id: str
    company_name: str

class Company(BaseModel):
    id: str
    name: str
    type: str
    development_stage: str
    headquarters: str
    financial_status: str
    partner_status: Optional[str] = None
    territory: str

class ExternalTrialData(BaseModel):
    trial_id: str
    molecule_id: str
    indication: str
    development_stage: str
    results_summary: str
    safety_profile: str
    efficacy_profile: str
    target_biology: str
    pathway_overlap: List[str] = Field(default_factory=list)
    patient_population_characteristics: str

class MarketBenchmark(BaseModel):
    therapeutic_area: str
    development_stage: str
    avg_roi: float
    avg_peak_sales_M: confloat(ge=0)
    avg_time_in_stage_months: conint(gt=0)
    success_rate: confloat(ge=0, le=1)

# Frontend Input Model (Adapted to new data types for filters)
class AssetPhenotypeFilters(BaseModel):
    stages_of_development: Optional[List[str]] = None
    therapeutic_area: Optional[List[str]] = None
    indication: Optional[List[str]] = None
    mechanism_of_action: Optional[List[str]] = None
    route_of_administration: Optional[List[str]] = None
    modality: Optional[List[str]] = None
    patent_expiry_yr: Optional[conint(gt=1900, le=2100)] = None

class CompanyDetailsFilters(BaseModel):
    company_type: Optional[List[str]] = None
    development_stage: Optional[List[str]] = None
    headquarters: Optional[List[str]] = None
    financial_status: Optional[List[str]] = None
    partner_status: Optional[List[str]] = None
    territory: Optional[List[str]] = None

class PeakSalesFilters(BaseModel):
    one_yr_sales_potential_M: Optional[confloat(ge=0)] = None
    five_yr_sales_potential_M: Optional[confloat(ge=0)] = None
    peak_sales_M: Optional[confloat(ge=0)] = None

class FrontendFilters(BaseModel):
    for_sale: Literal["all", "for_sale", "for_purchase", "sold"] = "all"
    asset_type: Optional[AssetType] = None
    deal_value_min: Optional[confloat(ge=0)] = None
    deal_value_max: Optional[confloat(ge=0)] = None
    asset_phenotype: Optional[AssetPhenotypeFilters] = Field(default_factory=AssetPhenotypeFilters)
    company_details: Optional[CompanyDetailsFilters] = Field(default_factory=CompanyDetailsFilters)
    peak_sales: Optional[PeakSalesFilters] = Field(default_factory=PeakSalesFilters)
    search_query: Optional[str] = None

class FrontendInput(BaseModel):
    filters: FrontendFilters = Field(default_factory=FrontendFilters)
    strategic_objective: StrategicObjectiveEnum
    custom_query_text: Optional[str] = None
    current_company_id: str # This ID comes with the user input


# Output Structures for Recommendations
class PruningRecommendation(BaseModel):
    action_type: Literal["Deprioritize Asset"]
    molecule_name: str
    molecule_id: str
    justification: str
    reason_criteria: List[str]
    risk_score: confloat(ge=0, le=1)
    opportunity_cost_estimate: str
    impact_on_portfolio: Dict[str, Any]

class DiversificationRecommendation(BaseModel):
    action_type: Literal["Acquire Asset", "Invest in Research Area"]
    molecule_name: Optional[str] = None
    molecule_id: Optional[str] = None
    reason_for_diversification: str
    strategic_fit_score: confloat(ge=0, le=1)
    target_disease_area: str
    proposed_moa: str
    proposed_modality: str

class FillingPipelineRecommendation(BaseModel):
    action_type: Literal["Acquire Asset", "Initiate Internal Project"]
    molecule_name: str
    molecule_id: str
    reason_for_suggestion: str
    suggested_role: str
    development_stage_fit: str

class AugmentingPipelineRecommendation(BaseModel):
    action_type: Literal["Initiate New Indication Trial"]
    molecule_name: str
    molecule_id: str
    new_indication: str
    justification: str
    evidence_strength: Literal["Low", "Moderate", "High", "Very High"]
    market_potential_estimate: str

class StrategicOutput(BaseModel):
    strategic_objective_addressed: StrategicObjectiveEnum
    # *** CRITICAL CHANGE HERE ***
    recommendations: List[
        Union[
            PruningRecommendation,
            DiversificationRecommendation,
            FillingPipelineRecommendation,
            AugmentingPipelineRecommendation
        ]
    ] = Field(default_factory=list) # Added default_factory for robustness
    strategic_summary: str
    overall_impact_on_portfolio: Dict[str, Any]
    recommended_portfolio_adjustments: Dict[str, Any]
    suggested_ideal_portfolio_characteristics: Dict[str, Any]

# Generate JSON schemas for the output models
strategic_output_schema = StrategicOutput.model_json_schema()
pruning_recommendation_schema = PruningRecommendation.model_json_schema()

# --- Initialize Global LLM and Embeddings ---
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GOOGLE_API_KEY)


# --- 2. Data Management Layer ---

class DataManager:
    def __init__(self, data_file_path: str, embeddings_model: GoogleGenerativeAIEmbeddings):
        self.data_file_path = data_file_path
        self.embeddings = embeddings_model
        self._load_all_data() # Load all data once at initialization

        # Initialize these as None; they will be populated when set_current_company is called
        self.internal_portfolio_vs: Optional[Chroma] = None
        self.competitive_landscape_vs: Optional[Chroma] = None
        self.external_molecule_database_vs: Optional[Chroma] = None
        self.external_trial_data_vs: Optional[Chroma] = None
        self.literature_vs: Optional[Chroma] = None

        self._initialize_static_vector_stores()


    def _initialize_static_vector_stores(self):
        """Initializes vector stores that do not depend on `your_company_id`."""
        # Initialize ChromaDB vector stores that don't depend on self.your_company_id
        # The external_molecule_database_vs and external_trial_data_vs can be initialized with all data
        self.external_molecule_database_vs = Chroma.from_documents(
            documents=[doc for m in self.all_molecules.values() for doc in self._molecule_to_documents(m)] if self.all_molecules else [],
            embedding=self.embeddings,
            collection_name="external_molecule_database_v3"
        )
        self.external_trial_data_vs = Chroma.from_documents(
            documents=[doc for t in self.external_trial_data.values() for doc in self._trial_to_documents(t)] if self.external_trial_data else [],
            embedding=self.embeddings,
            collection_name="external_trial_data_v3"
        )
        self.literature_vs = Chroma.from_texts(
            texts=[s["text"] for s in self.literature_snippets] if self.literature_snippets else [],
            metadatas=[{"source": s["source"]} for s in self.literature_snippets] if self.literature_snippets else [],
            embedding=self.embeddings,
            collection_name="literature_snippets_v3"
        )


    def _load_all_data(self):
        """Loads and parses all data from the single JSON file."""
        try:
            with open(self.data_file_path, 'r') as f:
                data = json.load(f)

            self.all_companies: Dict[str, Company] = {c["id"]: Company(**c) for c in data.get("companies", [])}
            self.all_molecules: Dict[str, Molecule] = {m["id"]: Molecule(**m) for m in data.get("molecules", [])}
            self.external_trial_data: Dict[str, ExternalTrialData] = {t["trial_id"]: ExternalTrialData(**t) for t in data.get("external_trial_data", [])}
            self.market_benchmarks: List[MarketBenchmark] = [MarketBenchmark(**bm) for bm in data.get("market_benchmarks", [])]
            self.literature_snippets: List[Dict[str,str]] = data.get("literature_snippets", [])

            print(f"Loaded {len(self.all_companies)} companies.")
            print(f"Loaded {len(self.all_molecules)} molecules.")
            print(f"Loaded {len(self.external_trial_data)} external trial data entries.")
            print(f"Loaded {len(self.market_benchmarks)} market benchmarks.")
            print(f"Loaded {len(self.literature_snippets)} literature snippets.")

            # These will be set dynamically based on user input
            self.your_company_id: Optional[str] = None
            self.your_company_profile: Optional[Company] = None
            self.internal_molecules: Dict[str, Molecule] = {}

        except FileNotFoundError:
            raise FileNotFoundError(f"Data file not found at: {self.data_file_path}")
        except ValidationError as e:
            print(f"Pydantic validation error when loading data: {e.errors()}")
            raise
        except json.JSONDecodeError:
            raise ValueError(f"Error decoding JSON from {self.data_file_path}. Ensure it's valid JSON.")
        except KeyError as e:
            raise KeyError(f"Missing expected key in JSON data: {e}. Check your JSON structure.")

    def set_current_company(self, company_id: str):
        """Sets the current company for the session and initializes company-specific vector stores."""
        self.your_company_id = company_id
        self.your_company_profile = self.all_companies.get(company_id)
        self.internal_molecules = {
            mol.id: mol for mol in self.all_molecules.values() if mol.company_id == company_id
        }
        if not self.your_company_profile:
            print(f"Warning: Current company ID '{company_id}' not found in loaded data. Internal portfolio will be empty.")

        # Initialize internal_portfolio_vs and competitive_landscape_vs here
        self.internal_portfolio_vs = Chroma.from_documents(
            documents=[doc for m in self.internal_molecules.values() for doc in self._molecule_to_documents(m)] if self.internal_molecules else [],
            embedding=self.embeddings,
            collection_name="internal_portfolio_v3"
        )
        self.competitive_landscape_vs = Chroma.from_documents(
            documents=[doc for m in self.all_molecules.values() if m.company_id != self.your_company_id for doc in self._molecule_to_documents(m)] if self.all_molecules else [],
            embedding=self.embeddings,
            collection_name="competitive_landscape_v3"
        )


    def _molecule_to_documents(self, molecule: Molecule) -> List[Document]:
        """Converts a Molecule Pydantic model to LangChain Document(s)."""
        content = (
            f"Molecule Name: {molecule.name}\n"
            f"ID: {molecule.id}\n"
            f"Type: {molecule.asset_type.value}\n"
            f"Stage: {molecule.development_stage}\n"
            f"Therapeutic Area: {molecule.therapeutic_area}\n"
            f"Indication: {molecule.indication}\n"
            f"Mechanism of Action: {molecule.mechanism_of_action}\n"
            f"Modality: {molecule.modality}\n"
            f"Patent Expiry: {molecule.patent_expiry_year}\n"
            f"Projected Peak Sales: ${molecule.projected_peak_sales_M or 'N/A'}M\n"
            f"Internal Risk Score: {molecule.internal_risk_score}\n"
            f"Efficacy: {molecule.efficacy_profile}\n"
            f"Safety: {molecule.safety_profile}\n"
            f"For Sale: {molecule.is_for_sale}\n"
            f"Company: {molecule.company_name} ({molecule.company_id})"
        )
        metadata = molecule.model_dump()
        metadata["asset_type"] = metadata["asset_type"].value # Ensure enum is string
        return [Document(page_content=content, metadata=metadata)]

    def _trial_to_documents(self, trial: ExternalTrialData) -> List[Document]:
        """Converts an ExternalTrialData Pydantic model to LangChain Document(s)."""
        content = (
            f"Trial ID: {trial.trial_id}\n"
            f"Molecule ID: {trial.molecule_id}\n"
            f"Indication: {trial.indication}\n"
            f"Stage: {trial.development_stage}\n"
            f"Results: {trial.results_summary}\n"
            f"Safety: {trial.safety_profile}\n"
            f"Efficacy: {trial.efficacy_profile}\n"
            f"Target Biology: {trial.target_biology}\n"
            f"Pathway Overlap: {', '.join(trial.pathway_overlap)}\n" # FIX: Join list into string
            f"Patient Population: {trial.patient_population_characteristics}"
        )
        metadata = trial.model_dump()
        # FIX: Convert pathway_overlap list to a comma-separated string for metadata
        metadata["pathway_overlap"] = ", ".join(trial.pathway_overlap)
        return [Document(page_content=content, metadata=metadata)]

    def get_molecule_by_id(self, molecule_id: str) -> Optional[Molecule]:
        return self.all_molecules.get(molecule_id)

    def get_company_by_id(self, company_id: str) -> Optional[Company]:
        return self.all_companies.get(company_id)

    def get_market_benchmark(self, ta: str, ds: str) -> Optional[MarketBenchmark]:
        for bm in self.market_benchmarks:
            if bm.therapeutic_area.lower() == ta.lower() and bm.development_stage.lower() == ds.lower():
                return bm
        return None

    def get_internal_portfolio(self) -> List[Molecule]:
        return list(self.internal_molecules.values())

    def get_all_molecules(self) -> List[Molecule]:
        return list(self.all_molecules.values())

    def get_filtered_molecules(self, molecules: List[Molecule], filters: FrontendFilters) -> List[Molecule]:
        """
        Splitting this function into several smaller parts:
        - _apply_asset_filters
        - _apply_deal_value_filters
        - _apply_phenotype_filters
        - _apply_company_filters
        """
        filtered_list = []
        for mol in molecules:
            # Apply 'for_sale' filter
            if filters.for_sale == "for_sale" and not mol.is_for_sale:
                continue
            if filters.for_sale == "for_purchase" and (not mol.is_for_sale or mol.company_id == self.your_company_id):
                continue
            # "sold" is a historical state, not a current filter on available assets

            # Apply asset_type filter
            if filters.asset_type and mol.asset_type != filters.asset_type:
                continue

            # Apply deal value filters (based on projected peak sales)
            if filters.deal_value_min is not None and (mol.projected_peak_sales_M is None or mol.projected_peak_sales_M < filters.deal_value_min):
                continue
            if filters.deal_value_max is not None and (mol.projected_peak_sales_M is None or mol.projected_peak_sales_M > filters.deal_value_max):
                continue
            if filters.peak_sales:
                if filters.peak_sales.peak_sales_M is not None and (mol.projected_peak_sales_M is None or mol.projected_peak_sales_M < filters.peak_sales.peak_sales_M):
                    continue
                # 1yr/5yr sales are not in Molecule model, skip for now.

            # Apply AssetPhenotypeFilters
            if filters.asset_phenotype:
                if filters.asset_phenotype.stages_of_development and mol.development_stage not in filters.asset_phenotype.stages_of_development:
                    continue
                if filters.asset_phenotype.therapeutic_area and mol.therapeutic_area not in filters.asset_phenotype.therapeutic_area:
                    continue
                if filters.asset_phenotype.indication and mol.indication not in filters.asset_phenotype.indication:
                    continue
                if filters.asset_phenotype.mechanism_of_action and mol.mechanism_of_action not in filters.asset_phenotype.mechanism_of_action:
                    continue
                if filters.asset_phenotype.route_of_administration and mol.route_of_administration not in filters.asset_phenotype.route_of_administration:
                    continue
                if filters.asset_phenotype.modality and mol.modality not in filters.asset_phenotype.modality:
                    continue
                if filters.asset_phenotype.patent_expiry_yr is not None and mol.patent_expiry_year < filters.asset_phenotype.patent_expiry_yr:
                    continue

            # Apply CompanyDetailsFilters (requires looking up company info)
            if filters.company_details:
                company_of_mol = self.get_company_by_id(mol.company_id)
                if company_of_mol:
                    if filters.company_details.company_type and company_of_mol.type not in filters.company_details.company_type:
                        continue
                    if filters.company_details.development_stage and company_of_mol.development_stage not in filters.company_details.development_stage:
                        continue
                    if filters.company_details.headquarters and company_of_mol.headquarters not in filters.company_details.headquarters:
                        continue
                    if filters.company_details.financial_status and company_of_mol.financial_status not in filters.company_details.financial_status:
                        continue
                    if filters.company_details.partner_status and company_of_mol.partner_status not in filters.company_details.partner_status:
                        continue
                    if filters.company_details.territory and company_of_mol.territory not in filters.company_details.territory:
                        continue
                else: # If company data for molecule is missing or doesn't exist, it doesn't match
                    continue

            filtered_list.append(mol)
        return filtered_list

# Instantiate DataManager (assuming 'company.json' exists in the same directory)
data_manager = DataManager('/content/company.json', embeddings)


# --- 3. Tool Definitions ---

@tool
def get_company_details(company_id: str) -> Company:
    """
    Retrieves detailed information about a specific company by its ID.
    Useful for understanding the profile of a competitor or partner.
    """
    company = data_manager.get_company_by_id(company_id)
    if not company:
        raise ValueError(f"Company with ID {company_id} not found.")
    return company

@tool
def get_molecule_details(molecule_id: str) -> Molecule:
    """
    Retrieves detailed information about a specific molecule by its ID.
    Useful for understanding the characteristics of an asset in the portfolio or competitive landscape.
    """
    molecule = data_manager.get_molecule_by_id(molecule_id)
    if not molecule:
        raise ValueError(f"Molecule with ID {molecule_id} not found.")
    return molecule

@tool
def search_internal_portfolio(query: Optional[str] = None, filters: Optional[FrontendFilters] = None) -> List[Molecule]:
    """
    Searches the current company's internal molecule portfolio using a natural language query and/or structured filters.
    Provides detailed information about internal assets that match the criteria.
    If no query or filters are provided, returns all internal molecules.
    """
    if not data_manager.internal_portfolio_vs:
        return [] # No internal portfolio set up

    # If a query is provided, use semantic search
    if query:
        retriever = data_manager.internal_portfolio_vs.as_retriever(search_kwargs={"k": 5})
        docs = retriever.invoke(query)
        # Convert documents back to Molecule models
        molecules = []
        for doc in docs:
            mol_data = doc.metadata
            try:
                # Need to convert asset_type string back to AssetType Enum for Pydantic
                mol_data['asset_type'] = AssetType(mol_data['asset_type'])
                molecules.append(Molecule(**mol_data))
            except ValidationError as e:
                print(f"Validation error converting document metadata to Molecule: {e}")
                continue
        # Apply structured filters to the semantically retrieved results
        if filters:
            molecules = data_manager.get_filtered_molecules(molecules, filters)
        return molecules
    elif filters:
        # If only filters are provided, get all internal molecules and then filter
        all_internal_mols = data_manager.get_internal_portfolio()
        return data_manager.get_filtered_molecules(all_internal_mols, filters)
    else:
        # If neither query nor filters, return all internal molecules
        return data_manager.get_internal_portfolio()

@tool
def search_competitive_landscape(query: Optional[str] = None, filters: Optional[FrontendFilters] = None) -> List[Molecule]:
    """
    Searches the competitive landscape (molecules owned by other companies) using a natural language query and/or structured filters.
    Provides detailed information about competitor assets that match the criteria.
    If no query or filters are provided, returns all competitive molecules.
    """
    if not data_manager.competitive_landscape_vs:
        return [] # No competitive landscape set up (shouldn't happen if data loaded)

    if query:
        retriever = data_manager.competitive_landscape_vs.as_retriever(search_kwargs={"k": 5})
        docs = retriever.invoke(query)
        molecules = []
        for doc in docs:
            mol_data = doc.metadata
            try:
                mol_data['asset_type'] = AssetType(mol_data['asset_type'])
                molecules.append(Molecule(**mol_data))
            except ValidationError as e:
                print(f"Validation error converting document metadata to Molecule: {e}")
                continue
        if filters:
            molecules = data_manager.get_filtered_molecules(molecules, filters)
        return molecules
    elif filters:
        all_competitive_mols = [
            mol for mol in data_manager.get_all_molecules()
            if mol.company_id != data_manager.your_company_id
        ]
        return data_manager.get_filtered_molecules(all_competitive_mols, filters)
    else:
        return [mol for mol in data_manager.get_all_molecules() if mol.company_id != data_manager.your_company_id]

@tool
def search_external_trial_data(query: str, molecule_id: Optional[str] = None) -> List[ExternalTrialData]:
    """
    Searches external clinical trial data for specific molecules or general therapeutic areas.
    Useful for assessing efficacy, safety, and patient populations from external studies.
    Always requires a query. Can optionally filter by molecule_id.
    """
    if not data_manager.external_trial_data_vs:
        return []

    retriever = data_manager.external_trial_data_vs.as_retriever(search_kwargs={"k": 5})
    docs = retriever.invoke(query)
    trial_data_list = []
    for doc in docs:
        trial_data = doc.metadata
        # Reconstruct the pathway_overlap list from the comma-separated string
        if "pathway_overlap" in trial_data and isinstance(trial_data["pathway_overlap"], str):
            trial_data["pathway_overlap"] = [p.strip() for p in trial_data["pathway_overlap"].split(',')]
        try:
            # Ensure all fields expected by ExternalTrialData are present or handled
            trial_data_list.append(ExternalTrialData(**trial_data))
        except ValidationError as e:
            print(f"Validation error converting document metadata to ExternalTrialData: {e}")
            continue

    if molecule_id:
        return [td for td in trial_data_list if td.molecule_id == molecule_id]
    return trial_data_list

@tool
def get_market_benchmark_data(therapeutic_area: str, development_stage: str) -> Optional[MarketBenchmark]:
    """
    Retrieves market benchmark data (e.g., average ROI, peak sales, success rates) for a given therapeutic area and development stage.
    Useful for evaluating the potential of assets against industry standards.
    """
    return data_manager.get_market_benchmark(therapeutic_area, development_stage)

@tool
def search_literature(query: str) -> List[Dict[str, str]]:
    """
    Searches the scientific literature snippets for relevant information.
    Useful for gathering background, novel insights, or supporting evidence for strategic decisions.
    """
    if not data_manager.literature_vs:
        return []
    retriever = data_manager.literature_vs.as_retriever(search_kwargs={"k": 5})
    docs = retriever.invoke(query)
    return [{"text": doc.page_content, "source": doc.metadata.get("source", "N/A")} for doc in docs]

# Add this tool definition within your existing tool definitions
@tool
def submit_strategic_analysis(analysis: StrategicOutput) -> str:
    """
    Use this tool to submit the final strategic portfolio analysis and recommendations.
    The 'analysis' argument MUST be a complete StrategicOutput object, strictly
    adhering to its schema, including all nested recommendation types (e.g., PruningRecommendation).
    This is the final step for outputting the strategic analysis.
    """
    # This tool doesn't "do" anything other than act as a structured output gate.
    # The actual data will be captured by the caller.
    return "Strategic analysis submitted successfully."


# IMPORTANT: Collect all tools into a list to be passed to bind_tools later
tools = [
    get_company_details,
    get_molecule_details,
    search_internal_portfolio,
    search_competitive_landscape,
    search_external_trial_data,
    get_market_benchmark_data,
    submit_strategic_analysis,
    search_literature
]


# --- 4. Agent State ---
class AgentState(TypedDict):
    """
    Represents the state of our agent.
    - `input`: The original `FrontendInput` provided by the user.
    - `current_company_id`: The ID of the user's company, extracted from input.
    - `chat_history`: A list of messages between the user and the agent.
    - `strategic_objective`: The strategic objective enum.
    - `custom_query_text`: Free text query for "Custom Query" objective.
    - `filters`: Structured filters for asset and company details.
    - `scratchpad`: A string to store intermediate thoughts, tool outputs, or reflections for the LLM.
    - `recommendations`: A list to store the final recommendations.
    - `final_output`: The Pydantic model for the final structured output.
    """
    input: FrontendInput
    current_company_id: str
    chat_history: List[BaseMessage]
    strategic_objective: StrategicObjectiveEnum
    custom_query_text: Optional[str]
    filters: FrontendFilters
    scratchpad: str
    recommendations: List[Any]
    final_output: Optional[StrategicOutput]


# --- 5. Graph Nodes ---

# Define the LLM with tool binding
# You should bind the tools to the LLM that will be used for deciding which tool to call.
# This ensures the LLM knows about the tools and their schemas.
# This `llm_with_tools` will be used in the agent's decision-making process.
llm_with_tools = llm.bind_tools(tools)


def call_tool_orchestrator(state: AgentState):
    """
    This node will use the LLM to decide which tool to call based on the strategic objective
    and the current state (filters, custom_query_text, etc.).
    It then returns the tool call as a string that the ToolNode can execute.
    """
    print("---CALLING TOOL ORCHESTRATOR (LLM DECIDING)---")
    # Define a prompt for the LLM to choose and format a tool call
    # Now using placeholders for dynamic content
    tool_prompt = ChatPromptTemplate.from_messages([
        SystemMessagePromptTemplate.from_template("""You are an expert strategic analyst. Based on the user's strategic objective,
        identify the most relevant tool to call and its parameters.
        You have access to the following tools: {tool_names}.

        Strategic Objective: {objective}
        Current Company ID: {current_company_id}
        Current Company Profile: {current_company_profile}
        User Filters: {filters_str}
        Custom Query Text: {custom_query_text}

        Your response MUST be a single, valid tool call. Do NOT output any other text or explanation.
        Only the tool call. For example: `tool_name(param='value', another_param=123)`.
        If no tools are directly applicable for a direct data retrieval step,
        consider which tool might provide initial relevant context for the objective.
        For Pruning, prioritize `search_internal_portfolio`.
        For Diversification/Filling/Augmenting, consider `search_competitive_landscape`, `search_external_trial_data`, or `search_literature`.
        """),
        HumanMessagePromptTemplate.from_template("What tool should I call to get the necessary data for the strategic objective?")
    ])

    try:
        # Prepare inputs for the prompt
        prompt_inputs = {
            "tool_names": ', '.join([t.name for t in tools]),
            "objective": state['strategic_objective'].value,
            "current_company_id": state['current_company_id'],
            "current_company_profile": json.dumps(data_manager.your_company_profile.model_dump(), indent=2) if data_manager.your_company_profile else "N/A",
            "filters_str": json.dumps(state['filters'].model_dump(), indent=2),
            "custom_query_text": state['custom_query_text'] or "N/A"
        }

        # Render the ChatPromptTemplate into a list of messages
        messages_to_llm = tool_prompt.invoke(prompt_inputs).messages

        tool_selection = llm_with_tools.invoke(messages_to_llm)

        # Extract the tool call from the AI message
        if not tool_selection.tool_calls:
            print("LLM did not return any tool calls. Returning to generate analysis directly.")
            return {"scratchpad": state['scratchpad'] + "\nLLM chose not to call a tool.",
                    "chat_history": state['chat_history'] + [AIMessage(content="LLM chose not to call a tool.")],
                    "tool_to_execute": None # Indicate no tool to execute
                   }

        tool_call = tool_selection.tool_calls[0] # Assuming one tool call is made
        tool_name = tool_call['name']
        tool_args = tool_call['args']

        tool_call_string = f"{tool_name}({', '.join([f'{k}={repr(v)}' for k,v in tool_args.items()])})"
        print(f"LLM decided to call: {tool_call_string}")

        return {"scratchpad": state['scratchpad'] + f"\nLLM chose tool: {tool_call_string}",
                "chat_history": state['chat_history'] + [AIMessage(content=f"LLM chose tool: {tool_call_string}")],
                "tool_to_execute": tool_call
               }

    except Exception as e:
        error_msg = f"Error in tool orchestration: {e}"
        print(error_msg)
        return {"scratchpad": state['scratchpad'] + f"\nError: {error_msg}",
                "chat_history": state['chat_history'] + [AIMessage(content=f"Error in tool orchestration: {error_msg}")],
                "tool_to_execute": None
               }


def generate_strategic_analysis(state: AgentState):
    """
    Generates the strategic analysis and recommendations by instructing the LLM
    to call the 'submit_strategic_analysis' tool with the complete StrategicOutput.
    """
    print("---GENERATING STRATEGIC ANALYSIS---")

    # Prompt the LLM to call the 'submit_strategic_analysis' tool
    # Include clear instructions and context
    prompt_template = ChatPromptTemplate.from_messages([
        SystemMessagePromptTemplate.from_template("""You are a Strategic Portfolio Manager AI. Your final task is to compile
        the strategic analysis and recommendations into a structured format.
        You have gathered all necessary information from the scratchpad.
        Your final output MUST be a call to the `submit_strategic_analysis` tool.

        The `submit_strategic_analysis` tool expects a single argument, 'analysis',
        which must be a JSON object strictly conforming to the StrategicOutput Pydantic model.
        Pay extreme attention to the schema of `StrategicOutput` and `PruningRecommendation`
        (and other recommendation types if applicable).

        Strategic Objective: {objective}
        Current company ID: {current_company_id}
        Current company profile: {current_company_profile}
        User-provided filters: {filters_str}
        Custom query text (if applicable): {custom_query_text}

        ---StrategicOutput JSON Schema---
        {strategic_output_schema}
        ---End StrategicOutput JSON Schema---

        ---PruningRecommendation JSON Schema---
        {pruning_recommendation_schema}
        ---End PruningRecommendation JSON Schema---

        Based on the data and your analysis in the scratchpad, construct the full
        `StrategicOutput` object and call the `submit_strategic_analysis` tool.
        Ensure all required fields are populated accurately and comprehensively.
        Focus on providing actionable recommendations for the '{objective}' objective.

        Example of a partial call structure (DO NOT use verbatim, populate with actual data):
        submit_strategic_analysis(analysis={{"strategic_objective_addressed": "Pruning the Portfolio", "recommendations": [{{"action_type": "Deprioritize Asset", "molecule_name": "...", "molecule_id": "...", "justification": "...", "reason_criteria": ["...", "..."], "risk_score": 0.X, "opportunity_cost_estimate": "...", "impact_on_portfolio": {{"...": "..."}}}}], "strategic_summary": "...", "overall_impact_on_portfolio": {{"...": "..."}}, "recommended_portfolio_adjustments": {{"...": "..."}}, "suggested_ideal_portfolio_characteristics": {{"...": "..."}}}})

        Remember to provide concrete recommendations for pruning assets from YOUR_COMPANY_001.
        Identify specific molecules to deprioritize, justify why, and detail the expected impact.
        """),
        HumanMessagePromptTemplate.from_template("Here is the gathered information and my thoughts so far:\n\n{scratchpad}\n\nNow, call the `submit_strategic_analysis` tool with the complete analysis.")
    ])

    try:
        prompt_inputs = {
            "objective": state['strategic_objective'].value,
            "current_company_id": state['current_company_id'],
            "current_company_profile": json.dumps(data_manager.your_company_profile.model_dump(), indent=2) if data_manager.your_company_profile else "N/A",
            "filters_str": json.dumps(state['filters'].model_dump(), indent=2),
            "custom_query_text": state['custom_query_text'] or "N/A",
            "scratchpad": state['scratchpad'],
            "strategic_output_schema": json.dumps(strategic_output_schema, indent=2),
            "pruning_recommendation_schema": json.dumps(pruning_recommendation_schema, indent=2)
        }

        # Bind all tools so the LLM knows about `submit_strategic_analysis`
        llm_with_all_tools = llm.bind_tools(tools)

        # Invoke the chain. The LLM will respond with a tool call.
        ai_message = llm_with_all_tools.invoke(prompt_template.invoke(prompt_inputs).messages)

        if not ai_message.tool_calls:
            raise ValueError("LLM did not make a tool call for strategic analysis. This is unexpected.")

        # Find the submit_strategic_analysis tool call
        submit_call = next((tc for tc in ai_message.tool_calls if tc['name'] == 'submit_strategic_analysis'), None)

        if not submit_call:
            raise ValueError("LLM did not call the 'submit_strategic_analysis' tool.")

        # The 'analysis' argument of the tool call should be the StrategicOutput object
        # LangChain's tool binding ensures these arguments are already parsed into Python types
        # if the LLM successfully generated valid JSON arguments.
        strategic_analysis_data = submit_call['args']['analysis']

        # Manually validate if necessary, though the tool binding should handle basic type checks
        validated_output = StrategicOutput(**strategic_analysis_data)

        return {"final_output": validated_output, "chat_history": state['chat_history'] + [AIMessage(content=f"Final Analysis Submitted Successfully.")]}

    except Exception as e:
        error_msg = f"Error generating strategic analysis or processing tool call: {e}\nRaw LLM message: {ai_message if 'ai_message' in locals() else 'N/A'}"
        print(error_msg)
        return {"scratchpad": state['scratchpad'] + f"\nError: {error_msg}", "final_output": None, "chat_history": state['chat_history'] + [AIMessage(content=f"Error: {error_msg}")]}

# --- 6. LangGraph Setup ---

# Define the graph
workflow = StateGraph(AgentState)

# Define nodes
workflow.add_node("call_tool_orchestrator", call_tool_orchestrator)
# Use LangGraph's prebuilt ToolNode to execute the tools
tool_executor = ToolNode(tools) # Pass the list of StructuredTools here
workflow.add_node("call_tool_executor", tool_executor)
workflow.add_node("generate_strategic_analysis", generate_strategic_analysis)

# Define edges
workflow.set_entry_point("call_tool_orchestrator")

# If `call_tool_orchestrator` successfully determines a tool to call, go to the executor
# Add a conditional edge here from tool orchestrator
def route_tool_orchestrator(state: AgentState) -> Literal["call_tool_executor", "generate_strategic_analysis"]:
    """
    Determines the next step after tool orchestration.
    If a tool was chosen, execute it. Otherwise, proceed to analysis.
    """
    if state.get('tool_to_execute'):
        return "call_tool_executor"
    else:
        return "generate_strategic_analysis"

workflow.add_conditional_edges(
    "call_tool_orchestrator",
    route_tool_orchestrator
)


# After the tool is executed, go to the strategic analysis generation
workflow.add_edge("call_tool_executor", "generate_strategic_analysis")

# After generating the analysis, the process ends
workflow.add_edge("generate_strategic_analysis", END)

# Compile the graph
app = workflow.compile()


async def run_strategic_analysis(input_json: Dict[str, Any]) -> Dict[str, Any]:
    """
    Main function to run the strategic analysis pipeline.
    Expects input_json conforming to FrontendInput schema.
    """
    try:
        frontend_input = FrontendInput(**input_json)
    except ValidationError as e:
        raise ValueError(f"Invalid input data: {e.errors()}")

    # Initialize data manager with the current company
    data_manager.set_current_company(frontend_input.current_company_id)
    print(f"--- Starting Analysis for Objective: {frontend_input.strategic_objective.value} (Company: {frontend_input.current_company_id}) ---")

    initial_state: AgentState = {
        "input": frontend_input,
        "current_company_id": frontend_input.current_company_id,
        "chat_history": [HumanMessage(content=f"Initiating strategic analysis for objective: {frontend_input.strategic_objective.value}")],
        "strategic_objective": frontend_input.strategic_objective,
        "custom_query_text": frontend_input.custom_query_text,
        "filters": frontend_input.filters,
        "scratchpad": "Starting analysis...",
        "recommendations": [],
        "final_output": None
    }

    # Run the graph
    final_state = {}
    async for state in app.astream(initial_state):
        if "__end__" not in state:
            # print(state) # Print intermediate states for debugging if needed
            # For ToolNode output, we need to extract relevant part for scratchpad
            if 'tool_output' in state: # ToolNode places output in 'tool_output' key by default
                tool_result = state['tool_output']
                tool_output_str = ""
                if isinstance(tool_result, list):
                    # Attempt to handle list of Pydantic models by dumping them
                    try:
                        tool_output_str = json.dumps([item.model_dump() if hasattr(item, 'model_dump') else item for item in tool_result], indent=2)
                    except TypeError: # Fallback for non-serializable items
                        tool_output_str = str(tool_result)
                elif isinstance(tool_result, BaseModel): # Handle single Pydantic model
                    tool_output_str = json.dumps(tool_result.model_dump(), indent=2)
                elif isinstance(tool_result, dict):
                    tool_output_str = json.dumps(tool_result, indent=2)
                else:
                    tool_output_str = str(tool_result)

                state['scratchpad'] = (
                    state.get('scratchpad', '') + f"\nTool Output ({state.get('tool_executed', 'UnknownTool')}): {tool_output_str}"
                )
                state['chat_history'].append(AIMessage(content=f"Tool Output: {tool_output_str}"))
            final_state.update(state) # Accumulate state
        else:
            final_state = state["__end__"] # Get the final state

    # The final_output will be in final_state['final_output']
    return final_state.get("final_output", {}).model_dump() if final_state.get("final_output") else {"error": "Analysis did not produce a valid output."}


# --- 7. Example Usage ---
if __name__ == "__main__":
    async def main():
        # Example 1: Pruning the Portfolio for InnovatePharma Inc.
        pruning_input_json = {
            "filters": {
                "for_sale": "all",
                "asset_type": "Molecule",
                "asset_phenotype": {
                    "stages_of_development": ["Phase 2", "Phase 3", "Marketed"],
                    "therapeutic_area": ["Metabolic", "Oncology", "Dermatology"]
                }
            },
            "strategic_objective": "Pruning the Portfolio",
            # CORRECTED: Using the correct company ID from your JSON
            "current_company_id": "YOUR_COMPANY_001"
        }

        print("\n--- Running Pruning Analysis ---")
        try:
            pruning_output = await run_strategic_analysis(pruning_input_json)
            print("\nPruning Output:")
            # Use json.dumps to pretty print the output
            print(json.dumps(pruning_output, indent=2))
        except Exception as e:
            print(f"An error occurred during strategic analysis: {e}")

    await main()

Loaded 6 companies.
Loaded 11 molecules.
Loaded 3 external trial data entries.
Loaded 4 market benchmarks.
Loaded 5 literature snippets.

--- Running Pruning Analysis ---
--- Starting Analysis for Objective: Pruning the Portfolio (Company: YOUR_COMPANY_001) ---
---CALLING TOOL ORCHESTRATOR (LLM DECIDING)---
LLM decided to call: search_internal_portfolio(filters={'for_sale': 'all', 'company_details': {'headquarters': None, 'development_stage': None, 'territory': None, 'company_type': None, 'financial_status': None, 'partner_status': None}, 'search_query': None, 'deal_value_max': None, 'peak_sales': {'one_yr_sales_potential_M': None, 'peak_sales_M': None, 'five_yr_sales_potential_M': None}, 'asset_phenotype': {'route_of_administration': None, 'stages_of_development': ['Phase 2', 'Phase 3', 'Marketed'], 'modality': None, 'therapeutic_area': ['Metabolic', 'Oncology', 'Dermatology'], 'patent_expiry_yr': None, 'indication': None, 'mechanism_of_action': None}, 'deal_value_min': None, 'asset_t

In [None]:
import os
import random
import json
from typing import List, Dict, Any, Optional, Literal, TypedDict
from pydantic import BaseModel, Field, conint, confloat, ValidationError
from enum import Enum
from datetime import datetime, timedelta
from typing import Union

# LangChain specific imports
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.documents import Document

# LangGraph imports
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode

# --- API Keys and Configuration ---
# Ensure your GOOGLE_API_KEY is correctly set.
# It's recommended to use environment variables for security.
# os.environ["GOOGLE_API_KEY"] = "YOUR_GOOGLE_API_KEY"
GOOGLE_API_KEY = "AIzaSyC4q-ry8oPTjBHDP1suYrtB2PX52MXREwg" # Replace with your key

# Optional, for Langsmith tracing
# os.environ["LANGCHAIN_API_KEY"] = "YOUR_LANGCHAIN_API_KEY"
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_PROJECT"] = "Strategic Portfolio AI"

# --- 1. Enums and Pydantic Models ---
class AssetType(str, Enum): MOLECULE = "Molecule"
class StrategicObjectiveEnum(str, Enum): PRUNING = "Pruning the Portfolio"
class Molecule(BaseModel): id: str; name: str; asset_type: AssetType; development_stage: str; therapeutic_area: str; indication: str; mechanism_of_action: str; route_of_administration: str; modality: str; patent_expiry_year: int; projected_peak_sales_M: Optional[confloat(ge=0)] = None; internal_risk_score: confloat(ge=0, le=1); efficacy_profile: str; safety_profile: str; company_id: str; company_name: str
class AssetPhenotypeFilters(BaseModel): stages_of_development: Optional[List[str]] = None
class FrontendFilters(BaseModel): asset_phenotype: Optional[AssetPhenotypeFilters] = Field(default_factory=AssetPhenotypeFilters)
class FrontendInput(BaseModel): filters: FrontendFilters = Field(default_factory=FrontendFilters); strategic_objective: StrategicObjectiveEnum; current_company_id: str
class PruningRecommendation(BaseModel): action_type: Literal["Deprioritize Asset"]; molecule_name: str; molecule_id: str; justification: str; reason_criteria: List[str]; risk_score: confloat(ge=0, le=1); opportunity_cost_estimate: str; impact_on_portfolio: Dict[str, Any]
class StrategicOutput(BaseModel): strategic_objective_addressed: StrategicObjectiveEnum; recommendations: List[PruningRecommendation] = Field(default_factory=list); strategic_summary: str; overall_impact_on_portfolio: Dict[str, Any]; recommended_portfolio_adjustments: Dict[str, Any]; suggested_ideal_portfolio_characteristics: Dict[str, Any]

# --- Initialize Global LLM ---
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.1, google_api_key=GOOGLE_API_KEY)

# --- Data Management Layer ---
class DataManager:
    # ... (This class is correct and remains unchanged) ...
    def __init__(self, data_file_path: str): self._load_all_data(data_file_path)
    def _load_all_data(self, data_file_path):
        try:
            with open(data_file_path, 'r') as f: data = json.load(f)
            self.all_molecules: Dict[str, Molecule] = {m["id"]: Molecule(**m) for m in data.get("molecules", [])}
            print(f"Loaded {len(self.all_molecules)} molecules.")
            self.your_company_id: Optional[str] = None; self.internal_molecules: Dict[str, Molecule] = {}
        except Exception as e: raise e
    def set_current_company(self, company_id: str): self.your_company_id = company_id; self.internal_molecules = {mol.id: mol for mol in self.all_molecules.values() if mol.company_id == company_id}
    def get_internal_portfolio(self) -> List[Molecule]: return list(self.internal_molecules.values())
    def get_filtered_molecules(self, molecules: List[Molecule], filters: FrontendFilters) -> List[Molecule]:
        if not filters.asset_phenotype or not filters.asset_phenotype.stages_of_development: return molecules
        return [m for m in molecules if m.development_stage in filters.asset_phenotype.stages_of_development]

if not os.path.exists("company.json"):
    dummy_data = {"molecules": [{"id": "MOL_001", "name": "Innovatinib", "asset_type": "Molecule", "development_stage": "Phase 2", "therapeutic_area": "Oncology", "indication": "NSCLC", "mechanism_of_action": "EGFR inhibitor", "route_of_administration": "Oral", "modality": "Small Molecule", "patent_expiry_year": 2035, "projected_peak_sales_M": 1200, "internal_risk_score": 0.4, "efficacy_profile": "Promising", "safety_profile": "Manageable", "company_id": "YOUR_COMPANY_001", "company_name": "InnovatePharma Inc."}, {"id": "MOL_002", "name": "Dermacure", "asset_type": "Molecule", "development_stage": "Phase 1", "therapeutic_area": "Dermatology", "indication": "Psoriasis", "mechanism_of_action": "JAK inhibitor", "route_of_administration": "Topical", "modality": "Small Molecule", "patent_expiry_year": 2038, "projected_peak_sales_M": 800, "internal_risk_score": 0.6, "efficacy_profile": "Early but positive", "safety_profile": "Good", "company_id": "YOUR_COMPANY_001", "company_name": "InnovatePharma Inc."}]}
    with open("company.json", "w") as f: json.dump(dummy_data, f, indent=2)

data_manager = DataManager('company.json')

# --- Tool Definitions ---
# Note: We are removing the submit_strategic_analysis tool as the final node will handle output.
@tool
def search_internal_portfolio(filters: FrontendFilters) -> List[Molecule]:
    """Searches the current company's internal molecule portfolio using structured filters."""
    print(f"Tool 'search_internal_portfolio' called with filters: {filters.model_dump_json(indent=2)}")
    all_internal = data_manager.get_internal_portfolio()
    return data_manager.get_filtered_molecules(all_internal, filters)

tools = [search_internal_portfolio]
llm_with_tools = llm.bind_tools(tools)

# --- Agent State and Graph Definition ---
class AgentState(TypedDict):
    input: FrontendInput
    chat_history: List[BaseMessage]
    scratchpad: str
    final_output: Optional[StrategicOutput]

# Node 1: The "Orchestrator" - decides which data-gathering tool to call
def call_tool_orchestrator(state: AgentState):
    """This node uses the LLM to decide which data-gathering tool to call."""
    print("---NODE: call_tool_orchestrator---")
    messages = state['chat_history']
    prompt = ChatPromptTemplate.from_messages([
        SystemMessage(content="You are an expert strategic analyst. Your current task is to gather data by calling the most relevant tool."),
        HumanMessage(content=f"Based on the strategic objective '{state['input'].strategic_objective.value}', what tool should I call first? The user has provided these filters: {state['input'].filters.model_dump_json(indent=2)}")
    ])
    chain = prompt | llm_with_tools
    try:
        ai_message = chain.invoke({})
        if not ai_message.tool_calls:
            print("Orchestrator chose not to call a tool. Proceeding to final analysis.")
        return {"chat_history": messages + [ai_message]}
    except Exception as e:
        print(f"Error in tool orchestration: {e}")
        return {"chat_history": messages + [AIMessage(content=f"Error: {e}")]}

# Node 2: The "Doer" - executes the tool call
tool_executor = ToolExecutor(tools)
def execute_tools_node(state: AgentState):
    """A wrapper node that executes tools."""
    print("---NODE: execute_tools_node---")
    last_message = state['chat_history'][-1]
    tool_call = last_message.tool_calls[0]
    action = ToolInvocation(tool=tool_call["name"], tool_input=tool_call["args"])
    response = tool_executor.invoke([action]) # response is a list of ToolMessage
    return {"chat_history": state['chat_history'] + response}


# Node 3: The "Processor" - formats tool output for the scratchpad
def process_tool_results(state: AgentState):
    """Processes the output from the tool node and updates the scratchpad."""
    print("---NODE: process_tool_results---")
    last_message = state['chat_history'][-1]
    if not isinstance(last_message, ToolMessage): return {}

    try:
        # The content of the ToolMessage is now a list of Pydantic objects.
        tool_output_list = last_message.content
        if isinstance(tool_output_list, list):
            serializable_list = [item.model_dump() for item in tool_output_list]
            tool_output_str = json.dumps(serializable_list, indent=2)
        else:
            tool_output_str = str(tool_output_list)
    except Exception as e:
        print(f"Could not parse tool output: {e}")
        tool_output_str = str(last_message.content)

    updated_scratchpad = state['scratchpad'] + f"\n\n### Data from `{last_message.name}`:\n```json\n{tool_output_str}\n```\n"
    return {"scratchpad": updated_scratchpad}


# Node 4: The "Generator" - creates the final JSON output
def generate_strategic_analysis(state: AgentState):
    """Generates the final structured JSON output based on the scratchpad."""
    print("---NODE: generate_strategic_analysis---")
    parser = PydanticOutputParser(pydantic_object=StrategicOutput)

    prompt_template = ChatPromptTemplate.from_messages([
        SystemMessage(content=f"""You are a world-class strategic portfolio analyst. Your task is to synthesize the provided data into a final, structured JSON report. The user's strategic objective is '{state['input'].strategic_objective.value}'. You must strictly follow the provided JSON format instructions."""),
        HumanMessage(content="Here is the data you have gathered in the scratchpad:\n{scratchpad}\n\nNow, generate the final JSON report based on this data.\n{format_instructions}")
    ])

    chain = prompt_template | llm | parser

    try:
        output = chain.invoke({
            "scratchpad": state['scratchpad'],
            "format_instructions": parser.get_format_instructions()
        })
        return {"final_output": output}
    except Exception as e:
        print(f"Error during final analysis generation: {e}")
        return {"final_output": None}


# --- Define and Compile the Graph ---
workflow = StateGraph(AgentState)
workflow.add_node("orchestrator", call_tool_orchestrator)
workflow.add_node("executor", execute_tools_node)
workflow.add_node("processor", process_tool_results)
workflow.add_node("generator", generate_strategic_analysis)

def route_from_orchestrator(state: AgentState):
    last_message = state['chat_history'][-1]
    if isinstance(last_message, AIMessage) and last_message.tool_calls:
        return "executor" # If a tool was chosen, execute it
    return "generator" # Otherwise, go straight to the final report

workflow.set_entry_point("orchestrator")
workflow.add_conditional_edges("orchestrator", route_from_orchestrator)
workflow.add_edge("executor", "processor")
workflow.add_edge("processor", "generator")
workflow.add_edge("generator", END)
app = workflow.compile()

# --- Runner Function ---
async def run_strategic_analysis(input_json: Dict[str, Any]) -> Dict[str, Any]:
    try:
        frontend_input = FrontendInput(**input_json)
    except ValidationError as e:
        return {"error": f"Invalid input data: {e.errors()}"}

    data_manager.set_current_company(frontend_input.current_company_id)
    print(f"--- Starting Analysis for Objective: {frontend_input.strategic_objective.value} (Company: {frontend_input.current_company_id}) ---")

    initial_state = {
        "input": frontend_input,
        "chat_history": [],
        "scratchpad": "## Initial Analysis State\n- Objective: " + frontend_input.strategic_objective.value,
        "final_output": None
    }

    final_state = {}
    async for output in app.astream(initial_state):
        if END in output:
            final_state = output[END]
            break

    output_obj = final_state.get("final_output")
    if output_obj and isinstance(output_obj, StrategicOutput):
        return output_obj.model_dump()
    else:
        print("\n--- ANALYSIS FAILED ---")
        return {"error": "Analysis failed to produce a valid StrategicOutput. See server logs."}


# --- Example Usage ---
if __name__ == "__main__":
    import asyncio
    async def main():
        pruning_input_json = {
            "filters": {
                "asset_phenotype": {
                    "stages_of_development": ["Phase 1", "Phase 2"]
                }
            },
            "strategic_objective": "Pruning the Portfolio",
            "current_company_id": "YOUR_COMPANY_001"
        }
        print("\n--- Running Pruning Analysis ---")
        try:
            pruning_output = await run_strategic_analysis(pruning_input_json)
            print("\n<<< FINAL STRATEGIC OUTPUT >>>")
            print(json.dumps(pruning_output, indent=2))
        except Exception as e:
            print(f"An unhandled error occurred during the analysis run: {e}")

    await main()

Loaded 11 molecules.


NameError: name 'ToolExecutor' is not defined

In [None]:
import os
import random
import json
from typing import List, Dict, Any, Optional, Literal, TypedDict
from pydantic import BaseModel, Field, conint, confloat, ValidationError
from enum import Enum
from datetime import datetime, timedelta
from typing import Union

# LangChain specific imports
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage
from langchain_core.documents import Document

# LangGraph imports
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode

# --- API Keys and Configuration ---
# It's highly recommended to use environment variables for security.
# os.environ["GOOGLE_API_KEY"] = "YOUR_GOOGLE_API_KEY"
GOOGLE_API_KEY = "AIzaSyC4q-ry8oPTjBHDP1suYrtB2PX52MXREwg" # Replace with your key

# Optional, for Langsmith tracing
# os.environ["LANGCHAIN_API_KEY"] = "YOUR_LANGCHAIN_API_KEY"
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_PROJECT"] = "Strategic Portfolio AI"


class AssetType(str, Enum): MOLECULE = "Molecule"
class StrategicObjectiveEnum(str, Enum): PRUNING = "Pruning the Portfolio"
class Molecule(BaseModel): id: str; name: str; asset_type: AssetType; development_stage: str; therapeutic_area: str; indication: str; mechanism_of_action: str; route_of_administration: str; modality: str; patent_expiry_year: int; projected_peak_sales_M: Optional[confloat(ge=0)] = None; internal_risk_score: confloat(ge=0, le=1); efficacy_profile: str; safety_profile: str; company_id: str; company_name: str
class AssetPhenotypeFilters(BaseModel): stages_of_development: Optional[List[str]] = None
class FrontendFilters(BaseModel): asset_phenotype: Optional[AssetPhenotypeFilters] = Field(default_factory=AssetPhenotypeFilters)
class FrontendInput(BaseModel): filters: FrontendFilters = Field(default_factory=FrontendFilters); strategic_objective: StrategicObjectiveEnum; current_company_id: str
class PruningRecommendation(BaseModel): action_type: Literal["Deprioritize Asset"]; molecule_name: str; molecule_id: str; justification: str; reason_criteria: List[str]; risk_score: confloat(ge=0, le=1); opportunity_cost_estimate: str; impact_on_portfolio: Dict[str, Any]
class StrategicOutput(BaseModel): strategic_objective_addressed: StrategicObjectiveEnum; recommendations: List[PruningRecommendation] = Field(default_factory=list); strategic_summary: str; overall_impact_on_portfolio: Dict[str, Any]; recommended_portfolio_adjustments: Dict[str, Any]; suggested_ideal_portfolio_characteristics: Dict[str, Any]

# --- Initialize Global LLM ---
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.1, google_api_key=GOOGLE_API_KEY)

# --- Data Management Layer ---
class DataManager:
    def __init__(self, data_file_path: str): self._load_all_data(data_file_path)
    def _load_all_data(self, data_file_path):
        try:
            with open(data_file_path, 'r') as f: data = json.load(f)
            self.all_molecules: Dict[str, Molecule] = {m["id"]: Molecule(**m) for m in data.get("molecules", [])}
            print(f"Loaded {len(self.all_molecules)} molecules.")
            self.your_company_id: Optional[str] = None; self.internal_molecules: Dict[str, Molecule] = {}
        except Exception as e: raise e
    def set_current_company(self, company_id: str): self.your_company_id = company_id; self.internal_molecules = {mol.id: mol for mol in self.all_molecules.values() if mol.company_id == company_id}
    def get_internal_portfolio(self) -> List[Molecule]: return list(self.internal_molecules.values())
    def get_filtered_molecules(self, molecules: List[Molecule], filters: FrontendFilters) -> List[Molecule]:
        if not filters.asset_phenotype or not filters.asset_phenotype.stages_of_development: return molecules
        return [m for m in molecules if m.development_stage in filters.asset_phenotype.stages_of_development]

if not os.path.exists("company.json"):
    dummy_data = {"molecules": [{"id": "MOL_001", "name": "Innovatinib", "asset_type": "Molecule", "development_stage": "Phase 2", "therapeutic_area": "Oncology", "indication": "NSCLC", "mechanism_of_action": "EGFR inhibitor", "route_of_administration": "Oral", "modality": "Small Molecule", "patent_expiry_year": 2035, "projected_peak_sales_M": 1200, "internal_risk_score": 0.4, "efficacy_profile": "Promising", "safety_profile": "Manageable", "company_id": "YOUR_COMPANY_001", "company_name": "InnovatePharma Inc."}, {"id": "MOL_002", "name": "Dermacure", "asset_type": "Molecule", "development_stage": "Phase 1", "therapeutic_area": "Dermatology", "indication": "Psoriasis", "mechanism_of_action": "JAK inhibitor", "route_of_administration": "Topical", "modality": "Small Molecule", "patent_expiry_year": 2038, "projected_peak_sales_M": 800, "internal_risk_score": 0.6, "efficacy_profile": "Early but positive", "safety_profile": "Good", "company_id": "YOUR_COMPANY_001", "company_name": "InnovatePharma Inc."}]}
    with open("company.json", "w") as f: json.dump(dummy_data, f, indent=2)

data_manager = DataManager('company.json')

# --- Tool Definitions ---
@tool
def search_internal_portfolio(filters: FrontendFilters) -> List[Molecule]:
    """Searches the current company's internal molecule portfolio using structured filters."""
    print(f"Tool 'search_internal_portfolio' called with filters: {filters.model_dump_json(indent=2)}")
    all_internal = data_manager.get_internal_portfolio()
    return data_manager.get_filtered_molecules(all_internal, filters)

@tool
def submit_strategic_analysis(analysis: StrategicOutput) -> str:
    """Use this tool to submit the final strategic portfolio analysis and recommendations. This is the final step."""
    print("Tool 'submit_strategic_analysis' called.")
    return "Strategic analysis submitted successfully. The process is complete."

tools = [search_internal_portfolio, submit_strategic_analysis]
llm_with_tools = llm.bind_tools(tools)

# --- Agent State and Graph Definition ---
class AgentState(TypedDict):
    input: FrontendInput
    chat_history: List[BaseMessage]
    final_output: Optional[StrategicOutput]

# Node 1: The "thinker"
def agent_node(state: AgentState):
    """Invokes the LLM to decide the next action or to finish."""
    print("---AGENT: Thinking...---")
    response = llm_with_tools.invoke(state["chat_history"])
    return {"chat_history": state["chat_history"] + [response]}

# Node 2: The "doer" - a WRAPPER around ToolNode
tool_node_instance = ToolNode(tools)
def execute_tools_node(state: AgentState) -> Dict[str, List[BaseMessage]]:
    """A wrapper node that executes tools by invoking the ToolNode with the chat history."""
    print("---TOOLS: Executing...---")
    messages = state['chat_history']
    tool_invocation_result = tool_node_instance.invoke(messages)
    if not isinstance(tool_invocation_result, list):
        tool_invocation_result = [tool_invocation_result]
    return {"chat_history": state["chat_history"] + tool_invocation_result}

# Conditional Edge: The Router
def should_continue(state: AgentState) -> Literal["execute_tools_node", END]:
    """Checks the last message in the state to decide the next step."""
    print("---ROUTER: Checking state---")
    last_message = state["chat_history"][-1]
    if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
        print("-> LLM did not call a tool. Ending.")
        return END
    if last_message.tool_calls[0]["name"] == "submit_strategic_analysis":
        try:
            final_data = last_message.tool_calls[0]['args']['analysis']
            state['final_output'] = StrategicOutput(**final_data)
            print("-> 'submit_strategic_analysis' called. Ending.")
        except (ValidationError, KeyError) as e:
            print(f"-> ERROR: Could not parse final analysis. {e}. Ending.")
        return END
    print("-> Tool call detected. Continuing to executor.")
    return "execute_tools_node"

# --- Define and Compile the Graph ---
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("execute_tools_node", execute_tools_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", should_continue, {"execute_tools_node": "execute_tools_node", END: END})
workflow.add_edge("execute_tools_node", "agent")
app = workflow.compile()

# --- Runner Function ---
async def run_strategic_analysis(input_json: Dict[str, Any]) -> Dict[str, Any]:
    try:
        frontend_input = FrontendInput(**input_json)
    except ValidationError as e:
        return {"error": f"Invalid input data: {e.errors()}"}

    data_manager.set_current_company(frontend_input.current_company_id)
    print(f"--- Starting Analysis for Objective: {frontend_input.strategic_objective.value} (Company: {frontend_input.current_company_id}) ---")

    # <<< FIX: Use a persistent SystemMessage for core instructions >>>
    system_prompt = f"""You are an expert strategic portfolio analyst. Your job is to execute a multi-step analysis using your available tools.

    Your overall goal is to address the user's strategic objective: "{frontend_input.strategic_objective.value}".

    Follow this process:
    1.  Your first step is ALWAYS to call the `search_internal_portfolio` tool to get the data you need to analyze.
    2.  After you receive the portfolio data from the tool, you MUST analyze it.
    3.  Based on your analysis, your final step is to call the `submit_strategic_analysis` tool with your complete findings. Do not respond with simple text; you must finish by calling this tool."""

    # The HumanMessage now only provides the specific data for the task
    human_prompt = f"""Please begin the analysis. Here are the filters to use for the `search_internal_portfolio` tool:
    {frontend_input.filters.model_dump_json(indent=2)}
    """

    initial_state = {
        "input": frontend_input,
        "chat_history": [
            SystemMessage(content=system_prompt),
            HumanMessage(content=human_prompt)
        ],
        "final_output": None
    }

    final_state = {}
    async for output in app.astream(initial_state):
        if END in output:
            final_state = output[END]
            break

    output_obj = final_state.get("final_output")

    if output_obj and isinstance(output_obj, StrategicOutput):
        return output_obj.model_dump()
    else:
        print("\n--- ANALYSIS FAILED ---")
        if final_state and 'chat_history' in final_state:
            print("\n--- Final Chat History ---")
            for msg in final_state['chat_history']:
                print("-------------")
                print(type(msg))
                print(msg.content)
            print("------------------------")
        return {"error": "Analysis failed to produce a valid StrategicOutput. See server logs."}


# --- Example Usage ---
if __name__ == "__main__":
    import asyncio
    async def main():
        pruning_input_json = {"filters": {"for_sale": "all", "asset_type": "Molecule", "asset_phenotype": {"stages_of_development": ["Phase 2", "Phase 3", "Marketed"], "therapeutic_area": ["Metabolic", "Oncology", "Dermatology"]}}, "strategic_objective": "Pruning the Portfolio", "current_company_id": "YOUR_COMPANY_001"}

        print("\n--- Running Pruning Analysis ---")
        try:
            pruning_output = await run_strategic_analysis(pruning_input_json)
            print("\n<<< FINAL STRATEGIC OUTPUT >>>")
            print(json.dumps(pruning_output, indent=2))
        except Exception as e:
            print(f"An unhandled error occurred during the analysis run: {e}")
    await main()

Loaded 11 molecules.

--- Running Pruning Analysis ---
--- Starting Analysis for Objective: Pruning the Portfolio (Company: YOUR_COMPANY_001) ---
---AGENT: Thinking...---
---ROUTER: Checking state---
-> Tool call detected. Continuing to executor.
---TOOLS: Executing...---
---AGENT: Thinking...---
---ROUTER: Checking state---
-> Tool call detected. Continuing to executor.
---TOOLS: Executing...---
Tool 'search_internal_portfolio' called with filters: {
  "asset_phenotype": {
    "stages_of_development": [
      "Phase 2",
      "Phase 3",
      "Marketed"
    ]
  }
}
---AGENT: Thinking...---
---ROUTER: Checking state---
-> ERROR: Could not parse final analysis. 5 validation errors for StrategicOutput
recommendations.0.action_type
  Input should be 'Deprioritize Asset' [type=literal_error, input_value='Discontinue', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/literal_error
recommendations.0.impact_on_portfolio
  Input should be a valid dictionary 