In [1]:
import os
import gc
import glob
import random
import itertools
import numpy as np
import pandas as pd

from sklearn.model_selection import KFold
from sklearn.decomposition import PCA
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.optim.lr_scheduler import CosineAnnealingLR

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

import warnings
warnings.filterwarnings('ignore')

In [2]:
pl.__version__

'1.0.3'

In [3]:
class cfg:
    seed = 69420
    g_comp = 29
    c_comp = 4
    fold = 0
    emb_dims = [(2, 15), (3, 20), (2, 15)]
    
    dropout_rate = 0.4
    hidden_size = 2048
    
    batch_size = 128
    lr = 0.001
    epoch = 15

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [5]:
seed_everything(cfg.seed)

In [6]:
class MoADataset(Dataset):
    def __init__(self, df, feature_cols, target_cols, phase='train'):
        self.df = df
        self.cat_cols = ['cp_type', 'cp_time', 'cp_dose']
        self.cont_cols = [c for c in feature_cols if c not in self.cat_cols]
        self.target_cols = target_cols
        self.phase = phase
        
        self.cont_features = self.df[self.cont_cols].values
        self.cat_features = self.df[self.cat_cols].values
        self.targets = self.df[self.target_cols].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        cont_f = torch.tensor(self.cont_features[idx, :], dtype=torch.float)
        cat_f = torch.tensor(self.cat_features[idx, :], dtype=torch.long)

        if self.phase != 'test':
            target = torch.tensor(self.targets[idx, :], dtype=torch.float)

            return cont_f, cat_f, target

        else:
            sig_id = self.df['sig_id'].iloc[idx]
            return cont_f, cat_f, sig_id

In [7]:
SEED              = 69420
TRAIN_COLUMNS     = ["sig_id", "cp_type", "cp_time", "cp_dose", "g-0", "g-1", "g-2", "g-3", "g-4", "g-5", "g-6", "g-7", "g-8", "g-9", "g-10", "g-11", "g-12", "g-13", "g-14", "g-15", "g-16", "g-17", "g-18", "g-19", "g-20", "g-21", "g-22", "g-23", "g-24", "g-25", "g-26", "g-27", "g-28", "g-29", "g-30", "g-31", "g-32", "g-33", "g-34", "g-35", "g-36", "g-37", "g-38", "g-39", "g-40", "g-41", "g-42", "g-43", "g-44", "g-45", "g-46", "g-47", "g-48", "g-49", "g-50", "g-51", "g-52", "g-53", "g-54", "g-55", "g-56", "g-57", "g-58", "g-59", "g-60", "g-61", "g-62", "g-63", "g-64", "g-65", "g-66", "g-67", "g-68", "g-69", "g-70", "g-71", "g-72", "g-73", "g-74", "g-75", "g-76", "g-77", "g-78", "g-79", "g-80", "g-81", "g-82", "g-83", "g-84", "g-85", "g-86", "g-87", "g-88", "g-89", "g-90", "g-91", "g-92", "g-93", "g-94", "g-95", "g-96", "g-97", "g-98", "g-99", "g-100", "g-101", "g-102", "g-103", "g-104", "g-105", "g-106", "g-107", "g-108", "g-109", "g-110", "g-111", "g-112", "g-113", "g-114", "g-115", "g-116", "g-117", "g-118", "g-119", "g-120", "g-121", "g-122", "g-123", "g-124", "g-125", "g-126", "g-127", "g-128", "g-129", "g-130", "g-131", "g-132", "g-133", "g-134", "g-135", "g-136", "g-137", "g-138", "g-139", "g-140", "g-141", "g-142", "g-143", "g-144", "g-145", "g-146", "g-147", "g-148", "g-149", "g-150", "g-151", "g-152", "g-153", "g-154", "g-155", "g-156", "g-157", "g-158", "g-159", "g-160", "g-161", "g-162", "g-163", "g-164", "g-165", "g-166", "g-167", "g-168", "g-169", "g-170", "g-171", "g-172", "g-173", "g-174", "g-175", "g-176", "g-177", "g-178", "g-179", "g-180", "g-181", "g-182", "g-183", "g-184", "g-185", "g-186", "g-187", "g-188", "g-189", "g-190", "g-191", "g-192", "g-193", "g-194", "g-195", "g-196", "g-197", "g-198", "g-199", "g-200", "g-201", "g-202", "g-203", "g-204", "g-205", "g-206", "g-207", "g-208", "g-209", "g-210", "g-211", "g-212", "g-213", "g-214", "g-215", "g-216", "g-217", "g-218", "g-219", "g-220", "g-221", "g-222", "g-223", "g-224", "g-225", "g-226", "g-227", "g-228", "g-229", "g-230", "g-231", "g-232", "g-233", "g-234", "g-235", "g-236", "g-237", "g-238", "g-239", "g-240", "g-241", "g-242", "g-243", "g-244", "g-245", "g-246", "g-247", "g-248", "g-249", "g-250", "g-251", "g-252", "g-253", "g-254", "g-255", "g-256", "g-257", "g-258", "g-259", "g-260", "g-261", "g-262", "g-263", "g-264", "g-265", "g-266", "g-267", "g-268", "g-269", "g-270", "g-271", "g-272", "g-273", "g-274", "g-275", "g-276", "g-277", "g-278", "g-279", "g-280", "g-281", "g-282", "g-283", "g-284", "g-285", "g-286", "g-287", "g-288", "g-289", "g-290", "g-291", "g-292", "g-293", "g-294", "g-295", "g-296", "g-297", "g-298", "g-299", "g-300", "g-301", "g-302", "g-303", "g-304", "g-305", "g-306", "g-307", "g-308", "g-309", "g-310", "g-311", "g-312", "g-313", "g-314", "g-315", "g-316", "g-317", "g-318", "g-319", "g-320", "g-321", "g-322", "g-323", "g-324", "g-325", "g-326", "g-327", "g-328", "g-329", "g-330", "g-331", "g-332", "g-333", "g-334", "g-335", "g-336", "g-337", "g-338", "g-339", "g-340", "g-341", "g-342", "g-343", "g-344", "g-345", "g-346", "g-347", "g-348", "g-349", "g-350", "g-351", "g-352", "g-353", "g-354", "g-355", "g-356", "g-357", "g-358", "g-359", "g-360", "g-361", "g-362", "g-363", "g-364", "g-365", "g-366", "g-367", "g-368", "g-369", "g-370", "g-371", "g-372", "g-373", "g-374", "g-375", "g-376", "g-377", "g-378", "g-379", "g-380", "g-381", "g-382", "g-383", "g-384", "g-385", "g-386", "g-387", "g-388", "g-389", "g-390", "g-391", "g-392", "g-393", "g-394", "g-395", "g-396", "g-397", "g-398", "g-399", "g-400", "g-401", "g-402", "g-403", "g-404", "g-405", "g-406", "g-407", "g-408", "g-409", "g-410", "g-411", "g-412", "g-413", "g-414", "g-415", "g-416", "g-417", "g-418", "g-419", "g-420", "g-421", "g-422", "g-423", "g-424", "g-425", "g-426", "g-427", "g-428", "g-429", "g-430", "g-431", "g-432", "g-433", "g-434", "g-435", "g-436", "g-437", "g-438", "g-439", "g-440", "g-441", "g-442", "g-443", "g-444", "g-445", "g-446", "g-447", "g-448", "g-449", "g-450", "g-451", "g-452", "g-453", "g-454", "g-455", "g-456", "g-457", "g-458", "g-459", "g-460", "g-461", "g-462", "g-463", "g-464", "g-465", "g-466", "g-467", "g-468", "g-469", "g-470", "g-471", "g-472", "g-473", "g-474", "g-475", "g-476", "g-477", "g-478", "g-479", "g-480", "g-481", "g-482", "g-483", "g-484", "g-485", "g-486", "g-487", "g-488", "g-489", "g-490", "g-491", "g-492", "g-493", "g-494", "g-495", "g-496", "g-497", "g-498", "g-499", "g-500", "g-501", "g-502", "g-503", "g-504", "g-505", "g-506", "g-507", "g-508", "g-509", "g-510", "g-511", "g-512", "g-513", "g-514", "g-515", "g-516", "g-517", "g-518", "g-519", "g-520", "g-521", "g-522", "g-523", "g-524", "g-525", "g-526", "g-527", "g-528", "g-529", "g-530", "g-531", "g-532", "g-533", "g-534", "g-535", "g-536", "g-537", "g-538", "g-539", "g-540", "g-541", "g-542", "g-543", "g-544", "g-545", "g-546", "g-547", "g-548", "g-549", "g-550", "g-551", "g-552", "g-553", "g-554", "g-555", "g-556", "g-557", "g-558", "g-559", "g-560", "g-561", "g-562", "g-563", "g-564", "g-565", "g-566", "g-567", "g-568", "g-569", "g-570", "g-571", "g-572", "g-573", "g-574", "g-575", "g-576", "g-577", "g-578", "g-579", "g-580", "g-581", "g-582", "g-583", "g-584", "g-585", "g-586", "g-587", "g-588", "g-589", "g-590", "g-591", "g-592", "g-593", "g-594", "g-595", "g-596", "g-597", "g-598", "g-599", "g-600", "g-601", "g-602", "g-603", "g-604", "g-605", "g-606", "g-607", "g-608", "g-609", "g-610", "g-611", "g-612", "g-613", "g-614", "g-615", "g-616", "g-617", "g-618", "g-619", "g-620", "g-621", "g-622", "g-623", "g-624", "g-625", "g-626", "g-627", "g-628", "g-629", "g-630", "g-631", "g-632", "g-633", "g-634", "g-635", "g-636", "g-637", "g-638", "g-639", "g-640", "g-641", "g-642", "g-643", "g-644", "g-645", "g-646", "g-647", "g-648", "g-649", "g-650", "g-651", "g-652", "g-653", "g-654", "g-655", "g-656", "g-657", "g-658", "g-659", "g-660", "g-661", "g-662", "g-663", "g-664", "g-665", "g-666", "g-667", "g-668", "g-669", "g-670", "g-671", "g-672", "g-673", "g-674", "g-675", "g-676", "g-677", "g-678", "g-679", "g-680", "g-681", "g-682", "g-683", "g-684", "g-685", "g-686", "g-687", "g-688", "g-689", "g-690", "g-691", "g-692", "g-693", "g-694", "g-695", "g-696", "g-697", "g-698", "g-699", "g-700", "g-701", "g-702", "g-703", "g-704", "g-705", "g-706", "g-707", "g-708", "g-709", "g-710", "g-711", "g-712", "g-713", "g-714", "g-715", "g-716", "g-717", "g-718", "g-719", "g-720", "g-721", "g-722", "g-723", "g-724", "g-725", "g-726", "g-727", "g-728", "g-729", "g-730", "g-731", "g-732", "g-733", "g-734", "g-735", "g-736", "g-737", "g-738", "g-739", "g-740", "g-741", "g-742", "g-743", "g-744", "g-745", "g-746", "g-747", "g-748", "g-749", "g-750", "g-751", "g-752", "g-753", "g-754", "g-755", "g-756", "g-757", "g-758", "g-759", "g-760", "g-761", "g-762", "g-763", "g-764", "g-765", "g-766", "g-767", "g-768", "g-769", "g-770", "g-771", "c-0", "c-1", "c-2", "c-3", "c-4", "c-5", "c-6", "c-7", "c-8", "c-9", "c-10", "c-11", "c-12", "c-13", "c-14", "c-15", "c-16", "c-17", "c-18", "c-19", "c-20", "c-21", "c-22", "c-23", "c-24", "c-25", "c-26", "c-27", "c-28", "c-29", "c-30", "c-31", "c-32", "c-33", "c-34", "c-35", "c-36", "c-37", "c-38", "c-39", "c-40", "c-41", "c-42", "c-43", "c-44", "c-45", "c-46", "c-47", "c-48", "c-49", "c-50", "c-51", "c-52", "c-53", "c-54", "c-55", "c-56", "c-57", "c-58", "c-59", "c-60", "c-61", "c-62", "c-63", "c-64", "c-65", "c-66", "c-67", "c-68", "c-69", "c-70", "c-71", "c-72", "c-73", "c-74", "c-75", "c-76", "c-77", "c-78", "c-79", "c-80", "c-81", "c-82", "c-83", "c-84", "c-85", "c-86", "c-87", "c-88", "c-89", "c-90", "c-91", "c-92", "c-93", "c-94", "c-95", "c-96", "c-97", "c-98", "c-99"]
TARGET_COLUMNS    = ['5-alpha_reductase_inhibitor', '11-beta-hsd1_inhibitor', 'acat_inhibitor', 'acetylcholine_receptor_agonist', 'acetylcholine_receptor_antagonist', 'acetylcholinesterase_inhibitor', 'adenosine_receptor_agonist', 'adenosine_receptor_antagonist', 'adenylyl_cyclase_activator', 'adrenergic_receptor_agonist', 'adrenergic_receptor_antagonist', 'akt_inhibitor', 'aldehyde_dehydrogenase_inhibitor', 'alk_inhibitor', 'ampk_activator', 'analgesic', 'androgen_receptor_agonist', 'androgen_receptor_antagonist', 'anesthetic_-_local', 'angiogenesis_inhibitor', 'angiotensin_receptor_antagonist', 'anti-inflammatory', 'antiarrhythmic', 'antibiotic', 'anticonvulsant', 'antifungal', 'antihistamine', 'antimalarial', 'antioxidant', 'antiprotozoal', 'antiviral', 'apoptosis_stimulant', 'aromatase_inhibitor', 'atm_kinase_inhibitor', 'atp-sensitive_potassium_channel_antagonist', 'atp_synthase_inhibitor', 'atpase_inhibitor', 'atr_kinase_inhibitor', 'aurora_kinase_inhibitor', 'autotaxin_inhibitor', 'bacterial_30s_ribosomal_subunit_inhibitor', 'bacterial_50s_ribosomal_subunit_inhibitor', 'bacterial_antifolate', 'bacterial_cell_wall_synthesis_inhibitor', 'bacterial_dna_gyrase_inhibitor', 'bacterial_dna_inhibitor', 'bacterial_membrane_integrity_inhibitor', 'bcl_inhibitor', 'bcr-abl_inhibitor', 'benzodiazepine_receptor_agonist', 'beta_amyloid_inhibitor', 'bromodomain_inhibitor', 'btk_inhibitor', 'calcineurin_inhibitor', 'calcium_channel_blocker', 'cannabinoid_receptor_agonist', 'cannabinoid_receptor_antagonist', 'carbonic_anhydrase_inhibitor', 'casein_kinase_inhibitor', 'caspase_activator', 'catechol_o_methyltransferase_inhibitor', 'cc_chemokine_receptor_antagonist', 'cck_receptor_antagonist', 'cdk_inhibitor', 'chelating_agent', 'chk_inhibitor', 'chloride_channel_blocker', 'cholesterol_inhibitor', 'cholinergic_receptor_antagonist', 'coagulation_factor_inhibitor', 'corticosteroid_agonist', 'cyclooxygenase_inhibitor', 'cytochrome_p450_inhibitor', 'dihydrofolate_reductase_inhibitor', 'dipeptidyl_peptidase_inhibitor', 'diuretic', 'dna_alkylating_agent', 'dna_inhibitor', 'dopamine_receptor_agonist', 'dopamine_receptor_antagonist', 'egfr_inhibitor', 'elastase_inhibitor', 'erbb2_inhibitor', 'estrogen_receptor_agonist', 'estrogen_receptor_antagonist', 'faah_inhibitor', 'farnesyltransferase_inhibitor', 'fatty_acid_receptor_agonist', 'fgfr_inhibitor', 'flt3_inhibitor', 'focal_adhesion_kinase_inhibitor', 'free_radical_scavenger', 'fungal_squalene_epoxidase_inhibitor', 'gaba_receptor_agonist', 'gaba_receptor_antagonist', 'gamma_secretase_inhibitor', 'glucocorticoid_receptor_agonist', 'glutamate_inhibitor', 'glutamate_receptor_agonist', 'glutamate_receptor_antagonist', 'gonadotropin_receptor_agonist', 'gsk_inhibitor', 'hcv_inhibitor', 'hdac_inhibitor', 'histamine_receptor_agonist', 'histamine_receptor_antagonist', 'histone_lysine_demethylase_inhibitor', 'histone_lysine_methyltransferase_inhibitor', 'hiv_inhibitor', 'hmgcr_inhibitor', 'hsp_inhibitor', 'igf-1_inhibitor', 'ikk_inhibitor', 'imidazoline_receptor_agonist', 'immunosuppressant', 'insulin_secretagogue', 'insulin_sensitizer', 'integrin_inhibitor', 'jak_inhibitor', 'kit_inhibitor', 'laxative', 'leukotriene_inhibitor', 'leukotriene_receptor_antagonist', 'lipase_inhibitor', 'lipoxygenase_inhibitor', 'lxr_agonist', 'mdm_inhibitor', 'mek_inhibitor', 'membrane_integrity_inhibitor', 'mineralocorticoid_receptor_antagonist', 'monoacylglycerol_lipase_inhibitor', 'monoamine_oxidase_inhibitor', 'monopolar_spindle_1_kinase_inhibitor', 'mtor_inhibitor', 'mucolytic_agent', 'neuropeptide_receptor_antagonist', 'nfkb_inhibitor', 'nicotinic_receptor_agonist', 'nitric_oxide_donor', 'nitric_oxide_production_inhibitor', 'nitric_oxide_synthase_inhibitor', 'norepinephrine_reuptake_inhibitor', 'nrf2_activator', 'opioid_receptor_agonist', 'opioid_receptor_antagonist', 'orexin_receptor_antagonist', 'p38_mapk_inhibitor', 'p-glycoprotein_inhibitor', 'parp_inhibitor', 'pdgfr_inhibitor', 'pdk_inhibitor', 'phosphodiesterase_inhibitor', 'phospholipase_inhibitor', 'pi3k_inhibitor', 'pkc_inhibitor', 'potassium_channel_activator', 'potassium_channel_antagonist', 'ppar_receptor_agonist', 'ppar_receptor_antagonist', 'progesterone_receptor_agonist', 'progesterone_receptor_antagonist', 'prostaglandin_inhibitor', 'prostanoid_receptor_antagonist', 'proteasome_inhibitor', 'protein_kinase_inhibitor', 'protein_phosphatase_inhibitor', 'protein_synthesis_inhibitor', 'protein_tyrosine_kinase_inhibitor', 'radiopaque_medium', 'raf_inhibitor', 'ras_gtpase_inhibitor', 'retinoid_receptor_agonist', 'retinoid_receptor_antagonist', 'rho_associated_kinase_inhibitor', 'ribonucleoside_reductase_inhibitor', 'rna_polymerase_inhibitor', 'serotonin_receptor_agonist', 'serotonin_receptor_antagonist', 'serotonin_reuptake_inhibitor', 'sigma_receptor_agonist', 'sigma_receptor_antagonist', 'smoothened_receptor_antagonist', 'sodium_channel_inhibitor', 'sphingosine_receptor_agonist', 'src_inhibitor', 'steroid', 'syk_inhibitor', 'tachykinin_antagonist', 'tgf-beta_receptor_inhibitor', 'thrombin_inhibitor', 'thymidylate_synthase_inhibitor', 'tlr_agonist', 'tlr_antagonist', 'tnf_inhibitor', 'topoisomerase_inhibitor', 'transient_receptor_potential_channel_antagonist', 'tropomyosin_receptor_kinase_inhibitor', 'trpv_agonist', 'trpv_antagonist', 'tubulin_inhibitor', 'tyrosine_kinase_inhibitor', 'ubiquitin_specific_protease_inhibitor', 'vegfr_inhibitor', 'vitamin_b', 'vitamin_d_receptor_agonist', 'wnt_inhibitor']
GENES             = [col for col in TRAIN_COLUMNS if col.startswith('g-')]
CELLS             = [col for col in TRAIN_COLUMNS if col.startswith('c-')]
GENE_PCA_COMP     = 29
CELL_PCA_COMP     = 4
VARIANCE_THRESHOLD= 0.8
FOLDS             = 7
VERBOSE           = True
INPUT_SIZE        = None
OUTPUT_SIZE       = None
SAVE_FOLDS        = True
USE_SAVED_FOLDS   = True

BATCH_SIZE        = 128
EARLY_STOPPING    = 10
LEARNING_RATE     = 0.01
WEIGHT_DECAY      = 1e-5
SAVE_TOP_K        = 1
MAX_EPOCHS        = 2000

PATH              = "../data/"
TRAIN_F           = os.path.join(PATH, "train_features.csv")
TRAIN_T           = os.path.join(PATH, "train_targets_scored.csv")
TRAIN_T_NS        = os.path.join(PATH, "train_targets_nonscored.csv")
TEST_F            = os.path.join(PATH, "test_features.csv")
SAMPLE_SUBMISSION = os.path.join(PATH, "sample_submission.csv")
PRE_FEATURES_CSV  = os.path.join(PATH, "preprocessed_train_features.csv")
PRE_TARGETS_CSV   = os.path.join(PATH, "preprocessed_train_targets.csv")
PRE_TEST_CSV      = os.path.join(PATH, "preprocessed_test_features.csv")
SUBMISSION        = os.path.join(PATH, "submission.csv")

seed_everything(SEED)

In [8]:
class DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, cfg, cv):
        super(DataModule, self).__init__()
        self.cfg = cfg
        self.data_dir = data_dir
        self.cv = cv

    def prepare_data(self):
        # Prepare Data
        train_feature = pd.read_csv(TRAIN_F)
        train_target = pd.read_csv(TRAIN_T)
        test = pd.read_csv(TEST_F)
        
        train_feature, train_target, test = self._mapping_and_filter(train_feature, train_target, test)
        train_feature, train_target, test = self._PCA(train_feature, train_target, test)

        train = pd.merge(train_target, train_feature, on='sig_id')
        self.target_cols = [c for c in train_target.columns if c != 'sig_id']

        test['is_train'] = 0
        train['is_train'] = 1
        self.df = pd.concat([train, test], axis=0, ignore_index=True)
        self.test_id = test['sig_id'].values
        
#         # Label Encoding
#         self.df = self.Encode(self.df)
#         # add PCA Features
#         self.df = self.add_PCA(self.df, g_comp=self.cfg.g_comp, c_comp=self.cfg.c_comp)
        self.feature_cols = [c for c in self.df.columns if c not in self.target_cols + ['sig_id', 'is_train', 'fold']]

        del train, train_target, train_feature, test
        gc.collect()
        
    def _mapping_and_filter(self, f, t, test_f, drop_cp=True):
        cp_type = {'trt_cp': 0, 'ctl_vehicle': 1}
        cp_dose = {'D1': 0, 'D2': 1}
        cp_time = {24: 0, 48: 1, 72:2}
        for df in [f, test_f]:
            df['cp_type'] = df['cp_type'].map(cp_type)
            df['cp_dose'] = df['cp_dose'].map(cp_dose)
            df['cp_time'] = df['cp_time'].map(cp_time)
        if drop_cp:
            t = t[f['cp_type'] == 0].reset_index(drop = True)
            f = f[f['cp_type'] == 0].reset_index(drop = True)
        
        if VERBOSE:
            print("Features Mapped to Integers.")
        
        del cp_type, cp_dose, cp_time
        gc.collect()
        
        return f, t, test_f

    def setup(self, stage=None):
        # Split Train, Test
        df = self.df[self.df['is_train'] == 1].reset_index(drop=True)
        test = self.df[self.df['is_train'] == 0].reset_index(drop=True)
        self.test_id = test['sig_id'].values

        # Split Train, Validation
        df['fold'] = -1
        for i, (trn_idx, val_idx) in enumerate(self.cv.split(df, df[self.target_cols])):
            df.loc[val_idx, 'fold'] = i
        fold = self.cfg.fold
        train = df[df['fold'] != fold].reset_index(drop=True)
        val = df[df['fold'] == fold].reset_index(drop=True)

        self.train_dataset = MoADataset(train, self.feature_cols, self.target_cols, phase='train')
        self.val_dataset = MoADataset(val, self.feature_cols, self.target_cols, phase='train')
        self.test_dataset = MoADataset(test, self.feature_cols, self.target_cols, phase='test')

        del df, test, train, val
        gc.collect()

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.cfg.batch_size,
                          pin_memory=True,
                          sampler=RandomSampler(self.train_dataset), drop_last=False)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.cfg.batch_size,
                          pin_memory=True,
                          sampler=SequentialSampler(self.val_dataset), drop_last=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.cfg.batch_size,
                          pin_memory=False,
                          shuffle=False, drop_last=False)
    
    def Encode(self, df):
        cp_type_encoder = {
            'trt_cp': 0,
            'ctl_vehicle': 1
        }

        cp_time_encoder = {
            24: 0,
            48: 1,
            72: 2,
        }

        cp_dose_encoder = {
            'D1': 0,
            'D2': 1
        }

        df['cp_type'] = df['cp_type'].map(cp_type_encoder)
        df['cp_time'] = df['cp_time'].map(cp_time_encoder)
        df['cp_dose'] = df['cp_dose'].map(cp_dose_encoder)

        for c in ['cp_type', 'cp_time', 'cp_dose']:
            df[c] = df[c].astype(int)

        return df
    
    
    def add_PCA(self, df, g_comp=29, c_comp=4):
        # g-features
        g_cols = [c for c in df.columns if 'g-' in c]
        temp = PCA(n_components=g_comp, random_state=self.cfg.seed).fit_transform(df[g_cols])
        temp = pd.DataFrame(temp, columns=[f'pca_g_{i + 1}' for i in range(g_comp)])
        df = pd.concat([df, temp], axis=1)

        # c-features
        c_cols = [c for c in df.columns if 'c-' in c]
        temp = PCA(n_components=c_comp, random_state=self.cfg.seed).fit_transform(df[c_cols])
        temp = pd.DataFrame(temp, columns=[f'pca_c_{i + 1}' for i in range(c_comp)])
        df = pd.concat([df, temp], axis=1)

        del temp

        return df
    
    def _PCA(self, f, t, test_f, drop_original=False):
        def create_pca(train, test, features, kind, n_components):
            train_ = train[features].copy()
            test_ = test[features].copy()
            data = pd.concat([train_, test_], axis = 0)
            pca = PCA(n_components = n_components,  random_state = SEED)
            data = pca.fit_transform(data)
            columns = [f'pca_{kind}_{i + 1}' for i in range(n_components)]
            data = pd.DataFrame(data, columns = columns)
            train_ = data.iloc[:train.shape[0]]
            test_ = data.iloc[train.shape[0]:].reset_index(drop = True)
            train = pd.concat([train, train_], axis = 1)
            test = pd.concat([test, test_], axis = 1)
            return train, test

        f, test_f = create_pca(f, test_f, GENES, kind = 'g', n_components = GENE_PCA_COMP)
        f, test_f = create_pca(f, test_f, CELLS, kind = 'c', n_components = CELL_PCA_COMP)
        if drop_original:
            f = f.drop(GENES).reset_index(drop=True)
            t = t.drop(CELLS).reset_index(drop=True)
                
        if VERBOSE:
            print("PCA Performed.")            
            
        return f, t, test_f

In [9]:
# # Just to update INPUT_SIZE and OUTPUT_SIZE
# dataset = DataModule("", cfg, MultilabelStratifiedKFold(n_splits=FOLDS))
# dataset.prepare_data()
# dataset.setup()

In [10]:
# dataset.df

In [11]:
# ftt

In [12]:
# f = pd.read_csv(TRAIN_F)
# f.shape

In [13]:
# dataset.df.shape

In [14]:
# f = pd.read_csv(PRE_FEATURES_CSV)
# t = pd.read_csv(PRE_TARGETS_CSV)
# test_f = pd.read_csv(PRE_TEST_CSV)

In [15]:
# ft = pd.merge(t, f, on='sig_id')

In [16]:
# ftt = pd.concat([ft, test_f], axis=0, ignore_index=True)

In [17]:
# ftt.columns == dataset.df.columns

In [18]:
class LightningSystem(pl.LightningModule):
    def __init__(self, net, cfg, target_cols):
        super(LightningSystem, self).__init__()
        self.net = net
        self.cfg = cfg
        self.target_cols = target_cols
        self.criterion = nn.BCEWithLogitsLoss()
        self.best_loss = 1e+9

    def configure_optimizers(self):
        self.optimizer = optim.AdamW(self.parameters(), lr=self.cfg.lr, weight_decay=2e-5)
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.cfg.epoch, eta_min=0)

        return [self.optimizer], [self.scheduler]

    def forward(self, cont_f, cat_f):
        return self.net(cont_f, cat_f)

    def step(self, batch):
        cont_f, cat_f, label = batch
        out = self.forward(cont_f, cat_f)
        loss = self.criterion(out, label)

        return loss, label

    def training_step(self, batch, batch_idx):
        loss, label = self.step(batch)
        logs = {'train/loss': loss.item()}

        return {'loss': loss, 'labels': label}

    def validation_step(self, batch, batch_idx):
        loss, label = self.step(batch)
        val_logs = {'val/loss': loss.item()}

        return {'val_loss': loss, 'labels': label.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val/epoch_loss': avg_loss.item()}

        return {'avg_val_loss': avg_loss}


    def test_step(self, batch, batch_idx):
        cont_f, cat_f, ids = batch
        out = self.forward(cont_f, cat_f)
        logits = torch.sigmoid(out)

        return {'pred': logits, 'id': ids}


    def test_epoch_end(self, outputs):
        preds = torch.cat([x['pred'] for x in outputs]).detach().cpu().numpy()
        res = pd.DataFrame(preds, columns=self.target_cols)

        ids = [x['id'] for x in outputs]
        ids = [list(x) for x in ids]
        ids = list(itertools.chain.from_iterable(ids))

        res.insert(0, 'sig_id', ids)

        res.to_csv('submission.csv', index=False)
        
        return {}

In [19]:
class LinearReluBnDropout(nn.Module):
    def __init__(self, in_features, out_features, dropout_rate):
        super(LinearReluBnDropout, self).__init__()

        self.block = nn.Sequential(
            nn.utils.weight_norm(nn.Linear(in_features, out_features)),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(out_features),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        x = self.block(x)

        return x


class TablarNet(nn.Module):
    def __init__(self, emb_dims, cfg, in_cont_features=875, out_features=206):
        super(TablarNet, self).__init__()

        self.embedding_layer = nn.ModuleList([nn.Embedding(x, y) for x, y in emb_dims])
        self.dropout = nn.Dropout(cfg.dropout_rate, inplace=True)

        self.first_bn_layer = nn.Sequential(
            nn.BatchNorm1d(in_cont_features),
            nn.Dropout(cfg.dropout_rate)
        )

        first_in_feature = in_cont_features + sum([y for x, y in emb_dims])

        self.block = nn.Sequential(
            LinearReluBnDropout(in_features=first_in_feature,
                                out_features=cfg.hidden_size,
                                dropout_rate=cfg.dropout_rate),
            LinearReluBnDropout(in_features=cfg.hidden_size,
                                out_features=cfg.hidden_size,
                                dropout_rate=cfg.dropout_rate)
        )

        self.last = nn.Linear(cfg.hidden_size, out_features)

    def forward(self, cont_f, cat_f):

        cat_x = [layer(cat_f[:, i]) for i, layer in enumerate(self.embedding_layer)]
        cat_x = torch.cat(cat_x, 1)
        cat_x = self.dropout(cat_x)

        cont_x = self.first_bn_layer(cont_f)

        x = torch.cat([cont_x, cat_x], 1)

        x = self.block(x)
        x = self.last(x)

        return x

In [20]:
def main():
    # Set data dir
    data_dir = '../input/lish-moa'
    # CV
    cv = MultilabelStratifiedKFold(n_splits=4)
    # Random Seed
    seed_everything(cfg.seed)

    # Lightning Data Module  ####################################################
    datamodule = DataModule(data_dir, cfg, cv)
    datamodule.prepare_data()
    target_cols = datamodule.target_cols
    feature_cols = datamodule.feature_cols

    # Model  ####################################################################
    # Adjust input dim (original + composition dim - category features)
    in_features = len(feature_cols) - 3
    net = TablarNet(cfg.emb_dims, cfg, in_cont_features=in_features)

    # Lightning Module  #########################################################
    model = LightningSystem(net, cfg, target_cols)

    checkpoint_callback = ModelCheckpoint(
        save_top_k=1,
        verbose=True,
        monitor='avg_val_loss',
        mode='min'
    )
    
    early_stop_callback = EarlyStopping(
        monitor='avg_val_loss',
        min_delta=0.00,
        patience=3,
        verbose=False,
        mode='min'
    )

    trainer = Trainer(
        logger=False,
        max_epochs=cfg.epoch,
        checkpoint_callback=checkpoint_callback,
        callbacks=[early_stop_callback],
        gpus=1
            )

    # Train & Test  ############################################################
    # Train
    trainer.fit(model, datamodule=datamodule)

    # Test
    trainer.test(model, datamodule=datamodule)

In [21]:
main()

Features Mapped to Integers.
PCA Performed.


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | net       | TablarNet         | 6 M   
1 | criterion | BCEWithLogitsLoss | 0     


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 0: avg_val_loss reached 0.03040 (best 0.03040), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=0.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 1: avg_val_loss reached 0.02047 (best 0.02047), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=1.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 2: avg_val_loss reached 0.01908 (best 0.01908), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=2.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 3: avg_val_loss reached 0.01877 (best 0.01877), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=3.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 4: avg_val_loss reached 0.01837 (best 0.01837), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=4.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5: avg_val_loss reached 0.01793 (best 0.01793), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=5.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 6: avg_val_loss reached 0.01784 (best 0.01784), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=6.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 7: avg_val_loss reached 0.01745 (best 0.01745), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=7.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 8: avg_val_loss reached 0.01725 (best 0.01725), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=8.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9: avg_val_loss reached 0.01716 (best 0.01716), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=9-v0.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 10: avg_val_loss reached 0.01710 (best 0.01710), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=10.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11: avg_val_loss reached 0.01696 (best 0.01696), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=11.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12: avg_val_loss reached 0.01693 (best 0.01693), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=12.ckpt as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13: avg_val_loss was not in top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14: avg_val_loss reached 0.01689 (best 0.01689), saving model to D:\Kevin\Machine Learning\MoA Prediction\notebooks\checkpoints\epoch=14.ckpt as top 1





HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------

