diff --git a/.github/workflows/containers-publish.yml b/.github/workflows/containers-publish.yml index 9cd5fcce..b178b0b2 100644 --- a/.github/workflows/containers-publish.yml +++ b/.github/workflows/containers-publish.yml @@ -1,8 +1,8 @@ name: "Containers: Publish" on: - push: - tags: ["v*"] + release: + types: [published] permissions: packages: write @@ -12,10 +12,10 @@ jobs: name: Build and Push runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v5 - name: Login to ghcr.io Docker registry - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 00000000..dcb7a4bb --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,29 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python application + +on: + push: + branches: [ "listOfMed" ] + pull_request: + branches: [ "listOfMed" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install the code linting and formatting tool Ruff + run: pipx install ruff + - name: Lint code with Ruff + run: ruff check --output-format=github --target-version=py39 diff --git a/README.md b/README.md index dec66516..0b48973e 100644 --- a/README.md +++ b/README.md @@ -36,15 +36,29 @@ Tools used for development: ### Running Balancer for development -Running Balancer: -- Start Docker Desktop and run `docker compose up --build` -- The email and password are set in `server/api/management/commands/createsu.py` -- Download a sample of papers to upload from [https://balancertestsite.com](https://balancertestsite.com/) +Start the Postgres, Django REST, and React services by starting Docker Desktop and running `docker compose up --build` -Running pgAdmin: +#### Postgres +- Download a sample of papers to upload from [https://balancertestsite.com](https://balancertestsite.com/) - The email and password of `pgAdmin` are specified in `balancer-main/docker-compose.yml` -- The first time you use `pgAdmin` after building the Docker containers you will need to register the server. -The `Host name/address`, `Username` and `Password` are specified in `balancer-main/docker-compose.yml` +- The first time you use `pgAdmin` after building the Docker containers you will need to register the server. + - The `Host name/address` is the Postgres server service name in the Docker Compose file + - The `Username` and `Password` are the Postgres server environment variables in the Docker Compose file +- You can use the below code snippet to query the database from a Jupyter notebook: + +``` +from sqlalchemy import create_engine +import pandas as pd + +engine = create_engine("postgresql+psycopg2://balancer:balancer@localhost:5433/balancer_dev") + +query = "SELECT * FROM api_embeddings;" + +df = pd.read_sql(query, engine) +``` + +#### Django REST +- The email and password are set in `server/api/management/commands/createsu.py` ## Local Kubernetes Deployment diff --git a/deploy/manifests/balancer/base/deployment.yml b/deploy/manifests/balancer/base/deployment.yml index 54efeedf..919d10c9 100644 --- a/deploy/manifests/balancer/base/deployment.yml +++ b/deploy/manifests/balancer/base/deployment.yml @@ -17,7 +17,7 @@ spec: app: balancer spec: containers: - - image: chrissst/balancer + - image: ghcr.io/codeforphilly/balancer-main/backend name: balancer envFrom: - configMapRef: diff --git a/deploy/manifests/balancer/overlays/dev/kustomization.yml b/deploy/manifests/balancer/overlays/dev/kustomization.yml index 46197140..92a6001b 100644 --- a/deploy/manifests/balancer/overlays/dev/kustomization.yml +++ b/deploy/manifests/balancer/overlays/dev/kustomization.yml @@ -5,7 +5,7 @@ resources: - "../../base" images: - - name: chrissst/balancer - newTag: v1 - + - name: ghcr.io/codeforphilly/balancer-main/backend + newTag: "1.0.2" + namespace: balancer diff --git a/docker-compose.yml b/docker-compose.yml index d8a8ca75..aea1993b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,34 +1,34 @@ services: db: - build: - context: ./db - dockerfile: Dockerfile - volumes: - - postgres_data:/var/lib/postgresql/data/ - environment: - - POSTGRES_USER=balancer - - POSTGRES_PASSWORD=balancer - - POSTGRES_DB=balancer_dev - ports: - - "5433:5432" - networks: - app_net: - ipv4_address: 192.168.0.2 - pgadmin: - container_name: pgadmin4 - image: dpage/pgadmin4 + build: + context: ./db + dockerfile: Dockerfile + volumes: + - postgres_data:/var/lib/postgresql/data/ environment: - PGADMIN_DEFAULT_EMAIL: balancer-noreply@codeforphilly.org - PGADMIN_DEFAULT_PASSWORD: balancer - # PGADMIN_LISTEN_PORT = 80 - # volumes: - # - ./pgadmin-data:/var/lib/pgadmin - # # PGADMIN_LISTEN_PORT = 80 + - POSTGRES_USER=balancer + - POSTGRES_PASSWORD=balancer + - POSTGRES_DB=balancer_dev ports: - - "5050:80" + - "5433:5432" networks: app_net: - ipv4_address: 192.168.0.4 + ipv4_address: 192.168.0.2 + # pgadmin: + # container_name: pgadmin4 + # image: dpage/pgadmin4 + # environment: + # PGADMIN_DEFAULT_EMAIL: balancer-noreply@codeforphilly.org + # PGADMIN_DEFAULT_PASSWORD: balancer + # # PGADMIN_LISTEN_PORT = 80 + # # volumes: + # # - ./pgadmin-data:/var/lib/pgadmin + # # # PGADMIN_LISTEN_PORT = 80 + # ports: + # - "5050:80" + # networks: + # app_net: + # ipv4_address: 192.168.0.4 backend: image: balancer-backend build: ./server @@ -52,13 +52,13 @@ services: args: - IMAGE_NAME=balancer-frontend ports: - - "3000:3000" + - "3000:3000" environment: - - CHOKIDAR_USEPOLLING=true - # - VITE_API_BASE_URL=https://balancertestsite.com/ + - CHOKIDAR_USEPOLLING=true + # - VITE_API_BASE_URL=https://balancertestsite.com/ volumes: - - "./frontend:/usr/src/app:delegated" - - "/usr/src/app/node_modules/" + - "./frontend:/usr/src/app:delegated" + - "/usr/src/app/node_modules/" depends_on: - backend networks: @@ -72,4 +72,4 @@ networks: driver: default config: - subnet: "192.168.0.0/24" - gateway: 192.168.0.1 \ No newline at end of file + gateway: 192.168.0.1 diff --git a/evaluation/README.md b/evaluation/README.md index a1d0ad70..a32df682 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -1,40 +1,206 @@ - # Evaluations -## LLM Output Evaluator +## `evals`: LLM evaluations to test and improve model outputs + +### Evaluation Metrics + +Natural Language Generation Performance: + +[Extractiveness](https://huggingface.co/docs/lighteval/en/metric-list#automatic-metrics-for-generative-tasks): + +* Extractiveness Coverage: Extent to which a summary is derivative of a text +* Extractiveness Density: How well the word sequence can be described as series of extractions +* Extractiveness Compression: Word ratio between the article and the summary + +API Performance: + +* Token Usage (input/output) +* Estimated Cost in USD +* Duration (in seconds) + +### Test Data + +Generate the dataset file by connecting to a database of research papers: + +Connect to the Postgres database of your local Balancer instance: + +``` +from sqlalchemy import create_engine + +engine = create_engine("postgresql+psycopg2://balancer:balancer@localhost:5433/balancer_dev") +``` + +Connect to the Postgres database of the production Balancer instance using a SQL file: + +``` +# Add Postgres.app binaries to the PATH +echo 'export PATH="/Applications/Postgres.app/Contents/Versions/latest/bin:$PATH"' >> ~/.zshrc + +createdb +pg_restore -v -d .sql +``` -The `evals` script evaluates the outputs of Large Language Models (LLMs) and estimates the associated token usage and cost. +Generate the dataset CSV file: -It supports batch evalaution via a configuration CSV and produces a detailed metrics report in CSV format. +``` +from sqlalchemy import create_engine +import pandas as pd -### Usage +engine = create_engine("postgresql://@localhost:5432/") -This script evaluates LLM outputs using the `lighteval` library: https://huggingface.co/docs/lighteval/en/metric-list#automatic-metrics-for-generative-tasks +query = "SELECT * FROM api_embeddings;" +df = pd.read_sql(query, engine) -Ensure you have the `lighteval` library and any model SDKs (e.g., OpenAI, Anthropic) configured properly. +df['INPUT'] = df.apply(lambda row: f"ID: {row['chunk_number']} | CONTENT: {row['text']}", axis=1) +# Ensure the chunks are joined in order of chunk_number by sorting the DataFrame before grouping and joining +df = df.sort_values(by=['name', 'upload_file_id', 'chunk_number']) +df_grouped = df.groupby(['name', 'upload_file_id'])['INPUT'].apply(lambda chunks: "\n".join(chunks)).reset_index() -```bash -python evals.py --config path/to/config.csv --reference path/to/reference.csv --output path/to/results.csv +df_grouped.to_csv('', index=False) ``` -The arguments to the script are: +### Running an Evaluation + +#### Bulk Model and Prompt Experimentation -- Path to the config CSV file: Must include the columns "Model Name" and "Query" -- Path to the reference CSV file: Must include the columns "Context" and "Reference" -- Path where the evaluation resuls will be saved +Compare the results of many different prompts and models at once + +``` +import pandas as pd + +data = [ + { + "MODEL": "", + "INSTRUCTIONS": """""" + }, + { + "MODEL": "", + "INSTRUCTIONS": """""" + }, +] + +df = pd.DataFrame.from_records(data) + +df.to_csv("", index=False) +``` -The script outputs a CSV with the following columns: +#### Execute on the Command Line -* Evaluates LLM outputs for: - * Extractiveness Coverage - * Extractiveness Density - * Extractiveness Compression +Execute [using `uv` to manage dependencies](https://docs.astral.sh/uv/guides/scripts/) without manually managing enviornments: -* Computes: +```sh +uv run evals.py --experiments path/to/ --dataset path/to/ --results path/to/ +``` + +Execute without using uv run by ensuring it is executable: + +```sh +./evals.py --experiments path/to/ --dataset path/to/ --results path/to/ +``` + +### Analyzing Test Results + +``` +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + +df = pd.read_csv("") + +# Define the metrics of interest +extractiveness_cols = ['Extractiveness Coverage', 'Extractiveness Density', 'Extractiveness Compression'] +token_cols = ['Input Token Usage', 'Output Token Usage'] +other_metrics = ['Cost (USD)', 'Duration (s)'] +all_metrics = extractiveness_cols + token_cols + other_metrics + +# Metric Histograms by Model +plt.style.use('default') +fig, axes = plt.subplots(len(all_metrics), 1, figsize=(12, 4 * len(all_metrics))) + +models = df['MODEL'].unique() +colors = plt.cm.Set3(np.linspace(0, 1, len(models))) + +for i, metric in enumerate(all_metrics): + ax = axes[i] if len(all_metrics) > 1 else axes + + # Create histogram for each model + for j, model in enumerate(models): + model_data = df[df['MODEL'] == model][metric] + ax.hist(model_data, alpha=0.7, label=model, bins=min(8, len(model_data)), + color=colors[j], edgecolor='black', linewidth=0.5) + + ax.set_title(f'{metric} Distribution by Model', fontsize=14, fontweight='bold') + ax.set_xlabel(metric, fontsize=12) + ax.set_ylabel('Frequency', fontsize=12) + ax.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left') + ax.grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +# Metric Statistics by Model +for metric in all_metrics: + print(f"\n{metric.upper()}:") + desc_stats = df.groupby('MODEL')[metric].agg([ + 'count', 'mean', 'std', 'min', 'median','max' + ]) + + print(desc_stats) + + +# Calculate Efficiency Metrics By model +df_analysis = df.copy() +df_analysis['Total Token Usage'] = df_analysis['Input Token Usage'] + df_analysis['Output Token Usage'] +df_analysis['Cost per Token'] = df_analysis['Cost (USD)'] / df_analysis['Total Token Usage'] +df_analysis['Tokens per Second'] = df_analysis['Total Token Usage'] / df_analysis['Duration (s)'] +df_analysis['Cost per Second'] = df_analysis['Cost (USD)'] / df_analysis['Duration (s)'] + +efficiency_metrics = ['Cost per Token', 'Tokens per Second', 'Cost per Second'] + +for metric in efficiency_metrics: + print(f"\n{metric.upper()}:") + eff_stats = df_analysis.groupby('MODEL')[metric].agg([ + 'count', 'mean', 'std', 'min', 'median', 'max' + ]) + + for col in ['mean', 'std', 'min', 'median', 'max']: + eff_stats[col] = eff_stats[col].apply(lambda x: f"{x:.3g}") + print(eff_stats) + + +``` + +### Contributing + +#### Adding Evaluation Metrics + +To add new evaluation metrics, modify the `evaluate_response()` function in `evaluation/evals.py`: + +**Update dependencies** in script header and ensure exception handling includes new metrics with `None` values. + +#### Adding New LLM Models + +To add a new LLM model for evaluation, implement a handler in `server/api/services/llm_services.py`: + +1. **Create a handler class** inheriting from `BaseModelHandler`: +2. **Register in ModelFactory** by adding to the `HANDLERS` dictionary: +3. **Use in experiments** by referencing the handler key in your experiments CSV: + +The evaluation system will automatically use your handler through the Factory Method pattern. + + +#### Running Tests + +The evaluation module includes comprehensive tests for all core functions. Run the test suite using: + +```sh +uv run test_evals.py +``` - * Token usage (input/output) - * Estimated cost in USD - * Duration (in seconds) +The tests cover: +- **Cost calculation** with various token usage and pricing scenarios +- **CSV loading** with validation and error handling +- **Response evaluation** including async operations and exception handling diff --git a/evaluation/evals.py b/evaluation/evals.py old mode 100644 new mode 100755 index f6e9bb3d..8eb7e9e6 --- a/evaluation/evals.py +++ b/evaluation/evals.py @@ -1,9 +1,20 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = "==3.11.11" +# dependencies = [ +# "pandas==2.2.3", +# "lighteval==0.10.0", +# "openai==1.83.0", +# "spacy==3.8.7", +# "pip" +# +# ] +# /// + """ Evaluate LLM outputs using multiple metrics and compute associated costs """ -# TODO: Add tests on a small dummy dataset to confirm it handles errors gracefully and produces expected outputs - import sys import os @@ -12,8 +23,12 @@ import argparse import logging +import asyncio +import time import pandas as pd + +# lighteval depends on `sentencepiece` and it only has prebuilt wheels for Python 3.11 or below from lighteval.tasks.requests import Doc from lighteval.metrics.metrics_sample import Extractiveness @@ -24,130 +39,168 @@ ) -def evaluate_response( - model_name: str, query: str, context: str, reference: str -) -> pd.DataFrame: +async def evaluate_response(model: str, instructions: str, input: str) -> pd.DataFrame: + """ + Test a prompt with a set of test data by scoring each item in the data set """ - Evaluates the response of a model to a given query and context, computes extractiveness metrics, token usage, and cost - Args: - model_name (str): The name of the model to be used for evaluation. - query (str): The user query to be processed. - context (str): The context or document content to be used. - reference (str): The reference text for comparison (not used in this function, but can be used for further evaluations). + try: + handler = ModelFactory.get_handler(model) - Returns: - pd.DataFrame: A DataFrame containing the output text, extractiveness metrics, token usage, cost, and duration. - """ + generated_text, token_usage, pricing, duration = await handler.handle_request( + instructions, input + ) - handler = ModelFactory.get_handler(model_name) + doc = Doc(query="", choices=[], gold_index=0, specific={"text": input}) + extractiveness = Extractiveness().compute( + formatted_doc=doc, predictions=[generated_text] + ) - # TODO: Add error handling for unsupported models + cost_metrics = calculate_cost_metrics(token_usage, pricing) - output_text, token_usage, pricing, duration = handler.handle_request(query, context) + result = pd.DataFrame( + [ + { + "Generated Text": generated_text, + "Extractiveness Coverage": extractiveness["summarization_coverage"], + "Extractiveness Density": extractiveness["summarization_density"], + "Extractiveness Compression": extractiveness[ + "summarization_compression" + ], + "Input Token Usage": token_usage.input_tokens, + "Output Token Usage": token_usage.output_tokens, + "Cost (USD)": cost_metrics["total_cost"], + "Duration (s)": duration, + } + ] + ) - doc = Doc(query="", choices=[], gold_index=0, specific={"text": context}) - extractiveness = Extractiveness().compute( - formatted_doc=doc, predictions=[output_text] - ) + except Exception as e: + logging.error(f"Error evaluating response for model {model}: {e}") + result = pd.DataFrame( + [ + { + "Generated Text": None, + "Extractiveness Coverage": None, + "Extractiveness Density": None, + "Extractiveness Compression": None, + "Input Token Usage": None, + "Output Token Usage": None, + "Cost (USD)": None, + "Duration (s)": None, + } + ] + ) + + return result - input_cost_dollars = (pricing["input"] / 1000000) * token_usage.input_tokens - output_cost_dollars = (pricing["output"] / 1000000) * token_usage.output_tokens +def calculate_cost_metrics(token_usage: dict, pricing: dict) -> dict: + """ + Calculate cost metrics based on token usage and pricing + """ + + TOKENS_PER_MILLION = 1_000_000 + + # Pricing is in dollars per million tokens + input_cost_dollars = ( + pricing["input"] / TOKENS_PER_MILLION + ) * token_usage.input_tokens + output_cost_dollars = ( + pricing["output"] / TOKENS_PER_MILLION + ) * token_usage.output_tokens total_cost_dollars = input_cost_dollars + output_cost_dollars - return pd.DataFrame( - [ - { - "Output Text": output_text, - "Extractiveness Coverage": extractiveness["summarization_coverage"], - "Extractiveness Density": extractiveness["summarization_density"], - "Extractiveness Compression": extractiveness[ - "summarization_compression" - ], - "Input Token Usage": token_usage.input_tokens, - "Output Token Usage": token_usage.output_tokens, - "Cost (USD)": total_cost_dollars, - "Duration (s)": duration, - } - ] - ) + return { + "input_cost": input_cost_dollars, + "output_cost": output_cost_dollars, + "total_cost": total_cost_dollars, + } -if __name__ == "__main__": - # TODO: Add CLI argument to specify the metrics to be computed - parser = argparse.ArgumentParser( - description="Evaluate LLM outputs using multiple metrics and compute associated costs" - ) - parser.add_argument("--config", "-c", required=True, help="Path to config CSV file") - parser.add_argument( - "--reference", "-r", required=True, help="Path to reference CSV file" - ) - parser.add_argument("--output", "-o", required=True, help="Path to output CSV file") +def load_csv(file_path: str, required_columns: list, nrows: int = None) -> pd.DataFrame: + """ + Load a CSV file and validate that it contains the required columns - args = parser.parse_args() + Args: + file_path (str): Path to the CSV file + required_columns (list): List of required column names + nrows (int): Number of rows to read from the CSV file + Returns: + pd.DataFrame + """ - df_config = pd.read_csv(args.config) - logging.info(f"Config DataFrame shape: {df_config.shape}") - logging.info(f"Config DataFrame columns: {df_config.columns.tolist()}") + if nrows is not None: + logging.info(f"Test mode enabled: Reading first {nrows} rows of {file_path}") - # Remove the trailing whitespace from column names - df_config.columns = df_config.columns.str.strip() + df = pd.read_csv(file_path, nrows=nrows) + + # Remove trailing whitespace from column names + df.columns = df.columns.str.strip() + + # Uppercase the column names to match the expected format + df.columns = df.columns.str.upper() # Check if the required columns are present - required_columns = ["Model Name", "Query"] - if not all(col in df_config.columns for col in required_columns): + if not all(col in df.columns for col in required_columns): raise ValueError( - f"Config DataFrame must contain the following columns: {required_columns}" + f"{file_path} must contain the following columns: {required_columns}" ) - # Check if all models in the config are supported by ModelFactory - if not all( - model in ModelFactory.HANDLERS.keys() - for model in df_config["Model Name"].unique() - ): - raise ValueError( - f"Unsupported model(s) found in config: {set(df_config['Model Name'].unique()) - set(ModelFactory.HANDLERS.keys())}" - ) + return df - df_reference = pd.read_csv(args.reference) - logging.info(f"Reference DataFrame shape: {df_reference.shape}") - logging.info(f"Reference DataFrame columns: {df_reference.columns.tolist()}") - # Remove the trailing whitespace from column names - df_reference.columns = df_reference.columns.str.strip() - # Check if the required columns are present - required_columns = ["Context", "Reference"] - if not all(col in df_reference.columns for col in required_columns): - raise ValueError( - f"Reference DataFrame must contain the following columns: {required_columns}" - ) +async def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--experiments", "-e", required=True, help="Path to experiments CSV file" + ) + parser.add_argument( + "--dataset", "-d", required=True, help="Path to dataset CSV file" + ) + parser.add_argument( + "--results", "-r", required=True, help="Path to results CSV file" + ) + parser.add_argument( + "--test", "-t", type=int, help="Run evaluation on first n rows of dataset only" + ) - # Cross join the config and reference DataFrames - df_in = df_config.merge(df_reference, how="cross") + args = parser.parse_args() - # TODO: Parallelize the evaluation process for each row in df_in using concurrent.futures or similar libraries - df_evals = pd.DataFrame() - for index, row in df_in.iterrows(): - df_evals = pd.concat( - [ - df_evals, - evaluate_response( - row["Model Name"], row["Query"], row["Context"], row["Reference"] - ), - ], - axis=0, - ) + # Load the experiment DataFrame + df_experiment = load_csv( + args.experiments, required_columns=["MODEL", "INSTRUCTIONS"] + ) - logging.info(f"Processed row {index + 1}/{len(df_in)}") + # Load the dataset DataFrame + df_dataset = load_csv(args.dataset, required_columns=["INPUT"], nrows=args.test) - # Concatenate the input and evaluations DataFrames + # Bulk model and prompt experimentation: Cross join the experiment and dataset DataFrames + df_in = df_experiment.merge(df_dataset, how="cross") + + # Evaluate each row in the input DataFrame concurrently + logging.info(f"Starting evaluation of {len(df_in)} rows") + start_time = time.time() + tasks = [ + evaluate_response(row.MODEL, row.INSTRUCTIONS, row.INPUT) + for row in df_in.itertuples(index=False) + ] + results = await asyncio.gather(*tasks) + end_time = time.time() + duration = end_time - start_time + logging.info(f"Completed evaluation of {len(results)} rows in {duration} seconds") + + df_evals = pd.concat(results, axis=0, ignore_index=True) + + # Concatenate the input and evaluations DataFrames df_out = pd.concat( [df_in.reset_index(drop=True), df_evals.reset_index(drop=True)], axis=1 ) - - df_out.to_csv(args.output, index=False) - logging.info(f"Output DataFrame shape: {df_out.shape}") - logging.info(f"Results saved to {args.output}") + df_out.to_csv(args.results, index=False) + logging.info(f"Results saved to {args.results}") logging.info("Evaluation completed successfully.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/evaluation/test_evals.py b/evaluation/test_evals.py index f41817c6..e6d0916d 100644 --- a/evaluation/test_evals.py +++ b/evaluation/test_evals.py @@ -1,53 +1,212 @@ - -from unittest.mock import patch, MagicMock +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = "==3.11.11" +# dependencies = [ +# "pandas==2.2.3", +# "lighteval==0.10.0", +# "openai==1.83.0", +# "spacy==3.8.7", +# "pytest==8.3.3", +# "pytest-asyncio==0.24.0", +# "pip" +# ] +# /// import pytest import pandas as pd +from unittest.mock import Mock, patch, AsyncMock +import tempfile +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from evaluation.evals import evaluate_response, calculate_cost_metrics, load_csv + -from evals import evaluate_response +@pytest.fixture +def mock_token_usage(): + token_usage = Mock() + token_usage.input_tokens = 100 + token_usage.output_tokens = 50 + return token_usage -class MockTokenUsage: - def __init__(self, input_tokens, output_tokens): - self.input_tokens = input_tokens - self.output_tokens = output_tokens -@patch("evals.ModelFactory.get_handler") -@patch("evals.Extractiveness.compute") -def test_evaluate_response(mock_extractiveness_compute, mock_get_handler): +@pytest.fixture +def temp_csv(): + def _create_csv(content): + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + f.write(content) + return f.name - # Mock BaseModelHandler - mock_handler = MagicMock() - mock_handler.handle_request.return_value = ( - "This is a summary.", - MockTokenUsage(input_tokens=100, output_tokens=50), - {"input": 15.0, "output": 30.0}, # $15 and $30 per 1M tokens - 1.23, # duration + return _create_csv + + +class TestCalculateCostMetrics: + @pytest.mark.parametrize( + "input_tokens,output_tokens,input_price,output_price,expected_input,expected_output,expected_total", + [ + (1000, 500, 5.0, 15.0, 0.005, 0.0075, 0.0125), + (0, 0, 5.0, 15.0, 0.0, 0.0, 0.0), + (1_000_000, 2_000_000, 10.0, 30.0, 10.0, 60.0, 70.0), + ], ) + def test_calculate_cost_metrics( + self, + input_tokens, + output_tokens, + input_price, + output_price, + expected_input, + expected_output, + expected_total, + ): + token_usage = Mock(input_tokens=input_tokens, output_tokens=output_tokens) + pricing = {"input": input_price, "output": output_price} + + result = calculate_cost_metrics(token_usage, pricing) - mock_get_handler.return_value = mock_handler + assert pytest.approx(result["input_cost"]) == expected_input + assert pytest.approx(result["output_cost"]) == expected_output + assert pytest.approx(result["total_cost"]) == expected_total - mock_extractiveness_compute.return_value = { - "summarization_coverage": 0.8, - "summarization_density": 1.5, - "summarization_compression": 2.0, - } - df = evaluate_response( - model_name="mock-model", - query="What is the summary?", - context="This is a long article about something important.", - reference="This is a reference summary.", +class TestLoadCsv: + @pytest.mark.parametrize( + "csv_content,required_columns,expected_len", + [ + ( + "model,instructions\ngpt-4,Test prompt\ngpt-3.5,Another prompt\n", + ["MODEL", "INSTRUCTIONS"], + 2, + ), + ( + " model , instructions \ngpt-4,Test prompt\n", + ["MODEL", "INSTRUCTIONS"], + 1, + ), + ("input\nTest input 1\nTest input 2\n", ["INPUT"], 2), + ], ) + def test_load_csv_valid( + self, temp_csv, csv_content, required_columns, expected_len + ): + temp_path = temp_csv(csv_content) + try: + df = load_csv(temp_path, required_columns) + assert len(df) == expected_len + assert list(df.columns) == required_columns + finally: + os.unlink(temp_path) + + @pytest.mark.parametrize( + "csv_content,required_columns", + [ + ("model,prompt\ngpt-4,Test prompt\n", ["MODEL", "INSTRUCTIONS"]), + ("wrong,columns\nval1,val2\n", ["MODEL", "INSTRUCTIONS"]), + ], + ) + def test_load_csv_missing_columns(self, temp_csv, csv_content, required_columns): + temp_path = temp_csv(csv_content) + try: + with pytest.raises(ValueError, match="must contain the following columns"): + load_csv(temp_path, required_columns) + finally: + os.unlink(temp_path) + + def test_load_csv_nonexistent_file(self): + with pytest.raises(FileNotFoundError): + load_csv("nonexistent_file.csv", ["MODEL"]) + + +class TestEvaluateResponse: + @pytest.mark.asyncio + async def test_evaluate_response_success(self, mock_token_usage): + mock_handler = AsyncMock() + mock_handler.handle_request.return_value = ( + "Generated response text", + mock_token_usage, + {"input": 5.0, "output": 15.0}, + 1.5, + ) + + mock_extractiveness = Mock() + mock_extractiveness.compute.return_value = { + "summarization_coverage": 0.8, + "summarization_density": 0.6, + "summarization_compression": 0.4, + } + + with ( + patch( + "evaluation.evals.ModelFactory.get_handler", return_value=mock_handler + ), + patch("evaluation.evals.Extractiveness", return_value=mock_extractiveness), + ): + result = await evaluate_response("gpt-4", "Test instructions", "Test input") + + assert isinstance(result, pd.DataFrame) + assert len(result) == 1 + row = result.iloc[0] + assert row["Generated Text"] == "Generated response text" + assert row["Extractiveness Coverage"] == 0.8 + assert row["Input Token Usage"] == 100 + assert row["Output Token Usage"] == 50 + assert row["Duration (s)"] == 1.5 + + @pytest.mark.parametrize( + "exception_side_effect", ["get_handler", "handle_request", "extractiveness"] + ) + @pytest.mark.asyncio + async def test_evaluate_response_exceptions( + self, mock_token_usage, exception_side_effect + ): + if exception_side_effect == "get_handler": + with patch( + "evaluation.evals.ModelFactory.get_handler", + side_effect=Exception("Test error"), + ): + result = await evaluate_response( + "invalid-model", "Test instructions", "Test input" + ) + + elif exception_side_effect == "handle_request": + mock_handler = AsyncMock() + mock_handler.handle_request.side_effect = Exception("Handler error") + with patch( + "evaluation.evals.ModelFactory.get_handler", return_value=mock_handler + ): + result = await evaluate_response( + "gpt-4", "Test instructions", "Test input" + ) + + elif exception_side_effect == "extractiveness": + mock_handler = AsyncMock() + mock_handler.handle_request.return_value = ( + "text", + mock_token_usage, + {"input": 5.0, "output": 15.0}, + 1.5, + ) + mock_extractiveness = Mock() + mock_extractiveness.compute.side_effect = Exception("Extractiveness error") + + with ( + patch( + "evaluation.evals.ModelFactory.get_handler", + return_value=mock_handler, + ), + patch( + "evaluation.evals.Extractiveness", return_value=mock_extractiveness + ), + ): + result = await evaluate_response( + "gpt-4", "Test instructions", "Test input" + ) + + assert isinstance(result, pd.DataFrame) + assert len(result) == 1 + assert pd.isna(result.iloc[0]["Generated Text"]) + - assert isinstance(df, pd.DataFrame) - assert df.shape == (1, 8) - assert df["Output Text"].iloc[0] == "This is a summary." - assert df["Extractiveness Coverage"].iloc[0] == 0.8 - assert df["Extractiveness Density"].iloc[0] == 1.5 - assert df["Extractiveness Compression"].iloc[0] == 2.0 - assert df["Input Token Usage"].iloc[0] == 100 - assert df["Output Token Usage"].iloc[0] == 50 - - expected_cost = (15.0 / 1_000_000) * 100 + (30.0 / 1_000_000) * 50 - assert pytest.approx(df["Cost (USD)"].iloc[0], rel=1e-4) == expected_cost - assert pytest.approx(df["Duration (s)"].iloc[0], rel=1e-4) == 1.23 \ No newline at end of file +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/frontend/src/App.css b/frontend/src/App.css index 1f7ba717..718ef1ad 100644 --- a/frontend/src/App.css +++ b/frontend/src/App.css @@ -130,6 +130,11 @@ This is the wording logo in the nav bar @apply bg-gradient-to-r from-blue-600 via-blue-700 to-blue-600 bg-clip-text text-transparent; } +/* Using logo-like styles outside of the header */ +.body_logo { + @apply bg-gradient-to-r from-blue-500 via-blue-700 to-blue-300 bg-clip-text font-quicksand font-bold text-transparent +} + /* Tailwind Styles */ /* @@ -170,7 +175,7 @@ The heading directly under the nav bar. } .desc1 { - @apply mx-auto mt-5 hidden max-w-[75%] text-center font-satoshi text-lg text-gray-400 sm:text-xl md:block; + @apply mx-auto mt-5 text-center font-satoshi text-lg text-gray-800 sm:text-xl md:block; font-size: 100%; letter-spacing: 0.03em; } diff --git a/frontend/src/api/apiClient.ts b/frontend/src/api/apiClient.ts index 451c1c41..26a6ab8a 100644 --- a/frontend/src/api/apiClient.ts +++ b/frontend/src/api/apiClient.ts @@ -67,6 +67,18 @@ const handleRuleExtraction = async (guid: string) => { } }; +const fetchRiskDataWithSources = async (medication: string, source: "include" | "diagnosis" | "diagnosis_depressed" = "include") => { + try { + const response = await api.post(`/v1/api/riskWithSources`, { + drug: medication, + source: source, + }); + return response.data; + } catch (error) { + console.error("Error fetching risk data: ", error); + throw error; + } +}; interface StreamCallbacks { onContent?: (content: string) => void; @@ -165,7 +177,6 @@ const handleSendDrugSummaryStream = async ( } }; - // Legacy function for backward compatibility const handleSendDrugSummaryStreamLegacy = async ( message: string, @@ -256,6 +267,19 @@ const updateConversationTitle = async ( } }; +// Assistant API functions +const sendAssistantMessage = async (message: string, previousResponseId?: string) => { + try { + const response = await api.post(`/v1/api/assistant`, { + message, + previous_response_id: previousResponseId, + }); + return response.data; + } catch (error) { + console.error("Error(s) during sendAssistantMessage: ", error); + throw error; + } +}; export { handleSubmitFeedback, @@ -268,5 +292,7 @@ export { deleteConversation, updateConversationTitle, handleSendDrugSummaryStream, - handleSendDrugSummaryStreamLegacy + handleSendDrugSummaryStreamLegacy, + fetchRiskDataWithSources, + sendAssistantMessage }; \ No newline at end of file diff --git a/frontend/src/components/ChipsInput/ChipsInput.tsx b/frontend/src/components/ChipsInput/ChipsInput.tsx index 7fd03b4d..32e4e332 100644 --- a/frontend/src/components/ChipsInput/ChipsInput.tsx +++ b/frontend/src/components/ChipsInput/ChipsInput.tsx @@ -59,7 +59,6 @@ const ChipsInput: React.FC = ({ const handleSuggestionClick = (selected: string) => { onChange([...value, selected]); - setInputFocused(false); // Close dropdown after selection setInputValue(""); }; @@ -131,7 +130,10 @@ const ChipsInput: React.FC = ({ {filteredSuggestions.map((item, idx) => (
  • handleSuggestionClick(item)} + onMouseDown={(e) => { + e.preventDefault(); + handleSuggestionClick(item) + }} className="px-3 py-2 cursor-pointer hover:bg-gray-100 text-sm" > {item} diff --git a/frontend/src/components/Footer/Footer.tsx b/frontend/src/components/Footer/Footer.tsx index eeb01935..68a22263 100644 --- a/frontend/src/components/Footer/Footer.tsx +++ b/frontend/src/components/Footer/Footer.tsx @@ -28,11 +28,11 @@ function Footer() { return ( //
    -
    +
    {/*
    */}{" "} {/* Added mt-5 and mr-5 */}
    -
    +
    Leave feedback + + Donate +
    -
    - © 2025 Balancer. All rights reserved. V1 2-04-2025 +
    +

    © 2025 Balancer. All rights reserved. V1 2-04-2025

    +
    +
    +

    + Balancer is an educational resource designed to support{" "} + —never replace— the judgment of licensed U.S. clinicians. Final prescribing decisions must always be made by the treating clinician. +

    diff --git a/frontend/src/components/Header/Chat.tsx b/frontend/src/components/Header/Chat.tsx index 54e965f8..c6315068 100644 --- a/frontend/src/components/Header/Chat.tsx +++ b/frontend/src/components/Header/Chat.tsx @@ -4,12 +4,9 @@ import "../../components/Header/chat.css"; import { useState, useEffect, useRef } from "react"; import TypingAnimation from "./components/TypingAnimation"; import ErrorMessage from "../ErrorMessage"; -import ConversationList from "./ConversationList"; -import { extractContentFromDOM } from "../../services/domExtraction"; import axios from "axios"; import { FaPlus, - FaMinus, FaTimes, FaComment, FaComments, @@ -19,20 +16,15 @@ import { FaExpandAlt, FaExpandArrowsAlt, } from "react-icons/fa"; -import { - fetchConversations, - continueConversation, - newConversation, - updateConversationTitle, - deleteConversation, -} from "../../api/apiClient"; - -interface ChatLogItem { +import { sendAssistantMessage } from "../../api/apiClient"; + +export interface ChatLogItem { is_user: boolean; content: string; timestamp: string; // EX: 2025-01-16T16:21:14.981090Z } +// Keep interface for backward compatibility with existing imports export interface Conversation { title: string; messages: ChatLogItem[]; @@ -47,14 +39,41 @@ interface ChatDropDownProps { const Chat: React.FC = ({ showChat, setShowChat }) => { const CHATBOT_NAME = "JJ"; const [inputValue, setInputValue] = useState(""); - const [chatLog, setChatLog] = useState([]); // Specify the type as ChatLogItem[] + const [currentMessages, setCurrentMessages] = useState([]); + const [currentResponseId, setCurrentResponseId] = useState< + string | undefined + >(undefined); const [isLoading, setIsLoading] = useState(false); - const [showConversationList, setShowConversationList] = useState(false); - const [conversations, setConversations] = useState([]); - const [activeConversation, setActiveConversation] = - useState(null); const [error, setError] = useState(null); + // Session storage functions for conversation management + const saveConversationToStorage = ( + messages: ChatLogItem[], + responseId?: string, + ) => { + const conversationData = { + messages, + responseId, + timestamp: new Date().toISOString(), + }; + sessionStorage.setItem( + "currentConversation", + JSON.stringify(conversationData), + ); + }; + + const loadConversationFromStorage = () => { + const stored = sessionStorage.getItem("currentConversation"); + if (stored) { + try { + return JSON.parse(stored); + } catch (error) { + console.error("Error parsing stored conversation:", error); + } + } + return null; + }; + const suggestionPrompts = [ "What are the side effects of Latuda?", "Why is cariprazine better than valproate for a pregnant patient?", @@ -63,27 +82,29 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { "Risks associated with Lithium.", "What medications could cause liver issues?", ]; - const [pageContent, setPageContent] = useState(""); const chatContainerRef = useRef(null); + const [bottom, setBottom] = useState(false); + + // Load conversation from sessionStorage on component mount useEffect(() => { - const observer = new MutationObserver(() => { - const content = extractContentFromDOM(); - setPageContent(content); - }); - - observer.observe(document.body, { - childList: true, - subtree: true, - characterData: true, - }); - - const extractedContent = extractContentFromDOM(); - // console.log(extractedContent); - setPageContent(extractedContent); + const storedConversation = loadConversationFromStorage(); + if (storedConversation) { + setCurrentMessages(storedConversation.messages || []); + setCurrentResponseId(storedConversation.responseId); + } }, []); - const [bottom, setBottom] = useState(false); + // Save conversation to sessionStorage when component unmounts + useEffect(() => { + return () => { + // Only save if the user hasn't logged out + const isLoggingOut = !localStorage.getItem("access"); + if (!isLoggingOut && currentMessages.length > 0) { + saveConversationToStorage(currentMessages, currentResponseId); + } + }; + }, [currentMessages, currentResponseId]); const handleScroll = (event: React.UIEvent) => { const target = event.target as HTMLElement; @@ -96,34 +117,24 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { const [expandChat, setExpandChat] = useState(false); useEffect(() => { - if (chatContainerRef.current && activeConversation) { + if (chatContainerRef.current) { const chatContainer = chatContainerRef.current; // Use setTimeout to ensure the new message has been rendered setTimeout(() => { chatContainer.scrollTop = chatContainer.scrollHeight; setBottom( chatContainer.scrollHeight - chatContainer.scrollTop === - chatContainer.clientHeight + chatContainer.clientHeight, ); }, 0); } - }, [activeConversation?.messages]); - - const loadConversations = async () => { - try { - const data = await fetchConversations(); - setConversations(data); - // setLoading(false); - } catch (error) { - console.error("Error loading conversations: ", error); - } - }; + }, [currentMessages]); const scrollToBottom = (element: HTMLElement) => element.scroll({ top: element.scrollHeight, behavior: "smooth" }); const handleScrollDown = ( - event: React.MouseEvent + event: React.MouseEvent, ) => { event.preventDefault(); const element = document.getElementById("inside_chat"); @@ -137,72 +148,52 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { event: | React.FormEvent | React.MouseEvent, - suggestion?: string + suggestion?: string, ) => { event.preventDefault(); + const messageContent = (inputValue || suggestion) ?? ""; + if (!messageContent.trim()) return; + const newMessage = { - content: (inputValue || suggestion) ?? "", + content: messageContent, is_user: true, timestamp: new Date().toISOString(), }; - const newMessages = [...chatLog, newMessage]; - - setChatLog(newMessages); - - // sendMessage(newMessages); try { - let conversation = activeConversation; - let conversationCreated = false; - - // Create a new conversation if none exists - if (!conversation) { - conversation = await newConversation(); - setActiveConversation(conversation); - setShowConversationList(false); - conversationCreated = true; - } + setIsLoading(true); + setError(null); - // Update the conversation with the new user message - const updatedMessages = [...conversation.messages, newMessage]; - setActiveConversation({ - ...conversation, - title: "Asking JJ...", - messages: updatedMessages, - }); + // Add user message to current conversation + const updatedMessages = [...currentMessages, newMessage]; + setCurrentMessages(updatedMessages); - setIsLoading(true); + // Save user message immediately to prevent loss + saveConversationToStorage(updatedMessages, currentResponseId); - // Continue the conversation and update with the bot's response - const data = await continueConversation( - conversation.id, - newMessage.content, - pageContent + // Call assistant API with previous response ID for continuity + const data = await sendAssistantMessage( + messageContent, + currentResponseId, ); - // Update the ConversationList component after previous function creates a title - if (conversationCreated) loadConversations(); // Note: no 'await' so this can occur in the background - - setActiveConversation((prevConversation: any) => { - if (!prevConversation) return null; - - return { - ...prevConversation, - messages: [ - ...prevConversation.messages, - { - is_user: false, - content: data.response, - timestamp: new Date().toISOString(), - }, - ], - title: data.title, - }; - }); - setError(null); + // Create assistant response message + const assistantMessage = { + content: data.response_output_text, + is_user: false, + timestamp: new Date().toISOString(), + }; + + // Update messages and store new response ID for next message + const finalMessages = [...updatedMessages, assistantMessage]; + setCurrentMessages(finalMessages); + setCurrentResponseId(data.final_response_id); + + // Save conversation to sessionStorage + saveConversationToStorage(finalMessages, data.final_response_id); } catch (error) { - console.error("Error(s) handling conversation:", error); + console.error("Error handling message:", error); let errorMessage = "Error submitting message"; if (error instanceof Error) { errorMessage = error.message; @@ -222,25 +213,8 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { } }; - const handleSelectConversation = (id: Conversation["id"]) => { - const selectedConversation = conversations.find( - (conversation: any) => conversation.id === id - ); - - if (selectedConversation) { - setActiveConversation(selectedConversation); - setShowConversationList(false); - } - }; - - const handleNewConversation = () => { - setActiveConversation(null); - setShowConversationList(false); - }; - useEffect(() => { if (showChat) { - loadConversations(); const resizeObserver = new ResizeObserver((entries) => { if (!entries || entries.length === 0) return; @@ -278,53 +252,35 @@ const Chat: React.FC = ({ showChat, setShowChat }) => { className=" mx-auto flex h-full flex-col overflow-auto rounded " >
    - - -
    - {activeConversation !== null && !showConversationList ? ( - activeConversation.title - ) : ( - <> - - Ask {CHATBOT_NAME} - - )} -
    +
    + + Ask {CHATBOT_NAME}
    + + + - {showConversationList ? ( -
    - -
    - ) : ( -
    - {activeConversation === null || - activeConversation.messages.length === 0 ? ( - <> -
    -
    Hi there, I'm {CHATBOT_NAME}!
    -

    - You can ask me all your bipolar disorder treatment - questions. -

    - - Learn more about my sources. - -
    -
    -
    - -
    Explore a medication
    -
    -
      - {suggestionPrompts.map((suggestion, index) => ( -
    • - -
    • - ))} -
    +
    + {currentMessages.length === 0 ? ( + <> +
    +
    Hi there, I'm {CHATBOT_NAME}!
    +

    + You can ask me questions about your uploaded documents. + I'll search through them to provide accurate, cited + answers. +

    + + Learn more about my sources. + +
    +
    +
    + ⚠️ IMPORTANT NOTICE +
    +

    + Balancer is NOT configured for use with Protected Health + Information (PHI) as defined under HIPAA. You must NOT + enter any patient-identifiable information including + names, addresses, dates of birth, medical record + numbers, or any other identifying information. By using + Balancer, you certify that you understand these + restrictions and will not enter any PHI. +

    +
    +
    +
    + +
    Explore a medication
    -
    -
    - -
    Refresh your memory
    -
    -
      - {refreshPrompts.map((suggestion, index) => ( -
    • - -
    • - ))} -
    +
      + {suggestionPrompts.map((suggestion, index) => ( +
    • + +
    • + ))} +
    +
    +
    +
    + +
    Refresh your memory
    - - ) : ( - activeConversation.messages - .slice() - .sort( - (a, b) => - new Date(a.timestamp).getTime() - - new Date(b.timestamp).getTime() - ) - .map((message, index) => ( -
    -
    - {message.is_user ? ( - - ) : ( - - )} -
    -
    -
    +                        {refreshPrompts.map((suggestion, index) => (
    +                          
  • +
  • -
    + {suggestion} + +
  • + ))} + + + + ) : ( + currentMessages + .slice() + .sort( + (a, b) => + new Date(a.timestamp).getTime() - + new Date(b.timestamp).getTime(), + ) + .map((message, index) => ( +
    +
    + +
    +
    +
    +                            {message.content}
    +                          
    - )) - )} - {isLoading && ( -
    -
    -
    + )) + )} + {isLoading && ( +
    +
    +
    - )} - {error && } -
    - )} +
    + )} + {error && } +
    = ({ showChat, setShowChat }) => { ) : ( -
    setShowChat(true)} - className="fixed bottom-9 left-10 h-16 w-16 inline-block cursor-pointer flex items-center justify-center rounded-full bg-blue-500 object-contain hover:cursor-pointer hover:bg-blue-300 md:bottom-20 md:right-20 no-print" - > - -
    - Any questions? Click here to to chat! -
    +
    setShowChat(true)} className="chat_button no-print"> +
    )} diff --git a/frontend/src/components/Header/FeatureMenuDropDown.tsx b/frontend/src/components/Header/FeatureMenuDropDown.tsx index a5c88356..b1bbf03e 100644 --- a/frontend/src/components/Header/FeatureMenuDropDown.tsx +++ b/frontend/src/components/Header/FeatureMenuDropDown.tsx @@ -1,21 +1,23 @@ -import { Link } from "react-router-dom"; +import { Link, useLocation } from "react-router-dom"; export const FeatureMenuDropDown = () => { + const location = useLocation(); + const currentPath = location.pathname; return (
    -
      - Manage files +
        + Manage files
        - Mange and chat with files. + Manage and chat with files
      -
        - Manage rules +
          + Manage rules
          Manage list of rules @@ -23,15 +25,15 @@ export const FeatureMenuDropDown = () => {
        -
          - Manage meds +
            + Manage meds
            Manage list of meds
          - {/* + {/*
            Manage Prompts @@ -63,4 +65,4 @@ export const FeatureMenuDropDown = () => {
    ); -}; +}; \ No newline at end of file diff --git a/frontend/src/components/Header/Header.tsx b/frontend/src/components/Header/Header.tsx index 30c704cd..f696b614 100644 --- a/frontend/src/components/Header/Header.tsx +++ b/frontend/src/components/Header/Header.tsx @@ -1,209 +1,243 @@ -import {useState, useRef, useEffect, Fragment} from "react"; +import { useState, useRef, useEffect, Fragment } from "react"; // import { useState, Fragment } from "react"; import accountLogo from "../../assets/account.svg"; -import {Link, useNavigate} from "react-router-dom"; +import { Link, useNavigate, useLocation } from "react-router-dom"; import LoginMenuDropDown from "./LoginMenuDropDown"; import "../../components/Header/header.css"; import Chat from "./Chat"; -import {FeatureMenuDropDown} from "./FeatureMenuDropDown"; +import { FeatureMenuDropDown } from "./FeatureMenuDropDown"; import MdNavBar from "./MdNavBar"; -import {connect, useDispatch} from "react-redux"; -import {RootState} from "../../services/actions/types"; -import {logout, AppDispatch} from "../../services/actions/auth"; -import {HiChevronDown} from "react-icons/hi"; -import {useGlobalContext} from "../../contexts/GlobalContext.tsx"; +import { connect, useDispatch } from "react-redux"; +import { RootState } from "../../services/actions/types"; +import { logout, AppDispatch } from "../../services/actions/auth"; +import { HiChevronDown } from "react-icons/hi"; +import { useGlobalContext } from "../../contexts/GlobalContext.tsx"; interface LoginFormProps { - isAuthenticated: boolean; - isSuperuser: boolean; + isAuthenticated: boolean; + isSuperuser: boolean; } -const Header: React.FC = ({ - isAuthenticated, - isSuperuser, - }) => { - const navigate = useNavigate(); - const [showFeaturesMenu, setShowFeaturesMenu] = useState(false); - const dropdownRef = useRef(null); - let delayTimeout: number | null = null; - const [showChat, setShowChat] = useState(false); - const [showLoginMenu, setShowLoginMenu] = useState(false); - const [redirect, setRedirect] = useState(false); - const {setShowSummary, setEnterNewPatient, triggerFormReset, setIsEditing} = useGlobalContext() - - const dispatch = useDispatch(); - - const logout_user = () => { - dispatch(logout()); - setRedirect(false); - }; - - const guestLinks = () => ( - + ); - const authLinks = () => ( - + ); + + const handleLoginMenu = () => { + setShowLoginMenu(!showLoginMenu); + }; + + const handleMouseEnter = () => { + if (delayTimeout !== null) { + clearTimeout(delayTimeout); + } + setShowFeaturesMenu(true); + }; + + const handleMouseLeave = () => { + delayTimeout = setTimeout(() => { + setShowFeaturesMenu(false); + }, 300) as unknown as number; // Adjust the delay time as needed + }; + + useEffect(() => { + return () => { + if (delayTimeout !== null) { + clearTimeout(delayTimeout); + } }; - - const handleMouseEnter = () => { - if (delayTimeout !== null) { - clearTimeout(delayTimeout); + }, [delayTimeout]); + + useEffect(() => { + window.scrollTo(0, 0); + }, [navigate]); + + const handleForm = () => { + setIsEditing(false); + triggerFormReset(); + setEnterNewPatient(true); + setShowSummary(false); + navigate("/"); + }; + + const location = useLocation(); + const currentPath = location.pathname; + + return ( +
    +
    +

    + Welcome to Balancer’s first release! Found a bug or have feedback? Let us know {" "} + + here {" "} + + or email {" "} + + balancerteam@codeforphilly.org + + . +

    +
    + + )} + + {redirect ? navigate("/") : } + + + + {isAuthenticated && ( + + )} + {/* */} + {isAuthenticated ? authLinks() : guestLinks()} +
    + + + ); }; const mapStateToProps = (state: RootState) => ({ - isAuthenticated: state.auth.isAuthenticated, - isSuperuser: state.auth.isSuperuser, + isAuthenticated: state.auth.isAuthenticated, + isSuperuser: state.auth.isSuperuser, }); const ConnectedLayout = connect(mapStateToProps)(Header); diff --git a/frontend/src/components/Header/MdNavBar.tsx b/frontend/src/components/Header/MdNavBar.tsx index 10a5c1c8..926794cf 100644 --- a/frontend/src/components/Header/MdNavBar.tsx +++ b/frontend/src/components/Header/MdNavBar.tsx @@ -1,146 +1,158 @@ -import { useState } from "react"; -import { Link } from "react-router-dom"; +import {useState} from "react"; +import {Link} from "react-router-dom"; // import { useLocation } from "react-router-dom"; import Chat from "./Chat"; // import logo from "../../assets/balancer.png"; import closeLogo from "../../assets/close.svg"; import hamburgerLogo from "../../assets/hamburger.svg"; -import { useDispatch } from "react-redux"; -import { logout, AppDispatch } from "../../services/actions/auth"; +import {useDispatch} from "react-redux"; +import {logout, AppDispatch} from "../../services/actions/auth"; interface LoginFormProps { - isAuthenticated: boolean; + isAuthenticated: boolean; + handleForm: () => void; } const MdNavBar = (props: LoginFormProps) => { - const [nav, setNav] = useState(true); - // const { pathname } = useLocation(); - const [showChat, setShowChat] = useState(false); - const { isAuthenticated } = props; - const handleNav = () => { - setNav(!nav); - }; + const [nav, setNav] = useState(true); + // const { pathname } = useLocation(); + const [showChat, setShowChat] = useState(false); + const {isAuthenticated, handleForm} = props; + const handleNav = () => { + setNav(!nav); + }; - const dispatch = useDispatch(); + const dispatch = useDispatch(); - const logout_user = () => { - dispatch(logout()); - }; + const logout_user = () => { + dispatch(logout()); + }; - return ( -
    - +
    + {nav && ( + logo + )} +
    +
    +
    Balancer -
    - {!nav && ( - logo +
    + {!nav && ( + logo + )} +
    +
    + +
      +
    • + + About Balancer + +
    • +
    • + + Help + +
    • +
    • + + Medical Suggester + +
    • +
    • + + Medications List + +
    • +
    • + + Chat + +
    • +
    • + + Leave Feedback + +
    • +
    • + + Donate + +
    • + {isAuthenticated && +
    • + + Sign Out + +
    • + } +
    +
    + {isAuthenticated && ( + )} -
    - -
      -
    • - - About Balancer - -
    • -
    • - - Help - -
    • -
    • - - Medical Suggester - -
    • -
    • - - Medications List - -
    • -
    • - - Chat - -
    • -
    • - - Leave Feedback - -
    • - {isAuthenticated && -
    • - - Sign Out - -
    • - } -
    - - {isAuthenticated && ( - - )} - - ); + ); }; export default MdNavBar; diff --git a/frontend/src/components/Header/chat.css b/frontend/src/components/Header/chat.css index 59ab9608..65e5e07f 100644 --- a/frontend/src/components/Header/chat.css +++ b/frontend/src/components/Header/chat.css @@ -1,3 +1,11 @@ +.chat_button { + @apply fixed bottom-9 left-10 h-16 w-16 cursor-pointer rounded-full bg-blue-500 hover:cursor-pointer hover:bg-blue-300 md:bottom-20 md:right-20 flex items-center justify-center animate-pulse-bounce; +} + +.chat_button_animations { + @apply flex justify-center items-center; +} + .inside_chat { @apply grow px-2 py-2 bg-neutral-100 overflow-y-auto overflow-x-hidden; } @@ -7,7 +15,7 @@ } .show_chat.windowed { - @apply h-[50%] w-[320px]; + @apply h-[75%] sm:w-[75%] md:w-[550px]; } .show_chat.full-screen { @@ -87,4 +95,15 @@ ul.chat_suggestion_list { .scroll_down { @apply z-40 absolute bottom-[90px] left-[45%] text-3xl text-gray-400 hover:text-blue-500 rounded-full border-2 border-white bg-white; +} + +@keyframes pulse-bounce { + 0%, 100% { + opacity: 1; + transform: translateY(0); + } + 50% { + opacity: .5; + transform: translateY(-25%); + } } \ No newline at end of file diff --git a/frontend/src/components/Header/header.css b/frontend/src/components/Header/header.css index f5c56227..4b0f4a2c 100644 --- a/frontend/src/components/Header/header.css +++ b/frontend/src/components/Header/header.css @@ -21,3 +21,19 @@ .arrow:hover { animation: none; } + +.header-nav-item { + @apply text-black border-transparent border-b-2 hover:border-blue-600 hover:text-blue-600 hover:border-b-2 hover:border-blue-600; +} + +.header-nav-item.header-nav-item-selected { + @apply text-blue-600; +} + +.subheader-nav-item { + @apply cursor-pointer rounded-lg p-3 transition duration-300 hover:bg-gray-100; +} + +.subheader-nav-item.subheader-nav-item-selected { + @apply text-blue-600; +} \ No newline at end of file diff --git a/frontend/src/pages/About/About.tsx b/frontend/src/pages/About/About.tsx index 4d4a14a6..b8170333 100644 --- a/frontend/src/pages/About/About.tsx +++ b/frontend/src/pages/About/About.tsx @@ -6,102 +6,68 @@ import image from "./OIP2.png"; function About() { return ( -
    - {/* Top section */} -
    -
    -
    - A tool that makes it easier to research medications for bipolar - disorder. -
    -
    - It can take two to 10 years—and three to 30 medications—for people - with bipolar disorder to find the right medication combination. - Balancer is designed to help physicians shorten this journey for - their patients. -
    -
    - about image -
    - - {/* Middle section */} -
    -
    -
    - Get accurate, helpful information on bipolar medications fast -
    -
    - Powered by innovative AI technology, Balancer is a tool that aids - in providing personalized medication recommendations for patients - with bipolar disorder in any state, including mania, depression, - hypomania and mixed. Our platform utilizes machine learning to - give you the latest, most up-to-date information on medications - and active clinical trials to treat bipolar disorder.{" "} -
    -
    - Balancer automates medication decision support by offering - tailored medication recommendations and comprehensive risk-benefit - assessments based on a patient's diagnosis, symptom severity, - treatment goals and individual characteristics.{" "} -
    -
    Our mission
    -
    -
    - Bipolar disorder affects approximately 5.7 million adult - Americans{" "} - - every year - - . Delays in the correct diagnosis and proper treatment of - bipolar disorder may result in social, occupational, and - economic burdens, as well as{" "} - - an increase in completed suicides - - . +
    + {/* Making it easier to research bipolar medications */} +
    +
    +
    +
    + Making it easier to research bipolar medications
    -
    - The team behind Balancer believes that building a searchable, - interactive and user-friendly research tool for bipolar - medications has the potential to improve the health and - well-being of people with bipolar disorder. +
    + It can take two to 10 years—and three to 30 medications—for people with bipolar disorder to find the right medication combination. Balancer is designed to help prescribers speed up that process by making research faster and more accessible.
    + about image
    - {/*
    -
    -
    44 million
    -
    Transactions every 24 hours
    -
    $119 million
    -
    Assets under holding
    -
    46,000
    -
    New users annually
    -
    -
    */}
    - + {/* How Balancer works */} +
    +
    + How Balancer works +
    +
      +
    • + Medication Suggestions (rules-based): +

      When you enter patient characteristics, Balancer suggests first-line, second-line, and third-line options. The recommendations follow a consistent framework developed from interviews with psychiatrists, psychiatry residents, nurse practitioners, and other prescribers. This part is not powered by AI.

      +
    • +
    • + Explanations & Research (AI-assisted): +

      For each suggestion, you can click to see supporting journal articles. Here, Balancer uses AI to search our database of medical research and highlight relevant sources for further reading.

      +
    • +
    +

    Together, these features help prescribers get reliable starting points quickly—without replacing professional judgment.

    +
    + {/* Important disclaimer */} +
    +
    + Important disclaimer +
    +

    Balancer is a free, open-source research tool built by volunteers at Code for Philly. It is for licensed U.S. prescribers and researchers only.

    +
      +
    • Balancer does not provide medical advice.

    • +
    • It does not determine treatment or replace clinical judgment.

    • +
    • Clinical decisions should always be based on the prescriber's expertise, knowledge of the patient, and official medical guidelines.

    • +
    +
    + {/* Our mission */} +
    +
    + Our mission +
    +

    Bipolar disorder affects approximately 5.7 million adult Americans every year. Delays in the correct diagnosis and proper treatment of bipolar disorder may result in social, occupational, and economic burdens, as well as an increase in completed suicides.

    +

    The team behind Balancer believes that building a searchable, interactive and user-friendly research tool for bipolar medications has the potential to improve the health and well-being of people with bipolar disorder.

    +
    {/* Support Us section */} -
    +
    Support Us
    -
    -
    - Balancer is a not-for-profit, civic-minded, open-source project +
    +
    + Balancer is a not-for-profit, civic-minded, open-source project sponsored by{" "} Code for Philly @@ -109,9 +75,14 @@ function About() { .
    +
    + + + -
    - + diff --git a/frontend/src/pages/AdminPortal/AdminPortal.tsx b/frontend/src/pages/AdminPortal/AdminPortal.tsx index c214270a..e0aaa5ab 100644 --- a/frontend/src/pages/AdminPortal/AdminPortal.tsx +++ b/frontend/src/pages/AdminPortal/AdminPortal.tsx @@ -1,115 +1,131 @@ //import Welcome from "../../components/Welcome/Welcome.tsx"; -import { Link } from "react-router-dom"; +import {Link} from "react-router-dom"; import pencilSVG from "../../assets/pencil.svg"; import uploadSVG from "../../assets/upload.svg"; import Layout_V2_Main from "../Layout/Layout_V2_Main"; +import React from "react"; const AdminPortal = () => { - return ( - <> - -
    -
    -
    -
    - Let's take a look inside Balancer's brain. You can manage the - brain from this portal. -
    -
    - {/*
    */} -
    - -
    -
    - - Manage Files - - - Manage the files stored in the Balancer's brain - -
    - Description of SVG -
    - - -
    -
    - - Upload PDF - - - Add to Balancer's brain - -
    - Description of SVG -
    - - -
    -
    - - Ask General Questions - - - Get answers from Balancer's brain - -
    - Description of SVG -
    - - -
    -
    - - Rules Manager for Medication Suggester - - - Manage and view the rules for the Medication Suggester - -
    - Description of SVG -
    - - -
    -
    - - Medications Database Manager - - - Manager the Medications store currently inside Balancer's - brain - -
    - Description of SVG + return ( + <> + +
    +
    +
    +
    + Let's take a look inside Balancer's brain. You can manage the + brain from this portal. +
    +
    + {/*
    */} +
    + + + +
    + + Manage Files + + + Manage the files stored in the Balancer's brain + +
    + Description of SVG +
    + + + + +
    + + Upload PDF + + + Add to Balancer's brain + +
    + Description of SVG + +
    + + + +
    + + Ask General Questions + + + Get answers from Balancer's brain + +
    + Description of SVG +
    + + + +
    + + Rules Manager + + + Manage and view the rules for the Medication Suggester + +
    + Description of SVG +
    + + + +
    + + Medications Database + + + Manage the Medications store currently inside Balancer's + brain + +
    + Description of SVG +
    + + +
    +
    - -
    -
    -
    - - - ); + + + ); }; export default AdminPortal; + + +const AdminDashboardItemWrapper = ({children}: { children: React.ReactNode }) => { + return ( +
    + {children} +
    + ) +} \ No newline at end of file diff --git a/frontend/src/pages/DrugSummary/DrugSummaryForm.tsx b/frontend/src/pages/DrugSummary/DrugSummaryForm.tsx index 46bb8647..d0a00f92 100644 --- a/frontend/src/pages/DrugSummary/DrugSummaryForm.tsx +++ b/frontend/src/pages/DrugSummary/DrugSummaryForm.tsx @@ -27,7 +27,6 @@ const DrugSummaryForm = () => { const { search } = useLocation(); const params = new URLSearchParams(search); const guid = params.get("guid") || ""; - const pageParam = params.get("page"); const textareaRef = useRef(null); useEffect(() => { @@ -49,17 +48,6 @@ const DrugSummaryForm = () => { useEffect(() => setHasPDF(!!guid), [guid]); - useEffect(() => { - if (pageParam && hasPDF) { - const page = parseInt(pageParam, 10); - if (!isNaN(page) && page > 0) { - window.dispatchEvent( - new CustomEvent("navigateToPdfPage", { detail: { pageNumber: page } }) - ); - } - } - }, [pageParam, hasPDF]); - useEffect(() => { if (!isStreaming && !isLoading && textareaRef.current) { textareaRef.current.focus(); @@ -184,7 +172,7 @@ const DrugSummaryForm = () => {
    {hasPDF && (
    - +
    )}
    diff --git a/frontend/src/pages/DrugSummary/PDFViewer.tsx b/frontend/src/pages/DrugSummary/PDFViewer.tsx index d0cd8636..39ddfbfc 100644 --- a/frontend/src/pages/DrugSummary/PDFViewer.tsx +++ b/frontend/src/pages/DrugSummary/PDFViewer.tsx @@ -1,9 +1,18 @@ -import { useState, useEffect, useMemo, useCallback, useRef } from "react"; +import { + useState, + useEffect, + useMemo, + useCallback, + useRef, + useTransition, + useDeferredValue, +} from "react"; import { Document, Page, pdfjs } from "react-pdf"; import { useLocation, useNavigate } from "react-router-dom"; import axios from "axios"; import "react-pdf/dist/esm/Page/AnnotationLayer.css"; import "react-pdf/dist/esm/Page/TextLayer.css"; +import ZoomMenu from "./ZoomMenu"; pdfjs.GlobalWorkerOptions.workerSrc = `//cdnjs.cloudflare.com/ajax/libs/pdf.js/${pdfjs.version}/pdf.worker.min.js`; @@ -15,6 +24,10 @@ const PDFViewer = () => { const [numPages, setNumPages] = useState(null); const [pageNumber, setPageNumber] = useState(1); const [scale, setScale] = useState(1.0); + const [uiScalePct, setUiScalePct] = useState(100); + const deferredScale = useDeferredValue(scale); + const [isPending, startTransition] = useTransition(); + const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const [pdfData, setPdfData] = useState(null); @@ -24,81 +37,170 @@ const PDFViewer = () => { null ); - const manualScrollInProgress = useRef(false); - const PAGE_INIT_DELAY = 800; - const headerRef = useRef(null); const containerRef = useRef(null); const contentRef = useRef(null); const pageRefs = useRef>({}); - const initializationRef = useRef(false); + const prevGuidRef = useRef(null); + const isFetchingRef = useRef(false); + const isAutoScrollingRef = useRef(false); const location = useLocation(); const navigate = useNavigate(); const params = new URLSearchParams(location.search); const guid = params.get("guid"); const pageParam = params.get("page"); + const baseURL = import.meta.env.VITE_API_BASE_URL as string | undefined; - const baseURL = import.meta.env.VITE_API_BASE_URL; - const pdfUrl = useMemo( - () => (guid ? `${baseURL}/v1/api/uploadFile/${guid}` : null), - [guid, baseURL] - ); + const pdfUrl = useMemo(() => { + return guid && baseURL ? `${baseURL}/v1/api/uploadFile/${guid}` : null; + }, [guid, baseURL]); + + useEffect(() => setUiScalePct(Math.round(scale * 100)), [scale]); useEffect(() => { - pageRefs.current = {}; - setIsDocumentLoaded(false); - initializationRef.current = false; + const nextPage = pageParam ? parseInt(pageParam, 10) : 1; + const guidChanged = guid !== prevGuidRef.current; - if (pageParam) { - const page = parseInt(pageParam, 10); - if (!isNaN(page) && page > 0) setTargetPageAfterLoad(page); - } else { - setTargetPageAfterLoad(1); + if (guidChanged) { + setIsDocumentLoaded(false); + setNumPages(null); + setPdfData(null); + setTargetPageAfterLoad(null); + const validPage = !isNaN(nextPage) && nextPage > 0 ? nextPage : 1; + setPageNumber(validPage); + setTargetPageAfterLoad(validPage); + prevGuidRef.current = guid; + return; } - }, [guid, pageParam]); - const scrollToPage = useCallback( - (page: number) => { - if (page < 1 || !numPages || page > numPages) return; - const targetRef = pageRefs.current[page]; - if (!targetRef) return; - - manualScrollInProgress.current = true; - targetRef.scrollIntoView({ behavior: "smooth", block: "start" }); - - const observer = new IntersectionObserver( - (entries, obs) => { - const entry = entries[0]; - if (entry?.isIntersecting) { - manualScrollInProgress.current = false; - obs.disconnect(); - } - }, - { threshold: 0.5 } - ); - observer.observe(targetRef); + const desired = !isNaN(nextPage) && nextPage > 0 ? nextPage : 1; + if (pageNumber !== desired) { + if (isDocumentLoaded && numPages) { + const clampedPage = Math.max(1, Math.min(desired, numPages)); + setPageNumber(clampedPage); + scrollToPage(clampedPage); + } else { + setPageNumber(desired); + setTargetPageAfterLoad(desired); + } + } + + prevGuidRef.current = guid; + }, [guid, pageParam, pageNumber, isDocumentLoaded, numPages]); + + const updateCurrentPageFromScroll = useCallback(() => { + if (!numPages || !contentRef.current) return; + if (isAutoScrollingRef.current) return; + + const container = contentRef.current; + const containerRectTop = container.getBoundingClientRect().top; + const containerCenter = containerRectTop + container.clientHeight / 2; + + let bestPage = 1; + let bestDist = Infinity; + + for (let i = 1; i <= numPages; i++) { + const el = pageRefs.current[i]; + if (!el) continue; + const r = el.getBoundingClientRect(); + const pageCenter = r.top + r.height / 2; + const dist = Math.abs(pageCenter - containerCenter); + if (dist < bestDist) { + bestDist = dist; + bestPage = i; + } + } + + if (bestPage !== pageNumber) { + console.log( + `Scroll detected: updating page from ${pageNumber} to ${bestPage}` + ); + setPageNumber(bestPage); const newParams = new URLSearchParams(location.search); - newParams.set("page", String(page)); + newParams.set("page", String(bestPage)); navigate(`${location.pathname}?${newParams.toString()}`, { replace: true, }); - setPageNumber(page); + } + }, [numPages, pageNumber, location.pathname, location.search, navigate]); + + useEffect(() => { + const container = contentRef.current; + if (!container) return; + + let ticking = false; + const onScroll = () => { + if (ticking) return; + ticking = true; + requestAnimationFrame(() => { + updateCurrentPageFromScroll(); + ticking = false; + }); + }; + + container.addEventListener("scroll", onScroll, { passive: true }); + return () => container.removeEventListener("scroll", onScroll); + }, [updateCurrentPageFromScroll]); + + const scrollToPage = useCallback( + (page: number) => { + if (!numPages || page < 1 || page > numPages) return; + + const targetRef = pageRefs.current[page]; + if (!targetRef) { + setTimeout(() => scrollToPage(page), 100); + return; + } + + isAutoScrollingRef.current = true; + targetRef.scrollIntoView({ + behavior: "smooth", + block: "start", + inline: "nearest", + }); + + const newParams = new URLSearchParams(location.search); + if (newParams.get("page") !== String(page)) { + newParams.set("page", String(page)); + navigate(`${location.pathname}?${newParams.toString()}`, { + replace: true, + }); + } + + setTimeout(() => { + isAutoScrollingRef.current = false; + updateCurrentPageFromScroll(); + }, 500); }, - [numPages, navigate, location.pathname, location.search] + [ + numPages, + location.pathname, + location.search, + navigate, + updateCurrentPageFromScroll, + ] ); const goToPage = useCallback( (page: number) => { if (typeof page !== "number" || isNaN(page)) return; - if (page < 1) page = 1; - else if (numPages && page > numPages) page = numPages; - setPageNumber(page); - scrollToPage(page); + const clamped = Math.max(1, numPages ? Math.min(page, numPages) : page); + + if (!isDocumentLoaded || !numPages) { + setTargetPageAfterLoad(clamped); + setPageNumber(clamped); + return; + } + + if (clamped === pageNumber) return; + + setPageNumber(clamped); + scrollToPage(clamped); }, - [numPages, scrollToPage] + [isDocumentLoaded, numPages, pageNumber, scrollToPage] ); useEffect(() => { @@ -115,39 +217,56 @@ const PDFViewer = () => { }, [goToPage]); useEffect(() => { - if ( - isDocumentLoaded && - numPages && - targetPageAfterLoad && - Object.keys(pageRefs.current).length > 0 - ) { - const validPage = Math.min(Math.max(1, targetPageAfterLoad), numPages); - setPageNumber(validPage); + if (isDocumentLoaded && numPages && targetPageAfterLoad !== null) { + const validPage = Math.max(1, Math.min(targetPageAfterLoad, numPages)); + console.log(`Navigating to page ${validPage} after document load`); const timeoutId = setTimeout(() => { scrollToPage(validPage); - setTargetPageAfterLoad(null); - }, PAGE_INIT_DELAY); + + setTimeout(() => { + setTargetPageAfterLoad(null); + }, 600); + }, 100); return () => clearTimeout(timeoutId); } }, [isDocumentLoaded, numPages, targetPageAfterLoad, scrollToPage]); useEffect(() => { - const calculateSize = () => { - if (containerRef.current && headerRef.current && contentRef.current) { - const headerHeight = headerRef.current.offsetHeight; - const contentPadding = 32; - const availableHeight = - containerRef.current.clientHeight - headerHeight - contentPadding; - const availableWidth = contentRef.current.clientWidth - contentPadding; - - setContainerSize({ width: availableWidth, height: availableHeight }); - } + const calc = () => { + if (!containerRef.current || !headerRef.current || !contentRef.current) + return; + const headerHeight = headerRef.current.offsetHeight; + const contentPadding = 32; + const availableHeight = + containerRef.current.clientHeight - headerHeight - contentPadding; + const availableWidth = contentRef.current.clientWidth - contentPadding; + setContainerSize({ + width: Math.max(0, availableWidth), + height: Math.max(0, availableHeight), + }); + }; + + const id = requestAnimationFrame(calc); + + let ro: ResizeObserver | null = null; + if ("ResizeObserver" in window && contentRef.current) { + ro = new ResizeObserver(() => calc()); + ro.observe(contentRef.current); + } else { + const onResize = () => requestAnimationFrame(calc); + window.addEventListener("resize", onResize); + return () => { + cancelAnimationFrame(id); + window.removeEventListener("resize", onResize); + }; + } + + return () => { + cancelAnimationFrame(id); + if (ro && contentRef.current) ro.disconnect(); }; - calculateSize(); - window.addEventListener("resize", calculateSize); - return () => window.removeEventListener("resize", calculateSize); }, []); const pdfOptions = useMemo( @@ -168,8 +287,10 @@ const PDFViewer = () => { }, []); const fetchPdf = useCallback(async () => { - if (!pdfUrl) return; + if (!pdfUrl || isFetchingRef.current) return; + try { + isFetchingRef.current = true; setLoading(true); setError(null); const token = localStorage.getItem("access"); @@ -193,6 +314,7 @@ const PDFViewer = () => { setPdfData(null); } finally { setLoading(false); + isFetchingRef.current = false; } }, [pdfUrl, isPDF]); @@ -202,6 +324,7 @@ const PDFViewer = () => { const onDocumentLoadSuccess = useCallback( ({ numPages }: DocumentLoadSuccess) => { + console.log(`Document loaded with ${numPages} pages`); setNumPages(numPages); setError(null); setIsDocumentLoaded(true); @@ -212,6 +335,9 @@ const PDFViewer = () => { if (!guid) return
    No document specified.
    ; + const baseWidth = Math.max(0, (containerSize.width || 0) - 50); + const readyToRender = !!file && containerSize.width > 0; + return (
    { >
    -
    +
    + setUiScalePct(pct)} + onSelectPct={(pct) => { + setUiScalePct(pct); + startTransition(() => setScale(pct / 100)); + }} + onPageFit={() => { + const pct = 100; + setUiScalePct(pct); + startTransition(() => setScale(pct / 100)); + }} + /> +
    + +
    - - Page {pageNumber} of {numPages || "-"} - + +
    + {pageNumber} + of {numPages || "-"} +
    +
    -
    - - {Math.round(scale * 100)}% - + +
    + {isPending && ( +
    + Rendering… +
    + )}
    +
    { Retry
    - ) : ( - file && ( -
    + setError(err.message)} + options={pdfOptions} > - setError(err.message)} - options={pdfOptions} - > -
    - {Array.from({ length: numPages || 0 }, (_, index) => { - const pageNum = index + 1; - return ( -
    + {Array.from({ length: numPages || 0 }, (_, index) => { + const pageNum = index + 1; + return ( +
    { + pageRefs.current[pageNum] = el; + }} + className="mb-4 w-full" + data-page={pageNum} + > + { - if (el) pageRefs.current[pageNum] = el; - }} - className="mb-4 w-full" - > - -
    - Page {pageNum} of {numPages} -
    + pageNumber={pageNum} + width={baseWidth} + scale={deferredScale} + renderTextLayer={true} + renderAnnotationLayer={false} + className="shadow-lg" + /> +
    + Page {pageNum} of {numPages}
    - ); - })} -
    - -
    - ) +
    + ); + })} +
    + +
    + ) : ( +
    )}
    diff --git a/frontend/src/pages/DrugSummary/ZoomMenu.tsx b/frontend/src/pages/DrugSummary/ZoomMenu.tsx new file mode 100644 index 00000000..2e51b227 --- /dev/null +++ b/frontend/src/pages/DrugSummary/ZoomMenu.tsx @@ -0,0 +1,167 @@ +// ZoomMenu.tsx +import { useEffect, useMemo, useRef, useState } from "react"; + +type Props = { + valuePct: number; + onSelectPct: (pct: number) => void; + onDeferPct?: (pct: number) => void; + onPageFit?: () => void; // +}; + +const ZOOM_STEPS = [50, 75, 100, 125, 150, 200, 300, 400]; + +export default function ZoomMenu({ + valuePct, + onSelectPct, + onDeferPct, + onPageFit, +}: Props) { + const [open, setOpen] = useState(false); + const anchorRef = useRef(null); + const popRef = useRef(null); + + // nearest step helpers + const stepIndex = useMemo(() => { + let idx = 0, + bestDiff = Infinity; + ZOOM_STEPS.forEach((s, i) => { + const d = Math.abs(s - valuePct); + if (d < bestDiff) { + bestDiff = d; + idx = i; + } + }); + return idx; + }, [valuePct]); + + const dec = () => onSelectPct(ZOOM_STEPS[Math.max(0, stepIndex - 1)]); + const inc = () => + onSelectPct(ZOOM_STEPS[Math.min(ZOOM_STEPS.length - 1, stepIndex + 1)]); + + useEffect(() => { + if (!open) return; + const onDocClick = (e: MouseEvent) => { + if (!popRef.current || !anchorRef.current) return; + const t = e.target as Node; + if (!popRef.current.contains(t) && !anchorRef.current.contains(t)) + setOpen(false); + }; + const onEsc = (e: KeyboardEvent) => { + if (e.key === "Escape") setOpen(false); + }; + document.addEventListener("mousedown", onDocClick); + document.addEventListener("keydown", onEsc); + return () => { + document.removeEventListener("mousedown", onDocClick); + document.removeEventListener("keydown", onEsc); + }; + }, [open]); + + return ( +
    + {/* Trigger */} + + + {/* Popover */} + {open && ( +
    + {/* header: current % + -/+ */} +
    +
    + {valuePct}% +
    +
    + + +
    +
    + + {/* options */} +
      +
    • + +
    • + {ZOOM_STEPS.map((pct) => ( +
    • + +
    • + ))} +
    +
    + )} +
    + ); +} diff --git a/frontend/src/pages/Feedback/Feedback.tsx b/frontend/src/pages/Feedback/Feedback.tsx index d4e61e40..f181f1e9 100644 --- a/frontend/src/pages/Feedback/Feedback.tsx +++ b/frontend/src/pages/Feedback/Feedback.tsx @@ -6,11 +6,14 @@ function Feedback() {
    diff --git a/frontend/src/pages/Files/FileRow.tsx b/frontend/src/pages/Files/FileRow.tsx index e4325009..19665855 100644 --- a/frontend/src/pages/Files/FileRow.tsx +++ b/frontend/src/pages/Files/FileRow.tsx @@ -84,6 +84,18 @@ const FileRow: React.FC = ({ setIsEditing(false); }; + const formatUTCDate = (dateStr: string | null) => { + if (!dateStr) return "N/A"; + const formatter = new Intl.DateTimeFormat("en-US", { + timeZone: "UTC", + year: "numeric", + month: "numeric", + day: "numeric" + }); + const formattedDate = formatter.format(new Date(dateStr)); + return formattedDate; + } + return (
  • {isEditing ? ( @@ -187,7 +199,7 @@ const FileRow: React.FC = ({

    Publication Date:{" "} - {isEditing ? ( + {isEditing ? = ({ disabled={loading} placeholder="Publication Date" /> - ) : ( - file.publication_date - ? new Intl.DateTimeFormat("en-US", { - year: "numeric", - month: "2-digit", - day: "2-digit" - }).format(new Date(file.publication_date)) - : "N/A" - )} + : formatUTCDate(file.publication_date)}

  • diff --git a/frontend/src/pages/Files/ListOfFiles.tsx b/frontend/src/pages/Files/ListOfFiles.tsx index 1802e832..b53874bf 100644 --- a/frontend/src/pages/Files/ListOfFiles.tsx +++ b/frontend/src/pages/Files/ListOfFiles.tsx @@ -1,5 +1,5 @@ import React, { useState, useEffect } from "react"; -import axios from "axios"; +import { api } from "../../api/apiClient"; import Layout from "../Layout/Layout"; import FileRow from "./FileRow"; import Table from "../../components/Table/Table"; @@ -30,17 +30,17 @@ const ListOfFiles: React.FC<{ showTable?: boolean }> = ({ const [downloading, setDownloading] = useState(null); const [opening, setOpening] = useState(null); + const baseUrl = import.meta.env.VITE_API_BASE_URL; + useEffect(() => { const fetchFiles = async () => { try { - const baseUrl = import.meta.env.VITE_API_BASE_URL; - const response = await axios.get(`${baseUrl}/v1/api/uploadFile`, { - headers: { - Authorization: `JWT ${localStorage.getItem("access")}`, - }, - }); - if (Array.isArray(response.data)) { - setFiles(response.data); + const url = `${baseUrl}/v1/api/uploadFile`; + + const { data } = await api.get(url); + + if (Array.isArray(data)) { + setFiles(data); } } catch (error) { console.error("Error fetching files", error); @@ -50,7 +50,7 @@ const ListOfFiles: React.FC<{ showTable?: boolean }> = ({ }; fetchFiles(); - }, []); + }, [baseUrl]); const updateFileName = (guid: string, updatedFile: Partial) => { setFiles((prevFiles) => @@ -63,15 +63,9 @@ const ListOfFiles: React.FC<{ showTable?: boolean }> = ({ const handleDownload = async (guid: string, fileName: string) => { try { setDownloading(guid); - const baseUrl = import.meta.env.VITE_API_BASE_URL; - const response = await axios.get(`${baseUrl}/v1/api/uploadFile/${guid}`, { - headers: { - Authorization: `JWT ${localStorage.getItem("access")}`, - }, - responseType: "blob", - }); + const { data } = await api.get(`/v1/api/uploadFile/${guid}`, { responseType: 'blob' }); - const url = window.URL.createObjectURL(new Blob([response.data])); + const url = window.URL.createObjectURL(new Blob([data])); const link = document.createElement("a"); link.href = url; link.setAttribute("download", fileName); @@ -90,15 +84,9 @@ const ListOfFiles: React.FC<{ showTable?: boolean }> = ({ const handleOpen = async (guid: string) => { try { setOpening(guid); - const baseUrl = import.meta.env.VITE_API_BASE_URL; - const response = await axios.get(`${baseUrl}/v1/api/uploadFile/${guid}`, { - headers: { - Authorization: `JWT ${localStorage.getItem("access")}`, - }, - responseType: "arraybuffer", - }); + const { data } = await api.get(`/v1/api/uploadFile/${guid}`, { responseType: 'arraybuffer' }); - const file = new Blob([response.data], { type: 'application/pdf' }); + const file = new Blob([data], { type: 'application/pdf' }); const fileURL = window.URL.createObjectURL(file); window.open(fileURL); } catch (error) { @@ -118,17 +106,24 @@ const ListOfFiles: React.FC<{ showTable?: boolean }> = ({ { Header: 'Date Published', accessor: 'publication_date' }, { Header: '', accessor: 'file_open' }, ]; + + const formatUTCDate = (dateStr: string | null) => { + if (!dateStr) return "N/A"; + const formatter = new Intl.DateTimeFormat("en-US", { + timeZone: "UTC", + year: "numeric", + month: "numeric", + day: "numeric" + }); + const formattedDate = formatter.format(new Date(dateStr)); + return formattedDate; + } + const data = files.map((file) => ( { file_name: file?.title || file.file_name.replace(/\.[^/.]+$/, ""), publication: file?.publication || '', - publication_date: file.publication_date - ? new Intl.DateTimeFormat("en-US", { - year: "numeric", - month: "2-digit", - day: "2-digit" - }).format(new Date(file.publication_date)) - : "", + publication_date: formatUTCDate(file.publication_date), file_open:

    {data.paragraph[0]} + + balancerteam@codeforphilly.org + .

    diff --git a/frontend/src/pages/Help/Help.tsx b/frontend/src/pages/Help/Help.tsx index 53617032..b8ca3959 100644 --- a/frontend/src/pages/Help/Help.tsx +++ b/frontend/src/pages/Help/Help.tsx @@ -1,5 +1,6 @@ import { Link } from "react-router-dom"; import Layout from "../Layout/Layout"; +import Welcome from "../../components/Welcome/Welcome"; import HelpCard from "./HelpCard"; import { helpData } from "./helpData"; @@ -7,16 +8,10 @@ function Help() { return (
    -
    -

    - Help & Support -

    -
    -
    -

    - Get help and support for improving your Balancer experience. -

    -
    +
    {helpData.map((helpDataEntry, index) => { const card = ( diff --git a/frontend/src/pages/Layout/Layout_V2_Sidebar.tsx b/frontend/src/pages/Layout/Layout_V2_Sidebar.tsx index a54a5c2a..19163290 100644 --- a/frontend/src/pages/Layout/Layout_V2_Sidebar.tsx +++ b/frontend/src/pages/Layout/Layout_V2_Sidebar.tsx @@ -1,117 +1,131 @@ -import React, { useState, useEffect } from "react"; -import { Link, useNavigate, useLocation } from "react-router-dom"; -import { ChevronLeft, ChevronRight, File, Loader2 } from "lucide-react"; +import React, {useState, useEffect} from "react"; +import {Link, useNavigate, useLocation} from "react-router-dom"; +import {ChevronLeft, ChevronRight, File, Loader2} from "lucide-react"; import axios from "axios"; interface File { - id: number; - guid: string; - file_name: string; - title: string | null; + id: number; + guid: string; + file_name: string; + title: string | null; } const Sidebar: React.FC = () => { - const [sidebarCollapsed, setSidebarCollapsed] = useState(false); - const [files, setFiles] = useState([]); - const [isLoading, setIsLoading] = useState(true); - const navigate = useNavigate(); - const location = useLocation(); + const [sidebarCollapsed, setSidebarCollapsed] = useState(false); + const [files, setFiles] = useState([]); + const [isLoading, setIsLoading] = useState(true); + const navigate = useNavigate(); + const location = useLocation(); - const toggleSidebar = () => { - setSidebarCollapsed(!sidebarCollapsed); - }; + const toggleSidebar = () => { + setSidebarCollapsed(!sidebarCollapsed); + }; + + useEffect(() => { + const fetchFiles = async () => { + try { + const baseUrl = import.meta.env.VITE_API_BASE_URL; + const response = await axios.get(`${baseUrl}/v1/api/uploadFile`, { + headers: { + Authorization: `JWT ${localStorage.getItem("access")}`, + }, + }); + if (Array.isArray(response.data)) { + setFiles(response.data); + } + } catch (error) { + console.error("Error fetching files", error); + } finally { + setIsLoading(false); + } + }; + + fetchFiles(); + }, []); - useEffect(() => { - const fetchFiles = async () => { - try { - const baseUrl = import.meta.env.VITE_API_BASE_URL; - const response = await axios.get(`${baseUrl}/v1/api/uploadFile`, { - headers: { - Authorization: `JWT ${localStorage.getItem("access")}`, - }, - }); - if (Array.isArray(response.data)) { - setFiles(response.data); + const handleFileClick = (guid: string) => { + const params = new URLSearchParams(location.search); + const currentGuid = params.get("guid"); + + if (guid !== currentGuid) { + navigate(`/drugsummary?guid=${guid}&page=1`); + } else { + navigate( + `/drugsummary?guid=${guid}${params.has("page") ? `&page=${params.get("page")}` : ""}` + ); } - } catch (error) { - console.error("Error fetching files", error); - } finally { - setIsLoading(false); - } }; - fetchFiles(); - }, []); - - const handleFileClick = (guid: string) => { - const params = new URLSearchParams(location.search); - const currentGuid = params.get("guid"); + useEffect(() => { + const handleResize = () => { + if (window.innerWidth < 640) { + setSidebarCollapsed(true) + } else { + setSidebarCollapsed(false) + } + } - if (guid !== currentGuid) { - navigate(`/drugsummary?guid=${guid}&page=1`); - } else { - navigate( - `/drugsummary?guid=${guid}${params.has("page") ? `&page=${params.get("page")}` : ""}` - ); - } - }; + handleResize() + window.addEventListener('resize', handleResize) + return () => window.removeEventListener('resize', handleResize) + }, []) - return ( -
    -
    - {!sidebarCollapsed && ( - -

    - Balancer -

    - - )} - -
    - - {/* File List Section */} -
    - {isLoading ? ( -
    - -
    - ) : ( -
      - {files.map((file) => ( -
    • +
      + {!sidebarCollapsed && ( + +

      + Balancer +

      + + )} +
      + + {/* File List Section */} +
      + {isLoading ? ( +
      + +
      + ) : ( +
        + {files.map((file) => ( +
      • + -
      • - ))} -
      - )} -
      -
    - ); + )} + + + ))} + + )} +
    +
    + ); }; export default Sidebar; diff --git a/frontend/src/pages/PatientManager/NewPatientForm.tsx b/frontend/src/pages/PatientManager/NewPatientForm.tsx index 22c64a43..232ed296 100644 --- a/frontend/src/pages/PatientManager/NewPatientForm.tsx +++ b/frontend/src/pages/PatientManager/NewPatientForm.tsx @@ -1,752 +1,756 @@ -import {FormEvent, ChangeEvent, useEffect, useState} from "react"; -import {v4 as uuidv4} from "uuid"; -import {PatientInfo, Diagnosis} from "./PatientTypes"; -import {useMedications} from "../ListMeds/useMedications"; +import { FormEvent, ChangeEvent, useEffect, useState } from "react"; +import { v4 as uuidv4 } from "uuid"; +import { PatientInfo, Diagnosis } from "./PatientTypes"; +import { useMedications } from "../ListMeds/useMedications"; import ChipsInput from "../../components/ChipsInput/ChipsInput"; import Tooltip from "../../components/Tooltip"; -import {api} from "../../api/apiClient"; -import {useGlobalContext} from "../../contexts/GlobalContext.tsx"; +import { api } from "../../api/apiClient"; +import { useGlobalContext } from "../../contexts/GlobalContext.tsx"; // import ErrorMessage from "../../components/ErrorMessage"; // create new interface for refactor and to work with backend interface PatientInfoInterface { - id?: string; - state?: string; - otherDiagnosis?: string; - description?: string; - depression?: boolean; - hypomania?: boolean; - mania?: boolean; - currentMedications?: string; - priorMedications?: string; - possibleMedications?: { - first?: string; - second?: string; - third?: string; - }; - psychotic: boolean; - suicideHistory: boolean; - kidneyHistory: boolean; - liverHistory: boolean; - bloodPressureHistory: boolean; - weightGainConcern: boolean; - reproductive: boolean; - riskPregnancy: boolean; + id?: string; + state?: string; + otherDiagnosis?: string; + description?: string; + depression?: boolean; + hypomania?: boolean; + mania?: boolean; + currentMedications?: string; + priorMedications?: string; + possibleMedications?: { + first?: string; + second?: string; + third?: string; + }; + psychotic: boolean; + suicideHistory: boolean; + kidneyHistory: boolean; + liverHistory: boolean; + bloodPressureHistory: boolean; + weightGainConcern: boolean; + reproductive: boolean; + riskPregnancy: boolean; + anyPregnancy: boolean; } // TODO: refactor with Formik export interface NewPatientFormProps { - enterNewPatient: boolean; - setEnterNewPatient: React.Dispatch>; - isEditing: boolean; - setIsEditing: React.Dispatch>; - patientInfo: PatientInfo; - setPatientInfo: React.Dispatch>; - allPatientInfo: PatientInfo[]; - setAllPatientInfo: React.Dispatch>; + enterNewPatient: boolean; + setEnterNewPatient: React.Dispatch>; + isEditing: boolean; + setIsEditing: React.Dispatch>; + patientInfo: PatientInfo; + setPatientInfo: React.Dispatch>; + allPatientInfo: PatientInfo[]; + setAllPatientInfo: React.Dispatch>; } const NewPatientForm = ({ - isEditing, - setIsEditing, - setPatientInfo, - allPatientInfo, - setAllPatientInfo, - patientInfo, - enterNewPatient, - setEnterNewPatient, - }: NewPatientFormProps) => { - - - const [isPressed, setIsPressed] = useState(false); - const [isLoading, setIsLoading] = useState(false); - - const {resetFormValues, setShowSummary} = useGlobalContext(); - - useEffect(() => { - if (resetFormValues) { - handleClickNewPatient(); - } - }, [resetFormValues]); - - // const [errors, setErrors] = useState([]); - - const nullPatient = { - ID: "", - Diagnosis: Diagnosis.Manic, - OtherDiagnosis: "", - Description: "", - CurrentMedications: "", - PriorMedications: "", - Mania: "False", - Depression: "False", - Hypomania: "False", - Psychotic: "No", - Suicide: "No", - Kidney: "No", - Liver: "No", - weight_gain: "No", - blood_pressure: "No", - Reproductive: "No", - risk_pregnancy: "No", - any_pregnancy: "No", - }; - - const defaultPatientInfo: PatientInfo = isEditing - ? {...patientInfo} - : nullPatient; - - const [newPatientInfo, setNewPatientInfo] = - useState(defaultPatientInfo); - - const handleMouseDown = () => { - setIsPressed(true); - }; - - const handleMouseUp = () => { - setIsPressed(false); - }; - - useEffect(() => { - const patientInfoFromLocalStorage = JSON.parse( - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore - localStorage.getItem("patientInfos") - ); - - if (patientInfoFromLocalStorage) { - setAllPatientInfo(patientInfoFromLocalStorage); - } - }, []); - - const handleSubmit = async (e: FormEvent) => { - e.preventDefault(); - - // send payload to backend using the new interface - const payload: PatientInfoInterface = { - id: newPatientInfo.ID, - state: newPatientInfo.Diagnosis, - otherDiagnosis: newPatientInfo.OtherDiagnosis, - description: newPatientInfo.Description, - depression: newPatientInfo.Depression == "True", - hypomania: newPatientInfo.Hypomania == "True", - mania: newPatientInfo.Hypomania == "True", - currentMedications: newPatientInfo.CurrentMedications, - priorMedications: newPatientInfo.PriorMedications, - psychotic: newPatientInfo.Psychotic == "Yes", - suicideHistory: newPatientInfo.Suicide == "Yes", - kidneyHistory: newPatientInfo.Kidney == "Yes", - liverHistory: newPatientInfo.Liver == "Yes", - bloodPressureHistory: newPatientInfo.blood_pressure == "Yes", - weightGainConcern: newPatientInfo.weight_gain == "Yes", - reproductive: newPatientInfo.Reproductive == "Yes", - riskPregnancy: newPatientInfo.risk_pregnancy == "Yes", - }; - - setIsLoading(true); // Start loading - - try { - const baseUrl = import.meta.env.VITE_API_BASE_URL; - const url = `${baseUrl}/v1/api/get_med_recommend`; - - const {data} = await api.post(url, payload); - - const categorizedMedications = { - first: data.first ?? [], - second: data.second ?? [], - third: data.third ?? [], - }; - - let patientID = newPatientInfo.ID; - - // Ensure ID is never blank - if (!patientID) { - const generatedGuid = uuidv4(); - patientID = generatedGuid.substring(0, 5); - } - - const updatedPatientInfo = { - ...newPatientInfo, - ID: patientID, // Assign or preserve ID - PossibleMedications: categorizedMedications, - }; - - let updatedAllPatientInfo = [...allPatientInfo]; - - // Check if patient exists and update, otherwise add as new - const existingPatientIndex = updatedAllPatientInfo.findIndex( - (patient) => patient.ID === patientID - ); - - if (existingPatientIndex !== -1) { - updatedAllPatientInfo[existingPatientIndex] = updatedPatientInfo; - } else { - updatedAllPatientInfo = [updatedPatientInfo, ...allPatientInfo]; - } - - // Update state and localStorage - setPatientInfo(updatedPatientInfo); - setAllPatientInfo(updatedAllPatientInfo); - setShowSummary(true) - localStorage.setItem( - "patientInfos", - JSON.stringify(updatedAllPatientInfo) - ); - } catch (error) { - console.error("Error occurred:", error); - } finally { - setIsEditing(false); - setEnterNewPatient(false); - setIsLoading(false); - handleClickNewPatient(); - window.scrollTo({top: 0}); - } - }; - - const handleDiagnosisChange = (e: ChangeEvent) => { - const selectedValue = e.target.value as keyof typeof Diagnosis; - - setNewPatientInfo({ - ...newPatientInfo, - Diagnosis: Diagnosis[selectedValue], - OtherDiagnosis: "", // Reset the OtherDiagnosis value - }); - }; - - const handleClickSummary = () => { - setNewPatientInfo((prevPatientInfo) => ({ - ...prevPatientInfo, - ID: "", - Diagnosis: Diagnosis.Manic, - OtherDiagnosis: "", - Description: "", - CurrentMedications: "", - PriorMedications: "", - Mania: "False", - Depression: "False", - Hypomania: "False", - Psychotic: "No", - Suicide: "No", - Kidney: "No", - Liver: "No", - weight_gain: "No", - blood_pressure: "No", - Reproductive: "No", - risk_pregnancy: "No", - })); - - setEnterNewPatient(!enterNewPatient); - setIsEditing(false); - }; - - const handleClickNewPatient = () => { - setNewPatientInfo((prevPatientInfo) => ({ - ...prevPatientInfo, - ID: "", - Diagnosis: Diagnosis.Manic, - OtherDiagnosis: "", - Description: "", - CurrentMedications: "", - PriorMedications: "", - Mania: "False", - Depression: "False", - Hypomania: "False", - Psychotic: "No", - Suicide: "No", - Kidney: "No", - Liver: "No", - weight_gain: "No", - blood_pressure: "No", - Reproductive: "No", - risk_pregnancy: "No", - })); - }; + isEditing, + setIsEditing, + setPatientInfo, + allPatientInfo, + setAllPatientInfo, + patientInfo, + enterNewPatient, + setEnterNewPatient, +}: NewPatientFormProps) => { + const [isPressed, setIsPressed] = useState(false); + const [isLoading, setIsLoading] = useState(false); + + const { resetFormValues, setShowSummary } = useGlobalContext(); + + useEffect(() => { + if (resetFormValues) { + handleClickNewPatient(); + } + }, [resetFormValues]); + + // const [errors, setErrors] = useState([]); + + const nullPatient = { + ID: "", + Diagnosis: Diagnosis.Manic, + OtherDiagnosis: "", + Description: "", + CurrentMedications: "", + PriorMedications: "", + Mania: "False", + Depression: "False", + Hypomania: "False", + Psychotic: "No", + Suicide: "No", + Kidney: "No", + Liver: "No", + weight_gain: "No", + blood_pressure: "No", + Reproductive: "No", + risk_pregnancy: "No", + any_pregnancy: "No", + PossibleMedications: { + first: [], + second: [], + third: [], + }, + }; + + const defaultPatientInfo: PatientInfo = isEditing + ? { ...patientInfo } + : nullPatient; + + const [newPatientInfo, setNewPatientInfo] = + useState(defaultPatientInfo); + + const handleMouseDown = () => { + setIsPressed(true); + }; + + const handleMouseUp = () => { + setIsPressed(false); + }; + + useEffect(() => { + const patientInfoFromLocalStorage = JSON.parse( + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + localStorage.getItem("patientInfos") + ); - // const handleCheckboxChange = ( - // e: React.ChangeEvent, - // checkboxName: string - // ) => { - // const isChecked = e.target.checked; - // setNewPatientInfo((prevInfo) => ({ - // ...prevInfo, - // [checkboxName]: isChecked ? "True" : "False", // Update for both checked and unchecked - // })); - // }; - - const handleRadioChange = ( - e: React.ChangeEvent, - radioName: string - ) => { - const selectedValue = e.target.value; - setNewPatientInfo((prevInfo) => ({ - ...prevInfo, - [radioName]: selectedValue, - })); + if (patientInfoFromLocalStorage) { + setAllPatientInfo(patientInfoFromLocalStorage); + } + }, []); + + const handleSubmit = async (e: FormEvent) => { + e.preventDefault(); + + // send payload to backend using the new interface + const payload: PatientInfoInterface = { + id: newPatientInfo.ID, + state: newPatientInfo.Diagnosis, + otherDiagnosis: newPatientInfo.OtherDiagnosis, + description: newPatientInfo.Description, + depression: newPatientInfo.Depression == "True", + hypomania: newPatientInfo.Hypomania == "True", + mania: newPatientInfo.Hypomania == "True", + currentMedications: newPatientInfo.CurrentMedications, + priorMedications: newPatientInfo.PriorMedications, + psychotic: newPatientInfo.Psychotic == "Yes", + suicideHistory: newPatientInfo.Suicide == "Yes", + kidneyHistory: newPatientInfo.Kidney == "Yes", + liverHistory: newPatientInfo.Liver == "Yes", + bloodPressureHistory: newPatientInfo.blood_pressure == "Yes", + weightGainConcern: newPatientInfo.weight_gain == "Yes", + reproductive: newPatientInfo.Reproductive == "Yes", + riskPregnancy: newPatientInfo.risk_pregnancy == "Yes", + anyPregnancy: newPatientInfo.any_pregnancy == "Yes", }; - useEffect(() => { - if (isEditing) { - setNewPatientInfo(patientInfo); - } - }, [isEditing, patientInfo]); - - const {medications} = useMedications(); - - return ( -
    - {/* {search} */} -
    -
    - {!enterNewPatient && ( -
    -
    -

    - Click To Enter New Patient -

    - -
    - - - -
    -
    + setIsLoading(true); // Start loading + + try { + const baseUrl = import.meta.env.VITE_API_BASE_URL; + const url = `${baseUrl}/v1/api/get_med_recommend`; + + const { data } = await api.post(url, payload); + + const categorizedMedications = { + first: data.first ?? [], + second: data.second ?? [], + third: data.third ?? [], + }; + + let patientID = newPatientInfo.ID; + + // Ensure ID is never blank + if (!patientID) { + const generatedGuid = uuidv4(); + patientID = generatedGuid.substring(0, 5); + } + + const updatedPatientInfo = { + ...newPatientInfo, + ID: patientID, // Assign or preserve ID + PossibleMedications: categorizedMedications, + }; + + let updatedAllPatientInfo = [...allPatientInfo]; + + // Check if patient exists and update, otherwise add as new + const existingPatientIndex = updatedAllPatientInfo.findIndex( + (patient) => patient.ID === patientID + ); + + if (existingPatientIndex !== -1) { + updatedAllPatientInfo[existingPatientIndex] = updatedPatientInfo; + } else { + updatedAllPatientInfo = [updatedPatientInfo, ...allPatientInfo]; + } + + // Update state and localStorage + setPatientInfo(updatedPatientInfo); + setAllPatientInfo(updatedAllPatientInfo); + setShowSummary(true); + localStorage.setItem( + "patientInfos", + JSON.stringify(updatedAllPatientInfo) + ); + } catch (error) { + console.error("Error occurred:", error); + } finally { + setIsEditing(false); + setEnterNewPatient(false); + setIsLoading(false); + handleClickNewPatient(); + window.scrollTo({ top: 0 }); + } + }; + + const handleDiagnosisChange = (e: ChangeEvent) => { + const selectedValue = e.target.value as keyof typeof Diagnosis; + + setNewPatientInfo({ + ...newPatientInfo, + Diagnosis: Diagnosis[selectedValue], + OtherDiagnosis: "", // Reset the OtherDiagnosis value + }); + }; + + const handleClickSummary = () => { + setNewPatientInfo((prevPatientInfo) => ({ + ...prevPatientInfo, + ID: "", + Diagnosis: Diagnosis.Manic, + OtherDiagnosis: "", + Description: "", + CurrentMedications: "", + PriorMedications: "", + Mania: "False", + Depression: "False", + Hypomania: "False", + Psychotic: "No", + Suicide: "No", + Kidney: "No", + Liver: "No", + weight_gain: "No", + blood_pressure: "No", + Reproductive: "No", + risk_pregnancy: "No", + })); + + setEnterNewPatient(!enterNewPatient); + setIsEditing(false); + }; + + const handleClickNewPatient = () => { + setNewPatientInfo((prevPatientInfo) => ({ + ...prevPatientInfo, + ID: "", + Diagnosis: Diagnosis.Manic, + OtherDiagnosis: "", + Description: "", + CurrentMedications: "", + PriorMedications: "", + Mania: "False", + Depression: "False", + Hypomania: "False", + Psychotic: "No", + Suicide: "No", + Kidney: "No", + Liver: "No", + weight_gain: "No", + blood_pressure: "No", + Reproductive: "No", + risk_pregnancy: "No", + any_pregnancy: "No", + })); + }; + + // const handleCheckboxChange = ( + // e: React.ChangeEvent, + // checkboxName: string + // ) => { + // const isChecked = e.target.checked; + // setNewPatientInfo((prevInfo) => ({ + // ...prevInfo, + // [checkboxName]: isChecked ? "True" : "False", // Update for both checked and unchecked + // })); + // }; + + const handleRadioChange = ( + e: React.ChangeEvent, + radioName: string + ) => { + const selectedValue = e.target.value; + setNewPatientInfo((prevInfo) => ({ + ...prevInfo, + [radioName]: selectedValue, + })); + }; + + useEffect(() => { + if (isEditing) { + setNewPatientInfo(patientInfo); + } + }, [isEditing, patientInfo]); + + const { medications } = useMedications(); + + return ( +
    + {/* {search} */} +
    +
    + {!enterNewPatient && ( +
    +
    +

    + Click To Enter New Patient +

    + +
    + + + +
    +
    +
    + )} + {enterNewPatient && ( +
    +
    +
    +

    + {isEditing + ? `Edit Patient ${patientInfo.ID} Details` + : "Enter Patient Details"} + {/* Details */} +

    +
    + +
    + +
    +
    + +
    +
    + +
    +
    +
    +

    + Select patient characteristics +

    + {/*
    +
    + Currently psychotic +
    + +
    +
    + handleRadioChange(e, "Psychotic")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + + + handleRadioChange(e, "Psychotic")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> +
    - )} - {enterNewPatient && ( -
    -
    -
    -

    - {isEditing - ? `Edit Patient ${patientInfo.ID} Details` - : "Enter Patient Details"} - {/* Details */} -

    -
    - -
    - -
    -
    - -
    -
    - -
    - -
    -
    -

    - Select patient characteristics -

    -
    -
    - Currently psychotic -
    - -
    -
    - handleRadioChange(e, "Psychotic")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - - - handleRadioChange(e, "Psychotic")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - -
    -
    -
    -
    -
    - History of suicide attempt(s) - +
    +
    */} +
    +
    + History of suicide attempt(s) + info - -
    - -
    -
    - handleRadioChange(e, "Suicide")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - - - handleRadioChange(e, "Suicide")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - -
    -
    -
    -
    -
    - History or risk of kidney disease - + +
    + +
    +
    + handleRadioChange(e, "Suicide")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + + + handleRadioChange(e, "Suicide")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + +
    +
    +
    +
    +
    + History or risk of kidney disease + info - -
    -
    -
    - handleRadioChange(e, "Kidney")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - - - handleRadioChange(e, "Kidney")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - -
    -
    -
    -
    -
    - History or risk of liver disease - + +
    +
    +
    + handleRadioChange(e, "Kidney")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + + + handleRadioChange(e, "Kidney")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + +
    +
    +
    +
    +
    + History or risk of liver disease + info - -
    -
    -
    - handleRadioChange(e, "Liver")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - - - handleRadioChange(e, "Liver")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - -
    -
    -
    - -
    -
    - - History or risk of low blood pressure, or concern for - falls - + +
    +
    +
    + handleRadioChange(e, "Liver")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + + + handleRadioChange(e, "Liver")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + +
    +
    +
    + +
    +
    + + History or risk of low blood pressure, or concern for + falls + info - -
    - -
    -
    - handleRadioChange(e, "blood_pressure")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - - - handleRadioChange(e, "blood_pressure")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - -
    -
    -
    -
    -
    - Has weight gain concerns - + +
    + +
    +
    + handleRadioChange(e, "blood_pressure")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + + + handleRadioChange(e, "blood_pressure")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + +
    +
    +
    +
    +
    + Has weight gain concerns + info - -
    - -
    -
    - handleRadioChange(e, "weight_gain")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - - - handleRadioChange(e, "weight_gain")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - -
    -
    -
    -
    -
    - Wants to conceive in next 2 years - + +
    + +
    +
    + handleRadioChange(e, "weight_gain")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + + + handleRadioChange(e, "weight_gain")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + +
    +
    +
    +
    +
    + Wants to conceive in next 2 years + info - -
    - -
    -
    - handleRadioChange(e, "risk_pregnancy")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - - - handleRadioChange(e, "risk_pregnancy")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - -
    -
    -
    -
    -
    - Any possibility of becoming pregnant - + +
    + +
    +
    + handleRadioChange(e, "risk_pregnancy")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + + + handleRadioChange(e, "risk_pregnancy")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + +
    +
    +
    +
    +
    + Any possibility of becoming pregnant + info - -
    -
    -
    - handleRadioChange(e, "any_pregnancy")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - - handleRadioChange(e, "any_pregnancy")} - className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" - /> - -
    -
    -
    -
    - - {/*
    + + +
    +
    + handleRadioChange(e, "any_pregnancy")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + + handleRadioChange(e, "any_pregnancy")} + className="w-4 h-4 text-indigo-600 border-gray-300 focus:ring-indigo-600" + /> + +
    +
    + +
    + + {/*
    */} -
    -
    - +
    +
    + med.name)} + value={ + (newPatientInfo.PriorMedications && + newPatientInfo.PriorMedications?.split(",")) || + [] + } + placeholder="Start typing..." + label="" + onChange={(chips) => + setNewPatientInfo({ + ...newPatientInfo, + PriorMedications: chips.join(","), + }) + } + /> +
    +
    + +
    +
    + +
    +
    -
    - ); + ) : ( +

    + {isEditing ? "Edit Form" : "Submit"} +

    + )} + +
    + +
    + )} +
    +
    + + ); }; export default NewPatientForm; diff --git a/frontend/src/pages/PatientManager/PatientHistory.tsx b/frontend/src/pages/PatientManager/PatientHistory.tsx index 8dbf5e64..f8dc14a6 100644 --- a/frontend/src/pages/PatientManager/PatientHistory.tsx +++ b/frontend/src/pages/PatientManager/PatientHistory.tsx @@ -86,7 +86,7 @@ const PatientHistory = ({
    - Current State: + Current or Most recent state
    {item.Diagnosis} diff --git a/frontend/src/pages/PatientManager/PatientManager.tsx b/frontend/src/pages/PatientManager/PatientManager.tsx index a0b6e46a..f49dfa48 100644 --- a/frontend/src/pages/PatientManager/PatientManager.tsx +++ b/frontend/src/pages/PatientManager/PatientManager.tsx @@ -1,19 +1,45 @@ -import {useState} from "react"; -import {Link} from "react-router-dom"; +import { useState } from "react"; +import { Link } from "react-router-dom"; import NewPatientForm from "./NewPatientForm.tsx"; import PatientHistory from "./PatientHistory.tsx"; // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore import PatientSummary from "./PatientSummary.tsx"; -import {Diagnosis, PatientInfo} from "./PatientTypes.ts"; -import {copy} from "../../assets/index.js"; +import { Diagnosis, PatientInfo } from "./PatientTypes.ts"; +import { copy } from "../../assets/index.js"; import Welcome from "../../components/Welcome/Welcome.tsx"; -import {useGlobalContext} from "../../contexts/GlobalContext.tsx"; +import { useGlobalContext } from "../../contexts/GlobalContext.tsx"; const PatientManager = () => { + const [patientInfo, setPatientInfo] = useState({ + ID: "", + Diagnosis: Diagnosis.Manic, + OtherDiagnosis: "", + Description: "", + CurrentMedications: "", + PriorMedications: "", + Depression: "", + Hypomania: "", + Mania: "", + Psychotic: "", + Suicide: "", + Kidney: "", + Liver: "", + blood_pressure: "", + weight_gain: "", + Reproductive: "", + risk_pregnancy: "", + any_pregnancy: "", + PossibleMedications: { + first: [], + second: [], + third: [], + }, + }); - - const [patientInfo, setPatientInfo] = useState({ + const handlePatientDeleted = (deletedId: string) => { + if (patientInfo.ID === deletedId) { + setPatientInfo({ ID: "", Diagnosis: Diagnosis.Manic, OtherDiagnosis: "", @@ -31,106 +57,86 @@ const PatientManager = () => { weight_gain: "", Reproductive: "", risk_pregnancy: "", + any_pregnancy: "", PossibleMedications: { - first: "", - second: "", - third: "", + first: [], + second: [], + third: [], }, - any_pregnancy: "" - }); - - const handlePatientDeleted = (deletedId: string) => { - if (patientInfo.ID === deletedId) { - setPatientInfo({ - ID: "", - Diagnosis: Diagnosis.Manic, - OtherDiagnosis: "", - Description: "", - CurrentMedications: "", - PriorMedications: "", - Depression: "", - Hypomania: "", - Mania: "", - Psychotic: "", - Suicide: "", - Kidney: "", - Liver: "", - blood_pressure: "", - weight_gain: "", - Reproductive: "", - risk_pregnancy: "", - any_pregnancy: "" - }); + }); - setIsPatientDeleted(true); - } - }; + setIsPatientDeleted(true); + } + }; - const [allPatientInfo, setAllPatientInfo] = useState([]); - const [isPatientDeleted, setIsPatientDeleted] = useState(false); - const { - showSummary, - setShowSummary, - enterNewPatient, - setEnterNewPatient, - isEditing, - setIsEditing - } = useGlobalContext(); + const [allPatientInfo, setAllPatientInfo] = useState([]); + const [isPatientDeleted, setIsPatientDeleted] = useState(false); + const { + showSummary, + setShowSummary, + enterNewPatient, + setEnterNewPatient, + isEditing, + setIsEditing, + } = useGlobalContext(); - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore - // TODO: add error and loading state guards + // TODO: add error and loading state guards - const descriptionEl = ( -
    - Use our tool to get medication suggestions for bipolar disorder based on - patient characteristics.{" "} - - Read about where we get our data. - -
    - ); + const descriptionEl = ( +
    +

    Use our tool to explore medication options for bipolar disorder based on patient characteristics.

    +

    + + Read about where we get our data. + +

    +

    Balancer is an educational resource designed to support —never replace— the judgment of licensed U.S. clinicians.

    +

    Final prescribing decisions must always be made by the treating clinician.

    +
    + ); - return ( -
    - -
    - - - -
    -
    - ); + return ( +
    + +
    + + + +
    +
    + ); }; export default PatientManager; diff --git a/frontend/src/pages/PatientManager/PatientSummary.tsx b/frontend/src/pages/PatientManager/PatientSummary.tsx index 4e9c44b0..16966360 100644 --- a/frontend/src/pages/PatientManager/PatientSummary.tsx +++ b/frontend/src/pages/PatientManager/PatientSummary.tsx @@ -1,522 +1,757 @@ -import React, {useState, useEffect, useRef} from "react"; -import axios from "axios"; -import {PatientInfo} from "./PatientTypes"; +import React, { useState, useEffect, useRef } from "react"; +import { PatientInfo } from "./PatientTypes"; import Tooltip from "../../components/Tooltip"; import TypingAnimation from "../../components/Header/components/TypingAnimation.tsx"; -import {FaPencilAlt, FaPrint, FaMinus, FaRegThumbsDown} from "react-icons/fa"; +import { FaPencilAlt, FaPrint, FaMinus, FaRegThumbsDown } from "react-icons/fa"; import FeedbackForm from "../Feedback/FeedbackForm"; import Modal from "../../components/Modal/Modal"; -import {EllipsisVertical} from "lucide-react"; - +import { EllipsisVertical } from "lucide-react"; +import { fetchRiskDataWithSources } from "../../api/apiClient.ts"; interface PatientSummaryProps { - showSummary: boolean; - setShowSummary: (state: boolean) => void; - setEnterNewPatient: (isEnteringNewPatient: boolean) => void; - setIsEditing: (isEditing: boolean) => void; - patientInfo: PatientInfo; - isPatientDeleted: boolean; - setPatientInfo: React.Dispatch>; + showSummary: boolean; + setShowSummary: (state: boolean) => void; + setEnterNewPatient: (isEnteringNewPatient: boolean) => void; + setIsEditing: (isEditing: boolean) => void; + patientInfo: PatientInfo; + isPatientDeleted: boolean; + setPatientInfo: React.Dispatch>; } +type SourceItem = { + title: string | null; + publication: string | null; + text: string; + rule_type?: "INCLUDE" | "EXCLUDE" | "include" | "exclude"; + history_type?: string; + guid?: string | null; + page?: number | null; + link_url?: string | null; +}; type RiskData = { - benefits: string[]; - risks: string[]; + benefits: string[]; + risks: string[]; + source?: string; + sources?: SourceItem[]; +}; + +type MedicationWithSource = { + name: string; + source: "include" | "diagnosis"; }; +const truncate = (s = "", n = 220) => + s.length > n ? s.slice(0, n).trim() + "…" : s; + const MedicationItem = ({ - medication, - isClicked, - riskData, - loading, - onClick, - }: { - medication: string; - isClicked: boolean; - riskData: RiskData | null; - loading: boolean; - onClick: () => void; -}) => { - if (medication === "None") { - return ( -
  • -
    -
    - {medication} -
    -
    -
  • - ); - } + medication, + isClicked, + riskData, + loading, + onSourcesClick, + onBenefitsRisksClick, + activePanel, +}: { + medication: string; + source: string; + isClicked: boolean; + riskData: RiskData | null; + loading: boolean; + onSourcesClick: () => void; + onBenefitsRisksClick: () => void; + activePanel: "sources" | "benefits-risks" | null; +}) => { + if (medication === "None") { return ( -
    -
  • -
    -
    - {medication} - {loading && isClicked && ( -
    - -
    - )} -
    -
    -
    - +
  • +
    +
    + {medication} +
    +
    +
  • + ); + } + + return ( +
    +
  • +
    +
    + {medication} + {loading && isClicked && ( +
    + +
    + )} +
    +
    +
    + + Sources + +
    +
    + Benefits and risks -
    -
  • - - {isClicked && riskData && ( -
    -
    -
    -

    - Benefits: -

    -
      - {riskData.benefits.map((benefit, index) => ( -
    • - {benefit} -
    • - ))} -
    -
    -
    -

    - Risks: -

    -
      - {riskData.risks.map((risk, index) => ( -
    • - {risk} -
    • - ))} -
    -
    +
    + + + {isClicked && riskData && activePanel === "benefits-risks" && ( +
    +
    +
    +

    + Benefits: +

    +
      + {riskData.benefits.map((b, i) => ( +
    • + {b} +
    • + ))} +
    +
    +
    +

    + Risks: +

    +
      + {riskData.risks.map((r, i) => ( +
    • + {r} +
    • + ))} +
    +
    +
    +
    + )} + + {isClicked && riskData && activePanel === "sources" && ( +
    +
    +

    Sources

    +
    + + {riskData.sources?.length ? ( +
      + {riskData.sources.map((s, idx) => ( +
    • +
      + {s.title || "Untitled source"} + + {s.link_url && ( + + View PDF + + )} +
      + + {s.publication && ( +
      {s.publication}
      + )} + +

      + {truncate(s.text)} +

      + + {s.page && ( +
      + Page {s.page}
      -
    - )} + )} + + ))} + + ) : ( +

    + No sources available for this medication. +

    + )}
    - ); + )} +
    + ); }; const MedicationTier = ({ - title, - medications, - clickedMedication, - riskData, - loading, - onMedicationClick, - }: { - title: string; - medications: string[]; - clickedMedication: string | null; - riskData: RiskData | null; - loading: boolean; - onMedicationClick: (medication: string) => void; + title, + tier, + medications, + clickedMedication, + riskData, + loading, + onSourcesClick, + onBenefitsRisksClick, + activePanel, +}: { + title: string; + tier: string; + medications: MedicationWithSource[]; + clickedMedication: string | null; + riskData: RiskData | null; + loading: boolean; + onSourcesClick: (medication: MedicationWithSource) => void; + onBenefitsRisksClick: (medication: MedicationWithSource) => void; + activePanel: "sources" | "benefits-risks" | null; }) => ( - <> -
    - {title}: -
    -
      - {medications.map((medication) => ( - onMedicationClick(medication)} - /> - ))} -
    - + <> +
    + {title}: +
    + {medications.length ? ( +
      + {medications.map((medicationObj) => ( + onSourcesClick(medicationObj)} + onBenefitsRisksClick={() => onBenefitsRisksClick(medicationObj)} + activePanel={activePanel} + /> + ))} +
    + ) : ( + {`Patient's other health concerns may contraindicate typical ${tier} line options.`} + )} + ); const PatientSummary = ({ - showSummary, - setShowSummary, - setEnterNewPatient, - setIsEditing, - patientInfo, - isPatientDeleted, - }: PatientSummaryProps) => { - const [loading, setLoading] = useState(false); - const [riskData, setRiskData] = useState(null); - const [clickedMedication, setClickedMedication] = useState( - null - ); + showSummary, + setShowSummary, + setEnterNewPatient, + setIsEditing, + patientInfo, + isPatientDeleted, +}: PatientSummaryProps) => { + const [loading, setLoading] = useState(false); + const [riskData, setRiskData] = useState(null); + const [clickedMedication, setClickedMedication] = useState( + null + ); + const [activePanel, setActivePanel] = useState< + "sources" | "benefits-risks" | null + >(null); - const [isModalOpen, setIsModalOpen] = useState({status: false, id: ""}); + const [isModalOpen, setIsModalOpen] = useState({ status: false, id: "" }); - const handleOpenModal = (id: string, event: React.MouseEvent) => { - event.stopPropagation(); - setIsModalOpen({status: true, id: id}); - }; + const handleOpenModal = (id: string, event: React.MouseEvent) => { + event.stopPropagation(); + setIsModalOpen({ status: true, id: id }); + }; - const handleCloseModal = (event: React.MouseEvent) => { - event.stopPropagation(); - setIsModalOpen({status: false, id: ""}); - }; + const handleCloseModal = (event: React.MouseEvent) => { + event.stopPropagation(); + setIsModalOpen({ status: false, id: "" }); + }; - useEffect(() => { - if (isPatientDeleted) { - setShowSummary(true); - setLoading(false); - setRiskData(null); - setClickedMedication(null); - } - }, [isPatientDeleted]); - - useEffect(() => { - setRiskData(null); - setClickedMedication(null); - }, [patientInfo]); - - const handleClickSummary = () => { - setShowSummary(!showSummary); - }; + useEffect(() => { + if (isPatientDeleted) { + setShowSummary(true); + setLoading(false); + setRiskData(null); + setClickedMedication(null); + setActivePanel(null); + } + }, [isPatientDeleted, setShowSummary]); - const handleMedicationClick = async (medication: string) => { - if (clickedMedication === medication) { - setClickedMedication(null); - setRiskData(null); - return; - } - - setClickedMedication(medication); - setLoading(true); - try { - const baseUrl = import.meta.env.VITE_API_BASE_URL; - const response = await axios.post(`${baseUrl}/chatgpt/risk`, { - diagnosis: medication, - }); - setRiskData(response.data); - } catch (error) { - console.error("Error fetching data: ", error); - } finally { - setLoading(false); - } - }; + useEffect(() => { + setRiskData(null); + setClickedMedication(null); + setActivePanel(null); + }, [patientInfo]); - const handlePatientEdit = () => { - setIsEditing(true); - setEnterNewPatient(true); - handleClickSummary(); - console.log({editingPatient: patientInfo}); - }; + const handleClickSummary = () => { + setShowSummary(!showSummary); + }; + const handleSourcesClick = async (medicationObj: MedicationWithSource) => { + const { name: medication, source } = medicationObj; - const handlePatientPrint = (e: any) => { - e.preventDefault(); - window.print(); - }; + if (clickedMedication === medication && activePanel === "sources") { + setClickedMedication(null); + setActivePanel(null); + setRiskData(null); + return; + } + + setClickedMedication(medication); + setActivePanel("sources"); + setLoading(true); + + try { + // Map source based on patient's diagnosis + let apiSource: "include" | "diagnosis" | "diagnosis_depressed" = source; + if (source === "diagnosis" && patientInfo.Diagnosis === "Depressed") { + apiSource = "diagnosis_depressed"; + } + + const data = await fetchRiskDataWithSources(medication, apiSource); + console.log("Risk data received for", medication, "with source", apiSource, ":", data); + console.log("Sources array:", data.sources); + console.log("Sources length:", data.sources?.length); + setRiskData(data as RiskData); + } catch (error) { + console.error("Error fetching risk data: ", error); + setRiskData(null); + } finally { + setLoading(false); + } + }; - const [isMobileDropDownOpen, setIsMobileDropDownOpen] = useState(false) - const mobileMenuRef = useRef(null); + const handleBenefitsRisksClick = async ( + medicationObj: MedicationWithSource + ) => { + const { name: medication, source } = medicationObj; - const handleMobileDropDownMenu = () => { - setIsMobileDropDownOpen(!isMobileDropDownOpen) + if (clickedMedication === medication && activePanel === "benefits-risks") { + setClickedMedication(null); + setActivePanel(null); + setRiskData(null); + return; } - const MobileMenuItem = ({item, onClick}: { item: string, onClick: (e: React.MouseEvent) => void }) => { - const handleClick = (e: React.MouseEvent) => { - e.stopPropagation(); - onClick?.(e); - setIsMobileDropDownOpen(false) - } - return (
    - {item}
    ) + setClickedMedication(medication); + setActivePanel("benefits-risks"); + setLoading(true); + + try { + // Map source based on patient's diagnosis + let apiSource: "include" | "diagnosis" | "diagnosis_depressed" = source; + if (source === "diagnosis" && patientInfo.Diagnosis === "Depressed") { + apiSource = "diagnosis_depressed"; + } + + const data = await fetchRiskDataWithSources(medication, apiSource); + setRiskData(data as RiskData); + } catch (error) { + console.error("Error fetching risk data: ", error); + setRiskData(null); + } finally { + setLoading(false); } + }; + + const handlePatientEdit = () => { + setIsEditing(true); + setEnterNewPatient(true); + handleClickSummary(); + console.log({ editingPatient: patientInfo }); + }; + + const handlePatientPrint = (e: any) => { + e.preventDefault(); + window.print(); + }; - useEffect(() => { - const handleClickOutsideMenu = (event: MouseEvent) => { - if (mobileMenuRef.current && !mobileMenuRef.current.contains(event.target as Node)) { - setIsMobileDropDownOpen(false) - } - } - if (isMobileDropDownOpen) { - document.addEventListener('mousedown', handleClickOutsideMenu); - } - return () => { - document.removeEventListener('mousedown', handleClickOutsideMenu); - }; - }, [isMobileDropDownOpen]); - - const renderMedicationsSection = () => ( -
    -
    - Possible Medications: -
    -
    - {patientInfo.PossibleMedications && ( - <> - (null); + + const handleMobileDropDownMenu = () => { + setIsMobileDropDownOpen(!isMobileDropDownOpen); + }; + + const MobileMenuItem = ({ + item, + onClick, + }: { + item: string; + onClick: (e: React.MouseEvent) => void; + }) => { + const handleClick = (e: React.MouseEvent) => { + e.stopPropagation(); + onClick?.(e); + setIsMobileDropDownOpen(false); + }; + return ( +
    + {item} +
    + ); + }; + + useEffect(() => { + const handleClickOutsideMenu = (event: MouseEvent) => { + if ( + mobileMenuRef.current && + !mobileMenuRef.current.contains(event.target as Node) + ) { + setIsMobileDropDownOpen(false); + } + }; + if (isMobileDropDownOpen) { + document.addEventListener("mousedown", handleClickOutsideMenu); + } + return () => { + document.removeEventListener("mousedown", handleClickOutsideMenu); + }; + }, [isMobileDropDownOpen]); + const renderMedicationsSection = () => ( +
    +
    + Possible Medications: +
    +
    + {patientInfo.PossibleMedications && ( + <> + +
    + +
    +
    + +
    + + )} +
    +
    + ); + + return ( +
    +
    +
    + {patientInfo.ID && ( + <> +
    + {!showSummary && ( +
    +
    +

    + Patient Summary +

    +
    + + -
    - +
    +
    +
    + )} + {showSummary && ( +
    +
    +
    +

    + Summary +

    + {isMobileDropDownOpen ? ( +
    +
    + -
    -
    - + { + if (patientInfo.ID) { + handleOpenModal(patientInfo.ID, event); + } + }} /> +
    +
    +
    + +
    +
    + × +
    +
    - - )} - -
    - ); + ) : ( +
    +
    + +
    +
    + × +
    +
    + )} + +
    +
    +
    +

    + {" "} + {patientInfo.ID} +

    +

    + Patient details and application +

    +
    +
    +
    +
    +
    +
    + Current or Most recent state +
    +
    + {patientInfo.Diagnosis} +
    +
    +
    +
    +
    + Risk Assessment: +
    +
    +
      + {/* {patientInfo.Psychotic === "Yes" && ( +
    • + Currently psychotic +
    • + )} */} + {patientInfo.Suicide === "Yes" && ( +
    • + + Patient has a history of suicide attempts + + info + + +
    • + )} + {patientInfo.Kidney === "Yes" && ( +
    • + + Patient has a history or risk of kidney + disease + + info + + +
    • + )} + {patientInfo.Liver === "Yes" && ( +
    • + + Patient has a history or risk of liver disease + + info + + +
    • + )} + {patientInfo.blood_pressure === "Yes" && ( +
    • + + Patient has a history or risk of low blood + pressure, or concern for falls + + info + + +
    • + )} + {patientInfo.weight_gain === "Yes" && ( +
    • + + PatienthHas weight gain concerns + + info + + +
    • + )} - return ( -
      -
      -
      - {patientInfo.ID && ( - <> -
      - {!showSummary && ( -
      -
      -

      - Patient Summary -

      -
      - - - -
      -
      -
      + {patientInfo.risk_pregnancy === "Yes" && ( +
    • + + Patient wants to conceive in next 2 years + + info + + +
    • )} - {showSummary && ( -
      -
      -
      -

      - Summary -

      - {isMobileDropDownOpen ? ( -
      -
      - - - { - if (patientInfo.ID) { - handleOpenModal(patientInfo.ID, event); - } - }} - /> -
      -
      -
      - -
      -
      - × -
      -
      -
      - ) : ( -
      -
      - -
      -
      - × -
      -
      - )} - -
      -
      -
      -

      - {" "} - {patientInfo.ID} -

      -

      - Patient details and application -

      -
      -
      -
      -
      -
      -
      - Current State: -
      -
      - {patientInfo.Diagnosis} -
      -
      -
      -
      -
      - Risk Assessment: -
      -
      -
        - {/* Risk Assessment Items */} - {patientInfo.Psychotic === "Yes" && ( -
      • - Currently psychotic -
      • - )} - {patientInfo.Suicide === "Yes" && ( -
      • - - Patient has a history of suicide attempts - + {patientInfo.any_pregnancy === "Yes" && ( +
      • + + Patient has a possibility of becoming pregnant + info - -
      • - )} - {/* Add other risk assessment items similarly */} -
      -
      -
      -
      -
      -
      -
    +
    +
    +
    +
    +
    + -
    - {patientInfo.PriorMedications?.split(",").join( - ", " - )} -
    -
    -
    - {renderMedicationsSection()} -
    -
    -
    + + +
    + {patientInfo.PriorMedications?.split(",").join( + ", " )} +
    - - - - - )} +
    + {renderMedicationsSection()} + +
    +
    + )}
    - - ); + + + + + )} +
    + + ); }; export default PatientSummary; diff --git a/frontend/src/pages/PatientManager/PatientTypes.ts b/frontend/src/pages/PatientManager/PatientTypes.ts index 0216bf27..f7a9d13c 100644 --- a/frontend/src/pages/PatientManager/PatientTypes.ts +++ b/frontend/src/pages/PatientManager/PatientTypes.ts @@ -1,3 +1,8 @@ +export type MedicationWithSource = { + name: string; + source: "include" | "diagnosis"; +}; + export interface PatientInfo { ID?: string; Diagnosis?: Diagnosis; @@ -8,10 +13,10 @@ export interface PatientInfo { Mania?: string; CurrentMedications?: string; PriorMedications?: string; - PossibleMedications?: { - first?: string; - second?: string; - third?: string; + PossibleMedications: { + first: MedicationWithSource[]; + second: MedicationWithSource[]; + third: MedicationWithSource[]; }; Psychotic: string; Suicide: string; @@ -48,6 +53,5 @@ export interface NewPatientInfo { export enum Diagnosis { Manic = "Manic", Depressed = "Depressed", - Hypomanic = "Hypomanic", - Euthymic = "Euthymic", + Hypomanic = "Hypomanic" } diff --git a/frontend/src/pages/RulesManager/RulesManager.tsx b/frontend/src/pages/RulesManager/RulesManager.tsx index 8f2de0b1..be4980d4 100644 --- a/frontend/src/pages/RulesManager/RulesManager.tsx +++ b/frontend/src/pages/RulesManager/RulesManager.tsx @@ -11,15 +11,19 @@ interface Medication { risks: string; } +interface MedicationSource { + medication: Medication; + sources: any[]; +} + interface MedRule { id: number; rule_type: string; history_type: string; reason: string; label: string; - medications: Medication[]; - sources: any[]; explanation: string | null; + medication_sources: MedicationSource[]; } interface MedRulesResponse { @@ -96,8 +100,11 @@ function RulesManager() { return newSet; }); }; - - const renderMedicationDetails = (medication: Medication, rule: MedRule) => { + const renderMedicationDetails = ( + medication: Medication, + rule: MedRule, + sources: any[] + ) => { if (!medication) return null; const medKey = `${rule.id}-${medication.name}`; @@ -171,9 +178,9 @@ function RulesManager() {
    Sources:
    - {rule.sources && rule.sources.length > 0 ? ( + {sources && sources.length > 0 ? (
      - {rule.sources.map((source, index) => ( + {sources.map((source, index) => (
    • @@ -228,10 +235,14 @@ function RulesManager() { Medications:
      - {Array.isArray(rule.medications) && - rule.medications.length > 0 ? ( - rule.medications.map((med) => - renderMedicationDetails(med, rule) + {Array.isArray(rule.medication_sources) && + rule.medication_sources.length > 0 ? ( + rule.medication_sources.map((medSrc) => + renderMedicationDetails( + medSrc.medication, + rule, + medSrc.sources + ) ) ) : (

      diff --git a/frontend/src/services/actions/auth.tsx b/frontend/src/services/actions/auth.tsx index 3a29bc38..2573c223 100644 --- a/frontend/src/services/actions/auth.tsx +++ b/frontend/src/services/actions/auth.tsx @@ -169,6 +169,9 @@ export const login = }; export const logout = () => async (dispatch: AppDispatch) => { + // Clear chat conversation data on logout for security + sessionStorage.removeItem('currentConversation'); + dispatch({ type: LOGOUT, }); diff --git a/frontend/tailwind.config.js b/frontend/tailwind.config.js index 7ca214a8..bcc1e693 100644 --- a/frontend/tailwind.config.js +++ b/frontend/tailwind.config.js @@ -10,6 +10,10 @@ export default { lora: "'Lora', serif", 'quicksand': ['Quicksand', 'sans-serif'] }, + animation: { + 'pulse-bounce': 'pulse-bounce 2s infinite', // Adjust duration and iteration as needed + }, + plugins: [], }, }, plugins: [], diff --git a/server/Dockerfile b/server/Dockerfile index 4b11bb05..e410c505 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -1,5 +1,4 @@ -# pull official base image -FROM python:3.11.4-slim-buster +FROM python:3.11.4-slim-bullseye # set work directory WORKDIR /usr/src/server @@ -24,4 +23,4 @@ COPY . /usr/src/server RUN sed -i 's/\r$//' /usr/src/server/entrypoint.sh && chmod +x /usr/src/server/entrypoint.sh # run entrypoint.sh -ENTRYPOINT ["/usr/src/server/entrypoint.sh"] +ENTRYPOINT ["/usr/src/server/entrypoint.sh"] \ No newline at end of file diff --git a/server/Dockerfile.prod b/server/Dockerfile.prod index 6555ed6f..97b2c142 100644 --- a/server/Dockerfile.prod +++ b/server/Dockerfile.prod @@ -1,5 +1,5 @@ # pull official base image -FROM python:3.11.4-slim-buster +FROM python:3.11.4-slim-bullseye # set work directory diff --git a/server/api/admin.py b/server/api/admin.py index 930e86a4..4f1edbdd 100644 --- a/server/api/admin.py +++ b/server/api/admin.py @@ -4,8 +4,6 @@ from .models.authUser import UserAccount from .views.ai_settings.models import AI_Settings from .views.ai_promptStorage.models import AI_PromptStorage -from .views.ai_settings.models import AI_Settings -from .views.ai_promptStorage.models import AI_PromptStorage from .models.model_embeddings import Embeddings from .views.feedback.models import Feedback from .models.model_medRule import MedRule @@ -14,7 +12,6 @@ @admin.register(MedRule) class MedRuleAdmin(admin.ModelAdmin): list_display = ['rule_type', 'history_type', 'label'] - filter_horizontal = ['medications', 'sources'] search_fields = ['label', 'history_type', 'reason'] @@ -24,7 +21,7 @@ class MedicationAdmin(admin.ModelAdmin): @admin.register(Medication) -class MedicationAdmin(admin.ModelAdmin): +class MedicationAdmin(admin.ModelAdmin): # noqa: F811 list_display = ['name', 'benefits', 'risks'] diff --git a/server/api/migrations/0012_remove_medrule_medications_medrulesource_and_more.py b/server/api/migrations/0012_remove_medrule_medications_medrulesource_and_more.py new file mode 100644 index 00000000..281271ed --- /dev/null +++ b/server/api/migrations/0012_remove_medrule_medications_medrulesource_and_more.py @@ -0,0 +1,53 @@ +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('api', '0011_embeddings_publication_embeddings_publication_date_and_more'), + ] + + operations = [ + migrations.SeparateDatabaseAndState( + database_operations=[], # Don't create DB table + state_operations=[ + migrations.CreateModel( + name='MedRuleSource', + fields=[ + ('id', models.BigAutoField(auto_created=True, + primary_key=True, serialize=False, verbose_name='ID')), + ('medrule', models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to='api.medrule')), + ('embedding', models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to='api.embeddings')), + ('medication', models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to='api.medication')), + ], + options={ + 'db_table': 'api_medrule_sources', + 'unique_together': {('medrule', 'embedding', 'medication')}, + }, + ), + ] + ), + migrations.SeparateDatabaseAndState( + database_operations=[], + state_operations=[ + migrations.AlterField( + model_name='medrule', + name='sources', + field=models.ManyToManyField( + blank=True, + related_name='med_rules', + through='api.MedRuleSource', + to='api.embeddings' + ), + ), + migrations.RemoveField( + model_name='medrule', + name='medications', + ), + ] + ), + ] diff --git a/server/api/migrations/0013_medrule_medications.py b/server/api/migrations/0013_medrule_medications.py new file mode 100644 index 00000000..15dea0ed --- /dev/null +++ b/server/api/migrations/0013_medrule_medications.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.3 on 2025-07-11 03:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('api', '0012_remove_medrule_medications_medrulesource_and_more'), + ] + + operations = [ + migrations.SeparateDatabaseAndState( + database_operations=[], + state_operations=[ + migrations.AddField( + model_name='medrule', + name='medications', + field=models.ManyToManyField( + related_name='med_rules', + to='api.medication', + ), + ), + ] + ), + ] diff --git a/server/api/migrations/0014_alter_medrule_rule_type.py b/server/api/migrations/0014_alter_medrule_rule_type.py new file mode 100644 index 00000000..7d43fcd9 --- /dev/null +++ b/server/api/migrations/0014_alter_medrule_rule_type.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.3 on 2025-10-25 16:37 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('api', '0013_medrule_medications'), + ] + + operations = [ + migrations.AlterField( + model_name='medrule', + name='rule_type', + field=models.CharField(choices=[('INCLUDE', 'Include'), ('EXCLUDE', 'Exclude')], max_length=500), + ), + ] diff --git a/server/api/models.py b/server/api/models.py index 71a83623..35e0d648 100644 --- a/server/api/models.py +++ b/server/api/models.py @@ -1,3 +1,2 @@ -from django.db import models # Create your models here. diff --git a/server/api/models/__init__.py b/server/api/models/__init__.py index 9f00b1ee..2b73b64a 100644 --- a/server/api/models/__init__.py +++ b/server/api/models/__init__.py @@ -1 +1 @@ -from .authUser import UserAccount +from .authUser import UserAccount # noqa: F401 diff --git a/server/api/models/model_embeddings.py b/server/api/models/model_embeddings.py index 2a1d4032..ed61f2fb 100644 --- a/server/api/models/model_embeddings.py +++ b/server/api/models/model_embeddings.py @@ -1,5 +1,4 @@ from django.db import models -from django.conf import settings from pgvector.django import VectorField import uuid from ..views.uploadFile.models import UploadFile diff --git a/server/api/models/model_medRule.py b/server/api/models/model_medRule.py index 9212330e..272e0bb9 100644 --- a/server/api/models/model_medRule.py +++ b/server/api/models/model_medRule.py @@ -1,28 +1,45 @@ from django.db import models from ..views.listMeds.models import Medication -from django.db.models import CASCADE from ..models.model_embeddings import Embeddings class MedRule(models.Model): - rule_type = models.CharField( - max_length=7, - choices=[('INCLUDE', 'Include'), ('EXCLUDE', 'Exclude')] - ) + RULE_TYPE_CHOICES = [ + ('INCLUDE', 'Include'), + ('EXCLUDE', 'Exclude'), + ] + rule_type = models.CharField(max_length=500, choices=RULE_TYPE_CHOICES) history_type = models.CharField(max_length=255) reason = models.TextField(blank=True, null=True) label = models.CharField(max_length=255, blank=True, null=True) - medications = models.ManyToManyField(Medication, related_name='med_rules') + explanation = models.TextField(blank=True, null=True) + medications = models.ManyToManyField( + Medication, + related_name='med_rules' + ) sources = models.ManyToManyField( Embeddings, related_name='med_rules', - blank=True + blank=True, + through='api.MedRuleSource' ) - explanation = models.TextField(blank=True, null=True) class Meta: db_table = 'api_medrule' - unique_together = ['rule_type', 'history_type'] + unique_together = [('rule_type', 'history_type')] + + def __str__(self): + return f"{self.rule_type} - {self.label or 'Unnamed'}" + + +class MedRuleSource(models.Model): + medrule = models.ForeignKey('api.MedRule', on_delete=models.CASCADE) + embedding = models.ForeignKey('api.Embeddings', on_delete=models.CASCADE) + medication = models.ForeignKey('api.Medication', on_delete=models.CASCADE) + + class Meta: + db_table = 'api_medrule_sources' + unique_together = [('medrule', 'embedding', 'medication')] def __str__(self): - return f"{self.rule_type} - {self.label}" + return f"Rule {self.medrule_id} | Embedding {self.embedding_id} | Medication {self.medication_id}" diff --git a/server/api/services/conversions_services.py b/server/api/services/conversions_services.py index d134ff49..71931f17 100644 --- a/server/api/services/conversions_services.py +++ b/server/api/services/conversions_services.py @@ -2,6 +2,23 @@ def convert_uuids(data): + """ + Recursively convert UUID objects to strings in nested data structures. + + Traverses dictionaries, lists, and other data structures to find UUID objects + and converts them to their string representation for serialization. + + Parameters + ---------- + data : any + The data structure to process (dict, list, UUID, or any other type) + + Returns + ------- + any + The data structure with all UUID objects converted to strings. + Structure and types are preserved except for UUID -> str conversion. + """ if isinstance(data, dict): return {key: convert_uuids(value) for key, value in data.items()} elif isinstance(data, list): diff --git a/server/api/services/embedding_services.py b/server/api/services/embedding_services.py index 5aacab38..6fd34d35 100644 --- a/server/api/services/embedding_services.py +++ b/server/api/services/embedding_services.py @@ -1,29 +1,63 @@ # services/embedding_services.py + +from pgvector.django import L2Distance + from .sentencetTransformer_model import TransformerModel + # Adjust import path as needed from ..models.model_embeddings import Embeddings -from pgvector.django import L2Distance -def get_closest_embeddings(user, message_data, document_name=None, guid=None, num_results=10): +def get_closest_embeddings( + user, message_data, document_name=None, guid=None, num_results=10 +): + """ + Find the closest embeddings to a given message for a specific user. + + Parameters + ---------- + user : User + The user whose uploaded documents will be searched + message_data : str + The input message to find similar embeddings for + document_name : str, optional + Filter results to a specific document name + guid : str, optional + Filter results to a specific document GUID (takes precedence over document_name) + num_results : int, default 10 + Maximum number of results to return + + Returns + ------- + list[dict] + List of dictionaries containing embedding results with keys: + - name: document name + - text: embedded text content + - page_number: page number in source document + - chunk_number: chunk number within the document + - distance: L2 distance from query embedding + - file_id: GUID of the source file + """ + # transformerModel = TransformerModel.get_instance().model embedding_message = transformerModel.encode(message_data) # Start building the query based on the message's embedding - closest_embeddings_query = Embeddings.objects.filter( - upload_file__uploaded_by=user - ).annotate( - distance=L2Distance( - 'embedding_sentence_transformers', embedding_message) - ).order_by('distance') + closest_embeddings_query = ( + Embeddings.objects.filter(upload_file__uploaded_by=user) + .annotate( + distance=L2Distance("embedding_sentence_transformers", embedding_message) + ) + .order_by("distance") + ) # Filter by GUID if provided, otherwise filter by document name if provided if guid: closest_embeddings_query = closest_embeddings_query.filter( - upload_file__guid=guid) + upload_file__guid=guid + ) elif document_name: - closest_embeddings_query = closest_embeddings_query.filter( - name=document_name) + closest_embeddings_query = closest_embeddings_query.filter(name=document_name) # Slice the results to limit to num_results closest_embeddings_query = closest_embeddings_query[:num_results] diff --git a/server/api/services/llm_services.py b/server/api/services/llm_services.py index 89eb2659..69df8172 100644 --- a/server/api/services/llm_services.py +++ b/server/api/services/llm_services.py @@ -7,193 +7,125 @@ import logging from abc import ABC, abstractmethod -import anthropic -import openai +from openai import AsyncOpenAI class BaseModelHandler(ABC): @abstractmethod - def handle_request( + async def handle_request( self, query: str, context: str ) -> tuple[str, dict[str, int], dict[str, float], float]: pass -class ClaudeHaiku35CitationsHandler(BaseModelHandler): - MODEL = "claude-3-5-haiku-20241022" - # Model Pricing: https://docs.anthropic.com/en/docs/about-claude/pricing#model-pricing - PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.80, "output": 4.00} +# LLM Pricing Calculator: https://www.llm-prices.com/ +# TODO: Add support for more models and their pricing + +# Anthropic Model Pricing: https://docs.anthropic.com/en/docs/about-claude/pricing#model-pricing + + +class GPT4OMiniHandler(BaseModelHandler): + MODEL = "gpt-4o-mini" + # TODO: Get the latest model pricing from OpenAI's API or documentation + # Model Pricing: https://platform.openai.com/docs/pricing + PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.15, "output": 0.60} def __init__(self) -> None: - self.client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) + self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) - def handle_request( + async def handle_request( self, query: str, context: str ) -> tuple[str, dict[str, int], dict[str, float], float]: """ - Handles the request to the Claude Haiku 3.5 model with citations enabled + Handles the request to the GPT-4o Mini model Args: query: The user query to be processed - context: The context or document content to be used for citations + context: The context or document content to be used """ - start_time = time.time() # TODO: Add error handling for API requests and invalid responses - message = self.client.messages.create( - model=self.MODEL, - max_tokens=1024, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": query}, - { - "type": "document", - "source": {"type": "content", "content": context}, - "citations": {"enabled": True}, - }, - ], - } - ], + response = await self.client.responses.create( + model=self.MODEL, instructions=query, input=context, temperature=0.0 ) duration = time.time() - start_time - # Response Structure: https://docs.anthropic.com/en/docs/build-with-claude/citations#response-structure - - text = [] - cited_text = [] - for content in message.to_dict()["content"]: - text.append(content["text"]) - if "citations" in content.keys(): - text.append( - " ".join( - [ - f"<{citation['start_block_index']} - {citation['end_block_index']}>" - for citation in content["citations"] - ] - ) - ) - cited_text.append( - " ".join( - [ - f"<{citation['start_block_index']} - {citation['end_block_index']}> {citation['cited_text']}" - for citation in content["citations"] - ] - ) - ) - - full_text = " ".join(text) - return ( - full_text, - message.usage, + response.output_text, + response.usage, self.PRICING_DOLLARS_PER_MILLION_TOKENS, duration, ) -class ClaudeHaiku3Handler(BaseModelHandler): - MODEL = "claude-3-haiku-20240307" - # Model Pricing: https://docs.anthropic.com/en/docs/about-claude/pricing#model-pricing - PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.25, "output": 1.25} +class GPT41NanoHandler(BaseModelHandler): + MODEL = "gpt-4.1-nano" - def __init__(self) -> None: - self.client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) + # Model Pricing: https://platform.openai.com/docs/pricing + PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.10, "output": 0.40} - def handle_request( - self, query: str, context: str - ) -> tuple[str, dict[str, int], dict[str, float], float]: - """ - Handles the request to the Claude Haiku 3 model with citations disabled + # GPT 4.1 Prompting Guide: https://cookbook.openai.com/examples/gpt4-1_prompting_guide - Args: - query: The user query to be processed - context: The context or document content to be used + # Long context performance can degrade as more items are required to be retrieved, + # or perform complex reasoning that requires knowledge of the state of the entire context - """ + # - start_time = time.time() - # TODO: Add error handling for API requests and invalid responses - message = self.client.messages.create( - model=self.MODEL, - max_tokens=1024, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": query}, - { - "type": "document", - "source": {"type": "content", "content": context}, - "citations": {"enabled": False}, - }, - ], - } - ], - ) - duration = time.time() - start_time + INSTRUCTIONS = """ + + # Role and Objective + + - You are a seasoned physician or medical professional who is developing a bipolar disorder treatment algorithim - text = [] - for content in message.to_dict()["content"]: - text.append(content["text"]) + - You are extracting bipolar medication decision points from a research paper that is chunked into multiple parts each labeled with an ID - full_text = " ".join(text) + # Instructions - return ( - full_text, - message.usage, - self.PRICING_DOLLARS_PER_MILLION_TOKENS, - duration, - ) + - Identify decision points for bipolar medications + - For each decision point you find, return a JSON object using the following format: -class GPT4OMiniHandler(BaseModelHandler): - MODEL = "gpt-4o-mini" - # Model Pricing: https://platform.openai.com/docs/pricing - PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.15, "output": 0.60} + { + "criterion": "", + "decision": "INCLUDE" or "EXCLUDE", + "medications": ["", "", ...], + "reason": "", + "sources": [""] + } - def __init__(self) -> None: - self.client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) - def handle_request( - self, query: str, context: str - ) -> tuple[str, dict[str, int], dict[str, float], float]: - """ - Handles the request to the GPT-4o Mini model + - Only extract bipolar medication decision points that are explicitly stated or strongly implied in the context and never rely on your own knowledge - Args: - query: The user query to be processed - context: The context or document content to be used + # Output Format - """ - start_time = time.time() - # TODO: Add error handling for API requests and invalid responses - response = self.client.responses.create( - model=self.MODEL, - instructions=query, - input=context, - ) - duration = time.time() - start_time + - Return the extracted bipolar medication decision points as a JSON array and if no decision points are found in the context return an empty array - return ( - response.output_text, - response.usage, - self.PRICING_DOLLARS_PER_MILLION_TOKENS, - duration, - ) + # Example + [ + { + "criterion": "History of suicide attempts", + "decision": "INCLUDE", + "medications": ["Lithium"], + "reason": "Lithium is the only medication on the market that has been proven to reduce suicidality in patients with bipolar disorder", + "sources": ["ID-0"] + }, + { + "criterion": "Weight gain concerns", + "decision": "EXCLUDE", + "medications": ["Quetiapine", "Aripiprazole", "Olanzapine", "Risperidone"], + "reason": "Seroquel, Risperdal, Abilify, and Zyprexa are known for causing weight gain", + "sources": ["ID-0", "ID-1", "ID-2"] + } + ] -class GPT41NanoHandler(BaseModelHandler): - MODEL = "gpt-4.1-nano" - # Model Pricing: https://platform.openai.com/docs/pricing - PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.10, "output": 0.40} + """ def __init__(self) -> None: - self.client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) - def handle_request( + async def handle_request( self, query: str, context: str ) -> tuple[str, dict[str, int], dict[str, float], float]: """ @@ -204,12 +136,16 @@ def handle_request( context: The context or document content to be used """ + + # If no query is provided, use the default instructions + if not query: + query = self.INSTRUCTIONS + start_time = time.time() # TODO: Add error handling for API requests and invalid responses - response = self.client.responses.create( - model=self.MODEL, - instructions=query, - input=context, + + response = await self.client.responses.create( + model=self.MODEL, instructions=query, input=context, temperature=0.0 ) duration = time.time() - start_time @@ -222,9 +158,10 @@ def handle_request( class ModelFactory: + # TODO: Define structured fields to extract from unstructured input data + # https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses&example=structured-data#examples + HANDLERS = { - "CLAUDE_HAIKU_3_5_CITATIONS": ClaudeHaiku35CitationsHandler, - "CLAUDE_HAIKU_3": ClaudeHaiku3Handler, "GPT_4O_MINI": GPT4OMiniHandler, "GPT_41_NANO": GPT41NanoHandler, } diff --git a/server/api/tests.py b/server/api/tests.py index f6eac93b..baf59b4e 100644 --- a/server/api/tests.py +++ b/server/api/tests.py @@ -1,4 +1,3 @@ -from django.test import TestCase import unittest from .services.tools.tools import validate_tool_inputs, execute_tool diff --git a/server/api/views/ai_promptStorage/serializers.py b/server/api/views/ai_promptStorage/serializers.py index ebbd2d4b..0358f8a5 100644 --- a/server/api/views/ai_promptStorage/serializers.py +++ b/server/api/views/ai_promptStorage/serializers.py @@ -1,6 +1,5 @@ from rest_framework import serializers from .models import AI_PromptStorage -from django.conf import settings class AI_PromptStorageSerializer(serializers.ModelSerializer): diff --git a/server/api/views/ai_promptStorage/views.py b/server/api/views/ai_promptStorage/views.py index e49c21c1..7354feb3 100644 --- a/server/api/views/ai_promptStorage/views.py +++ b/server/api/views/ai_promptStorage/views.py @@ -1,17 +1,15 @@ from rest_framework import status -from rest_framework.decorators import api_view, permission_classes -from rest_framework.permissions import IsAuthenticated +from rest_framework.decorators import api_view from rest_framework.response import Response from .models import AI_PromptStorage from .serializers import AI_PromptStorageSerializer -from django.views.decorators.csrf import csrf_exempt @api_view(['POST']) # @permission_classes([IsAuthenticated]) def store_prompt(request): print(request.user) - data = request.data.copy() + data = request.data.copy() # noqa: F841 print(request.user) serializer = AI_PromptStorageSerializer( data=request.data, context={'request': request}) diff --git a/server/api/views/ai_settings/urls.py b/server/api/views/ai_settings/urls.py index abe3c990..2266ed6e 100644 --- a/server/api/views/ai_settings/urls.py +++ b/server/api/views/ai_settings/urls.py @@ -1,5 +1,4 @@ -from django.urls import path, include -from rest_framework.routers import DefaultRouter +from django.urls import path from api.views.ai_settings import views urlpatterns = [ diff --git a/server/api/views/assistant/urls.py b/server/api/views/assistant/urls.py new file mode 100644 index 00000000..4c68f952 --- /dev/null +++ b/server/api/views/assistant/urls.py @@ -0,0 +1,5 @@ +from django.urls import path + +from .views import Assistant + +urlpatterns = [path("v1/api/assistant", Assistant.as_view(), name="assistant")] diff --git a/server/api/views/assistant/views.py b/server/api/views/assistant/views.py new file mode 100644 index 00000000..32089c58 --- /dev/null +++ b/server/api/views/assistant/views.py @@ -0,0 +1,325 @@ +import os +import json +import logging +import time +from typing import Callable + +from rest_framework.views import APIView +from rest_framework.response import Response +from rest_framework import status +from rest_framework.permissions import IsAuthenticated +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_exempt + +from openai import OpenAI + +from ...services.embedding_services import get_closest_embeddings +from ...services.conversions_services import convert_uuids + +# Configure logging +logger = logging.getLogger(__name__) + +GPT_5_NANO_PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.05, "output": 0.40} + + +def calculate_cost_metrics(token_usage: dict, pricing: dict) -> dict: + """ + Calculate cost metrics based on token usage and pricing + + Args: + token_usage: Dictionary containing input_tokens and output_tokens + pricing: Dictionary containing input and output pricing per million tokens + + Returns: + Dictionary containing input_cost, output_cost, and total_cost in USD + """ + TOKENS_PER_MILLION = 1_000_000 + + # Pricing is in dollars per million tokens + input_cost_dollars = (pricing["input"] / TOKENS_PER_MILLION) * token_usage.get( + "input_tokens", 0 + ) + output_cost_dollars = (pricing["output"] / TOKENS_PER_MILLION) * token_usage.get( + "output_tokens", 0 + ) + total_cost_dollars = input_cost_dollars + output_cost_dollars + + return { + "input_cost": input_cost_dollars, + "output_cost": output_cost_dollars, + "total_cost": total_cost_dollars, + } + + +# Open AI Cookbook: Handling Function Calls with Reasoning Models +# https://cookbook.openai.com/examples/reasoning_function_calls +def invoke_functions_from_response( + response, tool_mapping: dict[str, Callable] +) -> list[dict]: + """Extract all function calls from the response, look up the corresponding tool function(s) and execute them. + (This would be a good place to handle asynchroneous tool calls, or ones that take a while to execute.) + This returns a list of messages to be added to the conversation history. + + Parameters + ---------- + response : OpenAI Response + The response object from OpenAI containing output items that may include function calls + tool_mapping : dict[str, Callable] + A dictionary mapping function names (as strings) to their corresponding Python functions. + Keys should match the function names defined in the tools schema. + + Returns + ------- + list[dict] + List of function call output messages formatted for the OpenAI conversation. + Each message contains: + - type: "function_call_output" + - call_id: The unique identifier for the function call + - output: The result returned by the executed function (string or error message) + """ + intermediate_messages = [] + for response_item in response.output: + if response_item.type == "function_call": + target_tool = tool_mapping.get(response_item.name) + if target_tool: + try: + arguments = json.loads(response_item.arguments) + logger.info( + f"Invoking tool: {response_item.name} with arguments: {arguments}" + ) + tool_output = target_tool(**arguments) + logger.info(f"Tool {response_item.name} completed successfully") + except Exception as e: + msg = f"Error executing function call: {response_item.name}: {e}" + tool_output = msg + logger.error(msg, exc_info=True) + else: + msg = f"ERROR - No tool registered for function call: {response_item.name}" + tool_output = msg + logger.error(msg) + intermediate_messages.append( + { + "type": "function_call_output", + "call_id": response_item.call_id, + "output": tool_output, + } + ) + elif response_item.type == "reasoning": + logger.info(f"Reasoning step: {response_item.summary}") + return intermediate_messages + + +@method_decorator(csrf_exempt, name="dispatch") +class Assistant(APIView): + permission_classes = [IsAuthenticated] + + def post(self, request): + try: + user = request.user + + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + TOOL_DESCRIPTION = """ + Search the user's uploaded documents for information relevant to answering their question. + Call this function when you need to find specific information from the user's documents + to provide an accurate, citation-backed response. Always search before answering questions + about document content. + """ + + TOOL_PROPERTY_DESCRIPTION = """ + A specific search query to find relevant information in the user's documents. + Use keywords, phrases, or questions related to what the user is asking about. + Be specific rather than generic - use terms that would appear in the relevant documents. + """ + + tools = [ + { + "type": "function", + "name": "search_documents", + "description": TOOL_DESCRIPTION, + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": TOOL_PROPERTY_DESCRIPTION, + } + }, + "required": ["query"], + }, + } + ] + + def search_documents(query: str, user=user) -> str: + """ + Search through user's uploaded documents using semantic similarity. + + This function performs vector similarity search against the user's document corpus + and returns formatted results with context information for the LLM to use. + + Parameters + ---------- + query : str + The search query string + user : User + The authenticated user whose documents to search + + Returns + ------- + str + Formatted search results containing document excerpts with metadata + + Raises + ------ + Exception + If embedding search fails + """ + + try: + embeddings_results = get_closest_embeddings( + user=user, message_data=query.strip() + ) + embeddings_results = convert_uuids(embeddings_results) + + if not embeddings_results: + return "No relevant documents found for your query. Please try different search terms or upload documents first." + + # Format results with clear structure and metadata + prompt_texts = [ + f"[Document {i + 1} - File: {obj['file_id']}, Name: {obj['name']}, Page: {obj['page_number']}, Chunk: {obj['chunk_number']}, Similarity: {1 - obj['distance']:.3f}]\n{obj['text']}\n[End Document {i + 1}]" + for i, obj in enumerate(embeddings_results) + ] + + return "\n\n".join(prompt_texts) + + except Exception as e: + return f"Error searching documents: {str(e)}. Please try again if the issue persists." + + INSTRUCTIONS = """ + You are an AI assistant that helps users find and understand information about bipolar disorder + from their uploaded bipolar disorder research documents using semantic search. + + SEMANTIC SEARCH STRATEGY: + - Always perform semantic search using the search_documents function when users ask questions + - Use conceptually related terms and synonyms, not just exact keyword matches + - Search for the meaning and context of the user's question, not just literal words + - Consider medical terminology, lay terms, and related conditions when searching + + FUNCTION USAGE: + - When a user asks about information that might be in their documents ALWAYS use the search_documents function first + - Perform semantic searches using concepts, symptoms, treatments, and related terms from the user's question + - Only provide answers based on information found through document searches + + RESPONSE FORMAT: + After gathering information through semantic searches, provide responses that: + 1. Answer the user's question directly using only the found information + 2. Structure responses with clear sections and paragraphs + 3. Include citations using this exact format: ***[Name {name}, Page {page_number}]*** + 4. Only cite information that directly supports your statements + + If no relevant information is found in the documents, clearly state that the information is not available in the uploaded documents. + """ + + MODEL_DEFAULTS = { + "instructions": INSTRUCTIONS, + "model": "gpt-5-nano", # 400,000 token context window + # A summary of the reasoning performed by the model. This can be useful for debugging and understanding the model's reasoning process. + "reasoning": {"effort": "low", "summary": None}, + "tools": tools, + } + + # We fetch a response and then kick off a loop to handle the response + + message = request.data.get("message", None) + previous_response_id = request.data.get("previous_response_id", None) + + # Track total duration and cost metrics + start_time = time.time() + total_token_usage = {"input_tokens": 0, "output_tokens": 0} + + if not previous_response_id: + response = client.responses.create( + input=[ + {"type": "message", "role": "user", "content": str(message)} + ], + **MODEL_DEFAULTS, + ) + else: + response = client.responses.create( + input=[ + {"type": "message", "role": "user", "content": str(message)} + ], + previous_response_id=str(previous_response_id), + **MODEL_DEFAULTS, + ) + + # Accumulate token usage from initial response + if hasattr(response, "usage"): + total_token_usage["input_tokens"] += getattr( + response.usage, "input_tokens", 0 + ) + total_token_usage["output_tokens"] += getattr( + response.usage, "output_tokens", 0 + ) + + # Open AI Cookbook: Handling Function Calls with Reasoning Models + # https://cookbook.openai.com/examples/reasoning_function_calls + while True: + # Mapping of the tool names we tell the model about and the functions that implement them + function_responses = invoke_functions_from_response( + response, tool_mapping={"search_documents": search_documents} + ) + if len(function_responses) == 0: # We're done reasoning + logger.info("Reasoning completed") + final_response_output_text = response.output_text + final_response_id = response.id + logger.info(f"Final response: {final_response_output_text}") + break + else: + logger.info("More reasoning required, continuing...") + response = client.responses.create( + input=function_responses, + previous_response_id=response.id, + **MODEL_DEFAULTS, + ) + # Accumulate token usage from reasoning iterations + if hasattr(response, "usage"): + total_token_usage["input_tokens"] += getattr( + response.usage, "input_tokens", 0 + ) + total_token_usage["output_tokens"] += getattr( + response.usage, "output_tokens", 0 + ) + + # Calculate total duration and cost metrics + total_duration = time.time() - start_time + cost_metrics = calculate_cost_metrics( + total_token_usage, GPT_5_NANO_PRICING_DOLLARS_PER_MILLION_TOKENS + ) + + # Log cost and duration metrics + logger.info( + f"Request completed: " + f"Duration: {total_duration:.2f}s, " + f"Input tokens: {total_token_usage['input_tokens']}, " + f"Output tokens: {total_token_usage['output_tokens']}, " + f"Total cost: ${cost_metrics['total_cost']:.6f}" + ) + + return Response( + { + "response_output_text": final_response_output_text, + "final_response_id": final_response_id, + }, + status=status.HTTP_200_OK, + ) + + except Exception as e: + logger.error( + f"Unexpected error in Assistant view for user {request.user.id if hasattr(request, 'user') else 'unknown'}: {e}", + exc_info=True, + ) + return Response( + {"error": "An unexpected error occurred. Please try again later."}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) diff --git a/server/api/views/conversations/views.py b/server/api/views/conversations/views.py index ffe60e65..d5921eaf 100644 --- a/server/api/views/conversations/views.py +++ b/server/api/views/conversations/views.py @@ -1,6 +1,4 @@ -from rest_framework.views import APIView from rest_framework.response import Response -from rest_framework.decorators import api_view from rest_framework import viewsets, status from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated @@ -9,19 +7,15 @@ from bs4 import BeautifulSoup from nltk.stem import PorterStemmer import requests -import openai +from openai import OpenAI import tiktoken import os import json import logging -from api.views.ai_settings.models import AI_Settings -from api.views.ai_promptStorage.models import AI_PromptStorage from django.views.decorators.csrf import csrf_exempt -from django.db import transaction, connection from .models import Conversation, Message -from .serializers import ConversationSerializer, MessageSerializer +from .serializers import ConversationSerializer from ...services.tools.tools import tools, execute_tool -from ...services.tools.database import get_database_info @csrf_exempt @@ -31,7 +25,7 @@ def extract_text(request: str) -> JsonResponse: Currently only uses the first 3500 tokens. """ - openai.api_key = os.environ.get("OPENAI_API_KEY") + OpenAI.api_key = os.environ.get("OPENAI_API_KEY") data = json.loads(request.body) webpage_url = data["webpage_url"] @@ -48,7 +42,7 @@ def extract_text(request: str) -> JsonResponse: tokens = get_tokens(text_contents, "cl100k_base") - ai_response = openai.ChatCompletion.create( + ai_response = OpenAI.ChatCompletion.create( model="gpt-3.5-turbo", messages=[ { @@ -70,6 +64,7 @@ def get_tokens(string: str, encoding_name: str) -> str: output_string = encoding.decode(tokens) return output_string + class OpenAIAPIException(APIException): """Custom exception for OpenAI API errors.""" status_code = status.HTTP_500_INTERNAL_SERVER_ERROR @@ -83,6 +78,7 @@ def __init__(self, detail=None, code=None): self.detail = {"error": self.default_detail} self.status_code = code or self.status_code + class ConversationViewSet(viewsets.ModelViewSet): serializer_class = ConversationSerializer permission_classes = [IsAuthenticated] @@ -141,15 +137,15 @@ def update_title(self, request, pk=None): return Response({"status": "Title updated successfully", "title": conversation.title}) def get_chatgpt_response(self, conversation, user_message, page_context=None): + client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) messages = [{ - "role": "system", + "role": "system", "content": "You are a knowledgeable assistant. Balancer is a powerful tool for selecting bipolar medication for patients. We are open-source and available for free use. Your primary role is to assist licensed clinical professionals with information related to Balancer and bipolar medication selection. If applicable, use the supplied tools to assist the professional." }] if page_context: context_message = f"If applicable, please use the following content to ask questions. If not applicable, please answer to the best of your ability: {page_context}" messages.append({"role": "system", "content": context_message}) - for msg in conversation.messages.all(): role = "user" if msg.is_user else "assistant" messages.append({"role": role, "content": msg.content}) @@ -157,7 +153,7 @@ def get_chatgpt_response(self, conversation, user_message, page_context=None): messages.append({"role": "user", "content": user_message}) try: - response = openai.ChatCompletion.create( + response = client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, tools=tools, @@ -165,43 +161,45 @@ def get_chatgpt_response(self, conversation, user_message, page_context=None): ) response_message = response.choices[0].message - tool_calls = response_message.get('tool_calls', []) + tool_calls = getattr(response_message, "tool_calls", []) + + tool_calls = response_message.model_dump().get("tool_calls", []) if not tool_calls: return response_message['content'] - # Handle tool calls # Add the assistant's message with tool calls to the conversation messages.append({ "role": "assistant", - "content": response_message.get('content', ''), + "content": response_message.content or "", "tool_calls": tool_calls }) - + # Process each tool call for tool_call in tool_calls: tool_call_id = tool_call['id'] tool_function_name = tool_call['function']['name'] - tool_arguments = json.loads(tool_call['function'].get('arguments', '{}')) - + tool_arguments = json.loads( + tool_call['function'].get('arguments', '{}')) + # Execute the tool results = execute_tool(tool_function_name, tool_arguments) - + # Add the tool response message messages.append({ "role": "tool", "content": str(results), # Convert results to string "tool_call_id": tool_call_id }) - + # Final API call with tool results - final_response = openai.ChatCompletion.create( + final_response = client.chat.completions.create( model="gpt-3.5-turbo", - messages=messages - ) - return final_response.choices[0].message['content'] - except openai.error.OpenAIError as e: + messages=messages + ) + return final_response.choices[0].message.content + except OpenAI.error.OpenAIError as e: logging.error("OpenAI API Error: %s", str(e)) raise OpenAIAPIException(detail=str(e)) except Exception as e: @@ -209,12 +207,12 @@ def get_chatgpt_response(self, conversation, user_message, page_context=None): raise OpenAIAPIException(detail="An unexpected error occurred.") def generate_title(self, conversation): - # Get the first two messages messages = conversation.messages.all()[:2] context = "\n".join([msg.content for msg in messages]) prompt = f"Based on the following conversation, generate a short, descriptive title (max 6 words):\n\n{context}" - response = openai.ChatCompletion.create( + client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a helpful assistant that generates short, descriptive titles."}, @@ -222,4 +220,4 @@ def generate_title(self, conversation): ] ) - return response.choices[0].message['content'].strip() + return response.choices[0].message.content.strip() diff --git a/server/api/views/embeddings/embeddingsView.py b/server/api/views/embeddings/embeddingsView.py index 550caaa3..d0bdd8ca 100644 --- a/server/api/views/embeddings/embeddingsView.py +++ b/server/api/views/embeddings/embeddingsView.py @@ -3,7 +3,6 @@ from rest_framework.response import Response from rest_framework import status from django.http import StreamingHttpResponse -import os from ...services.embedding_services import get_closest_embeddings from ...services.conversions_services import convert_uuids from ...services.openai_services import openAIServices @@ -91,7 +90,7 @@ def stream_generator(): return Response({ "question": message, "llm_response": answer, - "embeddings_info": embeddings_results, + "embeddings_info": listOfEmbeddings, "sent_to_llm": prompt_text, }, status=status.HTTP_200_OK) diff --git a/server/api/views/feedback/urls.py b/server/api/views/feedback/urls.py index 41925665..2c0eab29 100644 --- a/server/api/views/feedback/urls.py +++ b/server/api/views/feedback/urls.py @@ -1,5 +1,4 @@ from django.urls import path -from api.views.feedback import views from .views import FeedbackView urlpatterns = [ diff --git a/server/api/views/feedback/views.py b/server/api/views/feedback/views.py index cf39c9bf..dcbef992 100644 --- a/server/api/views/feedback/views.py +++ b/server/api/views/feedback/views.py @@ -1,18 +1,10 @@ from rest_framework.views import APIView from rest_framework.response import Response -from rest_framework.decorators import api_view from rest_framework import status -from django.http import JsonResponse, HttpRequest -from django import forms -import requests -import json -import os -from .models import Feedback from .serializers import FeedbackSerializer # XXX: remove csrf_exempt usage before production -from django.views.decorators.csrf import csrf_exempt class FeedbackView(APIView): diff --git a/server/api/views/listMeds/views.py b/server/api/views/listMeds/views.py index 5a19dfa2..1976458e 100644 --- a/server/api/views/listMeds/views.py +++ b/server/api/views/listMeds/views.py @@ -1,17 +1,21 @@ from rest_framework import status -from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.views import APIView + from .models import Diagnosis, Medication, Suggestion -from .serializers import DiagnosisSerializer, MedicationSerializer, SuggestionSerializer -import json -from django.views.decorators.csrf import csrf_exempt +from .serializers import MedicationSerializer # Constants for medication inclusion and exclusion -MEDS_INCLUDE = {'suicideHistory': ['Lithium']} +MEDS_INCLUDE = { + 'suicideHistory': ['Lithium'] +} + MED_EXCLUDE = { 'kidneyHistory': ['Lithium'], 'liverHistory': ['Valproate'], - 'bloodPressureHistory': ['Asenapine', 'Lurasidone', 'Olanzapine', 'Paliperidone', 'Quetiapine', 'Risperidone', 'Ziprasidone', 'Aripiprazole', 'Cariprazine'], + 'bloodPressureHistory': [ + 'Asenapine', 'Lurasidone', 'Olanzapine', 'Paliperidone', + 'Quetiapine', 'Risperidone', 'Ziprasidone', 'Aripiprazole', 'Cariprazine' + ], 'weightGainConcern': ['Quetiapine', 'Risperidone', 'Aripiprazole', 'Olanzapine'] } @@ -20,6 +24,7 @@ class GetMedication(APIView): def post(self, request): data = request.data state_query = data.get('state', '') + print(state_query) include_result = [] exclude_result = [] for condition in MEDS_INCLUDE: @@ -29,27 +34,39 @@ def post(self, request): if data.get(condition, False): # Remove any medication from include list that is in the exclude list include_result = [ - med for med in include_result if med not in MED_EXCLUDE[condition]] + med for med in include_result if med not in MED_EXCLUDE[condition] + ] exclude_result.extend(MED_EXCLUDE[condition]) - diag_query = Diagnosis.objects.filter(state=state_query) - if diag_query.count() <= 0: + try: + diagnosis = Diagnosis.objects.get(state=state_query) + except Diagnosis.DoesNotExist: return Response({'error': 'Diagnosis not found'}, status=status.HTTP_404_NOT_FOUND) - diagnosis = diag_query[0] - meds = {'first': '', 'second': '', 'third': ''} + meds = {'first': [], 'second': [], 'third': []} + + priorMeds = data.get('priorMedications', "").split(',') + exclude_result.extend([med.strip() + for med in priorMeds if med.strip()]) + included_set = set(include_result) + excluded_set = set(exclude_result) + for med in include_result: - meds['first'] += med + ", " - for i, line in enumerate(['first', 'second', 'third']): - for suggestion in Suggestion.objects.filter(diagnosis=diagnosis, tier=(i + 1)): - to_exclude = False - for med in exclude_result: - if med in suggestion.medication.name: - to_exclude = True - break - if i > 0 and suggestion.medication.name in include_result: - to_exclude = True - if not to_exclude: - meds[line] += suggestion.medication.name + ", " - meds[line] = meds[line][:-2] if meds[line] else 'None' + meds['first'].append({'name': med, 'source': 'include'}) + + for i, tier_label in enumerate(['first', 'second', 'third']): + suggestions = Suggestion.objects.filter( + diagnosis=diagnosis, tier=i+1 + ) + for suggestion in suggestions: + med_name = suggestion.medication.name + if med_name in excluded_set: + continue + if i > 0 and med_name in included_set: + continue + meds[tier_label].append({ + 'name': med_name, + 'source': 'diagnosis_' + state_query.lower() + }) + return Response(meds) @@ -69,20 +86,7 @@ def get(self, request): return Response(serializer.data) def post(self, request): - # Implement logic for adding new medications (if needed) - # If adding medications, you would check if the medication already exists before creating it - data = request.data - name = data.get('name', '') - if not name: - return Response({'error': 'Medication name is required'}, status=status.HTTP_400_BAD_REQUEST) - if Medication.objects.filter(name=name).exists(): - return Response({'error': 'Medication already exists'}, status=status.HTTP_400_BAD_REQUEST) - # Assuming Medication model has `name`, `benefits`, `risks` as fields - serializer = MedicationSerializer(data=data) - if serializer.is_valid(): - serializer.save() - return Response(serializer.data, status=status.HTTP_201_CREATED) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + return Response({'error': 'Use AddMedication endpoint for creating medications'}, status=status.HTTP_405_METHOD_NOT_ALLOWED) class AddMedication(APIView): @@ -95,13 +99,15 @@ def post(self, request): name = data.get('name', '').strip() benefits = data.get('benefits', '').strip() risks = data.get('risks', '').strip() - # Validate the inputs + + # Validate required fields if not name: return Response({'error': 'Medication name is required'}, status=status.HTTP_400_BAD_REQUEST) if not benefits: return Response({'error': 'Medication benefits are required'}, status=status.HTTP_400_BAD_REQUEST) if not risks: return Response({'error': 'Medication risks are required'}, status=status.HTTP_400_BAD_REQUEST) + # Check if medication already exists if Medication.objects.filter(name=name).exists(): return Response({'error': f'Medication "{name}" already exists'}, status=status.HTTP_400_BAD_REQUEST) @@ -114,25 +120,22 @@ def post(self, request): class DeleteMedication(APIView): - "API endpoint to delete medication if medication in database" + """ + API endpoint to delete medication if medication in database. + """ def delete(self, request): data = request.data name = data.get('name', '').strip() - print("ok vin") - # Validate the inputs + + # Validate required fields if not name: return Response({'error': 'Medication name is required'}, status=status.HTTP_400_BAD_REQUEST) - # Check if medication exists - if Medication.objects.filter(name=name).exists(): - # return f'Medication "{name}" exists' - # Get the medication object + + # Check if medication exists and delete + try: medication = Medication.objects.get(name=name) - # Delete the medication medication.delete() - return Response({'success': "medication exists and will now be deleted"}, status=status.HTTP_201_CREATED) - else: - return Response({'error': 'Medication does not exist'}, status=status.HTTP_400_BAD_REQUEST) - # ask user if sure to delete? - # delete med from database - # Medication.objects.filter(name=name) + return Response({'success': f'Medication "{name}" has been deleted'}, status=status.HTTP_200_OK) + except Medication.DoesNotExist: + return Response({'error': f'Medication "{name}" does not exist'}, status=status.HTTP_404_NOT_FOUND) diff --git a/server/api/views/medRules/serializers.py b/server/api/views/medRules/serializers.py index da09f65c..df5e3663 100644 --- a/server/api/views/medRules/serializers.py +++ b/server/api/views/medRules/serializers.py @@ -1,30 +1,53 @@ from rest_framework import serializers -from ...models.model_medRule import MedRule +from ...models.model_medRule import MedRule, MedRuleSource from ..listMeds.serializers import MedicationSerializer from ...models.model_embeddings import Embeddings -from ..listMeds.models import Medication + class EmbeddingsSerializer(serializers.ModelSerializer): class Meta: model = Embeddings - fields = ['guid', 'name', 'text', 'page_num', 'chunk_number'] + fields = ["guid", "name", "text", "page_num", "chunk_number"] + + +class MedicationWithSourcesSerializer(serializers.Serializer): + medication = MedicationSerializer() + sources = EmbeddingsSerializer(many=True) class MedRuleSerializer(serializers.ModelSerializer): - medications = MedicationSerializer(many=True, read_only=True) - medication_ids = serializers.PrimaryKeyRelatedField( - many=True, write_only=True, queryset=Medication.objects.all(), source='medications' - ) - sources = EmbeddingsSerializer(many=True, read_only=True) - source_ids = serializers.PrimaryKeyRelatedField( - many=True, write_only=True, queryset=Embeddings.objects.all(), source='sources' - ) + medication_sources = serializers.SerializerMethodField() class Meta: model = MedRule fields = [ - 'id', 'rule_type', 'history_type', 'reason', 'label', - 'medications', 'medication_ids', - 'sources', 'source_ids', - 'explanation' - ] \ No newline at end of file + "id", + "rule_type", + "history_type", + "reason", + "label", + "explanation", + "medication_sources", + ] + + def get_medication_sources(self, obj): + medrule_sources = MedRuleSource.objects.filter(medrule=obj).select_related( + "medication", "embedding" + ) + + med_to_sources = {} + for ms in medrule_sources: + if ms.medication.id not in med_to_sources: + med_to_sources[ms.medication.id] = { + "medication": ms.medication, + "sources": [], + } + med_to_sources[ms.medication.id]["sources"].append(ms.embedding) + + return [ + { + "medication": MedicationSerializer(data["medication"]).data, + "sources": EmbeddingsSerializer(data["sources"], many=True).data, + } + for data in med_to_sources.values() + ] diff --git a/server/api/views/medRules/views.py b/server/api/views/medRules/views.py index f8ff39e1..2fae140b 100644 --- a/server/api/views/medRules/views.py +++ b/server/api/views/medRules/views.py @@ -7,7 +7,6 @@ from ...models.model_medRule import MedRule from .serializers import MedRuleSerializer # You'll need to create this from ..listMeds.models import Medication -from ..listMeds.serializers import MedicationSerializer from ..uploadFile.models import UploadFile from ...models.model_embeddings import Embeddings diff --git a/server/api/views/risk/urls.py b/server/api/views/risk/urls.py index 30e53424..8f0c8dd4 100644 --- a/server/api/views/risk/urls.py +++ b/server/api/views/risk/urls.py @@ -1,6 +1,8 @@ from django.urls import path from api.views.risk import views +from api.views.risk.views_riskWithSources import RiskWithSourcesView urlpatterns = [ - path("chatgpt/risk", views.medication, name="risk") -] + path("chatgpt/risk", views.medication, name="risk"), + path("v1/api/riskWithSources", RiskWithSourcesView.as_view()), +] diff --git a/server/api/views/risk/views.py b/server/api/views/risk/views.py index ea1a77a8..99327a8d 100644 --- a/server/api/views/risk/views.py +++ b/server/api/views/risk/views.py @@ -14,10 +14,11 @@ def medication(request): data = json.loads(request.body) if data is not None: - diagnosis = data["diagnosis"] # the variable name is diagnosis but this variable contain the medication name + # the variable name is diagnosis but this variable contain the medication name + diagnosis = data["diagnosis"] else: return JsonResponse({"error": "Diagnosis not found. Request must include diagnosis."}) - + try: med = Medication.objects.get(name=diagnosis) benefits = [f'- {benefit}' for benefit in med.benefits.split(', ')] diff --git a/server/api/views/risk/views_riskWithSources.py b/server/api/views/risk/views_riskWithSources.py new file mode 100644 index 00000000..0be43dbb --- /dev/null +++ b/server/api/views/risk/views_riskWithSources.py @@ -0,0 +1,465 @@ +from rest_framework.views import APIView +from rest_framework.response import Response +from rest_framework import status +from api.views.listMeds.models import Medication +from api.models.model_medRule import MedRule, MedRuleSource +import openai +import os + + +class RiskWithSourcesView(APIView): + def post(self, request): + openai.api_key = os.environ.get("OPENAI_API_KEY") + + drug = request.data.get("drug") + if not drug: + return Response({"error": "Drug not found. Request must include 'drug'."}, status=status.HTTP_400_BAD_REQUEST) + source = request.data.get("source") + print(f"Requested source: {source}") + + # If source is provided, validate it + valid_sources = ["include", "diagnosis", "diagnosis_depressed", "diagnosis_manic", "diagnosis_hypomanic", "diagnosis_euthymic"] + if source and source not in valid_sources: + return Response({"error": f"Source must be one of: {', '.join(valid_sources)}."}, status=status.HTTP_400_BAD_REQUEST) + + # If no source is provided, return all sources + if not source: + try: + return self._handle_all_sources(drug) + except Exception as e: + print(f"Error in _handle_all_sources: {str(e)}") + return Response({"error": f"Failed to retrieve all sources data: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # Handle diagnosis source by linking to medrules + if source in ["diagnosis", "diagnosis_depressed", "diagnosis_manic", "diagnosis_hypomanic", "diagnosis_euthymic"]: + try: + return self._handle_diagnosis_source(drug, source) + except Exception as e: + print(f"Error in _handle_diagnosis_source: {str(e)}") + return Response({"error": f"Failed to retrieve diagnosis data: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + if source == "include": + try: + return self._handle_include_source(drug) + except Exception as e: + print(f"Error in _handle_include_source: {str(e)}") + return Response({"error": f"Failed to retrieve include data: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # Handle include source (existing logic) + try: + med = Medication.objects.get(name=drug) + benefits = [f'- {b.strip()}' for b in med.benefits.split(',')] + risks = [f'- {r.strip()}' for r in med.risks.split(',')] + return Response({ + 'benefits': benefits, + 'risks': risks + }) + + except Medication.DoesNotExist: + prompt = ( + f"You are to provide a concise list of 5 key benefits and 5 key risks " + f"for the medication suggested when taking it for Bipolar. Each point should be short, " + f"clear and be kept under 10 words. Begin the benefits section with !!!benefits!!! and " + f"the risks section with !!!risk!!!. Please provide this information for the medication: {drug}." + ) + + try: + ai_response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[{"role": "system", "content": prompt}] + ) + except Exception as e: + return Response({"error": f"OpenAI request failed: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + content = ai_response['choices'][0]['message']['content'] + + if '!!!benefits!!!' not in content or '!!!risks!!!' not in content: + return Response({"error": "Unexpected format in OpenAI response."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + benefits_raw = content.split('!!!risks!!!')[0].replace( + '!!!benefits!!!', '').strip() + risks_raw = content.split('!!!risks!!!')[1].strip() + + benefits = [line.strip() + for line in benefits_raw.split('\n') if line.strip()] + risks = [line.strip() + for line in risks_raw.split('\n') if line.strip()] + + return Response({ + 'benefits': benefits, + 'risks': risks + }) + + def _handle_include_source(self, drug): + """Handle include source by looking up medrules for the medication""" + try: + # Get the medication + medication = Medication.objects.get(name=drug) + + print( + f"Found medication '{medication.name}' for '{drug}' with ID {medication.id}") + + # Find medrules that include this medication + medrule_ids = MedRuleSource.objects.filter( + medication=medication, + medrule__rule_type='INCLUDE' + ).values_list('medrule_id', flat=True).distinct() + + medrules = MedRule.objects.filter(id__in=medrule_ids) + print(f"Found {medrules.count()} medrules for {drug}") + benefits = [] + risks = [] + sources_info = [] + + # Extract benefits and sources + for medrule in medrules: + if medrule.explanation: + benefits.append(f"- {medrule.explanation}") + + # Get associated sources through MedRuleSource + medrule_sources = MedRuleSource.objects.filter( + medrule=medrule, + medication=medication + ) + print( + f"Found {medrule_sources.count()} sources for medrule {medrule.id}") + + for source_link in medrule_sources: + embedding = source_link.embedding + + # Get filename from upload_file if available + filename = None + if hasattr(embedding, 'upload_file') and embedding.upload_file: + filename = embedding.upload_file.file_name + + source_info = { + 'filename': filename, + 'title': getattr(embedding, 'title', None), + 'publication': getattr(embedding, 'publication', ''), + 'text': getattr(embedding, 'text', ''), + 'rule_type': medrule.rule_type, + 'history_type': medrule.history_type, + # Add link data for PDF navigation + 'upload_fileid': getattr(embedding, 'upload_file_id', None), + 'page': getattr(embedding, 'page_num', None), + 'link_url': self._build_pdf_link(embedding) + } + + sources_info.append(source_info) + + # Check EXCLUDE rules for risks + exclude_rules = MedRule.objects.filter( + medications=medication, + rule_type='EXCLUDE' + ) + + for rule in exclude_rules: + if rule.explanation: + risks.append(f"- {rule.explanation}") + + if not benefits and not risks: + basic_benefits = [ + f'- {b.strip()}' for b in medication.benefits.split(',')] + basic_risks = [ + f'- {r.strip()}' for r in medication.risks.split(',')] + + return Response({ + 'benefits': basic_benefits, + 'risks': basic_risks, + 'sources': 'include', + 'note': 'No specific medrule sources found, showing general medication information' + }) + + return Response({ + 'benefits': [f'- {b.strip()}' for b in medication.benefits.split(',')], + 'risks': risks if risks else [f'- {r.strip()}' for r in medication.risks.split(',')], + 'sources': sources_info, + 'medrules_found': len(medrules) + len(exclude_rules) + }) + + except Medication.DoesNotExist: + return Response({"error": f"Medication '{drug}' not found."}, status=status.HTTP_404_NOT_FOUND) + + def _handle_diagnosis_source(self, drug, source): + """Handle diagnosis source by looking up medrules for the medication""" + try: + # Get the medication + medication = Medication.objects.get(name=drug) + + print( + f"Found medication '{medication.name}' for '{drug}' with ID {medication.id}") + + # Map source parameter to history_type value + source_to_history_type = { + "diagnosis_depressed": "DIAGNOSIS_DEPRESSED", + "diagnosis_manic": "DIAGNOSIS_MANIC", + "diagnosis_hypomanic": "DIAGNOSIS_HYPOMANIC", + "diagnosis_euthymic": "DIAGNOSIS_EUTHYMIC", + "diagnosis": "DIAGNOSIS_DEPRESSED" # default to depressed for backward compatibility + } + + history_type = source_to_history_type.get(source, "DIAGNOSIS_DEPRESSED") + print(f"Using history_type: {history_type} for source: {source}") + + # Find medrules that include this medication with the specified history type + medrule_ids = MedRuleSource.objects.filter( + medication=medication, + medrule__history_type=history_type + ).values_list('medrule_id', flat=True).distinct() + + print(medrule_ids) + + medrules = MedRule.objects.filter(id__in=medrule_ids) + print(f"Found {medrules.count()} medrules for {drug}") + benefits = [] + risks = [] + sources_info = [] + + # Extract benefits and sources + for medrule in medrules: + print( + f"Medrule {medrule.id}: rule_type={medrule.rule_type}, explanation='{medrule.explanation}'") + + # Only add INCLUDE rules to benefits + if medrule.rule_type == 'INCLUDE' and medrule.explanation: + benefits.append(f"- {medrule.explanation}") + elif medrule.rule_type == 'EXCLUDE' and medrule.explanation: + # Add EXCLUDE rules to risks + risks.append(f"- {medrule.explanation}") + else: + print( + f"Medrule {medrule.id} has no explanation or is not INCLUDE/EXCLUDE, skipping") + + # Get associated sources through MedRuleSource + medrule_sources = MedRuleSource.objects.filter( + medrule=medrule, + medication=medication + ) + print( + f"Found {medrule_sources.count()} sources for medrule {medrule.id}") + + for source_link in medrule_sources: + embedding = source_link.embedding + + # Get filename from upload_file if available + filename = None + if hasattr(embedding, 'upload_file') and embedding.upload_file: + filename = embedding.upload_file.file_name + + source_info = { + 'filename': filename, + 'title': getattr(embedding, 'title', None), + 'publication': getattr(embedding, 'publication', ''), + 'text': getattr(embedding, 'text', ''), + 'rule_type': medrule.rule_type, + 'history_type': medrule.history_type, + # Add link data for PDF navigation + 'upload_fileid': getattr(embedding, 'upload_file_id', None), + 'page': getattr(embedding, 'page_num', None), + 'link_url': self._build_pdf_link(embedding) + } + + sources_info.append(source_info) + + # Check EXCLUDE rules for risks with the specified history type + exclude_rules = MedRule.objects.filter( + medications=medication, + rule_type='EXCLUDE', + history_type=history_type + ) + print(f"Found {exclude_rules.count()} exclude rules for {drug}") + + for rule in exclude_rules: + print( + f"Exclude rule {rule.id}: explanation='{rule.explanation}'") + if rule.explanation: + risks.append(f"- {rule.explanation}") + else: + print( + f"Exclude rule {rule.id} has no explanation, skipping risks") + + print( + f"Total benefits collected: {len(benefits)}, Total risks collected: {len(risks)}") + if not benefits and not risks: + basic_benefits = [ + f'- {b.strip()}' for b in medication.benefits.split(',')] + basic_risks = [ + f'- {r.strip()}' for r in medication.risks.split(',')] + + return Response({ + 'benefits': basic_benefits, + 'risks': basic_risks, + 'sources': sources_info, + 'note': 'No specific medrule sources found, showing general medication information' + }) + + return Response({ + 'benefits': benefits if benefits else [f'- {b.strip()}' for b in medication.benefits.split(',')], + 'risks': risks if risks else [f'- {r.strip()}' for r in medication.risks.split(',')], + 'sources': sources_info, + 'medrules_found': len(medrules) + len(exclude_rules) + }) + + except Medication.DoesNotExist: + return Response({"error": f"Medication '{drug}' not found."}, status=status.HTTP_404_NOT_FOUND) + + def _handle_all_sources(self, drug): + """Handle request with no source specified - return all sources""" + try: + # Get the medication + medication = Medication.objects.get(name=drug) + + print( + f"Found medication '{medication.name}' for '{drug}' with ID {medication.id}") + + # Find all medrules that are related to this medication via MedRuleSource + medrule_ids = MedRuleSource.objects.filter( + medication=medication + ).values_list('medrule_id', flat=True).distinct() + + medrules = MedRule.objects.filter(id__in=medrule_ids) + print( + f"Found {medrules.count()} total medrules for {drug} across all sources") + + benefits = [] + risks = [] + sources_info = [] + + # Extract benefits, risks, and sources from all medrules + for medrule in medrules: + print( + f"Medrule {medrule.id}: rule_type={medrule.rule_type}, history_type={medrule.history_type}") + + # Add INCLUDE rules to benefits + if medrule.rule_type == 'INCLUDE' and medrule.explanation: + benefits.append(f"- {medrule.explanation}") + # Add EXCLUDE rules to risks + elif medrule.rule_type == 'EXCLUDE' and medrule.explanation: + risks.append(f"- {medrule.explanation}") + + # Get associated sources through MedRuleSource + medrule_sources = MedRuleSource.objects.filter( + medrule=medrule, + medication=medication + ) + print( + f"Found {medrule_sources.count()} sources for medrule {medrule.id}") + + for source_link in medrule_sources: + embedding = source_link.embedding + + # Get filename from upload_file if available + filename = None + if hasattr(embedding, 'upload_file') and embedding.upload_file: + filename = embedding.upload_file.file_name + + source_info = { + 'filename': filename, + 'title': getattr(embedding, 'title', None), + 'publication': getattr(embedding, 'publication', ''), + 'text': getattr(embedding, 'text', ''), + 'rule_type': medrule.rule_type, + 'history_type': medrule.history_type, + # Add link data for PDF navigation + 'upload_fileid': getattr(embedding, 'upload_file_id', None), + 'page': getattr(embedding, 'page_num', None), + 'link_url': self._build_pdf_link(embedding) + } + + sources_info.append(source_info) + + print( + f"Total benefits collected: {len(benefits)}, Total risks collected: {len(risks)}, Total sources: {len(sources_info)}") + + # If no medrule-based benefits or risks, fall back to basic medication info + if not benefits and not risks: + basic_benefits = [ + f'- {b.strip()}' for b in medication.benefits.split(',')] + basic_risks = [ + f'- {r.strip()}' for r in medication.risks.split(',')] + + return Response({ + 'benefits': basic_benefits, + 'risks': basic_risks, + 'sources': sources_info, + 'note': 'No specific medrule sources found, showing general medication information' + }) + + return Response({ + 'benefits': benefits if benefits else [f'- {b.strip()}' for b in medication.benefits.split(',')], + 'risks': risks if risks else [f'- {r.strip()}' for r in medication.risks.split(',')], + 'sources': sources_info, + 'medrules_found': len(medrules), + 'source_type': 'all' + }) + + except Medication.DoesNotExist: + return Response({"error": f"Medication '{drug}' not found."}, status=status.HTTP_404_NOT_FOUND) + + def _build_pdf_link(self, embedding): + """Build the PDF viewer link URL by getting the document GUID from UploadFile""" + try: + # Get the upload_fileid from the embedding + upload_fileid = getattr(embedding, 'upload_file_id', None) + page = getattr(embedding, 'page_num', None) + + if not upload_fileid: + return None + + from api.views.uploadFile.models import UploadFile + + # Get the UploadFile record to get the document GUID + upload_file = UploadFile.objects.get(id=upload_fileid) + document_guid = upload_file.guid + + if document_guid: + base_url = "/drugsummary" + if page: + return f"{base_url}?guid={document_guid}&page={page}" + else: + return f"{base_url}?guid={document_guid}" + + except Exception as e: + print(f"Error building PDF link: {e}") + return None + + return None + + def _get_ai_response_for_diagnosis(self, drug): + """Get AI response with diagnosis-specific context""" + prompt = ( + f"You are providing medication information from a diagnosis/clinical perspective. " + f"Provide a concise list of 5 key benefits and 5 key risks for the medication {drug} " + f"when prescribed for Bipolar disorder, focusing on clinical evidence and diagnostic considerations. " + f"Each point should be short, clear and be kept under 10 words. " + f"Begin the benefits section with !!!benefits!!! and the risks section with !!!risk!!!." + ) + + try: + ai_response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[{"role": "system", "content": prompt}] + ) + except Exception as e: + return Response({"error": f"OpenAI request failed: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + content = ai_response['choices'][0]['message']['content'] + + if '!!!benefits!!!' not in content or '!!!risks!!!' not in content: + return Response({"error": "Unexpected format in OpenAI response."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + benefits_raw = content.split('!!!risks!!!')[0].replace( + '!!!benefits!!!', '').strip() + risks_raw = content.split('!!!risks!!!')[1].strip() + + benefits = [line.strip() + for line in benefits_raw.split('\n') if line.strip()] + risks = [line.strip() + for line in risks_raw.split('\n') if line.strip()] + + return Response({ + 'benefits': benefits, + 'risks': risks, + 'source': 'diagnosis', + 'note': 'Generated from AI with diagnosis context - medication not found in database' + }) diff --git a/server/api/views/text_extraction/views.py b/server/api/views/text_extraction/views.py index e0110a8e..e4122851 100644 --- a/server/api/views/text_extraction/views.py +++ b/server/api/views/text_extraction/views.py @@ -1,5 +1,7 @@ import os -from ...services.openai_services import openAIServices +import json +import re + from rest_framework.views import APIView from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response @@ -7,8 +9,8 @@ from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt import anthropic -import json -import re + +from ...services.openai_services import openAIServices from api.models.model_embeddings import Embeddings USER_PROMPT = """ @@ -30,11 +32,23 @@ """ -# TODO: Add docstrings and type hints -def anthropic_citations(client, user_prompt, content_chunks): +def anthropic_citations(client: anthropic.Client, user_prompt: str, content_chunks: list) -> tuple: """ + Sends a message to Anthropic Citations and extract and format the response + + Parameters + ---------- + client: An instance of the Anthropic API client used to make the request + user_prompt: The user's question or instruction to be processed by the model + content_chunks: A list of text chunks that provide context for the model to use during generation + + Returns + ------- + tuple + """ + message = client.messages.create( model="claude-3-5-haiku-20241022", max_tokens=1024, @@ -92,6 +106,7 @@ def get(self, request): query = Embeddings.objects.filter(upload_file__guid=guid) + # TODO: Format into the Anthropic API"s expected input format in the anthropic_citations function chunks = [{"type": "text", "text": chunk.text} for chunk in query] texts, cited_texts = anthropic_citations(client, USER_PROMPT, chunks) diff --git a/server/api/views/uploadFile/views.py b/server/api/views/uploadFile/views.py index 43b34592..6904e061 100644 --- a/server/api/views/uploadFile/views.py +++ b/server/api/views/uploadFile/views.py @@ -3,12 +3,8 @@ from rest_framework.response import Response from rest_framework import status from rest_framework.generics import UpdateAPIView -from django.utils.decorators import method_decorator -from django.views.decorators.csrf import csrf_exempt import pdfplumber from .models import UploadFile # Import your UploadFile model -from django.core.files.base import ContentFile -import os from .serializers import UploadFileSerializer from django.http import HttpResponse from ...services.sentencetTransformer_model import TransformerModel @@ -18,9 +14,7 @@ from .title import generate_title -@method_decorator(csrf_exempt, name='dispatch') class UploadFileView(APIView): - permission_classes = [IsAuthenticated] def get(self, request, format=None): print("UploadFileView, get list") @@ -118,6 +112,7 @@ def post(self, request, format=None): Embeddings.objects.create( upload_file=new_file, name=new_file.file_name, # You may adjust the naming convention + title=title, # Set the title from the document text=chunk, chunk_number=i, page_num=page_num, # Store the page number here @@ -160,7 +155,6 @@ def delete(self, request, format=None): return Response({"message": f"Error deleting file and embeddings: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) -@method_decorator(csrf_exempt, name='dispatch') class RetrieveUploadFileView(APIView): permission_classes = [IsAuthenticated] @@ -176,7 +170,6 @@ def get(self, request, guid, format=None): return Response({"message": "No file found or access denied."}, status=status.HTTP_404_NOT_FOUND) -@method_decorator(csrf_exempt, name='dispatch') class EditFileMetadataView(UpdateAPIView): permission_classes = [IsAuthenticated] serializer_class = UploadFileSerializer diff --git a/server/balancer_backend/settings.py b/server/balancer_backend/settings.py index 91a53a8b..df62d198 100644 --- a/server/balancer_backend/settings.py +++ b/server/balancer_backend/settings.py @@ -29,57 +29,56 @@ # Fetching the value from the environment and splitting to list if necessary. # Fallback to '*' if the environment variable is not set. -ALLOWED_HOSTS = os.environ.get('DJANGO_ALLOWED_HOSTS', '*').split() +ALLOWED_HOSTS = os.environ.get("DJANGO_ALLOWED_HOSTS", "*").split() # If the environment variable contains '*', the split method would create a list with an empty string. # So you need to check for this case and adjust accordingly. -if ALLOWED_HOSTS == ['*'] or ALLOWED_HOSTS == ['']: - ALLOWED_HOSTS = ['*'] +if ALLOWED_HOSTS == ["*"] or ALLOWED_HOSTS == [""]: + ALLOWED_HOSTS = ["*"] # Application definition INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'balancer_backend', - 'api', - 'corsheaders', - 'rest_framework', - 'djoser', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "balancer_backend", + "api", + "corsheaders", + "rest_framework", + "djoser", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'corsheaders.middleware.CorsMiddleware', - + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", + "corsheaders.middleware.CorsMiddleware", ] -ROOT_URLCONF = 'balancer_backend.urls' +ROOT_URLCONF = "balancer_backend.urls" CORS_ALLOW_ALL_ORIGINS = True TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [os.path.join(BASE_DIR, 'build')], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [os.path.join(BASE_DIR, "build")], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, @@ -89,7 +88,7 @@ # Change this to your desired URL LOGIN_REDIRECT_URL = os.environ.get("LOGIN_REDIRECT_URL") -WSGI_APPLICATION = 'balancer_backend.wsgi.application' +WSGI_APPLICATION = "balancer_backend.wsgi.application" # Database @@ -106,8 +105,8 @@ } } -EMAIL_BACKEND = 'django.core.mail.backends.smtp.EmailBackend' -EMAIL_HOST = 'smtp.gmail.com' +EMAIL_BACKEND = "django.core.mail.backends.smtp.EmailBackend" +EMAIL_HOST = "smtp.gmail.com" EMAIL_PORT = 587 EMAIL_HOST_USER = os.environ.get("EMAIL_HOST_USER", "") EMAIL_HOST_PASSWORD = os.environ.get("EMAIL_HOST_PASSWORD", "") @@ -119,25 +118,25 @@ AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", }, ] # Internationalization # https://docs.djangoproject.com/en/4.2/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -147,64 +146,90 @@ # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.2/howto/static-files/ -STATIC_URL = '/static/' +STATIC_URL = "/static/" STATICFILES_DIRS = [ - os.path.join(BASE_DIR, 'build/static'), + os.path.join(BASE_DIR, "build/static"), ] -STATIC_ROOT = os.path.join(BASE_DIR, 'static') +STATIC_ROOT = os.path.join(BASE_DIR, "static") AUTHENTICATION_BACKENDS = [ - 'django.contrib.auth.backends.ModelBackend', + "django.contrib.auth.backends.ModelBackend", ] REST_FRAMEWORK = { - 'DEFAULT_PERMISSION_CLASSES': [ - 'rest_framework.permissions.IsAuthenticated' - ], - 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework_simplejwt.authentication.JWTAuthentication', + "DEFAULT_PERMISSION_CLASSES": ["rest_framework.permissions.IsAuthenticated"], + "DEFAULT_AUTHENTICATION_CLASSES": ( + "rest_framework_simplejwt.authentication.JWTAuthentication", ), } SIMPLE_JWT = { - 'AUTH_HEADER_TYPES': ('JWT',), - 'ACCESS_TOKEN_LIFETIME': timedelta(minutes=60), - 'REFRESH_TOKEN_LIFETIME': timedelta(days=1), - 'TOKEN_OBTAIN_SERIALIZER': 'api.models.TokenObtainPairSerializer.MyTokenObtainPairSerializer', - 'AUTH_TOKEN_CLASSES': ( - 'rest_framework_simplejwt.tokens.AccessToken', - ) + "AUTH_HEADER_TYPES": ("JWT",), + "ACCESS_TOKEN_LIFETIME": timedelta(minutes=60), + "REFRESH_TOKEN_LIFETIME": timedelta(days=1), + "TOKEN_OBTAIN_SERIALIZER": "api.models.TokenObtainPairSerializer.MyTokenObtainPairSerializer", + "AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",), } DJOSER = { - 'LOGIN_FIELD': 'email', - 'USER_CREATE_PASSWORD_RETYPE': True, - 'USERNAME_CHANGED_EMAIL_CONFIRMATION': True, - 'PASSWORD_CHANGED_EMAIL_CONFIRMATION': True, - 'SEND_CONFIRMATION_EMAIL': True, - 'SET_USERNAME_RETYPE': True, - 'SET_PASSWORD_RETYPE': True, - 'PASSWORD_RESET_CONFIRM_URL': 'password/reset/confirm/{uid}/{token}', - 'USERNAME_RESET_CONFIRM_URL': 'email/reset/confirm/{uid}/{token}', - 'ACTIVATION_URL': 'activate/{uid}/{token}', - 'SEND_ACTIVATION_EMAIL': True, - 'SOCIAL_AUTH_TOKEN_STRATEGY': 'djoser.social.token.jwt.TokenStrategy', - 'SOCIAL_AUTH_ALLOWED_REDIRECT_URIS': ['http://localhost:8000/google', 'http://localhost:8000/facebook'], - 'SERIALIZERS': { - 'user_create': 'api.models.serializers.UserCreateSerializer', - 'user': 'api.models.serializers.UserCreateSerializer', - 'current_user': 'api.models.serializers.UserCreateSerializer', - 'user_delete': 'djoser.serializers.UserDeleteSerializer', - } + "LOGIN_FIELD": "email", + "USER_CREATE_PASSWORD_RETYPE": True, + "USERNAME_CHANGED_EMAIL_CONFIRMATION": True, + "PASSWORD_CHANGED_EMAIL_CONFIRMATION": True, + "SEND_CONFIRMATION_EMAIL": True, + "SET_USERNAME_RETYPE": True, + "SET_PASSWORD_RETYPE": True, + "PASSWORD_RESET_CONFIRM_URL": "password/reset/confirm/{uid}/{token}", + "USERNAME_RESET_CONFIRM_URL": "email/reset/confirm/{uid}/{token}", + "ACTIVATION_URL": "activate/{uid}/{token}", + "SEND_ACTIVATION_EMAIL": True, + "SOCIAL_AUTH_TOKEN_STRATEGY": "djoser.social.token.jwt.TokenStrategy", + "SOCIAL_AUTH_ALLOWED_REDIRECT_URIS": [ + "http://localhost:8000/google", + "http://localhost:8000/facebook", + ], + "SERIALIZERS": { + "user_create": "api.models.serializers.UserCreateSerializer", + "user": "api.models.serializers.UserCreateSerializer", + "current_user": "api.models.serializers.UserCreateSerializer", + "user_delete": "djoser.serializers.UserDeleteSerializer", + }, } # Default primary key field type # https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field -DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' - - -AUTH_USER_MODEL = 'api.UserAccount' +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" + + +AUTH_USER_MODEL = "api.UserAccount" + +# Logging configuration + +# LOGGING = { +# "version": 1, +# "disable_existing_loggers": False, +# "formatters": { +# "verbose": { +# "format": "{levelname} {asctime} {module} {process:d} {thread:d} {message}", +# "style": "{", +# }, +# "simple": { +# "format": "{levelname} {message}", +# "style": "{", +# }, +# }, +# "handlers": { +# "console": { +# "class": "logging.StreamHandler", +# "formatter": "verbose", +# }, +# }, +# "root": { +# "handlers": ["console"], +# "level": "INFO", +# }, +# } diff --git a/server/balancer_backend/urls.py b/server/balancer_backend/urls.py index 1c5bad8b..56f307e4 100644 --- a/server/balancer_backend/urls.py +++ b/server/balancer_backend/urls.py @@ -1,6 +1,8 @@ from django.contrib import admin # Import Django's admin interface module + # Import functions for URL routing and including other URL configs from django.urls import path, include, re_path + # Import TemplateView for rendering templates from django.views.generic import TemplateView import importlib # Import the importlib module for dynamic module importing @@ -10,25 +12,37 @@ # Map 'admin/' URL to the Django admin interface path("admin/", admin.site.urls), # Include Djoser's URL patterns under 'auth/' for basic auth - path('auth/', include('djoser.urls')), + path("auth/", include("djoser.urls")), # Include Djoser's JWT auth URL patterns under 'auth/' - path('auth/', include('djoser.urls.jwt')), + path("auth/", include("djoser.urls.jwt")), # Include Djoser's social auth URL patterns under 'auth/' - path('auth/', include('djoser.social.urls')), + path("auth/", include("djoser.social.urls")), ] # List of application names for which URL patterns will be dynamically added -urls = ['conversations', 'feedback', 'listMeds', 'risk', - 'uploadFile', 'ai_promptStorage', 'ai_settings', 'embeddings', 'medRules', 'text_extraction'] +urls = [ + "conversations", + "feedback", + "listMeds", + "risk", + "uploadFile", + "ai_promptStorage", + "ai_settings", + "embeddings", + "medRules", + "text_extraction", + "assistant", +] # Loop through each application name and dynamically import and add its URL patterns for url in urls: # Dynamically import the URL module for each app - url_module = importlib.import_module(f'api.views.{url}.urls') + url_module = importlib.import_module(f"api.views.{url}.urls") # Append the URL patterns from each imported module - urlpatterns += getattr(url_module, 'urlpatterns', []) + urlpatterns += getattr(url_module, "urlpatterns", []) # Add a catch-all URL pattern for handling SPA (Single Page Application) routing # Serve 'index.html' for any unmatched URL urlpatterns += [ - re_path(r'^.*$', TemplateView.as_view(template_name='index.html')),] + re_path(r"^.*$", TemplateView.as_view(template_name="index.html")), +] diff --git a/server/entrypoint.sh b/server/entrypoint.sh index 1651e065..2d2c872f 100755 --- a/server/entrypoint.sh +++ b/server/entrypoint.sh @@ -11,8 +11,8 @@ then echo "PostgreSQL started" fi -python manage.py makemigrations api --no-input -# python manage.py flush --no-input +python manage.py makemigrations api +# # python manage.py flush --no-input python manage.py migrate # create superuser for postgre admin on start up python manage.py createsu