# Survival Prediction of Lung Cancer Patients Based on Different Feature Selection Methods Using CNN-Cox Models

## Import the Libraries

In [989]:
pip install lifelines



In [990]:
pip install bioservices



In [991]:
import pandas as pd
import numpy as np
import random
import tensorflow as tf
import requests
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold,train_test_split
from lifelines.utils import concordance_index
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense,Dropout,Input,Flatten,concatenate
from bioservices import KEGG

## Prepare the dataset

### Clinical data

#### Load the dataset

In [992]:
clinical_data = pd.read_csv('/content/luad_clinical_patient.txt', sep='\t')

#### Select the relevant columns for analysis

In [993]:
clinical_data = clinical_data[['Patient Identifier', 'Overall Survival Status', 'Overall Survival (Months)']].loc[4:]
clinical_data = clinical_data.rename(columns = {'Overall Survival Status' : 'OS_STATUS', 'Overall Survival (Months)' : 'OS_MONTHS'})

print(clinical_data.head(10))

   Patient Identifier   OS_STATUS OS_MONTHS
4        TCGA-05-4244    0:LIVING         0
5        TCGA-05-4245    0:LIVING     23.98
6        TCGA-05-4249    0:LIVING     50.03
7        TCGA-05-4250  1:DECEASED      3.98
8        TCGA-05-4382    0:LIVING     19.94
9        TCGA-05-4384    0:LIVING     13.99
10       TCGA-05-4389    0:LIVING     44.97
11       TCGA-05-4390    0:LIVING     36.99
12       TCGA-05-4395  1:DECEASED         0
13       TCGA-05-4396  1:DECEASED      9.95


#### Convert OS_MONTHS values into binary format (1 for deceased, 0 for living)

In [994]:
clinical_data["OS_STATUS"] = clinical_data["OS_STATUS"].apply(lambda x : 1 if x.find('1:') != -1 else 0)

print(clinical_data.head(10))

   Patient Identifier  OS_STATUS OS_MONTHS
4        TCGA-05-4244          0         0
5        TCGA-05-4245          0     23.98
6        TCGA-05-4249          0     50.03
7        TCGA-05-4250          1      3.98
8        TCGA-05-4382          0     19.94
9        TCGA-05-4384          0     13.99
10       TCGA-05-4389          0     44.97
11       TCGA-05-4390          0     36.99
12       TCGA-05-4395          1         0
13       TCGA-05-4396          1      9.95


#### Check for unnecessary data

In [995]:
print(clinical_data['OS_STATUS'].value_counts())
print(clinical_data['OS_MONTHS'].value_counts().head(10))

OS_STATUS
0    334
1    188
Name: count, dtype: int64
OS_MONTHS
[Not Available]    9
25.99              4
0                  4
15.64              3
14.29              3
12.65              3
18.66              3
20.04              3
29.43              3
18.99              3
Name: count, dtype: int64


#### Drop rows where OS_MONTHS is 0 or '[Not Available]'

In [996]:
clinical_data = clinical_data[
    (clinical_data['OS_MONTHS'] != '0') &
    (clinical_data['OS_MONTHS'] != '[Not Available]')
]

print(clinical_data.head(10))

   Patient Identifier  OS_STATUS OS_MONTHS
5        TCGA-05-4245          0     23.98
6        TCGA-05-4249          0     50.03
7        TCGA-05-4250          1      3.98
8        TCGA-05-4382          0     19.94
9        TCGA-05-4384          0     13.99
10       TCGA-05-4389          0     44.97
11       TCGA-05-4390          0     36.99
13       TCGA-05-4396          1      9.95
14       TCGA-05-4397          1     24.01
15       TCGA-05-4398          0     47.01


#### Reset the index to start from 0

In [997]:
clinical_data = clinical_data.reset_index(drop=True)

print(clinical_data.head(10))

  Patient Identifier  OS_STATUS OS_MONTHS
0       TCGA-05-4245          0     23.98
1       TCGA-05-4249          0     50.03
2       TCGA-05-4250          1      3.98
3       TCGA-05-4382          0     19.94
4       TCGA-05-4384          0     13.99
5       TCGA-05-4389          0     44.97
6       TCGA-05-4390          0     36.99
7       TCGA-05-4396          1      9.95
8       TCGA-05-4397          1     24.01
9       TCGA-05-4398          0     47.01


### Gene data

#### Load the dataset

In [998]:
gene_data = pd.read_csv('/content/luad_gene_data.txt', sep='\t').T

print(gene_data)

                        0          1          2        3       4       5      \
Hugo_Symbol      LOC100130426   UBE2Q2P3   UBE2Q2P3  HMGB1P1  TIMM23   MOXD2   
Entrez_Gene_Id      100130426  100133144  100134869    10357   10431  136542   
TCGA-05-4244-01       -2.2883      0.038     0.0691  -1.9057 -0.0395     NaN   
TCGA-05-4249-01       -2.2883    -0.3514     0.1971   -0.295  0.1945     NaN   
TCGA-05-4250-01       -2.2883    -0.3435    -0.7239  -1.9091  0.7761     NaN   
...                       ...        ...        ...      ...     ...     ...   
TCGA-NJ-A55O-01       -2.2883     0.5729     1.0176  -0.0218  0.0408     NaN   
TCGA-NJ-A55R-01       -2.2883    -0.1679    -0.0462  -0.8099 -0.3206     NaN   
TCGA-NJ-A7XG-01       -2.2883     1.8645     2.7613  -0.4522 -0.6611     NaN   
TCGA-O1-A52J-01       -2.2883     0.4532      1.087  -1.3473  0.7679     NaN   
TCGA-S2-AA1A-01       -2.2883     0.9225    -0.2293  -0.7319  -1.161     NaN   

                     6         7       

#### Change the column name as gene symbol

In [999]:
gene_data.columns = gene_data.iloc[0]
gene_data = gene_data.drop(['Hugo_Symbol', 'Entrez_Gene_Id'])
gene_data = gene_data.astype(float)

print(gene_data)

Hugo_Symbol      LOC100130426  UBE2Q2P3  UBE2Q2P3  HMGB1P1  TIMM23  MOXD2  \
TCGA-05-4244-01       -2.2883    0.0380    0.0691  -1.9057 -0.0395    NaN   
TCGA-05-4249-01       -2.2883   -0.3514    0.1971  -0.2950  0.1945    NaN   
TCGA-05-4250-01       -2.2883   -0.3435   -0.7239  -1.9091  0.7761    NaN   
TCGA-05-4382-01       -2.2883    0.1873   -0.4402  -0.5333 -0.1787    NaN   
TCGA-05-4384-01       -2.2883   -1.2251   -1.3555  -0.8895 -1.1778    NaN   
...                       ...       ...       ...      ...     ...    ...   
TCGA-NJ-A55O-01       -2.2883    0.5729    1.0176  -0.0218  0.0408    NaN   
TCGA-NJ-A55R-01       -2.2883   -0.1679   -0.0462  -0.8099 -0.3206    NaN   
TCGA-NJ-A7XG-01       -2.2883    1.8645    2.7613  -0.4522 -0.6611    NaN   
TCGA-O1-A52J-01       -2.2883    0.4532    1.0870  -1.3473  0.7679    NaN   
TCGA-S2-AA1A-01       -2.2883    0.9225   -0.2293  -0.7319 -1.1610    NaN   

Hugo_Symbol      LOC155060  RNU12-2P    SSX9  LOC317712  ...    ZXDA    ZXD

#### Drop columns with any NaN values

In [1000]:
gene_data = gene_data.dropna(axis=1)

print(gene_data)

Hugo_Symbol      LOC100130426  UBE2Q2P3  UBE2Q2P3  HMGB1P1  TIMM23  LOC155060  \
TCGA-05-4244-01       -2.2883    0.0380    0.0691  -1.9057 -0.0395     1.0624   
TCGA-05-4249-01       -2.2883   -0.3514    0.1971  -0.2950  0.1945    -0.0690   
TCGA-05-4250-01       -2.2883   -0.3435   -0.7239  -1.9091  0.7761    -1.4074   
TCGA-05-4382-01       -2.2883    0.1873   -0.4402  -0.5333 -0.1787     0.5870   
TCGA-05-4384-01       -2.2883   -1.2251   -1.3555  -0.8895 -1.1778     0.7614   
...                       ...       ...       ...      ...     ...        ...   
TCGA-NJ-A55O-01       -2.2883    0.5729    1.0176  -0.0218  0.0408     0.6719   
TCGA-NJ-A55R-01       -2.2883   -0.1679   -0.0462  -0.8099 -0.3206     1.1540   
TCGA-NJ-A7XG-01       -2.2883    1.8645    2.7613  -0.4522 -0.6611     1.0220   
TCGA-O1-A52J-01       -2.2883    0.4532    1.0870  -1.3473  0.7679     0.1428   
TCGA-S2-AA1A-01       -2.2883    0.9225   -0.2293  -0.7319 -1.1610     1.1789   

Hugo_Symbol      RNU12-2P  

#### Remove the last '-01' part in index

In [1001]:
gene_data.index = gene_data.index.astype(str)
gene_data.index = gene_data.index.str.replace(r'-\d+$', '', regex=True)

print(gene_data)

Hugo_Symbol   LOC100130426  UBE2Q2P3  UBE2Q2P3  HMGB1P1  TIMM23  LOC155060  \
TCGA-05-4244       -2.2883    0.0380    0.0691  -1.9057 -0.0395     1.0624   
TCGA-05-4249       -2.2883   -0.3514    0.1971  -0.2950  0.1945    -0.0690   
TCGA-05-4250       -2.2883   -0.3435   -0.7239  -1.9091  0.7761    -1.4074   
TCGA-05-4382       -2.2883    0.1873   -0.4402  -0.5333 -0.1787     0.5870   
TCGA-05-4384       -2.2883   -1.2251   -1.3555  -0.8895 -1.1778     0.7614   
...                    ...       ...       ...      ...     ...        ...   
TCGA-NJ-A55O       -2.2883    0.5729    1.0176  -0.0218  0.0408     0.6719   
TCGA-NJ-A55R       -2.2883   -0.1679   -0.0462  -0.8099 -0.3206     1.1540   
TCGA-NJ-A7XG       -2.2883    1.8645    2.7613  -0.4522 -0.6611     1.0220   
TCGA-O1-A52J       -2.2883    0.4532    1.0870  -1.3473  0.7679     0.1428   
TCGA-S2-AA1A       -2.2883    0.9225   -0.2293  -0.7319 -1.1610     1.1789   

Hugo_Symbol   RNU12-2P    SSX9   EZHIP  EFCAB8  ...    ZXDA    

#### Check and control the duplicated genes

In [1002]:
duplicated_genes = gene_data.columns[gene_data.columns.duplicated()]

print(duplicated_genes)
print(len(duplicated_genes))

Index(['UBE2Q2P3', 'CC2D2B', 'CCDC7', 'CYorf15B', 'C1orf84', 'LINC00875',
       'ELMOD1', 'NBPF16', 'NEBL', 'NKAIN3', 'C5orf23', 'PALM2AKAP2',
       'PLEKHG7', 'QSOX1', 'SH3D20', 'SNAP47', 'NCRNA00185'],
      dtype='object', name='Hugo_Symbol')
17


In [1003]:
for gene in duplicated_genes:
  var_gene = np.var(gene_data[gene])

  if var_gene.values[0] < var_gene.values[1]:
    gene_data[gene] = gene_data[gene].iloc[:, 1]
  else:
    gene_data[gene] = gene_data[gene].iloc[:, 0]

gene_data = gene_data.iloc[:, ~gene_data.columns.duplicated()]
print(gene_data)

  return var(axis=axis, dtype=dtype, out=out, ddof=ddof, **kwargs)


Hugo_Symbol   LOC100130426  UBE2Q2P3  HMGB1P1  TIMM23  LOC155060  RNU12-2P  \
TCGA-05-4244       -2.2883    0.0380  -1.9057 -0.0395     1.0624    0.5387   
TCGA-05-4249       -2.2883   -0.3514  -0.2950  0.1945    -0.0690    1.4599   
TCGA-05-4250       -2.2883   -0.3435  -1.9091  0.7761    -1.4074   -2.1796   
TCGA-05-4382       -2.2883    0.1873  -0.5333 -0.1787     0.5870   -0.6958   
TCGA-05-4384       -2.2883   -1.2251  -0.8895 -1.1778     0.7614   -0.3706   
...                    ...       ...      ...     ...        ...       ...   
TCGA-NJ-A55O       -2.2883    0.5729  -0.0218  0.0408     0.6719    0.7543   
TCGA-NJ-A55R       -2.2883   -0.1679  -0.8099 -0.3206     1.1540   -0.1200   
TCGA-NJ-A7XG       -2.2883    1.8645  -0.4522 -0.6611     1.0220   -2.1796   
TCGA-O1-A52J       -2.2883    0.4532  -1.3473  0.7679     0.1428    2.9555   
TCGA-S2-AA1A       -2.2883    0.9225  -0.7319 -1.1610     1.1789   -0.2677   

Hugo_Symbol     SSX9   EZHIP  EFCAB8  SRP14P1  ...    ZXDA    Z

## Feature Selection

### 1. Variance-based Univariate Selection(Top 500)

#### Calculate variance by column

In [1004]:
variances = gene_data.var(axis=0)
print(variances)

Hugo_Symbol
LOC100130426    0.245146
UBE2Q2P3        1.248637
HMGB1P1         1.001940
TIMM23          1.001939
LOC155060       1.001941
                  ...   
ZYX             1.001936
FLJ10821        1.001935
ZZZ3            1.001940
TPTEP1          1.001939
AKR1C6P         0.416599
Length: 20096, dtype: float64


In [1005]:
top_100_variance = variances.sort_values(ascending=False).head(100)
print(top_100_variance)

Hugo_Symbol
OR6K2        57.466105
OR4C45       19.551354
KRTAP24-1    15.872888
OR7C3         5.544010
ZNHIT2        4.114460
               ...    
FLJ32662      1.539155
MLANA         1.537218
DPPA2P3       1.536937
SUMO4         1.536818
PRNT          1.536028
Length: 100, dtype: float64


#### Drop other columns

In [1006]:
gene_data_100 = gene_data[top_100_variance.index]
print(gene_data_100)

Hugo_Symbol      OR6K2   OR4C45  KRTAP24-1    OR7C3  ZNHIT2   DUX4L2  \
TCGA-05-4244 -121.9952 -71.1536   -64.1101 -30.9537 -0.0887 -31.5977   
TCGA-05-4249 -121.9952 -71.1536   -64.1101 -30.9537 -0.7377 -31.5977   
TCGA-05-4250 -121.9952 -71.1536   -64.1101 -30.9537 -0.1711 -31.5977   
TCGA-05-4382 -121.9952 -71.1536   -64.1101 -30.9537 -0.6707 -31.5977   
TCGA-05-4384 -121.9952 -71.1536   -64.1101 -30.9537  0.0768 -31.5977   
...                ...      ...        ...      ...     ...      ...   
TCGA-NJ-A55O -121.9952 -71.1536   -64.1101 -30.9537  0.2712 -31.5977   
TCGA-NJ-A55R -121.9952 -71.1536   -64.1101 -30.9537  0.1254 -31.5977   
TCGA-NJ-A7XG -121.9952 -71.1536   -64.1101 -30.9537  0.8624 -31.5977   
TCGA-O1-A52J -121.9952 -71.1536   -64.1101 -30.9537  0.4563 -31.5977   
TCGA-S2-AA1A -121.9952 -71.1536   -64.1101 -30.9537  0.5036 -31.5977   

Hugo_Symbol   TMEM189-UBE2V1  KRTAP20-2  OR12D3  ZFP91-CNTF  ...   GPR52  \
TCGA-05-4244         -3.5024   -23.5867  -27.29     -3.2558

#### Normalize the gene dataset (Z-score)

In [1007]:
numeric_cols = gene_data_100.columns[1:]
gene_data_100.loc[:, numeric_cols] = StandardScaler().fit_transform(gene_data_100[numeric_cols])
print(gene_data_100)

Hugo_Symbol      OR6K2    OR4C45  KRTAP24-1     OR7C3    ZNHIT2    DUX4L2  \
TCGA-05-4244 -121.9952 -0.062312   -0.06231 -0.076357  0.211212 -0.062286   
TCGA-05-4249 -121.9952 -0.062312   -0.06231 -0.076357 -0.109052 -0.062286   
TCGA-05-4250 -121.9952 -0.062312   -0.06231 -0.076357  0.170550 -0.062286   
TCGA-05-4382 -121.9952 -0.062312   -0.06231 -0.076357 -0.075990 -0.062286   
TCGA-05-4384 -121.9952 -0.062312   -0.06231 -0.076357  0.292882 -0.062286   
...                ...       ...        ...       ...       ...       ...   
TCGA-NJ-A55O -121.9952 -0.062312   -0.06231 -0.076357  0.388813 -0.062286   
TCGA-NJ-A55R -121.9952 -0.062312   -0.06231 -0.076357  0.316865 -0.062286   
TCGA-NJ-A7XG -121.9952 -0.062312   -0.06231 -0.076357  0.680555 -0.062286   
TCGA-O1-A52J -121.9952 -0.062312   -0.06231 -0.076357  0.480155 -0.062286   
TCGA-S2-AA1A -121.9952 -0.062312   -0.06231 -0.076357  0.503497 -0.062286   

Hugo_Symbol   TMEM189-UBE2V1  KRTAP20-2    OR12D3  ZFP91-CNTF  ...     GPR5

### 2. KEGG pathway

### 3. Reactome pathway

### 4. Cascaded Wx

## Merge the clinical data and the gene data

In [1008]:
merged_data = pd.merge(gene_data_100,
                       clinical_data,
                       left_on=gene_data_100.index,
                       right_on='Patient Identifier',
                       how='inner')

print(merged_data)

        OR6K2    OR4C45  KRTAP24-1     OR7C3    ZNHIT2    DUX4L2  \
0   -121.9952 -0.062312   -0.06231 -0.076357 -0.109052 -0.062286   
1   -121.9952 -0.062312   -0.06231 -0.076357  0.170550 -0.062286   
2   -121.9952 -0.062312   -0.06231 -0.076357 -0.075990 -0.062286   
3   -121.9952 -0.062312   -0.06231 -0.076357  0.292882 -0.062286   
4   -121.9952 -0.062312   -0.06231 -0.076357 -0.669737 -0.062286   
..        ...       ...        ...       ...       ...       ...   
499 -121.9952 -0.062312   -0.06231 -0.076357  0.388813 -0.062286   
500 -121.9952 -0.062312   -0.06231 -0.076357  0.316865 -0.062286   
501 -121.9952 -0.062312   -0.06231 -0.076357  0.680555 -0.062286   
502 -121.9952 -0.062312   -0.06231 -0.076357  0.480155 -0.062286   
503 -121.9952 -0.062312   -0.06231 -0.076357  0.503497 -0.062286   

     TMEM189-UBE2V1  KRTAP20-2    OR12D3  ZFP91-CNTF  ...    BTBD18   C9orf50  \
0         -0.737549  -0.076329 -0.062276    0.913967  ... -0.401240 -0.581424   
1         -0.737549  

## Model(CNN-Cox)

In [1009]:
def traindcnncoxmodel(merged_data, cancer_name, conv1, conv1_size, dense, input_shape, save_path, le, wi):


    # 생존 상태 및 생존 시간
    E = np.array(merged_data["OS_STATUS"])  # 생존 상태 (0: censored, 1: event)
    Y = np.array(merged_data["OS_MONTHS"])  # 생존 시간

    # 유전자 발현 데이터 (특성 선택)
    gene_columns = [col for col in merged_data.columns if col not in ["OS_STATUS", "OS_MONTHS", "Patient Identifier"]]
    X = np.array(merged_data[gene_columns]).astype('float64')
    Y = Y.astype('float64')

    # 표준화
    scaler = StandardScaler().fit(X)
    X = scaler.transform(X)

    score_tst_list = []
    score_dev_list = []

    kf_outer = StratifiedKFold(n_splits=5, random_state=1, shuffle=True)

    for i, (outer_train_idx, outer_test_idx) in enumerate(kf_outer.split(X, E)):
        fold = i + 1
        ci_tst_list = []
        ci_dev_list = []

        x_tst, c_tst, s_tst = X[outer_test_idx], E[outer_test_idx], Y[outer_test_idx]
        x_trn_full, c_trn_full, s_trn_full = X[outer_train_idx], E[outer_train_idx], Y[outer_train_idx]

        x_trn, x_dev, c_trn, c_dev, s_trn, s_dev = train_test_split(
            x_trn_full, c_trn_full, s_trn_full, test_size=0.1, stratify=c_trn_full, random_state=1
        )

        # 데이터 정렬
        sort_idx = np.argsort(s_trn)[::-1]
        x_trn, s_trn, c_trn = x_trn[sort_idx], s_trn[sort_idx], c_trn[sort_idx]

        # 데이터 reshape
        x_trn = x_trn.reshape(-1, le, wi, 1)
        x_dev = x_dev.reshape(-1, le, wi, 1)
        x_tst = x_tst.reshape(-1, le, wi, 1)

        # 모델 경로 설정
        modelpath = save_path + f"{cancer_name}_fold_{fold}_repeat_{i+1}_{le*wi}.weights.h5"
        checkpoint = MyCallback(modelpath, (x_trn, c_trn, s_trn, x_dev, c_dev, s_dev), fold=fold)

        # 모델 정의 및 컴파일
        model = dcnncox(conv1, conv1_size, dense, input_shape=(le, wi, 1))
        model.compile(loss="mean_squared_error", optimizer="adam")

        # 모델 훈련
        print(f"\n==================== Training Fold {fold} ====================")
        history = model.fit(x_trn, s_trn, batch_size=len(x_trn), epochs=10000, verbose=0, callbacks=[checkpoint], shuffle=False)

        # 모델 가중치 불러오기
        model.load_weights(modelpath)

        # 예측 및 C-index 계산 (Validation)
        hr_pred_dev = np.exp(model.predict(x_dev, batch_size=1, verbose=0))
        ci_dev = concordance_index(s_dev, -hr_pred_dev, c_dev)

        # 예측 및 C-index 계산 (Test)
        hr_pred_tst = np.exp(model.predict(x_tst, batch_size=1, verbose=0))
        ci_tst = concordance_index(s_tst, -hr_pred_tst, c_tst)

        ci_dev_list.append(ci_dev)
        ci_tst_list.append(ci_tst)

        score_dev_list.append(ci_dev_list)
        score_tst_list.append(ci_tst_list)

        print(f'C-index (fold {fold}): {np.mean(ci_tst_list):.4f} +/- {np.std(ci_tst_list):.4f}')

    print(f'\n{cancer_name} - Mean C-index: {np.mean(score_tst_list):.4f} +/- {np.std(score_tst_list):.4f}')

    return score_tst_list


In [1010]:
def dcnncox(conv1=128, conv1_size=(1, 10), dense=64, input_shape=(10, 10, 1)):
    input_img = Input(input_shape)

    tower = Conv2D(conv1, conv1_size, activation='relu')(input_img)
    tower1 = MaxPooling2D(1, 2)(tower)
    tower2 = Flatten()(tower1)

    out = Dense(dense, activation='relu')(tower2)
    last_layer = Dense(1, kernel_initializer='zeros', bias_initializer='zeros')(out)

    model = Model(inputs=[input_img], outputs=last_layer)
    return model

In [1011]:
#loss
def nll(E, NUM_E):
    def loss(y_true, y_pred):
        hazard_ratio = tf.squeeze(tf.exp(y_pred))
        log_risk = tf.math.log(tf.math.cumsum(hazard_ratio))
        uncensored_likelihood = tf.subtract(tf.squeeze(y_pred),log_risk)
        censored_likelihood = uncensored_likelihood * E
        neg_likelihood = -tf.reduce_sum(censored_likelihood) / NUM_E
        return neg_likelihood

    return loss

In [1012]:
def avgcindex(Cindex,cancer_types,numbers):
    cisum=[]
    for i in range(7):
        cancer_name = cancer_types[i]
        number = numbers[i]
        print(cancer_name,np.mean(Cindex[i]))
        cisum.append(np.mean(Cindex[i])*number)

    avgci= np.array(cisum).sum()/np.array(numbers).sum()
    return avgci

In [1013]:
def setup_seed(seed):
    random.seed(seed)  # 为python set random seed
    np.random.seed(seed)  # 为numpy set random seed
    tf.random.set_seed(seed)  # tf cpu fix seed

In [1014]:
class MyCallback(ModelCheckpoint):
    def __init__(self, filepath, data, fold=None, params=None, real_save=True, patience=20, max_epoch=1000):
        super(MyCallback, self).__init__(filepath, save_weights_only=True)
        self.patience = patience
        self.max_epoch = max_epoch
        self.fold = fold
        self.params = params if params else {}

        self.x_trn, self.c_trn, self.s_trn, self.x_dev, self.c_dev, self.s_dev = data

        self.best_cindex = 0
        self.best_epoch = 0
        self.real_save = real_save

    def print_status(self):
        print("\n==================== Fold %d Training Result ====================" % self.fold)
        print(f"- Best Epoch: {self.best_epoch}")
        print(f"- Best C-index: {self.best_cindex:.4f}")
        print("===============================================================")

    def on_train_begin(self, logs=None):
        print(f"\n===== Start Training Fold {self.fold} =====")
        print(f"Train samples: {len(self.x_trn)}, Dev samples: {len(self.x_dev)}")
        print(f"Hyperparameters: {self.params}")
        print("===================================================")

    def on_epoch_end(self, epoch, logs=None):
        pred_dev = -np.exp(self.model.predict(self.x_dev, batch_size=1, verbose=0))
        cindex_dev = concordance_index(self.s_dev, pred_dev, self.c_dev)

        if cindex_dev > self.best_cindex:
            self.best_cindex = cindex_dev
            self.best_epoch = epoch
            if self.real_save:
                if self.save_weights_only:
                    self.model.save_weights(self.filepath, overwrite=True)
                else:
                    self.model.save(self.filepath, overwrite=True)
        elif epoch - self.best_epoch > self.patience:
            self.model.stop_training = True
            print(f"Early stopping at epoch {epoch} (no improvement for {self.patience} epochs)")

        if epoch >= self.max_epoch:
            self.model.stop_training = True
            print(f"Stopping at max epoch {self.max_epoch}")

    def on_train_end(self, logs=None):
        print("[Training Ended]")
        self.print_status()


In [1015]:
# 예시로 LUAD 사용
cancer_name = 'LUAD'

# 모델 파라미터 설정
conv1 = 64
dense = 16
conv1_size = (1, 10)
input_shape = (10, 10, 1)
le, wi = 10, 10

save_path="/content/"

# 모델 훈련 및 C-index 계산
score_tst_list = traindcnncoxmodel(merged_data, cancer_name, conv1, conv1_size, dense, input_shape, save_path, le, wi)

# C-index 출력
print(f"Test C-index for {cancer_name}: {np.mean(score_tst_list):.4f} +/- {np.std(score_tst_list):.4f}")



===== Start Training Fold 1 =====
Train samples: 362, Dev samples: 41
Hyperparameters: {'verbose': 0, 'epochs': 10000, 'steps': 1}


Expected: ['keras_tensor_354']
Received: inputs=Tensor(shape=(362, 10, 10, 1))
Expected: ['keras_tensor_354']
Received: inputs=Tensor(shape=(1, 10, 10, 1))


Early stopping at epoch 53 (no improvement for 20 epochs)
[Training Ended]

- Best Epoch: 32
- Best C-index: 0.4880
C-index (fold 1): 0.5303 +/- 0.0000


===== Start Training Fold 2 =====
Train samples: 362, Dev samples: 41
Hyperparameters: {'verbose': 0, 'epochs': 10000, 'steps': 1}


Expected: ['keras_tensor_360']
Received: inputs=Tensor(shape=(362, 10, 10, 1))
Expected: ['keras_tensor_360']
Received: inputs=Tensor(shape=(1, 10, 10, 1))


Early stopping at epoch 44 (no improvement for 20 epochs)
[Training Ended]

- Best Epoch: 23
- Best C-index: 0.6006
C-index (fold 2): 0.4993 +/- 0.0000


===== Start Training Fold 3 =====
Train samples: 362, Dev samples: 41
Hyperparameters: {'verbose': 0, 'epochs': 10000, 'steps': 1}


Expected: ['keras_tensor_366']
Received: inputs=Tensor(shape=(362, 10, 10, 1))
Expected: ['keras_tensor_366']
Received: inputs=Tensor(shape=(1, 10, 10, 1))


Early stopping at epoch 24 (no improvement for 20 epochs)
[Training Ended]

- Best Epoch: 3
- Best C-index: 0.6052
C-index (fold 3): 0.3863 +/- 0.0000


===== Start Training Fold 4 =====
Train samples: 362, Dev samples: 41
Hyperparameters: {'verbose': 0, 'epochs': 10000, 'steps': 1}


Expected: ['keras_tensor_372']
Received: inputs=Tensor(shape=(362, 10, 10, 1))
Expected: ['keras_tensor_372']
Received: inputs=Tensor(shape=(1, 10, 10, 1))


Early stopping at epoch 82 (no improvement for 20 epochs)
[Training Ended]

- Best Epoch: 61
- Best C-index: 0.6554
C-index (fold 4): 0.4899 +/- 0.0000


===== Start Training Fold 5 =====
Train samples: 363, Dev samples: 41
Hyperparameters: {'verbose': 0, 'epochs': 10000, 'steps': 1}


Expected: ['keras_tensor_378']
Received: inputs=Tensor(shape=(363, 10, 10, 1))
Expected: ['keras_tensor_378']
Received: inputs=Tensor(shape=(1, 10, 10, 1))


Early stopping at epoch 21 (no improvement for 20 epochs)
[Training Ended]

- Best Epoch: 0
- Best C-index: 0.6309
C-index (fold 5): 0.5410 +/- 0.0000

LUAD - Mean C-index: 0.4894 +/- 0.0549
Test C-index for LUAD: 0.4894 +/- 0.0549


### Result