# Assignment 5, Question 4: Data Exploration

**Points: 15**

In this notebook, you'll explore the clinical trial dataset using pandas selection and filtering techniques.

You'll use utility functions from `q3_data_utils` where helpful, but also demonstrate direct pandas operations.

## Setup

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Import utilities from Q3
from q3_data_utils import load_data, detect_missing, filter_data

# Load the data
df = load_data('data/clinical_trial_raw.csv')
print(f"Loaded {len(df)} patients with {len(df.columns)} variables")

# Prewritten visualization functions for exploration
def plot_value_counts(series, title, figsize=(10, 6)):
    """
    Create a bar chart of value counts.
    
    Args:
        series: pandas Series with value counts
        title: Chart title
        figsize: Figure size tuple
    """
    plt.figure(figsize=figsize)
    series.plot(kind='bar')
    plt.title(title)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

def plot_crosstab(crosstab_data, title, figsize=(10, 6)):
    """
    Create a heatmap of crosstab data.
    
    Args:
        crosstab_data: pandas DataFrame from pd.crosstab()
        title: Chart title
        figsize: Figure size tuple
    """
    plt.figure(figsize=figsize)
    plt.imshow(crosstab_data.values, cmap='Blues', aspect='auto')
    plt.colorbar()
    plt.title(title)
    plt.xticks(range(len(crosstab_data.columns)), crosstab_data.columns, rotation=45)
    plt.yticks(range(len(crosstab_data.index)), crosstab_data.index)
    plt.tight_layout()
    plt.show()

Loaded 10000 patients with 18 variables


## Part 1: Basic Exploration (3 points)

Display:
1. Dataset shape
2. Column names and types
3. First 10 rows
4. Summary statistics (.describe())

In [2]:
# Display dataset info
print(f"Dataset shape:\n{df.shape}")
print(f"Column names and types:\n{df.dtypes}")
print(f"First ten rows:\n{df.head(10)}")
print(f"Summary statistics:\n{df.describe()}")

Dataset shape:
(10000, 18)
Column names and types:
patient_id             object
age                     int64
sex                    object
bmi                   float64
enrollment_date        object
systolic_bp           float64
diastolic_bp          float64
cholesterol_total     float64
cholesterol_hdl       float64
cholesterol_ldl       float64
glucose_fasting       float64
site                   object
intervention_group     object
follow_up_months        int64
adverse_events          int64
outcome_cvd            object
adherence_pct         float64
dropout                object
dtype: object
First ten rows:
  patient_id  age         sex   bmi enrollment_date  systolic_bp  \
0     P00001   80           F  29.3      2022-05-01        123.0   
1     P00002   80    Female     NaN      2022-01-06        139.0   
2     P00003   82      Female  -1.0      2023-11-04        123.0   
3     P00004   95      Female  25.4      2022-08-15        116.0   
4     P00005   95           M   NaN    

## Part 2: Column Selection (3 points)

Demonstrate different selection methods:

1. Select only numeric columns using `.select_dtypes()`
2. Select specific columns by name
3. Select a subset of rows and columns using `.loc[]`

In [3]:
# Select numeric columns
print(f"Select numeric columns:\n{df.select_dtypes(include = ['number'])}")

Select numeric columns:
      age   bmi  systolic_bp  diastolic_bp  cholesterol_total  \
0      80  29.3        123.0          80.0              120.0   
1      80   NaN        139.0          81.0              206.0   
2      82  -1.0        123.0          86.0              172.0   
3      95  25.4        116.0          77.0              200.0   
4      95   NaN         97.0          71.0              185.0   
...   ...   ...          ...           ...                ...   
9995   72  23.2        122.0          73.0              182.0   
9996  100  28.9        124.0          78.0              157.0   
9997   78  23.8        110.0          63.0              154.0   
9998   86  27.0        139.0          98.0              196.0   
9999   67  29.4        134.0          83.0              197.0   

      cholesterol_hdl  cholesterol_ldl  glucose_fasting  follow_up_months  \
0                55.0             41.0            118.0                20   
1                58.0            107.0   

In [4]:
# Select specific columns
print(f"Select specific columns by name. Columns chosen: (1) age (2) dropout\n{df[['age', 'dropout']]}")

Select specific columns by name. Columns chosen: (1) age (2) dropout
      age dropout
0      80      No
1      80      No
2      82      No
3      95      No
4      95     Yes
...   ...     ...
9995   72      No
9996  100      No
9997   78      No
9998   86      No
9999   67      No

[10000 rows x 2 columns]


In [5]:
# Use .loc[] to select subset
print(f"Selecting a subset of rows and columns using .loc[]. Rows: 2-4 (Index) | Columns: age, dropout\n{df.loc[2:4, ['age', 'dropout']]}")

Selecting a subset of rows and columns using .loc[]. Rows: 2-4 (Index) | Columns: age, dropout
   age dropout
2   82      No
3   95      No
4   95     Yes


## Part 3: Filtering (4 points)

Filter the data to answer these questions:

1. How many patients are over 65 years old?
2. How many patients have systolic BP > 140?
3. Find patients who are both over 65 AND have systolic BP > 140
4. Find patients from Site A or Site B using `.isin()`

In [6]:
# Use filter_data to filter for age > 65
filter_age = [{'column' : 'age', 'condition' : 'greater_than', 'value' : 65}]

# Apply filter and count the results
patients_over_65 = filter_data(df, filter_age)
print(f"Patients over 65: {len(patients_over_65)}")

Patients over 65: 8326


In [7]:
# Use filter_data to filter for systolic_bp > 140
filter_systolicbp = [{'column' : 'systolic_bp', 'condition' : 'greater_than', 'value' : 140}]

# Apply the filter and count the results
high_bp = filter_data(df, filter_systolicbp)
print(f"Patients with high BP: {len(high_bp)}")

Patients with high BP: 538


In [8]:
# Apply filter to multiple condition
filter_age_bp = [
    {'column': 'age', 'condition': 'greater_than', 'value': 65},
    {'column': 'systolic_bp', 'condition': 'greater_than', 'value': 140}
]

# Apply the filter and count the results
both_conditions = filter_data(df, filter_age_bp)
print(f"Patients over 65 AND high BP: {len(both_conditions)}")

Patients over 65 AND high BP: 464


In [9]:
# Apply filter_data that uses .isin()
filter_site = [{'column' : 'site', 'condition' : 'in_list', 'value' : ['Site A', 'Site B']}]

# Apply the filter and count the results
site_ab = filter_data(df, filter_site)
print(f"Patients from Site A or B: {len(site_ab)}")

Patients from Site A or B: 1397


## Part 4: Value Counts and Grouping (5 points)

1. Get value counts for the 'site' column
2. Get value counts for the 'intervention_group' column  
3. Create a crosstab of site vs intervention_group
4. Calculate mean age by site
5. Save the site value counts to `output/q4_site_counts.csv`

In [10]:
# Get value counts for site column
site_counts = df['site'].value_counts()
print(f"Value counts from the 'site' column:\n{site_counts}")

# Get value counts for intervention_group column
intervention_counts = df['intervention_group'].value_counts()
print(f"Value counts from the 'intervention_group' column:\n{intervention_counts}")

# Create a crosstab of site vs intervention group
crosstab = df.pivot_table(
    index = 'site',
    columns = 'intervention_group',
    aggfunc = 'size',
    fill_value = 0
)
print(f"Crosstabulation of the 'site' vs 'intervention_group' columns:\n{crosstab}")

# Calculate mean age by site
mean_age_by_site = df.groupby('site')['age'].mean()
print(f"Mean age by site:\n{mean_age_by_site}")

Value counts from the 'site' column:
site
site b         742
Site B         736
SITE B         703
SITE A         684
Site  A        681
Site A         661
Site C         658
site a         651
site c         615
SITE C         605
Site D         362
site d         349
Site_D         332
Site E         319
SITE D         313
SITE E         295
site e         294
  SITE B        94
  site b        90
  Site B        88
  Site C        83
  SITE A        74
  site a        74
  Site  A       67
  Site A        64
  site c        57
  SITE C        55
  Site E        42
  site d        41
  SITE D        41
  site e        36
  Site D        32
  SITE E        31
  Site_D        31
Name: count, dtype: int64
Value counts from the 'intervention_group' column:
intervention_group
Contrl              802
TREATMENT B         761
Treatment  B        760
Control             751
treatment b         750
control             734
Treatment B         730
CONTROL             715
TreatmentA          635


In [11]:
# Save output
site_counts.to_csv('output/q4_site_counts.csv')


## Summary

Write 2-3 sentences about what you learned from exploring this dataset.

**Your summary here:**

After exploring this dataset, we see that it is primarily composed of older adults (~85% older than 65) who are surprisingly healthy (~5% of patients have high BP). We must also note, however, that our dataset appears to be pretty messy. From our cross-tabulation of site versus intervention group, we see a bunch of formatting issues with the 'site' variable, which directly affects our calculation of the number of people at each site. As such, data cleaning is needed if we want our dataset to be usable and interpretable.
