# Covariate Generation Script

**Purpose:** Extract phenotypes, genetic principal components, genomic ancestry, and assessment center data to generate a covariate file for e.g. REGENIE.

**Output:** A TSV file containing `FID`, `IID`, `SEX`, `AGE`, derived interaction terms, `ANCESTRY` (Genomic), `ASSESSMENT_CENTER`, and `PC1-PC20`.

---

In [None]:
import re
import subprocess
import shlex

import dxdata
import dxpy
import pyspark

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

# --- HELPER FUNCTIONS ---


def fields_for_id(field_id, participant):
    """Retrieve and sort fields, robustly handling _i and _a indices."""
    fid_str = str(field_id)
    fields = participant.find_fields(name_regex=rf"^p{fid_str}(_i\d+)?(_a\d+)?$")

    def get_sort_key(field):
        i_match = re.search(r"_i(\d+)", field.name)
        a_match = re.search(r"_a(\d+)", field.name)
        return (
            int(i_match.group(1)) if i_match else 0,
            int(a_match.group(1)) if a_match else 0,
        )

    return sorted(fields, key=get_sort_key)


def get_primary_column(df, field_id):
    """Finds the Instance 0 column (pXXXX_i0) for a field."""
    candidates = [c for c in df.columns if c.startswith(f"p{field_id}")]
    # Sort by length and name to prioritize 'p54' or 'p54_i0' over 'p54_i1'
    candidates.sort(key=lambda x: (len(x), x))
    if not candidates:
        raise ValueError(f"Field {field_id} not found in dataframe")
    return candidates[0]

In [None]:
# --- INITIALIZATION ---
# Run once
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Load Dataset
dispensed_dataset_id = dxpy.find_one_data_object(
    typename="Dataset", name="app*.dataset", folder="/", name_mode="glob"
)["id"]
dataset = dxdata.load_dataset(id=dispensed_dataset_id)
participant = dataset["participant"]

In [8]:
# Get all column names
# --- CONFIGURATION ---
FIELD_IDS = {
    "AGE": "21022",
    "SEX": "22001",
    "PCS": "22009",
    "ANCESTRY": "30079",  # Genomic Ancestry (New), see https://doi.org/10.1101/2024.03.13.24303864
    "CENTER": "54",  # Assessment Centre
}


all_fields = []

for fid in FIELD_IDS.values():
    all_fields.extend(fields_for_id(fid, participant))

field_names = ["eid"] + [f.name for f in all_fields]

print("Retrieving data...")
# coding_values="raw" ensures Categories (like Center) are returned as Integers
df_raw = participant.retrieve_fields(
    names=field_names, engine=dxdata.connect(), coding_values="raw"
)

# --- 2. DEFINE COLUMNS ---

# Identify correct column names for Core vars
col_age = get_primary_column(df_raw, FIELD_IDS["AGE"])
col_sex = get_primary_column(df_raw, FIELD_IDS["SEX"])
# Identify PC columns (Array 1 to 20)
col_pcs = [
    c
    for c in df_raw.columns
    if c.startswith(f"p{FIELD_IDS['PCS']}_") and int(c.split("_a")[1]) <= 20
]

# Identify Extended vars (Instance 0)
col_center = get_primary_column(df_raw, FIELD_IDS["CENTER"])  # p54_i0
col_ancestry = get_primary_column(df_raw, FIELD_IDS["ANCESTRY"])  # p30079

print(f"Using Core Cols: {col_age}, {col_sex}, and {len(col_pcs)} PCs")
print(f"Using Extended Cols: {col_center}, {col_ancestry}")

# Create Friendly Name Map for PCs
pcs_map = {c: f"PC{c.split('_a')[1]}" for c in col_pcs}

# --- 3. FILTERING (CORE ONLY) ---

# We filter rows based ONLY on missing Core Data
# We DO NOT drop rows if they are missing Center or Ancestry yet
df_core_base = df_raw.na.drop(subset=[col_age, col_sex, col_center] + col_pcs)

print(f"Participants in Core Set: {df_core_base.count()}")

Retrieving data...
Using Core Cols: p21022, p22001, and 20 PCs
Using Extended Cols: p54_i0, p30079
Participants in Core Set: 487713


In [5]:
# Apply transformations to the base set
df_processed = (
    df_core_base.withColumn("FID", F.col("eid"))
    .withColumn("IID", F.col("eid"))
    .withColumn("SEX", F.col(col_sex).cast(IntegerType()))
    .withColumn("AGE", F.col(col_age).cast(IntegerType()))
    .withColumn("CENTER", F.col(col_center).cast(IntegerType()))
    .withColumn("ANCESTRY", F.col(col_ancestry).cast(IntegerType()))
    # Standard Interactions
    .withColumn("AGE2", (F.col(col_age) ** 2).cast(IntegerType()))
    .withColumn("AGESEX", (F.col(col_age) * F.col(col_sex)).cast(IntegerType()))
    .withColumn("AGE2SEX", ((F.col(col_age) ** 2) * F.col(col_sex)).cast(IntegerType()))
    # Rename PCs
    .select("*", *[F.col(c).alias(pcs_map[c]) for c in col_pcs])
)

In [16]:
cols_keep = (
    ["FID", "IID", "SEX", "AGE", "CENTER", "AGE2", "AGESEX", "AGE2SEX"]
    + list(pcs_map.values())
    + ["ANCESTRY"]
)

# Sanity Check: Ensure counts match
print(f"Final Rows: {df_processed.count()}")

df_final = df_processed.select(*cols_keep)

# Check for nulls in Ancestry
df_final.select(
    F.count(F.when(F.col("ANCESTRY").isNull(), "ANCESTRY")).alias("Missing_Ancestry")
).show()

df_final.drop("FID", "IID").show(1, vertical=True)

Final Rows: 487713
+----------------+
|Missing_Ancestry|
+----------------+
|           40073|
+----------------+

-RECORD 0--------------
 SEX      | 0          
 AGE      | 52         
 CENTER   | 11011      
 AGE2     | 2704       
 AGESEX   | 0          
 AGE2SEX  | 0          
 PC1      | -12.417    
 PC2      | 6.75787    
 PC3      | -4.42069   
 PC4      | 0.749104   
 PC5      | -1.30339   
 PC6      | 0.0162366  
 PC7      | 1.29456    
 PC8      | -1.45318   
 PC9      | -2.0664    
 PC10     | -2.42804   
 PC11     | -2.04608   
 PC12     | -0.119549  
 PC13     | -0.705609  
 PC14     | -1.9216    
 PC15     | 1.42809    
 PC16     | 3.75853    
 PC17     | 0.0576762  
 PC18     | -0.0193767 
 PC19     | 1.18116    
 PC20     | -1.58071   
 ANCESTRY | 5          
only showing top 1 row



In [None]:
# Set to your own
DX_PROJECT_PATH = "TASR/Phenotypes/"
OUTPUT_FILENAME = "Covariates.tsv"


def upload_df(df, filename: str, project_path: str, temp_dir: str = "/tmp"):
    local_p = f"{temp_dir}/{filename}"
    remote_p = f"..{temp_dir}/{filename}"

    print(f"Writing {filename}...")
    df.coalesce(1).write.csv(local_p, sep="\t", header=True, mode="overwrite")

    print(f"Uploading {filename}...")

    # Define commands
    commands = [
        f"hadoop fs -getmerge {local_p} {remote_p}",
        f"dx upload {remote_p} --path {project_path}",
    ]

    for cmd in commands:
        print(f"Running: {cmd}")
        # Run command and capture output
        result = subprocess.run(shlex.split(cmd), capture_output=True, text=True)

        # Print STDOUT (if any)
        if result.stdout:
            print(f"[STDOUT]\n{result.stdout}")

        # Print STDERR (if any)
        if result.stderr:
            print(f"[STDERR]\n{result.stderr}")

        # Check for failure
        if result.returncode != 0:
            print(f"ERROR: Command failed with return code {result.returncode}")
            break


upload_df(df_final, OUTPUT_FILENAME, DX_PROJECT_PATH)