# Contrastive Learning for Predicting Cancer Prognosis Using Gene Expression Values

## Sample Model Training

TrainCL4CaPro.ipynb notebook is offering comprehensive step-by-step instructions on how to effortlessly train a CL4CaPro model from the ground up and validate its performance.

### Prepare Dataset

Put TCGA-CDR-SupplementalTableS1.xlsx in the same path with *ExtractData.py*

#### ExtractData from TCGA Table

In [None]:
! python ExtractData.py

Then put EBPlusPlusAdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.tsv in the same path and run following script by modifying cancer name you want

#### Define Functions

In [None]:
import pandas as pd

def loadGen():
    Gendata = pd.read_csv('EBPlusPlusAdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.tsv',sep='\t')
    gene_id_list = Gendata['gene_id'].tolist()
    return Gendata, gene_id_list

def loadCancer(CancerType):
    CancerDataPath = './' + CancerType + 'Data.csv'
    data = pd.read_csv(CancerDataPath)
    return data

def AggregateInfo(data):
    Agg_Info = []
    for index, row in data.iterrows():
        bcr = row['bcr_patient_barcode'].split('-')
        PFI = row['PFI.1'] #row['DSS_cr']
        PFItime = row['PFI.time.1'] #row['DSS.time.cr']
        try:
            if int(PFI) != 0:
                PFI = 1
        except:
            PFI = 1
        Agg_Info.append([bcr, PFI, PFItime])
    return Agg_Info

def PreproData(Agg_Info, Gendata, CancerType):
    PreprocessedData = []
    for bcr, PFI, PFItime in Agg_Info:
        Agg_Gen_Info = ['-'.join(bcr), PFI, PFItime, CancerType]
        MatchFlag = 0
        for col in Gendata.columns.values:
            col_info = col.split('-')
            #Match Bar Code
            if col_info[0:3] == bcr[0:3]:
                number = int(col_info[3][0:2])
                if MatchFlag == 1:
                    #print('WARN: multiple match: ', col_info)
                    if number > 0 and number < 10: #multi-tumor samples
                        print('Multi-tumor samples: ', col)
                else:
                    #Check tumor sample
                    if number > 0 and number < 10: #tumor
                        Agg_Gen_Info.append(col)
                        MatchFlag = 1
                        GenGet = Gendata[col].tolist() #Get all gen data in the column
                        #Check missing values or negative values
                        for item in GenGet:
                            if item == 'NaN' or float(item) < 0:
                                Agg_Gen_Info.append(np.nan)
                                print('Find None or Negative')
                            else:
                                Agg_Gen_Info.append(item)#math.log2(item + 1))
                    else:
                        print('Find normal sample: ', col)
        if MatchFlag == 0:
            print('UnMatched:', Agg_Gen_Info)
        else:
            PreprocessedData.append(Agg_Gen_Info)
    return PreprocessedData

def add_header(gene_id_list):
    agg_header = gene_id_list
    agg_header.insert(0, 'bar')
    agg_header.insert(1, 'PFI')
    agg_header.insert(2, 'PFItime')
    agg_header.insert(3, 'gen_id')
    agg_header.insert(4, 'type')
    return agg_header

def saveDF(PreprocessedData, CancerType, agg_header):
    PreprocessedDF = pd.DataFrame(PreprocessedData, columns = agg_header)
    PreprocessedDF.to_csv('./' + CancerType + 'RawGeneData.txt', index=None)

#### Build Data for given cancer

In [None]:
CancerTypeList = ['LGG']
CancerGroupName = 'LGG'
TotalData = []
Gendata, gene_id_list = loadGen()
agg_header = add_header(gene_id_list)
for eachCancer in CancerTypeList:
    print('Processing: ', eachCancer)
    data = loadCancer(eachCancer)
    Agg_Info = AggregateInfo(data)
    PreprocessedData = PreproData(Agg_Info, Gendata, eachCancer)
    TotalData = TotalData + PreprocessedData
    saveDF(PreprocessedData, eachCancer, agg_header)
    print('Preprocess Successfully')
saveDF(TotalData, 'Total_' + CancerGroupName + '_', agg_header)
print('Total Data Saved.')

#### Build WholeTimeSeq Dataset for Cox

In [None]:
import numpy as np
cancer_group_list = ['LGG']
for cancer_group in cancer_group_list:
    data_get = pd.read_csv('TotalData_' + cancer_group + '.txt')
    data_get.insert(4, 'predicted_label', 0, True)
    timelist = data_get['PFItime'].tolist()
    timelist.sort()
    for n in [6, 8, 10, 12]:
        div_point = []
        timelabel = []
        for i in range(n):
            div_point.append(timelist[int(len(timelist) / n * i)])
        print(div_point)
        for item in data_get['PFItime'].tolist():
            i = 0
            while item >= div_point[i] and i < (n - 1):
                i += 1
            if item >= div_point[i] and i == (n - 1):
                i += 1
            timelabel.append(i - 1)
        data_get['predicted_label'] = np.array(timelabel)
        data_get.to_csv('DataSet/CancerRNA_' + cancer_group + '_WholeTimeSeq_' + str(n) + '.txt', index=None)

#### Build WholeTimeSeq Dataset for Classifier

In [None]:
cancer_group = 'LGG'
cancer_type = ['LGG']#, 'READ', 'STAD']
for cancer_get in cancer_type:
    threshold = 0
    n = 2
    div_point = [1] * n
    threshold = 3 * 365
    data_get = pd.read_csv('TotalRawData_' + cancer_group + '.txt')
    data_get = data_get[data_get.gen_id == cancer_get]
    data_get = data_get[(data_get.PFI == 1) | (data_get.PFItime > threshold)]
    data_get.insert(4, 'predicted_label', 0, True)

    timelabel = []
    for item in data_get['PFItime'].tolist():
        i = 0
        if item < threshold:
            timelabel.append(0)
        else:
            timelabel.append(1)

    data_get['predicted_label'] = np.array(timelabel)
    data_get.to_csv('DataSet/CancerRNA_' + cancer_get + '_Risk_' + str(n) + '.txt', index=None)

### Train Contrastive Learning Model for Classifier
Change row 18 to assign task as 'Task'

In [None]:
! python Auto_Train_GPU.py

### Train Contrastive Learning Model for Cox
Change row 18 to assign task as 'WholeTimeSeq'

In [None]:
! python Auto_Train_GPU.py

### Test and Validate Classifier CL4CaPro Model

In [None]:
! python Classifier_method.py --cancer_group BRCA --seed {seed} --core 20 > PlotLog/{}.log &

### Test and Validate Cox CL4CaPro Model

In [None]:
! python Cox_methods.py --cancer_group BRCA --seed {seed} --core 20 > PlotLog/{}.log &