In [41]:
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option("max_rows", 200)

from haversine import haversine, Unit

from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer

from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV

from sklearn.ensemble import RandomForestClassifier
from catboost import CatBoostClassifier
import xgboost as xgb

from sklearn.metrics import roc_auc_score

print("Import complete")

Import complete


## Import Files

In [42]:
# test data
test = pd.read_csv(r"source/test_features.csv")

# train data
train = pd.read_csv(r"source/train_features.csv")

# target
labels = pd.read_csv(r"source/train_labels.csv")

# check whether rows are equal
print("train data => rows: %s, cols: %s" % (train.shape[0], train.shape[1]))
print("labels data => rows: %s, cols: %s" % (labels.shape[0], labels.shape[1]))
print("test data => rows: %s, cols: %s" % (test.shape[0], test.shape[1]))

assert(train.shape[1] == test.shape[1])

train data => rows: 59400, cols: 40
labels data => rows: 59400, cols: 2
test data => rows: 14850, cols: 40


## Pre-processing

In [43]:
# check for duplicates in data
train_dup_count = np.sum(train.duplicated())
label_dup_count = np.sum(labels.duplicated())
test_dup_count = np.sum(test.duplicated())

print("duplicates in train dataset: %s" % train_dup_count)
print("duplicates in label dataset: %s" % label_dup_count)
print("duplicates in test dataset: %s" % test_dup_count)

assert(train_dup_count == 0 and label_dup_count == 0 and test_dup_count == 0)

duplicates in train dataset: 0
duplicates in label dataset: 0
duplicates in test dataset: 0


In [44]:
# make id as index
train = train.set_index("id")
test = test.set_index("id")
labels = labels.set_index("id")

In [45]:
# differentiate train and test data
train["type"] = "train"
test["type"] = "test"

# create a data column by merging both train and label set
data = pd.concat([train, test], ignore_index=False)

assert (data.shape[0] == train.shape[0] + test.shape[0])
assert (data.shape[1] == train.shape[1] == test.shape[1])

data.tail(2)

Unnamed: 0_level_0,amount_tsh,date_recorded,funder,gps_height,installer,longitude,latitude,wpt_name,num_private,basin,subvillage,region,region_code,district_code,lga,ward,population,public_meeting,recorded_by,scheme_management,scheme_name,permit,construction_year,extraction_type,extraction_type_group,extraction_type_class,management,management_group,payment,payment_type,water_quality,quality_group,quantity,quantity_group,source,source_type,source_class,waterpoint_type,waterpoint_type_group,type
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1
33492,0.0,2013-02-18,Germany,998,DWE,35.432732,-10.584159,Kwa John,0,Lake Nyasa,Namakinga B,Ruvuma,10,2,Songea Rural,Maposeni,150,True,GeoData Consultants Ltd,VWC,Mradi wa maji wa maposeni,True,2009,gravity,gravity,gravity,vwc,user-group,never pay,never pay,soft,good,insufficient,insufficient,river,river/lake,surface,communal standpipe,communal standpipe,test
68707,0.0,2013-02-13,Government Of Tanzania,481,Government,34.765054,-11.226012,Kwa Mzee Chagala,0,Lake Nyasa,Kamba,Ruvuma,10,3,Mbinga,Mbamba bay,40,True,GeoData Consultants Ltd,VWC,DANIDA,True,2008,gravity,gravity,gravity,vwc,user-group,never pay,never pay,soft,good,dry,dry,spring,spring,groundwater,communal standpipe,communal standpipe,test


The `installer` and `funder` columns have too much similar values.
They can be grouped together

Reference: https://github.com/sagol/pumpitup/blob/main/oof_model.ipynb

In [46]:
data['installer'] = data['installer'].astype(str).str.lower()

data['installer'].replace(
    to_replace=('fini water', 'fin water', 'finn water', 'finwater', 'finwate'),
    value='finw', inplace=True)

data['installer'].replace(to_replace='jaica co', value='jaica', inplace=True)

data['installer'].replace(
    to_replace=(
        'district water department', 'district water depar', 'district council',
        'district counci', 'village council orpha','kibaha town council',
        'village council', 'coun', 'village counil', 'council',
        'mbulu district council', 'counc', 'village council .oda',
        'sangea district coun', 'songea district coun', 'villege council',
        'district  council', 'quick win project /council', 'mbozi district council',
        'village  council', 'municipal council', 'tabora municipal council',
        'wb / district council'),
    value='council', inplace=True)

data['installer'].replace(
    to_replace=(
        'rc church', 'rc churc', 'rcchurch/cefa', 'irc', 'rc', 'rc ch', 'hw/rc',
        'rc church/central gover', 'kkkt church', 'pentecost church', 'roman church',
        'rc/mission', 'rc church/cefa', 'lutheran church', 'tag church',
        'free pentecoste church of tanz', 'rc c', 'church', 'rc cathoric',
        'morovian church', 'cefa/rc church', 'rc mission', 'anglican church',
        'church of disciples', 'anglikana church', 'cetral government /rc',
        'pentecostal church', 'cg/rc', 'rc missionary', 'sda church', 'methodist church', 'trc',
        'rc msufi', 'haidomu lutheran church', 'baptist church', 'rc church brother',
        'st magreth church', 'anglica church', 'global resource co', 'rc mi',
        'baptist church of tanzania', 'fpct church', 'rc njoro', 'rc .church',
        'rc mis', 'batist church', 'churc', 'dwe/anglican church','missi', 'mission',
        'ndanda missions', 'rc/mission', 'cvs miss', 'missionaries', 'hydom luthelani',
        'luthe', 'haydom lutheran hospital', 'lutheran', 'missio', 'germany missionary',
        'grail mission kiseki bar', 'missionary', 'heri mission', 'german missionsry',
        'wamissionari wa kikatoriki', 'neemia mission', 'wamisionari wa kikatoriki'),
    value='church', inplace=True)

data['installer'].replace(
    to_replace=(
        'central government', 'gove', 'central govt', 'gover', 'cipro/government',
        'governme', 'adra /government', 'isf/government', 'adra/government',
        'government /tcrs', 'village govt', 'government', 'government /community',
        'concern /government', 'goverm', 'village government', 'cental government',
        'govern', 'cebtral government', 'government /sda', 'tcrs /government',
        'tanzania government', 'centra govt', 'colonial government', 'misri government',
        'government and community', 'cetral government /rc', 'concern/government',
        'government of misri', 'lwi &central government', 'governmen', 'government/tcrs', 'government /world vision',
        'centra government'),
    value='tanzanian government', inplace=True)

data['installer'].replace(
    to_replace=('world vission', 'world division', 'word divisio','world visiin'),
    value='world vision', inplace=True)

data['installer'].replace(to_replace=('unicrf'), value='unicef', inplace=True)

data['installer'].replace(
    to_replace=(
        'commu', 'olgilai village community', 'adra /community', 'adra/community',
        'rwe/ community', 'killflora /community', 'communit', 'taboma/community',
        'arab community', 'adra/ community', 'sekei village community', 'rwe/community',
        'arabs community', 'village community', 'government /community',
        'dads/village community', 'killflora/ community', 'mtuwasa and community',
        'rwe /community', 'ilwilo community', 'summit for water/community',
        'igolola community', 'ngiresi village community', 'rwe community',
        'african realief committe of ku', 'twesa /community', 'shelisheli commission',
        'twesa/ community', 'marumbo community', 'government and community',
        'community bank', 'kitiangare village community', 'oldadai village community',
        'twesa/community', 'tlc/community', 'maseka community', 'islamic community',
        'district community j', 'village water commission', 'village community members',
        'tcrs/village community', 'village water committee', 'comunity'),
    value='community', inplace=True)

data['installer'].replace(
    to_replace=('danid', 'danda','danida co', 'danny', 'daniad', 'dannida', 'danids'),
    value='danida', inplace=True)

data['installer'].replace(
    to_replace=('hesaws', 'huches', 'hesaw', 'hesawz', 'hesawq', 'hesewa'),
    value='hesawa', inplace=True)
data['installer'].replace(
    to_replace=(
        'dwsp', 'kkkt _ konde and dwe', 'rwe/dwe', 'rwedwe', 'dwe/', 'dw', 'dwr',
        'dwe}', 'dwt', 'dwe /tassaf', 'dwe/ubalozi wa marekani', 'consultant and dwe',
        'dwe & lwi', 'ubalozi wa marekani /dwe', 'dwe&', 'dwe/tassaf', 'dw$',
        'dw e', 'tcrs/dwe', 'dw#', 'dweb', 'tcrs /dwe', 'water aid/dwe', 'dww'),
    value='dwe', inplace=True)
data['installer'].replace(
    to_replace=(
        'africa muslim', 'muslimu society(shia)', 'africa muslim agenc',
        'african muslims age', 'muslimehefen international','islamic',
        'the isla', 'islamic agency tanzania',  'islam', 'nyabibuye islamic center'),
    value='muslims', inplace=True)
data['installer'].replace(
    to_replace=(
        'british colonial government', 'british government', 'britain'),
    value='british', inplace=True)
data['installer'].replace(
    to_replace=(
        'tcrs/tlc', 'tcrs /care', 'cipro/care/tcrs', 'tcrs kibondo', 'tcrs.tlc',
        'tcrs /twesa', 'tassaf /tcrs', 'tcrs/care', 'tcrs twesa', 'rwe/tcrs',
        'tcrs/twesa', 'tassaf/ tcrs', 'tcrs/ tassaf', 'tcrs/ twesa', 'tcrs a',
        'tassaf/tcrs'),
    value='tcrs', inplace=True)
data['installer'].replace(
    to_replace=(
        'kkkt-dioces ya pare', 'kkkt leguruki', 'kkkt ndrumangeni', 'kkkt dme',
        'kkkt kilinga', 'kkkt canal', 'kkkt katiti juu', 'kkkt mareu'),
    value='kkkt', inplace=True)

data['installer'].replace(to_replace=('norad/'), value='norad', inplace=True)

data['installer'].replace( to_replace=('tasaf/dmdd', 'dmdd/solider'),value='dmdd', inplace=True)

data['installer'].replace(to_replace=('cjejow construction', 'cjej0'), value='cjejow', inplace=True)

data['installer'].replace(
    to_replace=('china henan constuction', 'china henan contractor', 'china co.', 'chinese'),
    value='china', inplace=True)

data['installer'].replace(
    to_replace=(
        'local contract', 'local technician', 'local', 'local  technician',
        'locall technician', 'local te', 'local technitian', 'local technical tec',
        'local fundi', 'local technical', 'localtechnician', 'village local contractor',
        'local l technician'),
    value='local', inplace=True)

data['installer'].replace(
    to_replace=(
        'oikos e .africa', 'oikos e.africa', 'africa amini alama',
        'africa islamic agency tanzania', 'africare', 'african development foundation',
        'oikos e. africa', 'oikos e.afrika', 'afroz ismail', 'africa', 'farm-africa',
        'oikos e africa', 'farm africa', 'africaone', 'tina/africare', 'africaone ltd',
        'african reflections foundation', 'africa m'),
    value='africa', inplace=True)

data['installer'].replace(to_replace=('0', 'nan', '-'), value='other', inplace=True)

df_installer_cnt = data.groupby('installer')['installer'].count()
other_list = df_installer_cnt[df_installer_cnt<71].index.tolist()
data['installer'].replace(to_replace=other_list, value='other', inplace=True)

In [47]:
data['funder'] = data['funder'].astype(str).str.lower()

data['funder'].replace(
    to_replace=(
        'kkkt_makwale', 'kkkt-dioces ya pare', 'world vision/ kkkt', 'kkkt church',
        'kkkt leguruki', 'kkkt ndrumangeni', 'kkkt dme', 'kkkt canal', 'kkkt usa',
        'kkkt mareu'),
    value='kkkt', inplace=True)

data['funder'].replace(
    to_replace=(
        'government of tanzania', 'norad /government', 'government/ community',
        'cipro/government', 'isf/government', 'finidagermantanzania govt',
        'government /tassaf', 'finida german tanzania govt', 'village government',
        'tcrs /government', 'village govt', 'government/ world bank',
        'danida /government', 'dhv/gove', 'concern /govern', 'vgovernment',
        'lwi & central government', 'government /sda', 'koica and tanzania government',
        'world bank/government', 'colonial government', 'misri government',
        'government and community', 'concern/governm', 'government of misri',
        'government/tassaf', 'government/school', 'government/tcrs', 'unhcr/government',
        'government /world vision', 'norad/government'),
    value='government', inplace=True)

data['funder'].replace(
    to_replace=(
        'british colonial government', 'japan government', 'china government',
        'finland government', 'belgian government', 'italy government',
        'irish government', 'egypt government', 'iran gover', 'swedish', 'finland'),
    value='foreign government', inplace=True)

data['funder'].replace(
    to_replace=(
        'rc church', 'anglican church', 'rc churc', 'rc ch', 'rcchurch/cefa',
        'irc', 'rc', 'churc', 'hw/rc', 'rc church/centr', 'pentecosta church',
        'roman church', 'rc/mission', "ju-sarang church' and bugango",
        'lutheran church', 'roman cathoric church', 'tag church ub', 'aic church',
        'free pentecoste church of tanz', 'tag church', 'fpct church', 'rc cathoric',
        'baptist church', 'morovian church', 'cefa/rcchurch', 'rc mission',
        'bukwang church saints', 'agt church', 'church of disciples', 'rc mofu',
        "gil cafe'church'", 'pentecostal church', 'bukwang church saint',
        'eung am methodist church', 'rc/dwe', 'cg/rc', 'eung-am methodist church',
        'rc missionary', 'sda church', 'methodist church', 'rc msufi',
        'haidomu lutheran church', 'nazareth church', 'st magreth church',
        'agape churc', 'rc missi', 'rc mi', 'rc njoro', 'world vision/rc church',
        'pag church', 'batist church', 'full gospel church', 'nazalet church',
        'dwe/anglican church', 'missi', 'mission', 'missionaries', 'cpps mission',
        'cvs miss', 'grail mission kiseki bar', 'shelisheli commission', 'missionary',
        'heri mission', 'german missionary', 'wamissionari wa kikatoriki',
        'rc missionary', 'germany missionary', 'missio', 'neemia mission', 'rc missi',
        'hydom luthelani', 'luthe', 'lutheran church',  'haydom lutheran hospital',
        'village council/ haydom luther', 'lutheran', 'haidomu lutheran church',
        'resolute golden pride project', 'resolute mininggolden pride',
        'germany cristians'),
    value='church', inplace=True)

data['funder'].replace(
    to_replace=(
        'olgilai village community', 'commu', 'community', 'arab community',
        'sekei village community', 'arabs community', 'village community',
        'mtuwasa and community', 'ilwilo community', 'igolola community',
        'ngiresi village community', 'marumbo community', 'village communi',
        'comune di roma', 'comunity construction fund', 'community bank',
        "oak'zion' and bugango b' commu", 'kitiangare village community',
        'oldadai village community', 'tlc/community', 'maseka community',
        'islamic community',  'tcrs/village community', 'buluga subvillage community',
        'okutu village community'),
    value='community', inplace=True)

data['funder'].replace(
    to_replace=(
        'council', 'wb / district council', 'cdtfdistrict council',
        'sangea district council', 'mheza distric counc', 'kyela council',
        'kibaha town council', 'swidish', 'mbozi district council',
        'village council/ rose kawala',  'songea municipal counci',
        'quick win project /council', 'village council', 'villege council',
        'tabora municipal council', 'kilindi district co', 'kigoma municipal council',
        'district council', 'municipal council', 'district medical',
        'sengerema district council', 'town council', 'mkinga  distric cou',
        'songea district council', 'district rural project', 'mkinga distric coun',
        'dadis'),
    value='district', inplace=True)

data['funder'].replace(
    to_replace=(
        'tcrs.tlc', 'tcrs /care', 'tcrst', 'cipro/care/tcrs', 'tcrs/care', 'tcrs kibondo'),
    value='tcrs', inplace=True)

data['funder'].replace(
    to_replace=(
        'fini water', 'finw', 'fin water', 'finn water', 'finwater'),
    value='fini', inplace=True)

data['funder'].replace(
    to_replace=(
        'islamic', 'the isla', 'islamic found', 'islamic agency tanzania',
        'islam', 'muislam', 'the islamic', 'nyabibuye islamic center', 'islamic society', 'african muslim agency',
        'muslims', 'answeer muslim grou', 'muslimu society(shia)',
        'unicef/african muslim agency', 'muslim world', 'muslimehefen international',
        'shear muslim', 'muslim society'),
    value='islam', inplace=True)

data['funder'].replace(
    to_replace=('danida', 'ms-danish', 'unhcr/danida', 'tassaf/ danida'),
    value='danida', inplace=True)

data['funder'].replace(
    to_replace=(
        'hesawa', 'hesawz', 'hesaw', 'hhesawa', 'hesawwa', 'hesawza', 'hesswa',
        'hesawa and concern world wide'),
    value='hesawa', inplace=True)

data['funder'].replace(
    to_replace=('world vision/adra', 'game division', 'worldvision'),
    value='world vision', inplace=True)

data['funder'].replace(
    to_replace=(
        'germany republi', 'a/co germany', 'aco/germany', 'bingo foundation germany',
        'africa project ev germany', 'tree ways german'),
    value='germany', inplace=True)

data['funder'].replace(to_replace=('0', 'nan', '-'), value='other', inplace=True)
df_funder_cnt = data.groupby('funder')['funder'].count()
other_list = df_funder_cnt[df_funder_cnt<98].index.tolist()
data['funder'].replace(to_replace=other_list, value='other', inplace=True)

Sanitizing `null` values

1. For boolean columns: -> fill with median value

In [48]:
# replace boolean cols with median values
data["public_meeting"].fillna(data["public_meeting"].median(), inplace=True)
data["permit"].fillna(data["permit"].median(), inplace=True)

2. For string/object values, replace with "other"

Reference: https://stackoverflow.com/a/60753938/10582056

In [49]:
nan_cols = ['subvillage', 'scheme_management', 'scheme_name']
imputer = SimpleImputer(missing_values=np.nan, strategy='constant', fill_value = 'other')
data[nan_cols] = imputer.fit_transform(data[nan_cols])

# apparently some columns have `0` instead of np.NaN.
# first replace 0 with np.nan
# then replace np.nan with `mean` and `median` values
mean_fill_cols = ["amount_tsh", "population", "gps_height"]
for col in mean_fill_cols:
    data[col].replace(to_replace=0, value=np.nan, inplace=True)
    data[col].fillna(value=data[col].mean(), inplace=True)

median_fill_cols = ["district_code", "construction_year", "num_private"]
for col in median_fill_cols:
    data[col].replace(to_replace=0, value=np.nan, inplace=True)
    data[col].fillna(value=data[col].median(), inplace=True)

assert (data.columns[data.isna().any()].tolist() == [])

Drop unnecessary / similar columns

Reference: https://github.com/sagol/pumpitup/blob/main/oof_model.ipynb

In [50]:
similar_cols = ['scheme_management', 'quantity_group', 'water_quality', 'region_code', 'payment_type', 'extraction_type', 'waterpoint_type_group', 'date_recorded', 'recorded_by']
data.drop(similar_cols, inplace=True, axis=1)

Data types in `data` are mixed. So transform

1. computable `numeric` and `boolean`(=[`public_meeting`, `permit`]) types to `float64`

2. string columns 'category' for quick transformations

In [51]:
# numeric columns
numeric_columns = ["amount_tsh", "gps_height", "longitude", "latitude", "population", "public_meeting", "permit", "district_code",
                   "construction_year", "num_private"]

# convert to float64
for col in numeric_columns:
    data[col] = data[col].astype("float64")

categorical_columns = ['funder', 'installer', 'basin', 'region', 'lga',
                       'extraction_type_group',
                       'extraction_type_class', 'management',
                       'management_group', 'payment', 'quality_group',
                       'quantity', 'source',
                       'source_type', 'source_class', 'waterpoint_type',
                       'type']

for col in categorical_columns:
    data[col] = data[col].astype("category")

Working with `Latitude` and `Longitude`

Reference: https://stackoverflow.com/a/31398615/10582056

1. Find the haversine distance

In [52]:
mean_lat = data["latitude"].mean()
mean_long = data["longitude"].mean()

data["haversine_distance"] = data.apply(lambda row: haversine((row["latitude"], row["longitude"]), (mean_lat, mean_long), unit=Unit.KILOMETERS), axis=1)

2. Convert `latitude`, `longitude` to `x_coordinate`, `y_coordinate` and `z_coordinate`

Reference: https://heartbeat.fritz.ai/working-with-geospatial-data-in-machine-learning-ad4097c7228d

In [53]:
data['x_coordinate'] = np.cos(data['latitude']) * np.cos(data['longitude'])
data['y_coordinate'] = np.cos(data['latitude']) * np.sin(data['longitude'])
data['z_coordinate'] = np.sin(data['latitude'])

# now that I don't need lat and long, remove them
data = data.drop(["latitude", "longitude"], axis=1)

Doing standardisation with `StandardScalar`

In [54]:
standard_cols = ["amount_tsh", "gps_height", "num_private", "district_code", "population", "haversine_distance"]
scalar = StandardScaler()

data[standard_cols] = scalar.fit_transform(data[standard_cols])

Make all strings to lowercase

In [55]:
string_cols = [col for col in data.columns if data[col].dtype in ("category", object)]
for col in string_cols:
    data[col] = data[col].apply(lambda x: x.lower(), convert_dtype=False)

## Apply CatBoost

Split data to train and test

In [59]:
size = 100

target = labels["status_group"].values.ravel()

train = data[data.type.eq("train")].drop("type", axis=1)
test = data[data.type.eq("test")].drop("type", axis=1)

X_train, X_test, y_train, y_test = train_test_split(
    train.head(size),
    target[:size],
    test_size = 0.2,
    shuffle = True,
    stratify = target[:size],
    random_state = 42
)

In [61]:
# displaying object column data
# both `train` and `test` have same columns

col_details = []
for col in train.columns:
    col_details.append((col, train[col].dtype, train[col].nunique(), list(train[col].unique())))

col_details.sort(key=lambda x: x[-2])

temp = pd.DataFrame(col_details, columns=["Column", "Dtype", "N_Unique", "Unique_vals"])
temp

Unnamed: 0,Column,Dtype,N_Unique,Unique_vals
0,public_meeting,float64,2,"[1.0, 0.0]"
1,permit,float64,2,"[0.0, 1.0]"
2,source_class,category,3,"[groundwater, surface, unknown]"
3,management_group,category,5,"[user-group, other, commercial, parastatal, un..."
4,quantity,category,5,"[enough, insufficient, dry, seasonal, unknown]"
5,quality_group,category,6,"[good, salty, milky, unknown, fluoride, colored]"
6,extraction_type_class,category,7,"[gravity, submersible, handpump, other, motorp..."
7,payment,category,7,"[pay annually, never pay, pay per bucket, unkn..."
8,source_type,category,7,"[spring, rainwater harvesting, dam, borehole, ..."
9,waterpoint_type,category,7,"[communal standpipe, communal standpipe multip..."


In [60]:
cat_features = [col for col in train.columns if data[col].dtype == "category"]
text_features = [col for col in train.columns if data[col].dtype == object]

['funder', 'installer', 'basin', 'region', 'lga', 'extraction_type_group', 'extraction_type_class', 'management', 'management_group', 'payment', 'quality_group', 'quantity', 'source', 'source_type', 'source_class', 'waterpoint_type']
['wpt_name', 'subvillage', 'ward', 'scheme_name']


In [64]:
cat_features

['funder',
 'installer',
 'basin',
 'region',
 'lga',
 'extraction_type_group',
 'extraction_type_class',
 'management',
 'management_group',
 'payment',
 'quality_group',
 'quantity',
 'source',
 'source_type',
 'source_class',
 'waterpoint_type']

In [None]:
random_grid = {
    "n_estimators":[300, 500, 800, 1000, 1300],
    "max_depth": [6, 10, 30, 50, 80, 100],
    "learning_rate": [0.1, 0.15, 0.3, 0.5, 0.8, 1],
    "loss_function": ['Logloss', 'AUC', 'RMSE'],
    "l2_leaf_reg": [1,2,3],
    "max_ctr_complexity": [3, 5, 8],
    "od_wait": [300, 500, 700, 1000],
    "od_type": ["IncToDec", "Iter"]
}

catboost = CatBoostClassifier(
    cat_features=cat_features,
    text_features=text_features,
    task_type="GPU"
)

rscv = RandomizedSearchCV(
    estimator = catboost,
    param_distributions = random_grid,
    scoring='roc_auc',
    n_iter=10,
    cv=3,
    verbose=2,
    random_state=42,
    n_jobs=4
)

print("Ready!")

In [None]:
rscv.fit(X_train, y_train['h1n1_vaccine'])

In [None]:
print(rscv.best_params_)
print(rscv.best_score_)
print("Finished")