# Assignment 5, Question 4: Data Exploration

**Points: 15**

In this notebook, you'll explore the clinical trial dataset using pandas selection and filtering techniques.

You'll use utility functions from `q3_data_utils` where helpful, but also demonstrate direct pandas operations.

## Setup

In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Import utilities from Q3
from q3_data_utils import load_data, detect_missing, filter_data   

# Load the data
df = load_data('data/clinical_trial_raw.csv')
print(f"Loaded {len(df)} patients with {len(df.columns)} variables")

# Prewritten visualization functions for exploration
def plot_value_counts(series, title, figsize=(10, 6)):
    """
    Create a bar chart of value counts.
    
    Args:
        series: pandas Series with value counts
        title: Chart title
        figsize: Figure size tuple
    """
    plt.figure(figsize=figsize)
    series.plot(kind="bar")
    plt.title(title)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

def plot_crosstab(crosstab_data, title, figsize=(10, 6)):
    """
    Create a heatmap of crosstab data.
    
    Args:
        crosstab_data: pandas DataFrame from pd.crosstab()
        title: Chart title
        figsize: Figure size tuple
    """
    plt.figure(figsize=figsize)
    plt.imshow(crosstab_data.values, cmap="Blues", aspect="auto")
    plt.colorbar()
    plt.title(title)
    plt.xticks(range(len(crosstab_data.columns)), crosstab_data.columns, rotation=45)
    plt.yticks(range(len(crosstab_data.index)), crosstab_data.index)
    plt.tight_layout()
    plt.show()

Loaded 10000 patients with 18 variables


## Part 1: Basic Exploration (3 points)

Display:
1. Dataset shape
2. Column names and types
3. First 10 rows
4. Summary statistics (.describe())

In [9]:
# TODO: Display dataset info
print("1. Dataset shape:")
print(f" Rows: {df.shape[0]}, Columns: {df.shape[1]}")
print()

print("2. Column names and types:")
print(df.dtypes)
print()

print("3. First 10 rows:")
print(df.head(10))
print()

print ("4. Summary statistics:")
print(df.describe())
print()

1. Dataset shape:
 Rows: 10000, Columns: 18

2. Column names and types:
patient_id             object
age                     int64
sex                    object
bmi                   float64
enrollment_date        object
systolic_bp           float64
diastolic_bp          float64
cholesterol_total     float64
cholesterol_hdl       float64
cholesterol_ldl       float64
glucose_fasting       float64
site                   object
intervention_group     object
follow_up_months        int64
adverse_events          int64
outcome_cvd            object
adherence_pct         float64
dropout                object
dtype: object

3. First 10 rows:
  patient_id  age         sex   bmi enrollment_date  systolic_bp  \
0     P00001   80           F  29.3      2022-05-01        123.0   
1     P00002   80    Female     NaN      2022-01-06        139.0   
2     P00003   82      Female  -1.0      2023-11-04        123.0   
3     P00004   95      Female  25.4      2022-08-15        116.0   
4     P00005   

## Part 2: Column Selection (3 points)

Demonstrate different selection methods:

1. Select only numeric columns using `.select_dtypes()`
2. Select specific columns by name
3. Select a subset of rows and columns using `.loc[]`

In [13]:
# TODO: Select numeric columns
print("1. Numeric columns only:")
numeric = df.select_dtypes(include=['number'])
print(numeric.head())
print(f"Shape: {numeric.shape}")
print()


1. Numeric columns only:
   age   bmi  systolic_bp  diastolic_bp  cholesterol_total  cholesterol_hdl  \
0   80  29.3        123.0          80.0              120.0             55.0   
1   80   NaN        139.0          81.0              206.0             58.0   
2   82  -1.0        123.0          86.0              172.0             56.0   
3   95  25.4        116.0          77.0              200.0             56.0   
4   95   NaN         97.0          71.0              185.0             78.0   

   cholesterol_ldl  glucose_fasting  follow_up_months  adverse_events  \
0             41.0            118.0                20               0   
1            107.0             79.0                24               0   
2             82.0             77.0                 2               0   
3            104.0            115.0                17               0   
4             75.0            113.0                 9               0   

   adherence_pct  
0           24.0  
1           77.0  
2   

In [None]:
# TODO: Select specific columns
# example: select patient_id, age, glucose_fasting, cholesterol_total
print("2. Specific columns by name:")
selected_columns = df[['patient_id', 'age', 'glucose_fasting', 'cholesterol_total']]
print(selected_columns.head())
print(f"Shape: {selected_columns.shape}")
print()

2. Specific columns by name:
  patient_id  age  glucose_fasting  cholesterol_total
0     P00001   80            118.0              120.0
1     P00002   80             79.0              206.0
2     P00003   82             77.0              172.0
3     P00004   95            115.0              200.0
4     P00005   95            113.0              185.0
Shape: (10000, 4)



In [None]:
# TODO: Use .loc[] to select subset
# example: select for age >30, patient_id, age, cholesterol_total columns
print("3. Subset of rows and columns:")
subset = df.loc[df['age']>30, ['patient_id', 'age', 'cholesterol_total']]
print(subset.head())
print()

3. Subset of rows and columns:
  patient_id  age  cholesterol_total
0     P00001   80              120.0
1     P00002   80              206.0
2     P00003   82              172.0
3     P00004   95              200.0
4     P00005   95              185.0



## Part 3: Filtering (4 points)

Filter the data to answer these questions:

1. How many patients are over 65 years old?
2. How many patients have systolic BP > 140?
3. Find patients who are both over 65 AND have systolic BP > 140
4. Find patients from Site A or Site B using `.isin()`

In [None]:
# TODO: Filter and count patients over 65
# 1. Use the filter_data utility from Q3
# 2. Create a filter for age > 65
# 3. Apply the filter and count the results
# print(f"Patients over 65: {len(patients_over_65)}")


In [None]:
# TODO: Filter for high BP
# 1. Use the filter_data utility from Q3
# 2. Create a filter for systolic_bp > 140
# 3. Apply the filter and count the results
# print(f"Patients with high BP: {len(high_bp)}")


In [None]:
# TODO: Multiple conditions with &
# 1. Use filter_data for multiple conditions:
# 2. Create filters for both conditions:
#     {'column': 'age', 'condition': 'greater_than', 'value': 65},
#     {'column': 'systolic_bp', 'condition': 'greater_than', 'value': 140}
# ]
# 3. Apply the filter and count the results
# print(f"Patients over 65 AND high BP: {len(both_conditions)}")
#
# 5. Alternative: Use in_range for age range:
# 5. Create filter for age range 65-100
# 6. Apply the filter and count the results


In [None]:
# TODO: Filter by site using .isin()
# 1. Use the filter_data utility from Q3
# 2. Create a filter for Site A or Site B
# 3. Apply the filter and count the results
# print(f"Patients from Site A or B: {len(site_ab)}")


## Part 4: Value Counts and Grouping (5 points)

1. Get value counts for the 'site' column
2. Get value counts for the 'intervention_group' column  
3. Create a crosstab of site vs intervention_group
4. Calculate mean age by site
5. Save the site value counts to `output/q4_site_counts.csv`

In [None]:
# TODO: Value counts and analysis


In [None]:
# TODO: Save output
# site_counts.to_csv('output/q4_site_counts.csv')


## Summary

Write 2-3 sentences about what you learned from exploring this dataset.

**Your summary here:**

TODO: Write your observations
