# Assignment 5, Question 4: Data Exploration

**Points: 15**

In this notebook, you'll explore the clinical trial dataset using pandas selection and filtering techniques.

You'll use utility functions from `q3_data_utils` where helpful, but also demonstrate direct pandas operations.

## Setup

In [None]:
# Rewritten Demo: quick exploration using q3_data_utils
from q3_data_utils import load_data, clean_data, detect_missing, fill_missing, transform_types, create_bins, summarize_by_group, filter_data 
import pandas as pd
import os

# Create output directory
os.makedirs('output', exist_ok=True)
DATA_FILE = 'data/clinical_trial_raw.csv'

# Load data
df = load_data(DATA_FILE)
print(f'Loaded {len(df)} rows, {len(df.columns)} columns')



# 1. Clean data (This single call now performs ALL necessary consolidation and cleaning)
df_clean = clean_data(df)
missing = detect_missing(df_clean)
print('Missing values per column:\n', missing.head(10))

# 2. Fill BMI with median and transform types
df_filled = fill_missing(df_clean, 'bmi', strategy='median')
df_typed = transform_types(df_filled, {'enrollment_date': 'datetime', 'age': 'numeric'})

# --- Removed redundant manual normalization for 'site_clean' and 'intervention_clean' ---
# The original columns 'site' and 'intervention_group' are now clean.

# 3. Create age bins and summarize by site (using the cleaned 'site' column)
df_binned = create_bins(df_typed, 'age', bins=[0,18,35,50,65,100], labels=['<18','18-34','35-49','50-64','65+'])
# Use 'site' instead of 'site_clean'
summary = summarize_by_group(df_binned, 'site', agg_dict={'age':'mean','bmi':'mean'})
print(summary.head())

# 4. Save outputs using the clean columns ('site' and 'intervention_group')
summary.to_csv('output/q4_site_summary.csv', index=False)
# Use 'site' instead of 'site_clean'
df_typed['site'].value_counts().to_csv('output/q4_site_counts.csv', header=['patient_count'])
print('Wrote output/q4_site_summary.csv and output/q4_site_counts.csv')

In [None]:
# Cell 2: Final Cleaning and Site Distribution
import os
import pandas as pd
import numpy as np


OUTPUT_DIR = 'output'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("2. Finalizing Site Cleaning, Generating Distribution, & Saving CSV")

# --- Robust Site Standardization ---
# 1. Standardize all text: Remove underscores, general non-alphanumeric characters (except spaces),
#    remove numbers, and convert to Title Case (handles Site_D, Site B94, etc.)
df['site'] = df['site'].astype(str).str.replace('_', ' ', regex=False).str.strip()
df['site'] = df['site'].str.replace(r'[^A-Za-z\s]', '', regex=True).str.strip() 
df['site'] = df['site'].str.title()

# 2. Aggressive Whitespace Normalization: Replace all sequences of whitespace (including hidden ones) 
#    with a single space, then strip again. This handles tabs, newlines, and non-breaking spaces.
df['site'] = df['site'].str.replace(r'\s+', ' ', regex=True).str.strip()

# 3. Final forced cleanup (in case 'Site A' and 'Site A ' were the culprits)
df['site'] = df['site'].str.replace('Site A ', 'Site A', regex=False)

# Value counts calculation (using the now fully standardized data)
site_counts_series = df['site'].value_counts()

# Convert the Series to a DataFrame for CSV saving
site_counts_df = site_counts_series.reset_index()
site_counts_df.columns = ['site', 'patient_count']

# Save to required output file
output_path = os.path.join(OUTPUT_DIR, 'q4_site_counts.csv')
site_counts_df.to_csv(output_path, index=False)

print(f"Site counts saved to {output_path}. Should show exactly 5 sites.")
display(site_counts_df)

## Part 1: Basic Exploration (3 points)

Display:
1. Dataset shape
2. Column names and types
3. First 10 rows
4. Summary statistics (.describe())

In [None]:
# Part 1:
print('Dataset shape:', df.shape)
print('\nColumn names and dtypes:')
print(df.dtypes)

print('\nFirst 10 rows:')
display(df.head(10))

print('\nSummary statistics (numeric columns):')
display(df.describe(include=[np.number]).T)

print('\nSummary statistics (all columns):')
display(df.describe(include="all").T)


## Part 2: Column Selection (3 points)

Demonstrate different selection methods:

1. Select only numeric columns using `.select_dtypes()`
2. Select specific columns by name
3. Select a subset of rows and columns using `.loc[]`

In [None]:
# Part 2:
numeric_cols = df.select_dtypes(include=[np.number])
print('Numeric columns (count):', len(numeric_cols.columns))
print(list(numeric_cols.columns))
display(numeric_cols.head())
print('Numeric-only dataframe shape:', numeric_cols.shape)

In [None]:
cols = ['patient_id', 'age', 'bmi', 'site']
cols_found = [c for c in cols if c in df.columns]
print('Requested columns found:', cols_found)
display(df[cols_found].head())

In [None]:
cols = ['patient_id', 'age', 'site']
cols_available = [c for c in cols if c in df.columns]
print('Using columns for .loc():', cols_available)
subset = df.loc[0:9, cols_available]
display(subset)

## Part 3: Filtering (4 points)

Filter the data to answer these questions:

1. How many patients are over 65 years old?
2. How many patients have systolic BP > 140?
3. Find patients who are both over 65 AND have systolic BP > 140
4. Find patients from Site A or Site B using `.isin()`

In [None]:
# Part 3.1: Filter and count patients over 65
filters = [{'column': 'age', 'condition': 'greater_than', 'value': 65}]
patients_over_65 = filter_data(df, filters) # <-- The 'du.' prefix is removed here
print(f"Patients over 65: {len(patients_over_65)}")
display(patients_over_65.head(20))

In [None]:
# Part 3.2: Filter for high systolic BP (>140)
filters = [{'column': 'systolic_bp', 'condition': 'greater_than', 'value': 140}]
high_bp = filter_data(df, filters)
print(f"Patients with systolic BP > 140: {len(high_bp)}")
display(high_bp.head())


In [None]:
# Part 3.3: Multiple conditions (age > 65 AND systolic_bp > 140)
filters = [
    {'column': 'age', 'condition': 'greater_than', 'value': 65},
    {'column': 'systolic_bp', 'condition': 'greater_than', 'value': 140}
]
both_conditions = filter_data(df, filters)
print(f"Patients over 65 AND systolic BP > 140: {len(both_conditions)}")
display(both_conditions.head())

# Alternative: use in_range for age 65-100
age_filters = [{'column': 'age', 'condition': 'in_range', 'value': [65, 100]}]
age_range = filter_data(df, age_filters)
print(f"Patients aged 65-100: {len(age_range)}")


In [None]:
# Part 3.4: Filter by site using .isin() (recommended)
# Use cleaned site column if available
site_col = 'site_clean' if 'site_clean' in df.columns else 'site'
site_values = ['Site A', 'Site B']
site_ab_isin = df[df[site_col].isin(site_values)]
print(f"Patients from Site A or Site B (using .isin on cleaned values): {len(site_ab_isin)}")
display(site_ab_isin.head(20))

## Part 4: Value Counts and Grouping (5 points)

1. Get value counts for the 'site' column
2. Get value counts for the 'intervention_group' column  
3. Create a crosstab of site vs intervention_group
4. Calculate mean age by site
5. Save the site value counts to `output/q4_site_counts.csv`

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# --- Setup: Assuming df is your fully cleaned DataFrame from du.clean_data() ---
# Note: The cleaning functions in q3_data_utils overwrite the 'site' and
# 'intervention_group' columns, so we use them directly.
OUTPUT_DIR = 'output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 1. Site value counts and plot
site_counts = df['site'].value_counts().reset_index()
site_counts.columns = ['site', 'patient_count'] 

print("\n1. Site value counts:")
display(site_counts)

# Plot: Site Distribution Bar Chart
plt.figure(figsize=(9, 6))
sns.barplot(
    x='site', 
    y='patient_count', 
    data=site_counts, 
    palette='viridis'
)
plt.title('Distribution of Patients by Site')
plt.xlabel('Clinical Site')
plt.ylabel('Patient Count')
plt.xticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'q4_site_counts_bar.png'))
plt.close() # Close plot


# 2. Intervention group value counts and plot
interv_counts = df['intervention_group'].value_counts().reset_index()
interv_counts.columns = ['intervention_group', 'patient_count']

print("\n2. Intervention group value counts:")
display(interv_counts)

# Plot: Intervention Group Distribution Bar Chart
plt.figure(figsize=(9, 6))
sns.barplot(
    x='intervention_group', 
    y='patient_count', 
    data=interv_counts, 
    palette='plasma'
)
plt.title('Distribution of Patients by Intervention Group')
plt.xlabel('Intervention Group')
plt.ylabel('Patient Count')
plt.xticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'q4_intervention_counts_bar.png'))
plt.close()


# 3. Crosstab of site vs intervention_group and heatmap
site_intervention_crosstab = pd.crosstab(df['site'], df['intervention_group'])
print("\n3. Crosstab of site vs intervention group:")
display(site_intervention_crosstab)

# Plot: Site vs Intervention Group Heatmap
plt.figure(figsize=(10, 7))
sns.heatmap(
    site_intervention_crosstab, 
    annot=True, 
    fmt='d', 
    cmap='Blues', 
    cbar_kws={'label': 'Patient Count'}
)
plt.title('Site vs Intervention Group Distribution')
plt.xlabel('Intervention Group')
plt.ylabel('Clinical Site')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'q4_site_intervention_heatmap.png'))
plt.close()


# 4. Mean age by site
mean_age_by_site = df.groupby('site')['age'].mean().round(1)
print("\n4. Mean age by site:")
display(mean_age_by_site)

# Optional bar plot of mean age by site
plt.figure(figsize=(10, 6))
mean_age_by_site.plot(kind='bar')
plt.title('Mean Age by Site')
plt.ylabel('Age (years)')
plt.xticks(rotation=0) 
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'q4_mean_age_by_site_bar.png'))
plt.close()


# 5. Save the site value counts to output/q4_site_counts.csv
# Use the series derived from the value_counts() to save the artifact
df['site'].value_counts().to_csv(os.path.join(OUTPUT_DIR, 'q4_site_counts.csv'), header=['patient_count'])
print("\n5. Site value counts saved to output/q4_site_counts.csv")

In [None]:
# 5. Save to CSV
output_file = 'output/q4_site_counts.csv'
site_counts.to_csv(output_file)
print(f"Saved site value counts to {output_file}")

# first few rows
print("\nPreview of saved CSV content:")
saved_counts = pd.read_csv(output_file)
display(saved_counts.head())

## Summary

Write 2-3 sentences about what you learned from exploring this dataset.

**Your summary here:**
Exploration of the raw dataset revealed several key areas requiring cleanup, including the presence of missing values, particularly in lab measurements, and the need for standardization across categorical columns like 'site' and 'intervention_group' due to inconsistent spelling and formatting. 
Furthermore, initial distribution plots indicated potential outliers in the age and BMI fields, confirming that data cleaning and transformation steps are necessary before any reliable statistical aggregation or modeling can be performed."

