# Statistician Agent with Strands
In this notebook we create the statistician agent using the open-source Strands agents framework

#### Install Strands agents and required dependencies

In [None]:
%pip install strands-agents strands-agents-tools boto3 requests --quiet

#### Ensure the latest version of boto3 is shown below
Ensure the boto3 version printed below is **1.39.3** or higher.

In [None]:
%pip show boto3

#### Import required libraries

In [None]:
from utils.magic_helper import register_cell_magic

In [None]:
%%write_and_run statistician_agent.py

import boto3
import json
import uuid
import requests
from typing import Dict, Any, List
from strands import Agent, tool
from strands.models import BedrockModel

# Get AWS account information
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity()['Account']
region = boto3.Session().region_name

# Prerequisites

This notebook assumes that you have deployed the CloudFormation stack located at https://github.com/aws-samples/amazon-bedrock-agents-cancer-biomarker-discovery to your AWS account in workshop mode.

#### Setup Lambda function configuration
We'll reuse the existing Lambda functions for statistical operations

In [None]:
%%write_and_run -a statistician_agent.py

# Lambda function configurations (reusing existing infrastructure)
bar_chart_lambda_function_name = "MatPlotBarChartLambda"  # Change if different in your account
scientific_plot_lambda_function_name = "ScientificPlotLambda"  # Change if different in your account

bar_chart_lambda_function_arn = f"arn:aws:lambda:{region}:{account_id}:function:{bar_chart_lambda_function_name}"
scientific_plot_lambda_function_arn = f"arn:aws:lambda:{region}:{account_id}:function:{scientific_plot_lambda_function_name}"

# Initialize AWS clients
lambda_client = boto3.client('lambda', region_name=region)
bedrock_client = boto3.client('bedrock-runtime', region_name=region)

print(f"Bar Chart Lambda ARN: {bar_chart_lambda_function_arn}")
print(f"Scientific Plot Lambda ARN: {scientific_plot_lambda_function_arn}")
print(f"Region: {region}")
print(f"Account ID: {account_id}")

# Strands Agent Creation
In this section we create the statistician agent using the Strands framework

#### Define agent configuration and instructions

In [None]:
%%write_and_run -a statistician_agent.py

statistician_agent_name = 'Statistician-strands'
statistician_agent_description = "scientific analysis for survival analysis using Strands framework"
statistician_agent_instruction = """You are a medical research assistant AI specialized in survival analysis with biomarkers. 
Your primary job is to interpret user queries, run scientific analysis tasks, and provide relevant medical insights 
with available visualization tools. Use only the appropriate tools as required by the specific question. 
Follow these instructions carefully: 

1. If the user query requires a Kaplan-Meier chart: 
   a. Map survival status as 0 for Alive and 1 for Dead for the event parameter. 
   b. Use survival duration as the duration parameter. 
   c. Use the group_survival_data tool to create baseline and condition group based on expression value threshold provided by the user. 

2. If a survival regression analysis is needed: 
   a. You need access to all records with columns start with survival status as first column, then survival duration, and the required biomarkers. 
   b. Use the fit_survival_regression tool to identify the best-performing biomarker based on the p-value summary. 
   c. Ask for S3 data location if not provided, do not assume S3 bucket names or object names. 

3. When you need to create a bar chart or plot: 
   a. Always pass x_values and y_values in Array type to the function. 
   If the user says x values are apple,egg and y values are 3,4 or as [apple,egg] and [3,4] pass their value as 
   ['apple', 'egg'] and [3,4] 

4. When providing your response: 
   a. Start with a brief summary of your understanding of the user's query. 
   b. Explain the steps you're taking to address the query. 
   Ask for clarifications from the user if required. 
   c. If you generate any charts or perform statistical analyses, 
   explain their significance in the context of the user's query. 
   d. Conclude with a concise summary of the findings and their potential implications for medical research. 
   e. Make sure to explain any medical or statistical concepts in a clear, accessible manner.
""" 

#### Define tools for Strands agent
These tools will invoke different services to perform operations for the agent

In [None]:
%%write_and_run -a statistician_agent.py

def invoke_lambda_function(function_arn: str, operation: str, payload: Dict[str, Any] = None) -> Dict[str, Any]:
    """
    Helper function to invoke the existing Lambda functions with Bedrock Agent compatible event structure
    """
    if payload is None:
        payload = {}
    
    # Prepare the event payload to match what the Lambda function expects from Bedrock Agents
    if operation == 'bar_chart':
        event = {
            'agent': 'strands-agent',
            'actionGroup': 'matplotbarchart',
            'function': 'bar_chart',
            'messageVersion': '1.0',
            'parameters': [
                {
                    'name': 'title',
                    'type': 'string',
                    'value': payload.get('title', '')
                },
                {
                    'name': 'x_label',
                    'type': 'string',
                    'value': payload.get('x_label', '')
                },
                {
                    'name': 'x_values',
                    'type': 'array',
                    'value': json.dumps(payload.get('x_values', []))  # Convert to JSON string
                },
                {
                    'name': 'y_label',
                    'type': 'string',
                    'value': payload.get('y_label', '')
                },
                {
                    'name': 'y_values',
                    'type': 'array',
                    'value': str(payload.get('y_values', []))  # Convert to string representation
                }
            ],
            'sessionAttributes': {},
            'promptSessionAttributes': {}
        }
    elif operation == 'plot_kaplan_meier':
        event = {
            'agent': 'strands-agent',
            'actionGroup': 'scientificAnalysisActionGroup',
            'function': 'plot_kaplan_meier',
            'messageVersion': '1.0',
            'parameters': [
                {
                    'name': 'biomarker_name',
                    'type': 'string',
                    'value': payload.get('biomarker_name', '')
                },
                {
                    'name': 'duration_baseline',
                    'type': 'array',
                    'value': str(payload.get('duration_baseline', []))  # Convert to string
                },
                {
                    'name': 'duration_condition',
                    'type': 'array',
                    'value': str(payload.get('duration_condition', []))  # Convert to string
                },
                {
                    'name': 'event_baseline',
                    'type': 'array',
                    'value': str(payload.get('event_baseline', []))  # Convert to string
                },
                {
                    'name': 'event_condition',
                    'type': 'array',
                    'value': str(payload.get('event_condition', []))  # Convert to string
                }
            ],
            'sessionAttributes': {},
            'promptSessionAttributes': {}
        }
    elif operation == 'fit_survival_regression':
        event = {
            'agent': 'strands-agent',
            'actionGroup': 'scientificAnalysisActionGroup',
            'function': 'fit_survival_regression',
            'messageVersion': '1.0',
            'parameters': [
                {
                    'name': 'bucket',
                    'type': 'string',
                    'value': payload.get('bucket', '')
                },
                {
                    'name': 'key',
                    'type': 'string',
                    'value': payload.get('key', '')
                }
            ],
            'sessionAttributes': {},
            'promptSessionAttributes': {}
        }
    else:
        raise ValueError(f"Unknown operation: {operation}")
    
    try:
        response = lambda_client.invoke(
            FunctionName=function_arn,
            InvocationType='RequestResponse',
            Payload=json.dumps(event)
        )
        
        result = json.loads(response['Payload'].read())
        
        # Extract the actual result from the response
        if isinstance(result, dict) and 'response' in result:
            return result['response']
        
        return result
        
    except Exception as e:
        return {"error": str(e)}

# Define the tools using Strands @tool decorator
@tool
def create_bar_chart(title: str, x_label: str, x_values: List[str], y_label: str, y_values: List[float]) -> str:
    """
    Create a bar chart with the specified parameters.
    
    Args:
        title (str): Title of the bar chart
        x_label (str): Label for the x-axis
        x_values (List[str]): Values for the x-axis (categories)
        y_label (str): Label for the y-axis
        y_values (List[float]): Values for the y-axis (numerical data)
    
    Returns:
        str: Result of the bar chart creation
    """
    payload = {
        'title': title,
        'x_label': x_label,
        'x_values': x_values,
        'y_label': y_label,
        'y_values': y_values
    }
    
    print(f"\nBar Chart Input: {json.dumps(payload, indent=2)}\n")
    result = invoke_lambda_function(bar_chart_lambda_function_arn, 'bar_chart', payload)
    print(f"\nBar Chart Output: {json.dumps(result, indent=2)}\n")
    return json.dumps(result, indent=2)

@tool
def plot_kaplan_meier(biomarker_name: str, duration_baseline: List[float], duration_condition: List[float], 
                     event_baseline: List[int], event_condition: List[int]) -> str:
    """
    Create a Kaplan-Meier survival plot for comparing two groups.
    
    Args:
        biomarker_name (str): Name of the biomarker being analyzed
        duration_baseline (List[float]): Survival duration in days for baseline group
        duration_condition (List[float]): Survival duration in days for condition group
        event_baseline (List[int]): Survival events for baseline (0=alive, 1=dead)
        event_condition (List[int]): Survival events for condition (0=alive, 1=dead)
    
    Returns:
        str: Result of the Kaplan-Meier plot creation
    """
    payload = {
        'biomarker_name': biomarker_name,
        'duration_baseline': duration_baseline,
        'duration_condition': duration_condition,
        'event_baseline': event_baseline,
        'event_condition': event_condition
    }
    
    print(f"\nKaplan-Meier Input: {json.dumps(payload, indent=2)}\n")
    result = invoke_lambda_function(scientific_plot_lambda_function_arn, 'plot_kaplan_meier', payload)
    print(f"\nKaplan-Meier Output: {json.dumps(result, indent=2)}\n")
    return json.dumps(result, indent=2)

@tool
def fit_survival_regression(bucket: str, key: str) -> str:
    """
    Fit a survival regression model using data from an S3 object.
    
    Args:
        bucket (str): S3 bucket where the data is stored
        key (str): JSON file name in the S3 bucket containing the data for model fitting
    
    Returns:
        str: Results of the survival regression analysis
    """
    payload = {
        'bucket': bucket,
        'key': key
    }
    
    print(f"\nSurvival Regression Input: {json.dumps(payload, indent=2)}\n")
    result = invoke_lambda_function(scientific_plot_lambda_function_arn, 'fit_survival_regression', payload)
    print(f"\nSurvival Regression Output: {json.dumps(result, indent=2)}\n")
    return json.dumps(result, indent=2)

# Create list of tools
statistician_tools = [create_bar_chart, plot_kaplan_meier, fit_survival_regression]
print(f"Created {len(statistician_tools)} tools for the Strands agent")

#### Setup AWS Bedrock provider for Strands

In [None]:
%%write_and_run -a statistician_agent.py

# Create Bedrock model for Strands
model = BedrockModel(
    model_id="anthropic.claude-3-5-sonnet-20241022-v2:0",
    region_name=region,
    temperature=0.1,
    streaming=False
)

#### Create the Strands agent

In [None]:
# Create the Strands agent
try:
    statistician_agent = Agent(
        model=model,
        tools=statistician_tools,
        system_prompt=statistician_agent_instruction
    )
    
    print(f"Successfully created Strands agent: {statistician_agent_name}")
    print(f"Agent has {len(statistician_tools)} tools available")
    
except Exception as e:
    print(f"Error creating agent: {e}")
    raise

#### Test the Strands agent

In [None]:
# Test the agent with a bar chart creation
test_query = """Create me a bar chart for the top 5 gene biomarkers (e.g.,TP53, BRCA1, EGFR, KRAS, MYC)
with respect to their prognostic significance in chemotherapy-treated patients.
The Y-axis should represent –log10(p-value) from a Cox proportional hazards model assessing association with overall survival. 
Y-axis values are: 8.3, 6.7, 5.9, 4.2, 3.8
"""

print(f"Testing agent with query: {test_query}")
print("=" * 50)

try:
    # Run the agent
    statistician_agent(test_query)
except Exception as e:
    print(f"Error during agent execution: {e}")
    import traceback
    traceback.print_exc()

#### Accessing the generated charts

1. Go to S3 bucket which includes the string 'env-build'
2. Navigate to **/Graphs** folder
3. Select the recent chart files and click **Download**

#### Advanced usage examples

In [None]:
# Example of more complex statistical analyses
advanced_queries = [
    """Create a Kaplan-Meier survival plot comparing high vs low expression of EGFR biomarker.
    Use the following sample data:
    - Biomarker name: EGFR
    - High expression group (baseline): survival durations in days [365, 420, 500, 180, 600, 720, 300, 450, 650, 380] 
      and survival events [0, 1, 0, 1, 0, 0, 1, 0, 0, 1] where 0=alive, 1=dead
    - Low expression group (condition): survival durations in days [200, 150, 320, 280, 400, 180, 250, 350, 300, 220]
      and survival events [1, 1, 0, 1, 0, 1, 1, 0, 1, 1] where 0=alive, 1=dead""",
    
    """Generate a bar chart showing average survival duration across different cancer stages.
    Create a bar chart with:
    - Title: "Average Survival Duration by Cancer Stage"
    - X-axis label: "Cancer Stage" 
    - X-axis values: ["Stage I", "Stage II", "Stage III", "Stage IV"]
    - Y-axis label: "Average Survival (months)"
    - Y-axis values: [36.5, 28.2, 18.7, 8.4]""",
    
    """Create a Kaplan-Meier plot for TP53 mutation status comparing mutant vs wild-type patients.
    Use this sample data:
    - Biomarker name: TP53_mutation
    - Wild-type group (baseline): survival durations [450, 380, 520, 290, 600, 340, 480, 390, 550, 420]
      and events [0, 1, 0, 1, 0, 1, 0, 0, 0, 1]
    - Mutant group (condition): survival durations [180, 220, 160, 240, 200, 190, 210, 170, 230, 195]
      and events [1, 1, 1, 0, 1, 1, 0, 1, 1, 1]""",
    
    """Generate a bar chart showing biomarker expression levels in smokers vs non-smokers.
    Create a chart with:
    - Title: "Average Biomarker Expression: Smokers vs Non-smokers"
    - X-axis label: "Patient Group"
    - X-axis values: ["Current Smokers", "Former Smokers", "Never Smokers"]
    - Y-axis label: "Average LRIG1 Expression Level"
    - Y-axis values: [15.2, 22.8, 28.4]"""
]

def test_advanced_query(query: str):
    """
    Test an advanced query with the agent
    """
    print(f"\nTesting query: {query}")
    print("-" * 40)
    
    try:
        statistician_agent(query)
    except Exception as e:
        print(f"Error: {e}")

for query in advanced_queries[:1]:  # Test first query only
    test_advanced_query(query)

#### Session management and conversation continuity

In [None]:
# Demonstrate conversation continuity
def interactive_session():
    """
    Simple interactive session with the agent
    """
    print("Interactive Statistical Analysis Session")
    print("Type 'quit' to exit")
    print("=" * 50)
    
    while True:
        user_input = input("\nYour question: ")
        
        if user_input.lower() in ['quit', 'exit', 'q']:
            print("Session ended.")
            break
            
        try:
            statistician_agent(user_input)
        except Exception as e:
            print(f"Error: {e}")

interactive_session()

## Summary

This notebook demonstrated how integrate Strands with a Lambda function.

### Benefits of Strands Approach:
- **Open source** - No vendor lock-in
- **Cost effective** - Only pay for the underlying LLM calls
- **Flexible** - Easy to customize and extend

### Tools Available:
- `create_bar_chart`: Creates bar charts for statistical visualizations
- `plot_kaplan_meier`: Generates Kaplan-Meier survival plots
- `fit_survival_regression`: Performs survival regression analysis on S3 data

### Chart Access:
Generated charts are automatically saved to your S3 bucket in the `/Graphs` folder and can be downloaded from the AWS Console.

The agent maintains the same medical research statistical analysis capabilities while providing a more flexible and cost-effective solution for survival analysis and biomarker visualization.

### Usage Notes:
- Ensure your Lambda functions (MatPlotBarChartLambda and ScientificPlotLambda) are deployed and accessible
- The agent expects specific data formats for survival analysis (0/1 for events, duration in days)
- S3 bucket permissions should allow the Lambda functions to write chart files
- For survival regression analysis, provide the correct S3 bucket and key for your data files