In [2]:
import xgboost as xgb 
from xgboost import XGBClassifier, plot_importance
from sklearn.model_selection import train_test_split
import warnings
import pandas as pd
warnings.filterwarnings('ignore')

df = pd.read_csv('train.csv')
df_test = pd.read_csv('test.csv')
df_target = df['SUBCLASS']

data = df.drop(columns=['ID', 'SUBCLASS'])
df_test.drop(columns=['ID'], inplace=True)
data

Unnamed: 0,A2M,AAAS,AADAT,AARS1,ABAT,ABCA1,ABCA2,ABCA3,ABCA4,ABCA5,...,ZNF292,ZNF365,ZNF639,ZNF707,ZNFX1,ZNRF4,ZPBP,ZW10,ZWINT,ZYX
0,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
1,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
2,R895R,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
3,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
4,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6196,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
6197,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
6198,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,T181S,WT
6199,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT


In [3]:
disease_list = df.SUBCLASS.unique()
disease_list

array(['KIPAN', 'SARC', 'SKCM', 'KIRC', 'GBMLGG', 'STES', 'BRCA', 'THCA',
       'LIHC', 'HNSC', 'PAAD', 'OV', 'PRAD', 'UCEC', 'LAML', 'COAD',
       'ACC', 'LGG', 'LUSC', 'LUAD', 'CESC', 'PCPG', 'THYM', 'BLCA',
       'TGCT', 'DLBC'], dtype=object)

1. 처음에는 각 질병(암)마다 WT이 아닌(유전체 암 발현된)개수 5개 이상인 유전체들을 1차로 수집하고,
2. 암별로 transaction의 개수와 발현된 유전체 개수의 불균형이 심하고, 암마다 유의미한 유전체의 개수가 천차만별이라,
각 암 별로 의미있는 유전체의 개수를 공동발현한 유전체들을 추출하기 위해 상관관계 상위 >0.5로 다시 뽑는다. 
3.이 유전체 SET을 기반으로 DataFrame을 새로 만들고, 각 질병별 선정된 유전체들 외의 유전체들 값들을 WT처럼 0으로 cleansing(noise 제거)처리하고,
4. XGBoost, RF, 등으로 학습 후 결과를 도출해본다.


In [4]:
from collections import defaultdict
import numpy as np
def get_gene_dictionary():
    main_genes = defaultdict()
    for d in disease_list:
        count_df = df[df.SUBCLASS==d].apply(lambda x : x!='WT').agg(sum)
        count_df = count_df[2:].reset_index(name='count')        
        gene_list = count_df[count_df['count']>5]['index'].values
        main_genes[d]= gene_list
    return main_genes
main_genes = get_gene_dictionary()

In [4]:
for k, v in main_genes.items():
    print(k, len(v))
    #extract_genes = extract_th(k)
    #print(len(extract_genes))

KIPAN 286
SARC 31
SKCM 2112
KIRC 68
GBMLGG 147
STES 2034
BRCA 691
THCA 4
LIHC 106
HNSC 334
PAAD 22
OV 47
PRAD 16
UCEC 2006
LAML 9
COAD 936
ACC 51
LGG 16
LUSC 587
LUAD 468
CESC 175
PCPG 3
THYM 1
BLCA 209
TGCT 14
DLBC 17


In [5]:
from collections import defaultdict
corr_th_genes_dic = defaultdict()

def extract_set(cancer_name):
    each_glist = main_genes.get(cancer_name)
    print(cancer_name, len(each_glist))
    raw =  pd.concat([data, df_test])
    gframe = raw.loc[:, each_glist]
    
    for c in gframe.columns:
        gframe[c] = gframe[c].apply(lambda x : 0 if x =='WT' else len(str(x).split(' ')))
    corr_matrix = gframe.corr().abs()
    sol = (corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
            .stack()
            .sort_values(ascending=False))
    aa = sol.reset_index()
    aa.columns =['level0','level1','rate']
    #print(aa)
    set1 = aa[aa.rate >= 0.5]['level0'].drop_duplicates()
    set2 = aa[aa.rate >= 0.5]['level1'].drop_duplicates()
    
    a1 = set() 
    a1.update(set1)
    a1.update(set2) 
    
    if len(a1) == 0 :
        return each_glist
    return a1

for d in disease_list : 
    corr_th_genes_dic[d] = extract_set(d)
    len(corr_th_genes_dic[d])
    

KIPAN 286
SARC 31
SKCM 2112
KIRC 68
GBMLGG 147
STES 2034
BRCA 691
THCA 4
LIHC 106
HNSC 334
PAAD 22
OV 47
PRAD 16
UCEC 2006
LAML 9
COAD 936
ACC 51
LGG 16
LUSC 587
LUAD 468
CESC 175
PCPG 3
THYM 1
BLCA 209
TGCT 14
DLBC 17


In [6]:
df_test

Unnamed: 0,A2M,AAAS,AADAT,AARS1,ABAT,ABCA1,ABCA2,ABCA3,ABCA4,ABCA5,...,ZNF292,ZNF365,ZNF639,ZNF707,ZNFX1,ZNRF4,ZPBP,ZW10,ZWINT,ZYX
0,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
1,WT,WT,WT,WT,WT,R587Q,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,I383Sfs,WT,WT,WT,WT
2,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
3,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
4,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2541,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
2542,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT
2543,WT,WT,WT,WT,WT,L217I,P221P P251P,R5M,G606D,I248Nfs,...,S2049Vfs S1909Vfs,L232R,WT,WT,L305L,WT,N252I N251I,G679V,WT,WT
2544,WT,WT,WT,WT,WT,WT,WT,WT,R1517H,WT,...,WT,WT,WT,WT,WT,WT,WT,WT,WT,WT


In [7]:

for i in range(df_test.shape[0]):
    percent_list = np.array([])
    for d in disease_list :
        #print(d)
        s1 = corr_th_genes_dic[d]
        th_count = len(s1)
        #print(s1)
        #print(th_count)
        
        seta = set(df_test.iloc[i].apply(lambda x : 0 if x=='WT' else len(str(x).split(' '))).sort_values(ascending=False)[:th_count].index) #10 유전자 개수, 유전자 마다 다른 threshold로 검사?
        #print(seta)
        intersect_count = len(seta.intersection(s1))
        #print()
        #print(round(intersect_count/th_count, 3))
        percent_list = np.append( percent_list , round(intersect_count/th_count, 3))
        
    dindex = percent_list.argmax()
    print(i, disease_list[dindex])
    
    
    


0 CESC
1 SKCM
2 THCA
3 LAML
4 KIPAN
5 SKCM
6 LAML
7 SKCM
8 STES
9 PRAD
10 STES
11 STES
12 LAML
13 OV
14 STES
15 STES
16 ACC
17 STES
18 STES
19 LGG
20 LAML
21 SARC
22 STES
23 STES
24 STES
25 OV
26 LAML
27 SKCM
28 STES
29 STES
30 STES
31 SKCM
32 LAML
33 PRAD
34 STES
35 STES
36 THCA
37 STES
38 STES
39 LAML
40 STES
41 LAML
42 LGG
43 SKCM
44 STES
45 STES
46 SARC
47 PRAD
48 STES
49 LAML
50 SARC
51 STES
52 STES
53 LAML
54 STES
55 STES
56 STES
57 STES
58 SARC
59 STES
60 STES
61 OV
62 STES
63 STES
64 STES
65 LAML
66 SKCM
67 STES
68 LAML
69 SKCM
70 PAAD
71 SKCM
72 OV
73 STES
74 LGG
75 STES
76 OV
77 OV
78 SARC
79 UCEC
80 PRAD
81 LGG
82 LAML
83 STES
84 LAML
85 OV
86 SKCM
87 LGG
88 PRAD
89 STES
90 STES
91 STES
92 SKCM
93 STES
94 STES
95 ACC
96 LGG
97 SKCM
98 OV
99 LAML
100 TGCT
101 LGG
102 THCA
103 LAML
104 STES
105 ACC
106 LAML
107 STES
108 SKCM
109 DLBC
110 STES
111 STES
112 STES
113 STES
114 LAML
115 STES
116 STES
117 STES
118 STES
119 SKCM
120 SKCM
121 STES
122 LAML
123 OV
124 ACC
125 LAML
126 

KeyboardInterrupt: 

In [7]:
import matplotlib.pyplot as plt
import seaborn as sns
def drawing(dname):
    ccdf = df[df.SUBCLASS==dname].loc[:, main_genes.get(dname)]
    for c in ccdf.columns:
        ccdf[c] = ccdf[c].apply(lambda x : 0 if x =='WT' else len(str(x).split(' ')))
    sns.clustermap(ccdf.corr(), 
            annot = True,      # 실제 값 화면에 나타내기
            cmap = 'RdYlBu_r',  # Red, Yellow, Blue 색상으로 표시
            vmin = -1, vmax = 1, #컬러차트 -1 ~ 1 범위로 표시
            )   

In [8]:
dname='BRCA'
drawing(dname)

In [10]:
drawing('SARC')


NameError: name 'total' is not defined

In [1]:
#drawing('SKCM')


In [None]:
drawing('KIRC')


In [None]:
drawing('GBMLGG')


In [None]:
drawing('STES')


In [None]:
drawing('THCA')

In [None]:
drawing('LIHC')


In [None]:
drawing('HNSC')


In [None]:
 
    
drawing('PAAD')
extract_set('PAAD')


In [None]:
drawing('OV')


In [None]:
drawing('PRAD')


In [None]:
drawing('UCEC')


In [None]:
drawing('LAML')


In [None]:
drawing('COAD')


In [None]:
drawing('ACC')


In [None]:
drawing('LGG')


In [None]:
drawing('LUSC')


In [None]:
drawing('LUAD')


In [None]:
drawing('CESC')


In [None]:
drawing('PCPG')


In [None]:
drawing('THYM')


In [None]:
drawing('BLCA')


In [None]:
drawing('TGCT')


In [None]:
drawing('DLBC')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.clustermap(ccdf.corr(), 
            annot = True,      # 실제 값 화면에 나타내기
            cmap = 'RdYlBu_r',  # Red, Yellow, Blue 색상으로 표시
            vmin = -1, vmax = 1, #컬러차트 -1 ~ 1 범위로 표시
            )

In [None]:
def extract_set(cancer_name):
    each_glist = main_genes.get(cancer_name)
    raw =  pd.concat([data, df_test])
    gframe = raw.loc[:, each_glist]
    
    for c in gframe.columns:
        gframe[c] = gframe[c].apply(lambda x : 0 if x =='WT' else len(str(x).split(' ')))
    corr_matrix = gframe.corr().abs()
    sol = (corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
            .stack()
            .sort_values(ascending=False))
    aa = sol.reset_index()
    aa.columns =['level0','level1','rate']
    #print(aa)
    set1 = aa[aa.rate >= 0.5]['level0'].drop_duplicates()
    set2 = aa[aa.rate >= 0.5]['level1'].drop_duplicates()
    
    a1 = set() 
    a1.update(set1)
    a1.update(set2)