# Question 4: GenAI Clinical Data Assistant

## Objective
Develop a Generative AI Assistant that translates natural language questions into structured Pandas queries for clinical trial adverse events data.

## Approach
1. **Schema Definition**: Define the dataset structure for the LLM to understand
2. **ClinicalTrialDataAgent Class**: Use OpenAI to parse questions into structured JSON (`target_column`, `filter_value`)
3. **Execution**: Apply Pandas filters and return count of unique subjects (USUBJID) and list of matching IDs
4. **Bonus**: Interactive Gradio web interface with visualization

## Setup and Dependencies

In [6]:
# Install required packages
!pip install openai pandas gradio matplotlib -q

# Import libraries
import pandas as pd
import json
from openai import OpenAI
import matplotlib.pyplot as plt

## Load Data

Load the ADAE (Adverse Events) dataset exported from `pharmaverseadam::adae` in R.

In [10]:
# Load adverse events data
ae = pd.read_csv("adae.csv")

# Display basic info
print(f"Dataset shape: {ae.shape[0]} rows x {ae.shape[1]} columns")
print(f"\nKey columns for this exercise:")
print(f"- USUBJID: Unique subject identifier")
print(f"- AETERM: Adverse event term (e.g., HEADACHE)")
print(f"- AESOC: System Organ Class (e.g., CARDIAC DISORDERS)")
print(f"- AESEV: Severity (MILD, MODERATE, SEVERE)")

ae.head()

Dataset shape: 1191 rows x 107 columns

Key columns for this exercise:
- USUBJID: Unique subject identifier
- AETERM: Adverse event term (e.g., HEADACHE)
- AESOC: System Organ Class (e.g., CARDIAC DISORDERS)
- AESEV: Severity (MILD, MODERATE, SEVERE)


Unnamed: 0,STUDYID,DOMAIN,USUBJID,AESEQ,AESPID,AETERM,AELLT,AELLTCD,AEDECOD,AEPTCD,...,DTHCGR1,LSTALVDT,SAFFL,RACEGR1,AGEGR1,REGION1,LDDTHGR1,DTH30FL,DTHA30FL,DTHB30FL
0,CDISCPILOT01,AE,01-701-1015,1,E07,APPLICATION SITE ERYTHEMA,APPLICATION SITE REDNESS,,APPLICATION SITE ERYTHEMA,,...,,2014-07-02,Y,White,18-64,,,,,
1,CDISCPILOT01,AE,01-701-1015,2,E08,APPLICATION SITE PRURITUS,APPLICATION SITE ITCHING,,APPLICATION SITE PRURITUS,,...,,2014-07-02,Y,White,18-64,,,,,
2,CDISCPILOT01,AE,01-701-1015,3,E06,DIARRHOEA,DIARRHEA,,DIARRHOEA,,...,,2014-07-02,Y,White,18-64,,,,,
3,CDISCPILOT01,AE,01-701-1023,2,E09,ERYTHEMA,LOCALIZED ERYTHEMA,,ERYTHEMA,,...,,2012-09-02,Y,White,18-64,,,,,
4,CDISCPILOT01,AE,01-701-1023,1,E08,ERYTHEMA,ERYTHEMA,,ERYTHEMA,,...,,2012-09-02,Y,White,18-64,,,,,


## Step 1: Schema Definition

Define a schema string that describes the relevant columns to the LLM. This helps the model understand the dataset structure and map user questions to the correct columns.

In [11]:
SCHEMA = """
This is a clinical trial Adverse Events (AE) dataset. Key columns:
- USUBJID: Unique subject identifier
- AETERM: Reported term for the adverse event (e.g., "HEADACHE", "NAUSEA")
- AESOC: System Organ Class (e.g., "CARDIAC DISORDERS", "SKIN AND SUBCUTANEOUS TISSUE DISORDERS")
- AESEV: Severity/Intensity of the AE ("MILD", "MODERATE", "SEVERE")
- AESER: Serious adverse event flag ("Y" or "N")
- AEREL: Causality/Relationship to treatment
- TRTEMFL: Treatment-emergent flag ("Y" or "N")
- ACTARM: Actual treatment arm ("Placebo", "Xanomeline High Dose", "Xanomeline Low Dose")
"""

## Step 2: Initialize OpenAI Client

Set up the OpenAI API client. Replace the API key with your own.

In [12]:
# Initialize OpenAI client
# Replace with your own API key
client = OpenAI(api_key="your-api-key")

## Step 3: ClinicalTrialDataAgent Class

This class implements the core functionality:
1. `parse_question()`: Uses LLM to convert natural language to structured JSON
2. `execute_query()`: Applies Pandas filter based on LLM output
3. `ask()`: Main method that combines parsing and execution

In [26]:
class ClinicalTrialDataAgent:
    def __init__(self, client, schema):
        self.client = client
        self.schema = schema

    def parse_question(self, question):
        """
        Use LLM to parse user question into Structured JSON Output.

        Returns:
            dict: {"target_column": "COLUMN_NAME", "filter_value": "VALUE"}
        """
        prompt = f"""
You are a clinical data assistant. Given a user's question about adverse events data,
extract the target column and filter value.

{self.schema}

User question: {question}

Respond ONLY with a valid JSON object in this exact format:
{{"target_column": "COLUMN_NAME", "filter_value": "VALUE"}}

Examples:
- "Show me severe adverse events" -> {{"target_column": "AESEV", "filter_value": "SEVERE"}}
- "Find patients with headache" -> {{"target_column": "AETERM", "filter_value": "HEADACHE"}}
- "Get cardiac disorders" -> {{"target_column": "AESOC", "filter_value": "CARDIAC DISORDERS"}}
"""

        response = self.client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )

        result = response.choices[0].message.content.strip()
        return json.loads(result)

    def execute_query(self, parsed_output, df):
        """
        Apply Pandas filter based on LLM output.

        Returns:
            dict: count of unique subjects (USUBJID) and list of matching IDs
        """
        target_column = parsed_output["target_column"]
        filter_value = parsed_output["filter_value"]

        # Filter dataframe (case-insensitive partial match)
        filtered_df = df[df[target_column].str.upper().str.contains(filter_value.upper(), na=False)]

        # Get unique subjects
        unique_subjects = filtered_df["USUBJID"].unique().tolist()
        count = len(unique_subjects)

        return {
            "count": count,
            "subjects": unique_subjects,
            "filter_applied": f"{target_column} contains '{filter_value}'",
            "filtered_df": filtered_df
        }

    def ask(self, question, df):
        """
        Main method: parse question and execute query.

        Args:
            question: Natural language question from user
            df: Pandas DataFrame to query

        Returns:
            dict: Query results with count and subject IDs
        """
        print(f"\n{'='*70}")
        print(f"Question: {question}")
        print('='*70)

        # Step 1: Parse question with LLM
        parsed = self.parse_question(question)
        print(f"\nParsed JSON Output:")
        print(f'  {{"target_column": "{parsed["target_column"]}", "filter_value": "{parsed["filter_value"]}"}}')

        # Step 2: Execute query
        result = self.execute_query(parsed, df)
        print(f"\nFilter Applied: {result['filter_applied']}")
        print(f"\nCount of Unique Subjects: {result['count']}")
        print(f"\nList of Matching IDs:\n  {', '.join(result['subjects'])}")

        return result

## Step 4: Test Script

Run 3 example queries to demonstrate the agent's capabilities, as required by the assessment.

In [27]:
# Initialize agent
agent = ClinicalTrialDataAgent(client, SCHEMA)

# Test Query 1: Query by Severity
result1 = agent.ask("Give me the subjects who had Adverse events of Moderate severity", ae)


Question: Give me the subjects who had Adverse events of Moderate severity

Parsed JSON Output:
  {"target_column": "AESEV", "filter_value": "MODERATE"}

Filter Applied: AESEV contains 'MODERATE'

Count of Unique Subjects: 136

List of Matching IDs:
  01-701-1023, 01-701-1047, 01-701-1097, 01-701-1111, 01-701-1115, 01-701-1133, 01-701-1146, 01-701-1148, 01-701-1180, 01-701-1181, 01-701-1188, 01-701-1192, 01-701-1211, 01-701-1239, 01-701-1294, 01-701-1302, 01-701-1317, 01-701-1341, 01-701-1360, 01-701-1383, 01-701-1444, 01-702-1082, 01-703-1076, 01-703-1086, 01-703-1100, 01-703-1119, 01-703-1182, 01-703-1210, 01-703-1258, 01-703-1295, 01-703-1299, 01-703-1403, 01-703-1439, 01-704-1008, 01-704-1025, 01-704-1065, 01-704-1120, 01-704-1266, 01-704-1332, 01-705-1059, 01-705-1186, 01-705-1199, 01-705-1281, 01-705-1292, 01-705-1310, 01-706-1041, 01-706-1049, 01-706-1384, 01-707-1206, 01-708-1032, 01-708-1347, 01-708-1428, 01-709-1081, 01-709-1099, 01-709-1102, 01-709-1168, 01-709-1217, 01-709

## Step 5: Test Queries

Run 3 example queries to demonstrate the agent's capabilities:
1. Query by severity
2. Query by system organ class
3. Query by specific adverse event term

In [15]:
# Test Query 1: Query by Severity
result1 = agent.ask("Give me the subjects who had Adverse events of Moderate severity", ae)


Question: Give me the subjects who had Adverse events of Moderate severity
Parsed: {'target_column': 'AESEV', 'filter_value': 'MODERATE'}
Filter: AESEV contains 'MODERATE'
Count: 136 unique subjects
Subject IDs: ['01-701-1023', '01-701-1047', '01-701-1097', '01-701-1111', '01-701-1115', '01-701-1133', '01-701-1146', '01-701-1148', '01-701-1180', '01-701-1181']...


In [29]:
# Test Query 2: Query by System Organ Class
result2 = agent.ask("Show me all skin and subcutaneous tissue disorders", ae)


Question: Show me all skin and subcutaneous tissue disorders

Parsed JSON Output:
  {"target_column": "AESOC", "filter_value": "SKIN AND SUBCUTANEOUS TISSUE DISORDERS"}

Filter Applied: AESOC contains 'SKIN AND SUBCUTANEOUS TISSUE DISORDERS'

Count of Unique Subjects: 105

List of Matching IDs:
  01-701-1023, 01-701-1097, 01-701-1130, 01-701-1148, 01-701-1188, 01-701-1192, 01-701-1239, 01-701-1275, 01-701-1294, 01-701-1302, 01-701-1387, 01-702-1082, 01-703-1076, 01-703-1119, 01-704-1008, 01-704-1009, 01-704-1017, 01-704-1065, 01-704-1074, 01-704-1093, 01-704-1114, 01-704-1120, 01-704-1135, 01-704-1266, 01-704-1323, 01-704-1325, 01-704-1332, 01-705-1031, 01-705-1059, 01-705-1280, 01-705-1281, 01-705-1292, 01-705-1310, 01-706-1041, 01-706-1384, 01-708-1158, 01-708-1347, 01-709-1029, 01-709-1081, 01-709-1088, 01-709-1168, 01-709-1217, 01-709-1306, 01-709-1309, 01-710-1006, 01-710-1021, 01-710-1045, 01-710-1053, 01-710-1070, 01-710-1077, 01-710-1137, 01-710-1154, 01-710-1249, 01-710-1264,

In [30]:
# Test Query 3: Query by Treatment Arm
result3 = agent.ask("Find subjects in the placebo group", ae)


Question: Find subjects in the placebo group

Parsed JSON Output:
  {"target_column": "ACTARM", "filter_value": "Placebo"}

Filter Applied: ACTARM contains 'Placebo'

Count of Unique Subjects: 69

List of Matching IDs:
  01-701-1015, 01-701-1023, 01-701-1047, 01-701-1130, 01-701-1153, 01-701-1203, 01-701-1363, 01-701-1387, 01-701-1392, 01-701-1415, 01-703-1042, 01-703-1100, 01-703-1210, 01-703-1299, 01-704-1010, 01-704-1164, 01-704-1351, 01-704-1435, 01-704-1445, 01-705-1059, 01-705-1186, 01-705-1349, 01-706-1041, 01-707-1206, 01-708-1087, 01-708-1158, 01-708-1253, 01-708-1286, 01-708-1296, 01-708-1316, 01-709-1088, 01-709-1259, 01-709-1301, 01-709-1306, 01-709-1312, 01-709-1339, 01-710-1027, 01-710-1060, 01-710-1077, 01-710-1083, 01-710-1183, 01-710-1264, 01-710-1271, 01-710-1314, 01-710-1315, 01-710-1368, 01-711-1036, 01-713-1179, 01-713-1256, 01-713-1269, 01-714-1035, 01-714-1375, 01-715-1207, 01-716-1024, 01-716-1026, 01-716-1044, 01-716-1108, 01-716-1160, 01-716-1308, 01-716-1441

## Bonus: Interactive Gradio Web Interface

An interactive web interface that allows users to ask natural language questions and view results with visualizations.

In [32]:
import gradio as gr

def create_plot(filtered_df):
    """Generate visualization based on query results"""
    if len(filtered_df) == 0:
        fig, ax = plt.subplots(figsize=(8, 5))
        ax.text(0.5, 0.5, "No data found", ha='center', va='center', fontsize=16)
        ax.axis('off')
        return fig

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Plot 1: Severity Distribution
    severity_counts = filtered_df["AESEV"].value_counts()
    colors = {"MILD": "#F8766D", "MODERATE": "#00BA38", "SEVERE": "#619CFF"}
    severity_counts.plot(
        kind="bar",
        ax=axes[0],
        color=[colors.get(x, "gray") for x in severity_counts.index]
    )
    axes[0].set_title("AE Severity Distribution")
    axes[0].set_xlabel("Severity")
    axes[0].set_ylabel("Count")
    axes[0].tick_params(axis='x', rotation=0)

    # Plot 2: Treatment Arm Distribution
    arm_counts = filtered_df["ACTARM"].value_counts()
    arm_counts.plot(kind="bar", ax=axes[1], color="steelblue")
    axes[1].set_title("Distribution by Treatment Arm")
    axes[1].set_xlabel("Treatment Arm")
    axes[1].set_ylabel("Count")
    axes[1].tick_params(axis='x', rotation=15)

    plt.tight_layout()
    return fig

def process_query(question):
    """Process user query and return results with visualization"""
    if not question.strip():
        return "Please enter a question.", None

    try:
        # Step 1: Parse question with LLM
        parsed = agent.parse_question(question)

        # Step 2: Execute query
        result = agent.execute_query(parsed, ae)

        # Format text output
        text_output = f"""
**Parsed JSON Output:**
```json
{{"target_column": "{parsed['target_column']}", "filter_value": "{parsed['filter_value']}"}}
```

**Filter Applied:** {result['filter_applied']}

**Count of Unique Subjects:** {result['count']}

**List of Matching IDs:**
{', '.join(result['subjects'])}
"""

        # Step 3: Create plot
        fig = create_plot(result['filtered_df'])

        return text_output, fig

    except Exception as e:
        return f"Error: {str(e)}", None

# Create Gradio interface
demo = gr.Interface(
    fn=process_query,
    inputs=gr.Textbox(
        label="Ask a question about adverse events",
        placeholder="e.g., Give me subjects who had adverse events of moderate severity",
        lines=2
    ),
    outputs=[
        gr.Markdown(label="Results"),
        gr.Plot(label="Visualization")
    ],
    title="üè• Clinical Trial Data Assistant",
    description="Ask natural language questions about adverse events data. The AI will parse your question into a structured JSON query (target_column + filter_value) and return the count of unique subjects and their IDs.",
    examples=[
        ["Give me the subjects who had Adverse events of Moderate severity"],
        ["Find patients with cardiac disorders"],
        ["Show me patients who experienced headache"]
    ],
    allow_flagging="never"
)

print("Launching Gradio interface...")
demo.launch()



Launching Gradio interface...
It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://febb7359cea9348cdc.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


