<a href="https://colab.research.google.com/github/JinLeeGG/Survival-Prediction-Model-for-AML-using-Gene-Expression-Data-from-TCGA/blob/main/3.%20Machine_Learning_Modeling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building a Machine Learning Model that predicts patient's status based on RNA Expression data

- This project is to develop a machine learning model capable of predicting patient's survival status (Alive/Deceased) based on high-dimensional RNA-seq gene expression data.

## Data Preparation and Integration:
- The Data is prepared during preprocessing:

  - Clinical Data: Contained patient-specific information, including survival status (Status) and observation period (Observation Period).

  - RNA Expression Data: A large matrix containing expression levels for over 20,000 genes for each patient.

- These datasets were cleaned and merged into a single, comprehensive DataFrame (merged_df) using the patient barcode (bcr_patient_barcode) as the common identifier.

## Model selection:
- Logistic Regression
- Random Forest




In [20]:
# import libraries
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Load Dataset

In [21]:
merged_df = pd.read_csv('/content/drive/MyDrive/Acute Myeloid Leukemia (TCGA, PanCancer Atlas)/Processed_Data/merged_df.csv')
merged_df.dropna(subset=['Observation Period'], inplace=True)
merged_df

Unnamed: 0,bcr_patient_barcode,Status,Observation Period,A1BG-AS|503538,A1BG|1,A1CF|29974,A2LD1|87769,A2ML1|144568,A2M|2,A4GALT|53947,...,ZWINT|11130,ZXDA|7789,ZXDB|158586,ZXDC|79364,ZYG11B|79699,ZYX|7791,ZZEF1|23140,ZZZ3|26009,psiTPTE22|387590,tAKR|389932
0,TCGA-AB-2803,1,792.0,792.14,1139.18,0.00,194.50,24.36,982.14,24.98,...,555.04,67.00,795.76,3093.76,1114.18,9613.40,5332.46,2452.22,33.00,7.78
1,TCGA-AB-2805,1,576.0,429.64,403.44,0.00,227.10,33.66,193.26,5.00,...,1360.22,41.88,912.38,5481.82,3834.64,18642.30,12197.30,3494.92,19.22,26.34
2,TCGA-AB-2806,1,944.0,891.18,1004.70,0.00,179.84,45.82,129.92,55.72,...,1623.44,231.76,2251.04,6184.50,1696.52,18565.60,12208.00,6213.06,332.12,1.00
3,TCGA-AB-2807,1,180.0,1095.44,1121.68,1.00,111.06,11.08,884.28,272.40,...,1575.48,283.66,1559.34,2978.60,1990.44,7733.44,9364.42,2986.18,51.88,22.28
4,TCGA-AB-2808,0,2861.0,570.74,531.26,0.00,123.08,21.64,757.42,537.62,...,2168.70,106.86,1111.84,3922.22,2723.36,10197.40,8040.82,3697.18,47.70,5.92
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
174,TCGA-AB-3007,0,1581.0,1561.60,1503.12,0.00,321.06,18.34,201.94,776.14,...,1480.26,288.32,1284.20,6217.06,1926.22,6308.06,12576.40,3929.76,30.68,24.38
175,TCGA-AB-3008,1,822.0,1052.62,824.12,0.00,113.20,73.24,2400.86,520.90,...,3793.34,349.88,1349.90,5590.80,2401.34,16703.00,12810.40,3548.82,155.78,0.00
176,TCGA-AB-3009,1,576.0,489.64,514.78,0.00,506.60,229.24,648.56,38.02,...,1133.66,212.18,1078.90,5370.92,2515.52,23951.40,12152.40,3942.60,1493.54,1.00
177,TCGA-AB-3011,0,1885.0,899.50,736.42,0.00,93.70,20.36,162.48,60.20,...,1464.88,87.22,699.92,5071.14,1535.64,9142.90,9567.60,3060.22,47.78,1.04


### Define target, features

In [22]:
# Define the target variable
y = merged_df['Status']

# Define features by dropping the non-gene and target columns
X = merged_df.drop(columns=['bcr_patient_barcode', 'Status', 'Observation Period'])

# Display the shapes
print("Features shape (X):", X.shape)
print("Target shape (y):", y.shape)

Features shape (X): (167, 20319)
Target shape (y): (167,)


### Split the data (train/test)

In [23]:
from sklearn.model_selection import train_test_split

# Split the data, keeping the proportion of classes the same (stratify)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=2025, stratify=y
)

# **Logistic Regression Model**

In [24]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score

# Initialize and train the model
log_reg = LogisticRegression(random_state=2025, max_iter=1000)
log_reg.fit(X_train, y_train)

# Make predictions
y_pred_lr = log_reg.predict(X_test)

# Evaluate the model
print("--- Logistic Regression Results ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred_lr):.2f}")
print(classification_report(y_test, y_pred_lr))

--- Logistic Regression Results ---
Accuracy: 0.56
              precision    recall  f1-score   support

           0       0.41      0.58      0.48        12
           1       0.71      0.55      0.62        22

    accuracy                           0.56        34
   macro avg       0.56      0.56      0.55        34
weighted avg       0.60      0.56      0.57        34



# **Random Forest Model**

In [25]:
from sklearn.ensemble import RandomForestClassifier

# Initialize and train the model
rand_forest = RandomForestClassifier(random_state=2025, n_estimators=100)
rand_forest.fit(X_train, y_train)

# Make predictions
y_pred_rf = rand_forest.predict(X_test)

# Evaluate the model
print("--- Random Forest Results ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred_rf):.2f}")
print(classification_report(y_test, y_pred_rf))

--- Random Forest Results ---
Accuracy: 0.65
              precision    recall  f1-score   support

           0       0.50      0.33      0.40        12
           1       0.69      0.82      0.75        22

    accuracy                           0.65        34
   macro avg       0.60      0.58      0.57        34
weighted avg       0.62      0.65      0.63        34



## **Filtering genes by Log-rank test method**

In [26]:
!pip install lifelines

Collecting lifelines
  Downloading lifelines-0.30.0-py3-none-any.whl.metadata (3.2 kB)
Collecting autograd-gamma>=0.3 (from lifelines)
  Downloading autograd-gamma-0.5.0.tar.gz (4.0 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting formulaic>=0.2.2 (from lifelines)
  Downloading formulaic-1.2.0-py3-none-any.whl.metadata (7.0 kB)
Collecting interface-meta>=1.2.0 (from formulaic>=0.2.2->lifelines)
  Downloading interface_meta-1.3.0-py3-none-any.whl.metadata (6.7 kB)
Downloading lifelines-0.30.0-py3-none-any.whl (349 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m349.3/349.3 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading formulaic-1.2.0-py3-none-any.whl (117 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading interface_meta-1.3.0-py3-none-any.whl (14 kB)
Building wheels for collected packages: autograd-gamma
  Building wheel for autograd-gamma (se

In [27]:
# Import necessary libraries
from lifelines.statistics import logrank_test
import pandas as pd
from tqdm import tqdm # A library to visualize loop progress

# --- 1. Data Preparation ---
# 'merged_df' is the DataFrame containing both clinical and gene expression data.

# 1. Prepare feature (X) data: Exclude non-gene expression columns from merged_df.
X_full = merged_df.drop(columns=['bcr_patient_barcode', 'Status', 'Observation Period'])

# 2. Specify the DataFrame containing survival information (time and event).
df_survival_final = merged_df

# Initialize a list to store the results.
results = []

# Use tqdm to monitor the progress of the loop over all genes.
for gene in tqdm(X_full.columns, desc="Analyzing all genes (Log-rank)"):
    try:
        # Get the expression data for the current gene
        gene_expression = X_full[gene]

        # Divide patients into high/low expression groups based on the median.
        median_expression = gene_expression.median()

        # If the median is 0 (most values are 0), skip this gene and move to the next.
        if median_expression == 0:
            continue

        # Create group filters
        high_group_filter = gene_expression >= median_expression
        low_group_filter = gene_expression < median_expression

        # Perform the log-rank test (column names updated for user's data)
        result = logrank_test(
            durations_A=df_survival_final.loc[high_group_filter, 'Observation Period'],
            durations_B=df_survival_final.loc[low_group_filter, 'Observation Period'],
            event_observed_A=df_survival_final.loc[high_group_filter, 'Status'],
            event_observed_B=df_survival_final.loc[low_group_filter, 'Status']
        )

        # Append the result (gene name and p-value) to the list.
        results.append({'gene': gene, 'p_value': result.p_value})

    except Exception as e:
        # Handle any exceptions during analysis to prevent the loop from crashing.
        # print(f"Error analyzing {gene}: {e}") # Uncomment to debug specific errors
        continue

# Convert the results list to a DataFrame.
results_df = pd.DataFrame(results)

# Filter for significant results (p-value <= 0.05) and sort by p-value.
significant_genes_logrank = results_df[results_df['p_value'] <= 0.05].sort_values(by='p_value')

# Print the final results
print("\n--- Top 20 Genes Significantly Associated with Survival ---")
print(significant_genes_logrank.head(20))


Analyzing all genes (Log-rank): 100%|██████████| 20319/20319 [06:46<00:00, 49.97it/s]


--- Top 20 Genes Significantly Associated with Survival ---
                   gene       p_value
10432       PARP3|10039  1.985124e-08
5479          FIBP|9158  2.350450e-07
10931      PLA2G4A|5321  1.101993e-06
15238     TOMM40L|84134  1.322534e-06
8331       LPCAT3|10162  1.405197e-06
15357      TREML2|79865  2.198326e-06
8630       MAP4K1|11184  2.424977e-06
2755          CCND3|896  2.529626e-06
1061      ATP13A2|23400  2.575356e-06
12172     RHOBTB2|23221  2.832890e-06
9349      MYBPHL|343263  2.916382e-06
3265         CLCN5|1184  3.481599e-06
4004        DCTN2|10540  3.768060e-06
14449       SYTL4|94121  3.930102e-06
1283        BCKDK|10295  7.396359e-06
4201       DIRC3|729582  8.489778e-06
12489      RPS6KA1|6195  8.825879e-06
12212       RINL|126432  9.030160e-06
10587        PDE3B|5140  9.804775e-06
1502   C10orf128|170371  1.325526e-05





In [28]:
print(len(significant_genes_logrank)) # 2825 genes are filtered out (p-value < 0.05)

2825


# Random Forest Model

In [29]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score

# --- 1. Data Preparation ---

# Get the list of significant gene names
selected_genes = significant_genes_logrank['gene'].tolist()

# Create the feature set (X) using the selected significant genes.
X = merged_df[selected_genes]

# Specify the target variable (y) as the survival status ('Status').
y = merged_df['Status']


# --- 2. Train and Test Data Split ---

# Split the data into training (80%) and testing (20%) sets.
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=2025, stratify=y
)


# --- 3. Model Training and Evaluation ---

# Initialize and train the Random Forest model.
rf_model = RandomForestClassifier(n_estimators=100, random_state=2025, class_weight='balanced')
rf_model.fit(X_train, y_train)

# Make predictions on the test data.
y_pred = rf_model.predict(X_test)

# Print the final model performance.
print("--- Random Forest Model Performance (Based on Significant Genes) ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
print(classification_report(y_test, y_pred))

--- 🌳 Random Forest Model Performance (Based on Significant Genes) ---
Accuracy: 0.676
              precision    recall  f1-score   support

           0       0.55      0.50      0.52        12
           1       0.74      0.77      0.76        22

    accuracy                           0.68        34
   macro avg       0.64      0.64      0.64        34
weighted avg       0.67      0.68      0.67        34



# Logistic Regression model

In [30]:
# Initialize and train the model
log_reg = LogisticRegression(random_state=2025, max_iter=1000)
log_reg.fit(X_train, y_train)

# Make predictions
y_pred_lr = log_reg.predict(X_test)

# Evaluate the model
print("--- Logistic Regression Results ---")
print(f"Accuracy: {accuracy_score(y_test, y_pred_lr):.2f}")
print(classification_report(y_test, y_pred_lr))

--- Logistic Regression Results ---
Accuracy: 0.47
              precision    recall  f1-score   support

           0       0.33      0.50      0.40        12
           1       0.62      0.45      0.53        22

    accuracy                           0.47        34
   macro avg       0.48      0.48      0.46        34
weighted avg       0.52      0.47      0.48        34

