In [1]:
!pip install -q iterative-stratification

In [2]:
# imports

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

In [3]:
# load data
train = pd.read_csv('/kaggle/input/lish-moa/train_features.csv')
train_target = pd.read_csv('/kaggle/input/lish-moa/train_targets_scored.csv')
test = pd.read_csv('/kaggle/input/lish-moa/test_features.csv')

In [4]:
# feature engineering

# From https://www.kaggle.com/carlmcbrideellis/moa-setting-ctl-vehicle-0-improves-score
train.at[train['cp_type'].str.contains('ctl_vehicle'),train.filter(regex='-.*').columns] = 0.0
test.at[test['cp_type'].str.contains('ctl_vehicle'),test.filter(regex='-.*').columns] = 0.0

# One hot encoding
train_size = train.shape[0]
traintest = pd.concat([train, test])
traintest = pd.concat([traintest, pd.get_dummies(traintest['cp_type'], prefix='cp_type')], axis=1)
traintest = pd.concat([traintest, pd.get_dummies(traintest['cp_time'], prefix='cp_time')], axis=1)
traintest = pd.concat([traintest, pd.get_dummies(traintest['cp_dose'], prefix='cp_dose')], axis=1)
traintest = traintest.drop(['cp_type', 'cp_time', 'cp_dose'], axis=1)
train = traintest[:train_size]
test  = traintest[train_size:]
del traintest

# normalisation
g_columns = [ c for c in train.columns if 'g-' in c ]
scaler = StandardScaler()
train[g_columns] = scaler.fit_transform(train[g_columns])
test[g_columns] = scaler.transform(test[g_columns])
print("{0}:{1}".format(train.shape, test.shape))

# top 500 features + sig_id (separately computed this list using RFECV)
features = ['g-1', 'g-2', 'g-3', 'g-4', 'g-5', 'g-6', 'g-7', 'g-8', 'g-9', 'g-10', 'g-11', 'g-12', 'g-13', 'g-14', 'g-15', 'g-16', 'g-17', 'g-18', 'g-19', 
            'g-20', 'g-21', 'g-22', 'g-23', 'g-24', 'g-25', 'g-26', 'g-27', 'g-28', 'g-29', 'g-30', 'g-31', 'g-32', 'g-33', 'g-34', 'g-35', 'g-36', 'g-37', 
            'g-38', 'g-39', 'g-40', 'g-41', 'g-42', 'g-43', 'g-44', 'g-45', 'g-46', 'g-47', 'g-48', 'g-49', 'g-50', 'g-51', 'g-52', 'g-53', 'g-54', 'g-55', 
            'g-56', 'g-57', 'g-58', 'g-59', 'g-60', 'g-61', 'g-62', 'g-63', 'g-64', 'g-65', 'g-66', 'g-67', 'g-68', 'g-69', 'g-70', 'g-71', 'g-72', 'g-73', 
            'g-74', 'g-75', 'g-76', 'g-77', 'g-78', 'g-79', 'g-80', 'g-81', 'g-82', 'g-83', 'g-84', 'g-85', 'g-86', 'g-87', 'g-88', 'g-89', 'g-90', 'g-91', 
            'g-92', 'g-93', 'g-94', 'g-95', 'g-96', 'g-97', 'g-98', 'g-99', 'g-100', 'g-101', 'g-102', 'g-103', 'g-104', 'g-105', 'g-106', 'g-107', 'g-108', 
            'g-109', 'g-110', 'g-111', 'g-112', 'g-113', 'g-114', 'g-115', 'g-116', 'g-117', 'g-118', 'g-119', 'g-120', 'g-121', 'g-122', 'g-123', 'g-124', 
            'g-125', 'g-126', 'g-127', 'g-128', 'g-129', 'g-130', 'g-132', 'g-137', 'g-139', 'g-140', 'g-144', 'g-151', 'g-152', 'g-154', 'g-155', 'g-158', 
            'g-161', 'g-167', 'g-168', 'g-169', 'g-171', 'g-172', 'g-174', 'g-177', 'g-179', 'g-180', 'g-183', 'g-187', 'g-189', 'g-190', 'g-192', 'g-196', 
            'g-201', 'g-202', 'g-203', 'g-204', 'g-209', 'g-213', 'g-215', 'g-217', 'g-218', 'g-219', 'g-221', 'g-229', 'g-231', 'g-238', 'g-239', 'g-242', 
            'g-244', 'g-245', 'g-247', 'g-251', 'g-254', 'g-260', 'g-262', 'g-264', 'g-267', 'g-268', 'g-275', 'g-279', 'g-281', 'g-282', 'g-284', 'g-285', 
            'g-286', 'g-287', 'g-288', 'g-289', 'g-290', 'g-291', 'g-292', 'g-293', 'g-294', 'g-295', 'g-296', 'g-299', 'g-302', 'g-303', 'g-306', 'g-308', 
            'g-309', 'g-310', 'g-311', 'g-313', 'g-315', 'g-316', 'g-317', 'g-318', 'g-319', 'g-320', 'g-321', 'g-322', 'g-323', 'g-325', 'g-326', 'g-328', 
            'g-331', 'g-335', 'g-336', 'g-337', 'g-338', 'g-339', 'g-342', 'g-348', 'g-349', 'g-350', 'g-351', 'g-352', 'g-353', 'g-355', 'g-356', 'g-357', 
            'g-358', 'g-359', 'g-364', 'g-365', 'g-370', 'g-371', 'g-372', 'g-373', 'g-374', 'g-375', 'g-376', 'g-377', 'g-378', 'g-379', 'g-380', 'g-382', 
            'g-387', 'g-390', 'g-393', 'g-401', 'g-403', 'g-411', 'g-416', 'g-417', 'g-419', 'g-420', 'g-421', 'g-425', 'g-427', 'g-428', 'g-429', 'g-431', 
            'g-433', 'g-435', 'g-441', 'g-442', 'g-445', 'g-447', 'g-448', 'g-457', 'g-460', 'g-464', 'g-467', 'g-473', 'g-475', 'g-476', 'g-479', 'g-480', 
            'g-481', 'g-482', 'g-484', 'g-485', 'g-486', 'g-489', 'g-490', 'g-492', 'g-500', 'g-501', 'g-504', 'g-508', 'g-516', 'g-521', 'g-524', 'g-526', 
            'g-527', 'g-528', 'g-529', 'g-530', 'g-531', 'g-535', 'g-537', 'g-542', 'g-548', 'g-554', 'g-557', 'g-558', 'g-562', 'g-565', 'g-571', 'g-572', 
            'g-574', 'g-577', 'g-580', 'g-581', 'g-582', 'g-583', 'g-584', 'g-585', 'g-586', 'g-587', 'g-588', 'g-589', 'g-590', 'g-591', 'g-592', 'g-593', 
            'g-594', 'g-595', 'g-596', 'g-597', 'g-598', 'g-599', 'g-601', 'g-602', 'g-603', 'g-604', 'g-605', 'g-606', 'g-607', 'g-609', 'g-610', 'g-611', 
            'g-612', 'g-613', 'g-614', 'g-615', 'g-617', 'g-618', 'g-619', 'g-620', 'g-623', 'g-624', 'g-626', 'g-627', 'g-628', 'g-630', 'g-632', 'g-633', 
            'g-634', 'g-635', 'g-636', 'g-637', 'g-638', 'g-639', 'g-643', 'g-645', 'g-650', 'g-651', 'g-652', 'g-655', 'g-656', 'g-662', 'g-663', 'g-665', 
            'g-667', 'g-675', 'g-677', 'g-680', 'g-684', 'g-685', 'g-690', 'g-701', 'g-702', 'g-703', 'g-704', 'g-713', 'g-714', 'g-721', 'g-725', 'g-726', 
            'g-727', 'g-734', 'g-735', 'g-737', 'g-739', 'g-740', 'g-741', 'g-742', 'g-744', 'g-751', 'g-755', 'g-758', 'g-759', 'g-760', 'g-761', 'g-762', 
            'g-763', 'g-764', 'g-765', 'g-767', 'g-769', 'c-2', 'c-3', 'c-4', 'c-5', 'c-6', 'c-7', 'c-8', 'c-9', 'c-10', 'c-11', 'c-12', 'c-13', 'c-14', 
            'c-15', 'c-16', 'c-17', 'c-18', 'c-19', 'c-20', 'c-21', 'c-22', 'c-23', 'c-24', 'c-27', 'c-30', 'c-31', 'c-32', 'c-35', 'c-37', 'c-39', 'c-40', 
            'c-41', 'c-43', 'c-45', 'c-49', 'c-50', 'c-51', 'c-52', 'c-53', 'c-54', 'c-55', 'c-56', 'c-57', 'c-58', 'c-59', 'c-60', 'c-61', 'c-62', 'c-63', 
            'c-64', 'c-65', 'c-66', 'c-67', 'c-68', 'c-69', 'c-70', 'c-71', 'c-72', 'c-73', 'c-74', 'c-75', 'c-76', 'c-77', 'c-78', 'c-79', 'c-80', 'c-81', 
            'c-82', 'c-83', 'c-84', 'c-85', 'c-86', 'c-87', 'c-88', 'c-89', 'c-90', 'c-91', 'c-92', 'c-93', 'c-94', 'c-95', 'c-96', 'c-98', 'sig_id']

train = train[features]
test = test[features]
print("{0}:{1}".format(train.shape, test.shape))

(23814, 880):(3982, 880)
(23814, 501):(3982, 501)


In [5]:
# kfold
train.loc[:, "kfold"] = -1
train_target_folds = train_target.copy()
train_target_folds.loc[:, "kfold"] = -1
train = train.sample(frac=1).reset_index(drop=True)
kf = KFold(n_splits=5, shuffle=True)
for f, (t_, v_) in enumerate(kf.split(X=train, y=train_target)):
    train.loc[v_, "kfold"] = f
    train_target_folds.loc[v_, "kfold"] = f

In [6]:
# multilabel stratified kfold
train.loc[:, "skfold"] = -1
train_target_folds.loc[:, "skfold"] = -1
mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True)
for f, (t_, v_) in enumerate(mskf.split(X=train, y=train_target)):
    train.loc[v_, "skfold"] = f
    train_target_folds.loc[v_, "skfold"] = f
train.to_csv("cleaned_train_folds.csv", index=False)
train_target_folds.to_csv("train_target_folds.csv", index=False)



In [7]:
train.skfold.value_counts()

4    4763
3    4763
2    4763
0    4763
1    4762
Name: skfold, dtype: int64

In [8]:
train.kfold.value_counts()

3    4763
2    4763
1    4763
0    4763
4    4762
Name: kfold, dtype: int64