# Multi-Task Clinical Prediction Model

## Data Collection & Loading

In [10]:
import pandas as pd

In [11]:
# Create empty list to collect data
all_cycles = []

# Define which files to load
cycles = [
    ('H', '2013-2014'),
    ('I', '2015-2016'),
    ('J', '2017-2018'),
    ('L', '2021-2023')
]

# Loop through each cycle
for letter, year in cycles:
    demo = pd.read_sas(f'data/DEMO_{letter}.xpt')
    bio = pd.read_sas(f'data/BIOPRO_{letter}.xpt')
    data = demo.merge(bio, on='SEQN')
    data['cycle'] = year
    all_cycles.append(data)
    print(f"{year}: {len(data)} patients")

# Stack them all
all_data = pd.concat(all_cycles)
print(f"\nTotal: {len(all_data)} patients")

2013-2014: 6979 patients
2015-2016: 6744 patients
2017-2018: 6401 patients
2021-2023: 7199 patients

Total: 27323 patients


In [12]:
# Show first 50 column names
print(all_data.columns[:50].tolist())

['SEQN', 'SDDSRVYR', 'RIDSTATR', 'RIAGENDR', 'RIDAGEYR', 'RIDAGEMN', 'RIDRETH1', 'RIDRETH3', 'RIDEXMON', 'RIDEXAGM', 'DMQMILIZ', 'DMQADFC', 'DMDBORN4', 'DMDCITZN', 'DMDYRSUS', 'DMDEDUC3', 'DMDEDUC2', 'DMDMARTL', 'RIDEXPRG', 'SIALANG', 'SIAPROXY', 'SIAINTRP', 'FIALANG', 'FIAPROXY', 'FIAINTRP', 'MIALANG', 'MIAPROXY', 'MIAINTRP', 'AIALANGA', 'DMDHHSIZ', 'DMDFMSIZ', 'DMDHHSZA', 'DMDHHSZB', 'DMDHHSZE', 'DMDHRGND', 'DMDHRAGE', 'DMDHRBR4', 'DMDHREDU', 'DMDHRMAR', 'DMDHSEDU', 'WTINT2YR', 'WTMEC2YR', 'SDMVPSU', 'SDMVSTRA', 'INDHHIN2', 'INDFMIN2', 'INDFMPIR', 'LBXSAL', 'LBDSALSI', 'LBXSAPSI']


In [13]:
# Search for ALT column
alt_cols = [col for col in all_data.columns if 'ALT' in col or 'SAT' in col or 'SGPT' in col]
print("Found these ALT-related columns:")
print(alt_cols)

Found these ALT-related columns:
['LBXSATSI', 'LBDSATLC']


In [14]:
# Look at ALT data
print(f"Total patients: {len(all_data)}")
print(f"Patients WITH ALT data: {all_data['LBXSATSI'].notna().sum()}")
print(f"Patients MISSING ALT: {all_data['LBXSATSI'].isna().sum()}")

Total patients: 27323
Patients WITH ALT data: 25030
Patients MISSING ALT: 2293


In [15]:
rename_map = {
    'SEQN': 'Participant_ID',
    'RIAGENDR': 'Gender',
    'RIDAGEYR': 'Age',
    'LBXSATSI': 'ALT_Enzyme_U_L',
    'LBDSATLC': 'ALT_Comment_Code',
    'cycle': 'Survey_Cycle'
}

all_data = all_data.rename(columns=rename_map)
print("Renamed columns:")
print(all_data[list(rename_map.values())].head())

Renamed columns:
   Participant_ID  Gender   Age  ALT_Enzyme_U_L  ALT_Comment_Code Survey_Cycle
0         73557.0     1.0  69.0            16.0               NaN    2013-2014
1         73558.0     1.0  54.0            29.0               NaN    2013-2014
2         73559.0     1.0  72.0            16.0               NaN    2013-2014
3         73561.0     2.0  73.0            28.0               NaN    2013-2014
4         73562.0     1.0  56.0            16.0               NaN    2013-2014


In [16]:
# Save to CSV
all_data.to_csv('data/liver_head_data.csv', index=False)

In [17]:
print("\n--- BASIC EDA ---")

# 1. Basic shape
print(f"Total rows: {all_data.shape[0]}")
print(f"Total columns: {all_data.shape[1]}")

# 2. Quick info
print("\nData types and missing values:")
print(all_data.info())

# 3. Missing data summary (top 10 most missing)
missing_summary = all_data.isna().sum().sort_values(ascending=False).head(10)
print("\nTop 10 columns with most missing values:")
print(missing_summary)

# 4. Descriptive statistics (numeric)
print("\nDescriptive statistics for numeric columns:")
print(all_data.describe().T.head(10))



--- BASIC EDA ---
Total rows: 27323
Total columns: 95

Data types and missing values:
<class 'pandas.core.frame.DataFrame'>
Index: 27323 entries, 0 to 7198
Data columns (total 95 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   Participant_ID    27323 non-null  float64
 1   SDDSRVYR          27323 non-null  float64
 2   RIDSTATR          27323 non-null  float64
 3   Gender            27323 non-null  float64
 4   Age               27323 non-null  float64
 5   RIDAGEMN          0 non-null      float64
 6   RIDRETH1          27323 non-null  float64
 7   RIDRETH3          27323 non-null  float64
 8   RIDEXMON          27323 non-null  float64
 9   RIDEXAGM          4895 non-null   float64
 10  DMQMILIZ          24097 non-null  float64
 11  DMQADFC           1557 non-null   float64
 12  DMDBORN4          27323 non-null  float64
 13  DMDCITZN          20115 non-null  float64
 14  DMDYRSUS          5536 non-null   float64
 15  DMDEDU

In [19]:
features_to_use = [
    'Gender', 'Age', 'RIDRETH1',
    'BMXBMI', 'BMXWAIST', 'BMXWT', 'BMXHT',
    'LBXSAPSI', 'LBXGTSI', 'LBXGLU', 'LBXTR', 'LBXTC', 'LBXHDL',
    'SMQ020', 'ALQ101', 'ALQ120Q',
    'LBXSCR', 'URXUMA', 'URXUCR',
    'Survey_Cycle'
]

target = 'ALT_Enzyme_U_L'

# Keep only features that actually exist
existing_features = [f for f in features_to_use if f in all_data.columns]
print("Using these features:", existing_features)

X = all_data[existing_features]
y = all_data['ALT_Enzyme_U_L']

Using these features: ['Gender', 'Age', 'RIDRETH1', 'LBXSAPSI', 'LBXSCR', 'Survey_Cycle']
