## Import `dxdata` package and initialize Spark engine
### Docs at: https://github.com/dnanexus/OpenBio/blob/master/dxdata/getting_started_with_dxdata.ipynb

In [1]:
import dxdata
import os

# Initialize dxdata engine
engine = dxdata.connect(dialect="hive+pyspark")

## Connect to the dataset

Next, we can set a `DATASET_ID` variable, which takes a value: `[projectID]:[dataset ID]`
We use it to define the `dataset` with `dxdata.load_dataset` function.

**projectID** and **dataset ID** values are unique to your project.
Notebook example **101** explains how to get them.

In [18]:
project = os.getenv('DX_PROJECT_CONTEXT_ID')
record = os.popen("dx find data --type Dataset --delimiter ',' | awk -F ',' '{print $5}'").read().rstrip()
DATASET_ID = project + ":" + record
dataset = dxdata.load_dataset(id=DATASET_ID)

pheno = dataset['participant']

In [3]:
project

'project-J3YgK2QJ50vb4pG71P1gk3gQ'

In [17]:
os.popen("dx find data --type Dataset --delimiter ',' | awk -F ',' '{print $5}'").read()

'record-J3Yzq5QJ405q1YQKv2Bq1GZj\n'

## Extract field names from ukb_field_mapping.json

Now we'll load the field mapping JSON file and extract all field names in the format required for data extraction.

In [5]:
!ls /mnt/project

 A102_Explore-participant-data_Python.ipynb    'basics field names.json'
 Bulk					        data_extraction.ipynb
'Respiratory patients'			        data_participant.csv
'Showcase metadata'			       'field names.json'
'Untitled Workflow - 10___15___2025 11:24 PM'   ukb_field_mapping.json
 app240523_20251009163455		        ukb_respiratory_pipeline.py
 app240523_20251009163455.dataset


In [None]:
import json

# Load the hierarchical field mapping JSON file
# Update this path to match your actual file location
mapping_file = '/mnt/project/ukb_field_mapping_new.json'

with open(mapping_file, 'r') as f:
    field_mapping = json.load(f)

def extract_field_ids_recursive(data, field_ids=None, category_name=None):
    """
    Recursively extract field IDs from the hierarchical JSON structure.
    
    A field entry is identified by having 'name' and 'value_type' keys.
    Category levels are dictionaries that don't have these keys.
    
    Args:
        data: Dictionary (can be nested categories or field entries)
        field_ids: List to accumulate field IDs
    
    Returns:
        List of field IDs
    """
    if field_ids is None:
        field_ids = []
    
    if isinstance(data, dict):
        # Check if this is a field entry (has 'name' and 'value_type')
        if 'name' in data and 'value_type' in data:
            # This is a field entry, not a category - skip it
            return field_ids
        
        # Otherwise, iterate through the dictionary
        for key, value in data.items():
            if isinstance(value, dict):
                # Check if this value is a field entry
                if 'name' in value and 'value_type' in value:
                    # This key is a field ID
                    field_ids.append(key)
                else:
                    # This is a nested category, recurse into it
                    extract_field_ids_recursive(value, field_ids)
    
    return field_ids

def get_category_fields(field_mapping, category_path=None):
    """
    Extract field IDs from either all categories or a specific category.
    
    Args:
        field_mapping: The complete hierarchical field mapping dictionary
        category_path: Optional list of category names to navigate to specific category
                      e.g., ['Lifestyle and environment', 'Smoking'] 
                      If None, extracts all fields from entire structure
    
    Returns:
        List of field names (prefixed with 'p')
    """
    # If no category path specified, extract all fields
    if category_path is None:
        field_ids = extract_field_ids_recursive(field_mapping)
    else:
        # Navigate to the specific category
        current_level = field_mapping
        
        for category in category_path:
            if category in current_level:
                current_level = current_level[category]
            else:
                print(f"Warning: Category '{category}' not found in the structure")
                print(f"Available categories at this level: {list(current_level.keys())}")
                return []
        
        # Extract fields only from this category
        field_ids = extract_field_ids_recursive(current_level)
    
    # Convert to the format used in UK Biobank (prefixed with 'p')
    field_names = [f'p{field_id}' for field_id in field_ids]
    
    # Remove duplicates and sort
    field_names = sorted(list(set(field_names)))
    
    return field_names

def print_categories(field_mapping, indent=0):
    """
    Helper function to print the category structure for easy navigation.
    
    Args:
        field_mapping: The hierarchical field mapping dictionary
        indent: Current indentation level (for display)
    """
    if isinstance(field_mapping, dict):
        # Skip if this is a field entry
        if 'name' in field_mapping and 'value_type' in field_mapping:
            return
        
        for key, value in field_mapping.items():
            if isinstance(value, dict) and not ('name' in value and 'value_type' in value):
                # This is a category
                print("  " * indent + f"- {key}")
                print_categories(value, indent + 1)

# OPTION 1: Extract ALL fields (default behavior)
print("=" * 70)
print("EXTRACTING ALL FIELDS FROM ALL CATEGORIES")
print("=" * 70)
field_names = get_category_fields(field_mapping, category_path=None)

print(f"Total unique fields to extract: {len(field_names)}")
print(f"\nFirst 10 field names:")
for field in field_names[:10]:
    print(f"  {field}")

# OPTION 2: Extract fields from a SPECIFIC CATEGORY
# Uncomment and modify the category_path below to extract only one category

# print("\n" + "=" * 70)
# print("AVAILABLE CATEGORIES")
# print("=" * 70)
# print_categories(field_mapping)
# 
# print("\n" + "=" * 70)
# print("EXTRACTING FIELDS FROM SPECIFIC CATEGORY")
# print("=" * 70)
# # Example: Extract only respiratory-related fields
# # Modify the category_path list to navigate to your desired category
# category_path = ['Health and medical history', 'Respiratory']  # Adjust as needed
# field_names = get_category_fields(field_mapping, category_path=category_path)
# 
# print(f"Category: {' -> '.join(category_path)}")
# print(f"Total unique fields to extract: {len(field_names)}")
# print(f"\nField names:")
# for field in field_names:
#     print(f"  {field}")


Total unique fields to extract: 86

Field names:
p1558
p1787
p20002
p20116
p20126
p20205
p20208
p20252
p20255
p20256
p20257
p20258
p2090
p2100
p21000
p21001
p22506
p23104
p2316
p2335
p24003
p24004
p24005
p24006
p24016
p2867
p2887
p30000
p30010
p30020
p30030
p30040
p30050
p30060
p30070
p30080
p30090
p30100
p30110
p30120
p30130
p30140
p30150
p30600
p30610
p30620
p3063
p30630
p3064
p30640
p30650
p30660
p30670
p30680
p30690
p30700
p30710
p30720
p30730
p30740
p30750
p30760
p30770
p30780
p30790
p30800
p30810
p31
p34
p3786
p40001
p40002
p41202
p41204
p41270
p4559
p4598
p4631
p4717
p50
p52
p5375
p5663
p6138
p6152
p738


In [28]:
pattern = ".*FEV.*"
fields_height = list(pheno.find_fields(title_regex=pattern))

for _ in fields_height:
    print("[" + _.name + "]\t" + _.title + " (" + _.linkout + ")")
    

[p3063_i0_a0]	Forced expiratory volume in 1-second (FEV1) | Instance 0 | Array 0 (http://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=3063)
[p3063_i0_a1]	Forced expiratory volume in 1-second (FEV1) | Instance 0 | Array 1 (http://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=3063)
[p3063_i0_a2]	Forced expiratory volume in 1-second (FEV1) | Instance 0 | Array 2 (http://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=3063)
[p3063_i1_a0]	Forced expiratory volume in 1-second (FEV1) | Instance 1 | Array 0 (http://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=3063)
[p3063_i1_a1]	Forced expiratory volume in 1-second (FEV1) | Instance 1 | Array 1 (http://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=3063)
[p3063_i1_a2]	Forced expiratory volume in 1-second (FEV1) | Instance 1 | Array 2 (http://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=3063)
[p3063_i2_a0]	Forced expiratory volume in 1-second (FEV1) | Instance 2 | Array 0 (http://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=3063)
[p3063_i2_a1]	Forced expiratory vo

## Verify fields exist in the dataset and get their titles

Let's verify that these fields exist in the UK Biobank dataset and retrieve their full titles.


In [30]:
# Verify fields and get their information (including ALL instances)
# Use fuzzy matching to find all instances of each field (e.g., p50, p50_i0, p50_i1, etc.)

verified_fields = []
missing_fields = []

for field_name in field_names:
    try:
        # Try exact match first (for fields without instances)
        try:
            field = pheno.find_field(name=field_name)
            verified_fields.append({
                'name': field.name,
                'title': field.title,
                'linkout': field.linkout
            })
            print(f"✓ [{field.name}]\t{field.title}")
        except:
            # If exact match fails, search for all instances using regex
            # Pattern matches: p50$ (exact) OR p50_i0, p50_i1, p50_i2, etc.
            pattern = f"^{field_name}$|^{field_name}_i\\d+$|^{field_name}_i\\d+_a\\d+$"
            instance_fields = list(pheno.find_fields(name_regex=pattern))
            
            if instance_fields:
                for field in instance_fields:
                    verified_fields.append({
                        'name': field.name,
                        'title': field.title,
                        'linkout': field.linkout
                    })
                    print(f"✓ [{field.name}]\t{field.title}")
                print(f"  → Found {len(instance_fields)} instances for {field_name}")
            else:
                raise Exception(f"No instances found")
                
    except Exception as e:
        missing_fields.append(field_name)
        print(f"✗ [{field_name}]\tNot found: {str(e)}")

print(f"\n\nSummary:")
print(f"Found: {len(verified_fields)} fields (including all instances)")
print(f"Missing: {len(missing_fields)} base fields")
if missing_fields:
    print(f"\nMissing fields: {missing_fields}")


✓ [p1558_i0]	Alcohol intake frequency. | Instance 0
✓ [p1558_i1]	Alcohol intake frequency. | Instance 1
✓ [p1558_i2]	Alcohol intake frequency. | Instance 2
✓ [p1558_i3]	Alcohol intake frequency. | Instance 3
  → Found 4 instances for p1558
✓ [p1787_i0]	Maternal smoking around birth | Instance 0
✓ [p1787_i1]	Maternal smoking around birth | Instance 1
✓ [p1787_i2]	Maternal smoking around birth | Instance 2
  → Found 3 instances for p1787
✓ [p20002_i0_a0]	Non-cancer illness code, self-reported | Instance 0 | Array 0
✓ [p20002_i0_a1]	Non-cancer illness code, self-reported | Instance 0 | Array 1
✓ [p20002_i0_a2]	Non-cancer illness code, self-reported | Instance 0 | Array 2
✓ [p20002_i0_a3]	Non-cancer illness code, self-reported | Instance 0 | Array 3
✓ [p20002_i0_a4]	Non-cancer illness code, self-reported | Instance 0 | Array 4
✓ [p20002_i0_a5]	Non-cancer illness code, self-reported | Instance 0 | Array 5
✓ [p20002_i0_a6]	Non-cancer illness code, self-reported | Instance 0 | Array 6
✓ [p200

## Save field names to JSON file

Save the verified field names to a JSON file in the same format as the original field names.json file.


In [7]:
# It is read-only file system

# # Save field names to JSON file
# output_file = '/mnt/project/extracted_field_names.json'

# # Save as list (one field per line like the original)
# with open(output_file, 'w') as f:
#     json.dump([f['name'] for f in verified_fields], f, indent=0)

# print(f"Field names saved to: {output_file}")
# print(f"Total fields saved: {len(verified_fields)}")


## Extract Data in Batches and Save

This cell performs batch extraction to handle large numbers of fields efficiently.

**IMPORTANT**: Uses **Direct Spark Save** to avoid Spark driver memory errors.

**Features:**
- Processes fields in batches (default: 100 fields per batch)
- Saves directly from Spark DataFrame to CSV (avoids `toPandas()` memory issues)
- Saves each batch as a separate CSV file (`ukb_batch_01_of_XX.csv`, etc.)
- Automatically loads and merges all batches into a single pandas DataFrame
- Includes memory cleanup after processing

**Configuration:**
- `BATCH_SIZE`: Number of fields per batch (default: 100, reduce if still getting errors)
- `USE_SPARK_SAVE`: Whether to save directly from Spark (recommended: `True`)

**Why Direct Spark Save?**
When converting large Spark DataFrames to pandas using `toPandas()`, you may encounter:
```
Py4JJavaError: Total size of serialized results is bigger than spark.driver.maxResultSize (1024.0 MiB)
```

By setting `USE_SPARK_SAVE=True`, data is saved directly from Spark to CSV files, bypassing this limitation.

**Benefits:**
- Prevents Spark driver memory errors
- Progress is saved incrementally (if extraction fails, you don't lose everything)
- Each batch file can be used independently if needed
- Can process any number of fields without hitting memory limits


## Retrieve data from the table

The following code selects the `participant` table.
Then we can define which field we are interested in using the `find_field` function.

There are three main ways to identify the field of interest:

- With `name` argument: here we give field ID. We can construct field ID used by the `dxdata` package from the field ID defined by UKB Showcase. The numeric showcase ID is translated to the Spark DB column name by adding the letter `p` at the beginning: e.g. *Standing height* showcase id is `50`, so Spark ID would be `p50`. Usually, fields have multiple instances. In such case, we add the `_i` suffix followed by instance number, e.g. *Standing height | Instance 0* will be `p50_i0`
- With `title` argument: here we define the field by full title, followed by ` | Instance` suffix, e.g. `Age at recruitment` or `Standing height | Instance 0`
- With `title_regex` argument: here we define the field by [regular expression](https://docs.python.org/3/howto/regex.html) matching the part of the title. We can use a keyword here, e.g. `.*height.*` will return all columns with the word *height* in the title.

## Optional: Reload and Merge Saved Batch Files

If you've already run the batch extraction and saved the intermediate batch files, you can use this cell to reload and merge them without re-extracting from UK Biobank.

**Usage:**
```python
# Uncomment to reload all batch files matching the pattern
df = reload_and_merge_batches()

# Or specify a custom pattern
df = reload_and_merge_batches(batch_pattern='my_custom_batch_*.csv')
```

This is useful when:
- You need to restart your notebook session
- You want to work with the data without re-running the extraction
- You're experimenting with different processing approaches


In [None]:
# Get field objects for extraction
field_objects = []
for field_info in verified_fields:
    try:
        field = pheno.find_field(name=field_info['name'])
        field_objects.append(field)
    except Exception as e:
        print(f"Error loading field {field_info['name']}: {e}")

# Add participant ID - this will be included in EVERY batch
eid_field = pheno.find_field(name="eid")

# Ensure we have ICD-10 diagnosis fields for filtering
# These fields contain the diagnosis codes
icd10_fields = ['p41202', 'p41204', 'p41270']
icd10_field_objects = []
for icd_field in icd10_fields:
    try:
        # Check if already in field_objects
        if not any(f.name == icd_field for f in field_objects):
            field = pheno.find_field(name=icd_field)
            icd10_field_objects.append(field)
            print(f"Added ICD-10 diagnosis field: {icd_field}")
    except Exception as e:
        print(f"Warning: Could not add {icd_field}: {e}")

# Append ICD-10 fields to field_objects
field_objects.extend(icd10_field_objects)

print(f"\nTotal fields to extract: {len(field_objects)} (excluding eid)")

# ============================================================
# BATCH PROCESSING CONFIGURATION
# ============================================================
BATCH_SIZE = 100  # Reduced to 100 fields per batch to avoid Spark memory errors
USE_SPARK_SAVE = True  # Save directly from Spark (avoids driver maxResultSize error)

# Split fields into batches
total_batches = (len(field_objects) + BATCH_SIZE - 1) // BATCH_SIZE
print(f"Processing in {total_batches} batch(es) of up to {BATCH_SIZE} fields each")
print(f"Save method: {'Direct Spark save (recommended)' if USE_SPARK_SAVE else 'Pandas conversion'}")

batch_files = []
import os
import shutil

for batch_idx in range(total_batches):
    start_idx = batch_idx * BATCH_SIZE
    end_idx = min((batch_idx + 1) * BATCH_SIZE, len(field_objects))
    
    # Get fields for this batch
    batch_fields = [eid_field] + field_objects[start_idx:end_idx]
    
    print(f"\n{'='*70}")
    print(f"BATCH {batch_idx + 1}/{total_batches}")
    print(f"Processing fields {start_idx + 1} to {end_idx} ({len(batch_fields)-1} fields + eid)")
    print(f"{'='*70}")
    
    # Retrieve data for this batch (returns PySpark DataFrame)
    print(f"Extracting data from UK Biobank...")
    spark_df_batch = pheno.retrieve_fields(fields=batch_fields, engine=engine)
    
    # Get row count using PySpark method
    row_count = spark_df_batch.count()
    col_count = len(spark_df_batch.columns)
    print(f"Extracted: {row_count} participants, {col_count} fields")
    
    # Save batch
    batch_filename = f'ukb_batch_{batch_idx+1:02d}_of_{total_batches:02d}.csv'
    
    if USE_SPARK_SAVE:
        # Save directly from Spark to avoid memory issues
        print(f"Saving directly from Spark DataFrame (avoids memory errors)...")
        
        # Save as CSV using Spark (creates a directory with part files)
        temp_dir = f'temp_batch_{batch_idx+1:02d}'
        spark_df_batch.coalesce(1).write.mode('overwrite').option('header', 'true').csv(temp_dir)
        
        # Move the single CSV file from the temp directory to the final location
        # Find the part file (usually part-00000*.csv)
        part_files = [f for f in os.listdir(temp_dir) if f.startswith('part-') and f.endswith('.csv')]
        if part_files:
            source_file = os.path.join(temp_dir, part_files[0])
            shutil.move(source_file, batch_filename)
            file_size = os.path.getsize(batch_filename) / (1024**2)  # Size in MB
            print(f"✓ Saved batch to: {batch_filename} ({file_size:.2f} MB)")
        else:
            print(f"Warning: No part file found in {temp_dir}")
        
        # Clean up temp directory
        shutil.rmtree(temp_dir, ignore_errors=True)
        
    else:
        # Alternative: Convert to pandas first (may fail with large batches)
        try:
            print(f"Converting to pandas DataFrame...")
            df_batch = spark_df_batch.toPandas()
            print(f"Batch shape: {df_batch.shape}")
            
            df_batch.to_csv(batch_filename, index=False)
            file_size = os.path.getsize(batch_filename) / (1024**2)  # Size in MB
            print(f"✓ Saved batch to: {batch_filename} ({file_size:.2f} MB)")
        except Exception as e:
            print(f"ERROR: Failed to convert batch {batch_idx+1} to pandas: {e}")
            print(f"Consider reducing BATCH_SIZE or setting USE_SPARK_SAVE=True")
            raise
    
    batch_files.append(batch_filename)
    print(f"✓ Batch {batch_idx + 1} completed")

print(f"\n{'='*70}")
print(f"ALL BATCHES SAVED")
print(f"{'='*70}")
print(f"Saved {len(batch_files)} batch files:")
total_size = 0
for bf in batch_files:
    if os.path.exists(bf):
        file_size = os.path.getsize(bf) / (1024**2)  # Size in MB
        total_size += file_size
        print(f"  ✓ {bf} ({file_size:.2f} MB)")
print(f"Total size: {total_size:.2f} MB")

print(f"\n{'='*70}")
print(f"LOADING AND MERGING ALL BATCHES")
print(f"{'='*70}")
print(f"Note: Loading saved CSV files into pandas for merging...")

# Load all batch CSVs into pandas DataFrames
import pandas as pd
batch_dataframes = []

for i, batch_file in enumerate(batch_files, start=1):
    print(f"Loading batch {i}/{len(batch_files)}: {batch_file}")
    df_batch = pd.read_csv(batch_file)
    print(f"  Shape: {df_batch.shape}")
    batch_dataframes.append(df_batch)

# Merge all batches on 'eid' (participant ID)
print(f"\nMerging batches...")
df = batch_dataframes[0]
for i, batch_df in enumerate(batch_dataframes[1:], start=2):
    print(f"  Merging batch {i}...")
    # Drop 'eid' from subsequent batches before merging to avoid duplication
    batch_df_no_eid = batch_df.drop(columns=['eid'])
    df = df.join(batch_df_no_eid)

print(f"\n✓ All batches merged successfully!")
print(f"Full dataset shape: {df.shape}")
print(f"Total participants: {df.shape[0]}")
print(f"Total fields: {df.shape[1]} (including eid)")

# Display first few rows
print("\nFirst few rows:")
print(df.head())

# Clean up memory
del batch_dataframes
import gc
gc.collect()

print(f"\n✓ Memory cleaned up")


Extracting data for 399 fields...
Data extracted: 501936 participants, 399 fields

Converting to pandas DataFrame...
Full dataset shape: (501936, 399)

First few rows:
       eid  p1558_i0  p1558_i1  p1558_i2  p1558_i3  p1787_i0  p1787_i1  \
0  1000020       2.0       NaN       NaN       NaN       0.0       NaN   
1  1000053       2.0       NaN       NaN       NaN       0.0       NaN   
2  1000171       1.0       NaN       NaN       NaN      -1.0       NaN   
3  1000186       5.0       NaN       NaN       NaN       0.0       NaN   
4  1000199       5.0       NaN       NaN       NaN       1.0       NaN   

   p1787_i2  p20002_i0_a0  p20002_i0_a1  ...  p6138_i2  p6138_i3  p6152_i0  \
0       NaN        1371.0        1473.0  ...      None      None      [-7]   
1       NaN        1065.0        1458.0  ...      None      None      [-7]   
2       NaN           NaN           NaN  ...      None      None      [-7]   
3       NaN        1309.0        1265.0  ...      None      None      [-7]

In [21]:
df.head()

Unnamed: 0,eid,p1558_i0,p1558_i1,p1558_i2,p1558_i3,p1787_i0,p1787_i1,p1787_i2,p20116_i0,p20116_i1,...,p6138_i2,p6138_i3,p6152_i0,p6152_i1,p6152_i2,p6152_i3,p738_i0,p738_i1,p738_i2,p738_i3
0,1000020,2.0,,,,0.0,,,0.0,,...,,,[-7],,,,2.0,,,
1,1000053,2.0,,,,0.0,,,0.0,,...,,,[-7],,,,3.0,,,
2,1000171,1.0,,,,-1.0,,,1.0,,...,,,[-7],,,,4.0,,,
3,1000186,5.0,,,,0.0,,,0.0,,...,,,[-7],,,,2.0,,,
4,1000199,5.0,,,,1.0,,,0.0,,...,,,[9],,,,1.0,,,


## Filter for Respiratory Disease Cohort

We will filter participants to include only those with respiratory disease diagnoses based on ICD-10 codes:
- **J09-J98**: Diseases of the respiratory system
- **I26-I27**: Pulmonary heart disease and diseases of pulmonary circulation

This filtering will be applied after data extraction using the diagnosis fields (p41202 "Diagnoses - main ICD10", p41204 "Diagnoses - secondary ICD10", p41270 "Diagnoses ICD10").


In [33]:
import re

def has_respiratory_diagnosis(row):
    """
    Check if a participant has respiratory disease diagnosis based on ICD-10 codes:
    - J09-J98: Diseases of the respiratory system
    - I26-I27: Pulmonary heart disease and diseases of pulmonary circulation
    
    Note: ICD-10 codes may be stored without decimal points (e.g., J181 = J18.1)
    """
    # Collect all diagnosis codes from available ICD-10 fields
    all_diagnoses = []
    
    for field in ['p41202', 'p41204', 'p41270']:
        if field in row.index and row[field] is not None:
            # Handle both list and single string values
            if isinstance(row[field], list):
                all_diagnoses.extend(row[field])
            else:
                all_diagnoses.append(str(row[field]))
    
    # Check each diagnosis code
    for code in all_diagnoses:
        if code and isinstance(code, str):
            # Remove any whitespace and convert to uppercase
            code = code.strip().upper()
            
            # Check for J09-J98 (respiratory diseases)
            # ICD-10 format: Letter + 2-3 digits (category) + optional subcategory
            # Examples: J18.1 (stored as J181), J96.90 (stored as J9690)
            if code.startswith('J'):
                # Extract the category code (first 2 digits after 'J')
                # For J codes, category is always 2 digits: J00-J99
                match = re.match(r'J(\d{2})', code)
                if match:
                    category = int(match.group(1))
                    if 9 <= category <= 98:
                        return True
            
            # Check for I26-I27 (pulmonary heart disease)
            # For I codes, category is also 2 digits: I00-I99
            elif code.startswith('I'):
                match = re.match(r'I(\d{2})', code)
                if match:
                    category = int(match.group(1))
                    if 26 <= category <= 27:
                        return True
    
    return False

# Apply filter to get respiratory disease cohort
print("Filtering for respiratory disease patients...")
df['has_respiratory_disease'] = df.apply(has_respiratory_diagnosis, axis=1)

respiratory_df = df[df['has_respiratory_disease'] == True].copy()
respiratory_df = respiratory_df.drop(columns=['has_respiratory_disease'])

print(f"\nCohort filtering results:")
print(f"Original dataset: {len(df)} participants")
print(f"Respiratory disease cohort: {len(respiratory_df)} participants ({len(respiratory_df)/len(df)*100:.2f}%)")

# Show some diagnosis examples
print("\nExample diagnoses from filtered cohort (first 5 participants):")
for idx, row in respiratory_df.head().iterrows():
    diagnoses = []
    for field in ['p41202', 'p41204', 'p41270']:
        if field in row.index and row[field] is not None:
            if isinstance(row[field], list):
                diagnoses.extend(row[field])
            else:
                diagnoses.append(str(row[field]))
    
    # Filter for respiratory codes only using correct parsing
    resp_codes = []
    for d in diagnoses:
        if d and isinstance(d, str):
            d_clean = d.strip().upper()
            # Check J09-J98
            if d_clean.startswith('J'):
                match = re.match(r'J(\d{2})', d_clean)
                if match and 9 <= int(match.group(1)) <= 98:
                    resp_codes.append(d)
            # Check I26-I27
            elif d_clean.startswith('I'):
                match = re.match(r'I(\d{2})', d_clean)
                if match and 26 <= int(match.group(1)) <= 27:
                    resp_codes.append(d)
    
    print(f"  Participant {row['eid']}: {', '.join(resp_codes[:10])}")  # Show first 10 codes

# Save filtered cohort to CSV
output_csv = 'ukb_respiratory_cohort.csv'
respiratory_df.to_csv(output_csv, index=False)
print(f"\nFiltered respiratory cohort saved to: {output_csv}")

# Also save full dataset (unfiltered) for comparison
output_csv_full = 'ukb_full_data.csv'
df.to_csv(output_csv_full, index=False)
print(f"Full dataset saved to: {output_csv_full}")


Filtering for respiratory disease patients...

Cohort filtering results:
Original dataset: 501936 participants
Respiratory disease cohort: 133842 participants (26.67%)

Example diagnoses from filtered cohort (first 5 participants):
  Participant 1000496: J181, J90, J981, J181, J90, J981
  Participant 1000517: J151, J181, J189, J90, J439, J440, J9690, J151, J181, J189
  Participant 1001018: I269, I269, J439, J981, J984, I269, J439, J981, J984
  Participant 1001080: J189, J189, J189
  Participant 1001128: J13, J459, J13, J459

Filtered respiratory cohort saved to: ukb_respiratory_cohort.csv
Full dataset saved to: ukb_full_data.csv


## Analyze ICD-10 Code Distribution in Respiratory Cohort

Let's examine the distribution of respiratory ICD-10 codes to understand the types of respiratory diseases in our cohort.


In [None]:
from collections import Counter

# Collect all respiratory ICD-10 codes from the cohort
all_resp_codes = []

for idx, row in respiratory_df.iterrows():
    for field in ['p41202', 'p41204', 'p41270']:
        if field in row.index and row[field] is not None:
            if isinstance(row[field], list):
                codes = row[field]
            else:
                codes = [str(row[field])]
            
            # Filter for respiratory codes (J09-J98, I26-I27)
            for code in codes:
                if code and isinstance(code, str):
                    code = code.strip().upper()
                    # Check if it's a respiratory code
                    # Use correct ICD-10 parsing: extract first 2 digits (category)
                    if code.startswith('J'):
                        match = re.match(r'J(\d{2})', code)
                        if match:
                            category = int(match.group(1))
                            if 9 <= category <= 98:
                                all_resp_codes.append(code)
                    elif code.startswith('I'):
                        match = re.match(r'I(\d{2})', code)
                        if match:
                            category = int(match.group(1))
                            if 26 <= category <= 27:
                                all_resp_codes.append(code)

# Count frequency of each code
code_counts = Counter(all_resp_codes)

print(f"Total respiratory diagnoses recorded: {len(all_resp_codes)}")
print(f"Unique respiratory diagnosis codes: {len(code_counts)}")
print(f"\nTop 20 most common respiratory diagnoses:")
print("-" * 60)

for code, count in code_counts.most_common(20):
    percentage = (count / len(respiratory_df)) * 100
    print(f"{code:8s} : {count:6d} cases ({percentage:5.2f}% of cohort)")

# Breakdown by major category
j_codes = {k: v for k, v in code_counts.items() if k.startswith('J')}
i_codes = {k: v for k, v in code_counts.items() if k.startswith('I')}

print(f"\n\nBreakdown by ICD-10 category:")
print(f"J codes (Respiratory system diseases): {sum(j_codes.values())} diagnoses across {len(j_codes)} unique codes")
print(f"I26-I27 codes (Pulmonary heart disease): {sum(i_codes.values())} diagnoses across {len(i_codes)} unique codes")

# Save code statistics to file
import json
code_stats = {
    'total_participants': len(respiratory_df),
    'total_diagnoses': len(all_resp_codes),
    'unique_codes': len(code_counts),
    'top_20_codes': [{'code': code, 'count': count, 'percentage': round((count/len(respiratory_df))*100, 2)} 
                     for code, count in code_counts.most_common(20)],
    'category_breakdown': {
        'J_codes': {'count': sum(j_codes.values()), 'unique': len(j_codes)},
        'I_codes': {'count': sum(i_codes.values()), 'unique': len(i_codes)}
    }
}

stats_file = 'icd10_code_statistics.json'
with open(stats_file, 'w') as f:
    json.dump(code_stats, f, indent=2)
    
print(f"\n\nCode statistics saved to: {stats_file}")


Total respiratory diagnoses recorded: 590935
Unique respiratory diagnosis codes: 211

Top 20 most common respiratory diagnoses:
------------------------------------------------------------
J459     : 101651 cases (75.95% of cohort)
J449     :  43193 cases (32.27% of cohort)
J181     :  42408 cases (31.69% of cohort)
J22      :  40329 cases (30.13% of cohort)
J90      :  38496 cases (28.76% of cohort)
J189     :  29217 cases (21.83% of cohort)
I269     :  20890 cases (15.61% of cohort)
J981     :  20739 cases (15.50% of cohort)
J440     :  18262 cases (13.64% of cohort)
J47      :  14351 cases (10.72% of cohort)
J342     :  11746 cases ( 8.78% of cohort)
J439     :  11191 cases ( 8.36% of cohort)
J301     :   9337 cases ( 6.98% of cohort)
J984     :   8852 cases ( 6.61% of cohort)
J348     :   8633 cases ( 6.45% of cohort)
J690     :   8080 cases ( 6.04% of cohort)
J841     :   7902 cases ( 5.90% of cohort)
J128     :   6951 cases ( 5.19% of cohort)
J9690    :   6903 cases ( 5.16% of co