# Breast cancer case

In [1]:
import numpy as np
import pandas as pd
import feyn

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from _crossvalidation import crossvalidation_as_framework

### Load and prep data

In [2]:
data = pd.read_csv('../data/brca_data_w_meta.csv')

In [3]:
# Define the target variable
target = "vital.status"

In [4]:
stypes = {}
for f in data.columns:
    if data[f].dtype == 'object':
        stypes[f] = 'c'

### Cross-validation of the framework

In [5]:
results = crossvalidation_as_framework(data,
                                     target,
                                     kind = "classification",
                                     stypes = stypes,
                                     n_epochs = 500,
                                     criterion='bic',
                                     max_complexity = 10)

In [12]:
results.groupby("fold").first().roc_auc_val.mean(), results.roc_auc_val.mean()

(0.6183480214152156, 0.6313666034560834)

In [17]:
for model in results.groupby("fold").first().model_structure:
    print(model)

logreg((cnCLDN19 + muGPR98)*(rsCIDEC*rsPIK3C2G + rsSLC7A4))
logreg((rsALOX15 + rsIGSF1)*(rsALOX15B + rsAPOB + rsTRPV6))
logreg(tanh(rsKLK8*(muPDZD2 + rsTMPRSS4 + exp(-cnSLC13A2**2 - rsLGALS12**2))))
logreg(cnTGFBR3*rsPOF1B + rsFGFBP1*(rsSLC6A15 + tanh(rsCIDEC)))
logreg(tanh(cnABCC6 + rsMAGEA3 + rsSLC28A3*(rsCACNG4 + rsGLYAT)))


In [18]:
results100 = crossvalidation_as_framework(data,
                                     target,
                                     kind = "classification",
                                     stypes = stypes,
                                     n_epochs = 100,
                                     criterion='bic',
                                     max_complexity = 10)

In [20]:
results.groupby("fold").first().roc_auc_val.mean(), results.roc_auc_val.mean()

(0.6183480214152156, 0.6313666034560834)

In [19]:
for model in results100.groupby("fold").first().model_structure:
    print(model)

logreg(rsPTCHD1*(cnADAMTS16 + muDMD + rsFABP4 + rsSOX2))
logreg(rsAPOB*(cnPRODH*(cnCOL14A1 + ppPI3K.p110.alpha) + rsTAT))
logreg(cnCLCA2 + rsSLC13A2*(cnSLC30A8 + rsAPOB + rsPAX7))
logreg(rsTCN1*(ppp27 + rsDEFB132*rsPTPRZ1 + rsSOX2))
logreg(cnTNFRSF11B + ppMSH6**2 + rsCYP4Z2P + rsPCOLCE2 + rsPLA2G2D)
