In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn import svm

In [25]:
def status_map(status):
    if 'no_relapse' in status or 'NoRelapse' in status:
        return 0
    else:
        return 1

class BcData:
    def __init__(self):
        self.data = pd.read_csv("data/data_good.csv")
        self.total = pd.read_csv("data/Total_old.csv", names=["gsm", "status"])
        self._drop_grey()
        self._groupby_gene()

    # Drop grey columns
    def _drop_grey(self):
        status_list =['relapse', 'no_relapse', 'test1relapse',
                  'test1no_relapse', 'test2relapse',
                  'test2no_relapse', 'NewTest1_Relapse',
                  'NewTest1_NoRelapse', 'NewTest2_Relapse', 'NewTest2_NoRelapse']
        self.gsm_series = self.total[self.total.status.isin(status_list)].gsm
        new_cols = pd.Series(["GeneSymbol"]).append(self.gsm_series)

        self.total = self.total[self.total.gsm.isin(self.gsm_series)]
        self.data = self.data.filter(items=new_cols)

    # Group rows by gene leaving max median row
    def _groupby_gene(self):
        # Get max median in a group
        def _get_max_median(group):
            idx = group.median(axis=1).idxmax()
            return group.loc[idx]

        groups = self.data.groupby("GeneSymbol", as_index=False, sort=False)
        grouped_result = list(map(lambda group: _get_max_median(group[1]), groups))
        res = pd.concat(grouped_result, axis=1)

        # Set header
        header = res.iloc[0]
        res = res[1:]
        res.columns = header
        self.data = res

    def _log_table(self):
        self.data = np.log2(self.data)

    # Drop rows with quantile less than threshold (values = {7, 8, 9})
    def filter_percentile(self, quantile=1, threshold=9):
        q = self.data.quantile(q=quantile, axis=1)
        index = q[q >= threshold].index.values
        return self.data.loc[index, :].T

    # Drop rows with max/min diff less than threshold (values = {1.5, 2})
    def filter_diff_percentile(self, qmax=1, qmin=0, threshold=2):
        threshold = np.log2(threshold)
        max = self.data.quantile(q=qmax, axis=1)
        min = self.data.quantile(q=qmin, axis=1)
        index = max[max - min >= threshold].index.values
        return self.data.loc[index, :].T

    def get_status(self):
        return self.total.status.map(status_map)

In [26]:
df = BcData()
df.data

GeneSymbol,STAT1,GAPDH,ACTB,PRPF8,CAPNS1,RPL35,RPL28,EIF4G2,EIF3D,PARK7,...,LOC100507009,OR7E47P,EGOT,LOC100510224,ZNF324B,OR7E156P,ALS2CL,C4orf34,TBX10,KCNE4
GSM441628,1111.95,9226.08,8039.97,279.784,1154.88,3020.27,6000.5,1928.14,796.193,2511.83,...,72.4975,108.918,58.2354,51.859,63.8791,92.5358,135.858,18.695,47.5438,32.1067
GSM441629,700.323,6382.34,7094.68,892.762,1089.71,2167.08,5237.38,2114.91,653.19,2274.08,...,108.74,80.7739,50.4372,46.3919,74.5425,86.0676,122.775,22.8111,43.9503,356.43
GSM441643,327.376,10450,8810.01,614.688,2224.19,1927.17,7321.02,1744.14,593.776,2109.05,...,74.9073,104.881,85.2122,76.1531,74.9073,94.5928,231.767,20.1587,56.3981,244.8
GSM441644,1563.17,13756.8,9381.77,290.467,1463.44,2789.48,8052.49,1336.71,512.864,2053.7,...,77.3274,105.174,127.252,35.2412,74.6465,89.0994,179.185,19.5876,57.5737,50.3291
GSM441657,1711.59,11279.5,9996.92,413.91,837.345,3093.49,6462.12,2098.84,759.985,2008.22,...,82.6427,92.155,53.4107,35.8899,62.2448,63.3539,127.02,20.0265,92.4918,126.315
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GSM79256,494.902,5912.75,10617,589.767,1585.45,2923.42,5432.68,1394.79,987.019,1675.44,...,128.886,125.469,91.8698,61.9296,76.1495,124.912,161.39,25.7579,41.5986,56.9028
GSM79307,397.771,6204.99,11992.1,382.866,1486.3,2187.62,5028.52,1407.66,610.979,2260.06,...,131.074,143.668,322.957,48.3264,70.3816,116.641,183.829,20.3404,42.0447,90.1375
GSM79194,300.353,6762.16,12650.9,684.849,990.406,1547.63,3996.28,741.996,759.633,1543.88,...,226.837,155.933,80.7446,59.1357,90.7672,124.325,279.324,22.9909,55.3962,163.158
GSM79179,415.944,10580.1,13686.7,348.245,1066.72,2480.38,4148.88,1776.27,907.129,1494.29,...,164.136,121.529,73.4877,33.2208,72.909,92.2668,177.218,21.6785,46.3305,179.001


In [4]:
grouped = df.data.groupby("GeneSymbol", as_index=False, sort=False)
grouped = list(map(lambda df: df[1], grouped))
# newd.loc["STAT1"]

In [5]:
newd = df.data.groupby('GeneSymbol', as_index=False, sort=False).max()
newd

Unnamed: 0,GeneSymbol,GSM441628,GSM441629,GSM441643,GSM441644,GSM441657,GSM441663,GSM441672,GSM441677,GSM441689,...,GSM79316,GSM79301,GSM79303,GSM79278,GSM79158,GSM79256,GSM79307,GSM79194,GSM79179,GSM79182
0,STAT1,1111.9500,700.3230,327.3760,1563.1700,1711.5900,1777.1600,1491.6300,998.5040,788.1790,...,730.1850,564.9710,1394.5000,388.5310,1279.2800,494.9020,397.7710,300.3530,415.9440,652.7210
1,GAPDH,9226.0800,6382.3400,10450.0000,13756.8000,11279.5000,7235.6500,10087.0000,8516.9200,7371.1900,...,9041.8200,6329.7600,12994.3000,5294.5400,6743.9100,5912.7500,6204.9900,6762.1600,10580.1000,6796.6000
2,ACTB,8317.7900,7094.6800,10224.3000,9990.2600,9996.9200,7357.9500,8495.9900,8868.9400,9317.1900,...,14421.5000,11754.6000,15244.0000,10808.4000,15063.6000,10617.0000,11992.1000,12650.9000,13686.7000,9801.7600
3,PRPF8,279.7840,892.7620,614.6880,290.4670,413.9100,368.8420,769.1610,404.6150,483.8170,...,615.7220,635.5830,395.4080,399.7820,671.8000,589.7670,382.8660,684.8490,348.2450,756.2910
4,CAPNS1,1154.8800,1089.7100,2224.1900,1463.4400,837.3450,1941.2700,1295.1400,1624.8000,1435.7800,...,1056.5800,1438.1500,1118.9900,1050.1900,1427.5400,1585.4500,1486.3000,990.4060,1066.7200,1220.2900
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10502,OR7E156P,92.5358,86.0676,94.5928,89.0994,63.3539,95.3777,88.8058,111.1780,81.8165,...,113.3670,88.7571,199.3210,106.5780,108.0690,124.9120,116.6410,124.3250,92.2668,114.4590
10503,ALS2CL,135.8580,122.7750,231.7670,179.1850,127.0200,120.7980,143.7650,170.3350,111.4620,...,197.9460,158.3730,202.4980,232.8110,193.9110,161.3900,183.8290,279.3240,177.2180,214.1660
10504,C4orf34,18.6950,22.8111,20.1587,19.5876,20.0265,21.2956,19.1105,18.5949,21.2486,...,28.9671,20.8904,17.6882,21.1995,17.1490,25.7579,20.3404,22.9909,21.6785,20.2061
10505,TBX10,47.5438,43.9503,56.3981,57.5737,92.4918,46.7302,49.7785,40.8507,44.4658,...,42.9858,40.1002,35.8531,42.1160,40.3773,41.5986,42.0447,55.3962,46.3305,41.0901


In [None]:
# X = df.filter_percentile(quantile=1, threshold=9)
X = df.filter_diff_percentile(qmax=0.75, qmin=0.25, threshold=1.8)
y = df.get_status()
print("Number of features: {}".format(len(X.columns)))

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Set dual = True if number of features > number of examples and vice versa
clf = svm.LinearSVC(penalty='l1', dual=False, C=0.1, max_iter=10000)
# clf = svm.SVC(kernel='linear', C=1)

scores = cross_val_score(clf, X, y, cv=5)
scores


In [100]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Set dual = True if number of features > number of examples and vice versa
clf = svm.LinearSVC(penalty='l1', dual=False, C=0.1, max_iter=10000)
# clf = svm.SVC(kernel='linear', C=1)

scores = cross_val_score(clf, X, y, cv=5)
scores

array([0.76666667, 0.73333333, 0.73333333, 0.72483221, 0.73154362])

In [9]:

# def _get_max_median(df):
#     result = df.median(axis=1)
#     return result
gmed = grouped[0].median(axis=1)

In [7]:
grouped[0]


Unnamed: 0,GeneSymbol,GSM441628,GSM441629,GSM441643,GSM441644,GSM441657,GSM441663,GSM441672,GSM441677,GSM441689,...,GSM79316,GSM79301,GSM79303,GSM79278,GSM79158,GSM79256,GSM79307,GSM79194,GSM79179,GSM79182
0,STAT1,271.703,143.876,96.2013,211.158,245.954,227.158,369.52,289.518,137.64,...,155.06,139.564,173.008,183.115,159.965,122.855,134.612,157.545,129.474,124.1
1,STAT1,148.662,81.2455,58.1097,135.255,169.862,148.2,190.434,144.622,88.7613,...,85.0733,86.2349,104.746,60.2855,95.5286,58.8385,66.8661,74.5329,68.6508,69.4676
2,STAT1,562.069,349.958,191.297,879.009,771.731,1017.08,985.208,730.354,435.268,...,485.335,335.513,686.761,287.55,754.057,264.548,246.724,195.216,245.225,440.042
783,STAT1,1111.95,700.323,327.376,1563.17,1711.59,1777.16,1491.63,998.504,788.179,...,730.185,564.971,1394.5,388.531,1279.28,494.902,397.771,300.353,415.944,652.721
7876,STAT1,287.471,231.615,126.211,324.371,380.588,375.278,766.178,424.821,194.283,...,244.619,163.391,535.471,359.704,550.767,191.282,170.408,161.369,212.865,239.264
