# Scaling XGBoost with Dask and Coiled

This notebook walks through training a distributed [XGBoost](https://xgboost.readthedocs.io/en/latest/) model locally on a small dataset using [Dask](https://dask.org/) and then using Dask and [Coiled](https://coiled.io/) to scale out to the cloud to run XGBoost on a larger-than-memory dataset.

## About the Dataset
We'll be using the Arcos dataset on Opioid sales as released by the Washington Post.
- Download the dataset here: https://www.washingtonpost.com/national/2019/07/18/how-download-use-dea-pain-pills-database/
- Washington Post Github repository here: https://github.com/wpinvestigative/arcos-api/
- Descriptions of columns here: https://github.com/wpinvestigative/arcos-api/blob/master/data/data_dictionary.csv

Includes:
- More than 178M transactions.
- Restricted to transactions where "Measure" is "Tab". This means the DOSAGE_UNIT field is the number of pills per tab.
- Restricted to oxycodone and hydrocodone


## TO DO
1. Get LocalCluster working (>> what's the point of including this?)
2. One-hot Encoding
3. Argumentation for XGBoost parameters
4. 

## 1. Importing Libraries

We'll start by importing all the libraries we'll need to run this notebook.

*Note how the the objects we import from **dask_ml** resemble the familiar sklearn API.*

In [1]:
import coiled
import dask.dataframe as  dd
from dask.distributed import Client, LocalCluster
from dask_ml.preprocessing import Categorizer, OneHotEncoder
from dask_ml.model_selection import train_test_split
import xgboost as xgb
from dask.distributed import performance_report

## 2. Local Distributed XGBoost Model using Dask

Next, let's instantiate a local version of the Dask distributed scheduler using the **LocalCluster** object. 

This object will handle parallelism for us on our local machine.

In [None]:
# local dask cluster
cluster = LocalCluster(n_workers=8)
client = Client(cluster)
client

In [5]:
# Connect Dask client to the Coiled cluster
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: Cluster
Dashboard: http://54.83.97.181:8787,

0,1
Dashboard: http://54.83.97.181:8787,Workers: 20
Total threads:  40,Total memory:  153.01 GiB

0,1
Comm: tls://10.4.0.230:8786,Workers: 20
Dashboard: http://10.4.0.230:8787/status,Total threads:  40
Started:  16 minutes ago,Total memory:  153.01 GiB

0,1
Comm: tls://10.4.1.174:42127,Total threads: 2
Dashboard: http://10.4.1.174:34451/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-0ivp289v,Local directory: /dask-worker-space/worker-0ivp289v

0,1
Comm: tls://10.4.1.9:41819,Total threads: 2
Dashboard: http://10.4.1.9:43783/status,Memory: 7.57 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-78vr3wie,Local directory: /dask-worker-space/worker-78vr3wie

0,1
Comm: tls://10.4.1.173:40159,Total threads: 2
Dashboard: http://10.4.1.173:43611/status,Memory: 7.49 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-xhdtjzbn,Local directory: /dask-worker-space/worker-xhdtjzbn

0,1
Comm: tls://10.4.1.97:46425,Total threads: 2
Dashboard: http://10.4.1.97:39671/status,Memory: 7.57 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-nscumejh,Local directory: /dask-worker-space/worker-nscumejh

0,1
Comm: tls://10.4.1.130:33323,Total threads: 2
Dashboard: http://10.4.1.130:45677/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-6noz6ikp,Local directory: /dask-worker-space/worker-6noz6ikp

0,1
Comm: tls://10.4.1.139:39835,Total threads: 2
Dashboard: http://10.4.1.139:35017/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-nc9qzcra,Local directory: /dask-worker-space/worker-nc9qzcra

0,1
Comm: tls://10.4.1.89:35151,Total threads: 2
Dashboard: http://10.4.1.89:39091/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-mrus40jc,Local directory: /dask-worker-space/worker-mrus40jc

0,1
Comm: tls://10.4.1.251:43553,Total threads: 2
Dashboard: http://10.4.1.251:43999/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-gtjfk_p2,Local directory: /dask-worker-space/worker-gtjfk_p2

0,1
Comm: tls://10.4.1.179:34085,Total threads: 2
Dashboard: http://10.4.1.179:35811/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-bd2d6c8h,Local directory: /dask-worker-space/worker-bd2d6c8h

0,1
Comm: tls://10.4.1.222:46177,Total threads: 2
Dashboard: http://10.4.1.222:36665/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-ybzku13u,Local directory: /dask-worker-space/worker-ybzku13u

0,1
Comm: tls://10.4.1.42:39761,Total threads: 2
Dashboard: http://10.4.1.42:44905/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-9_wu4qrz,Local directory: /dask-worker-space/worker-9_wu4qrz

0,1
Comm: tls://10.4.1.84:38029,Total threads: 2
Dashboard: http://10.4.1.84:37853/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-d7vdwhir,Local directory: /dask-worker-space/worker-d7vdwhir

0,1
Comm: tls://10.4.1.159:43641,Total threads: 2
Dashboard: http://10.4.1.159:40879/status,Memory: 7.57 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-g7cif9b6,Local directory: /dask-worker-space/worker-g7cif9b6

0,1
Comm: tls://10.4.1.158:41607,Total threads: 2
Dashboard: http://10.4.1.158:41345/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-vfdkvfd0,Local directory: /dask-worker-space/worker-vfdkvfd0

0,1
Comm: tls://10.4.1.113:39301,Total threads: 2
Dashboard: http://10.4.1.113:44833/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-8w7kkmk1,Local directory: /dask-worker-space/worker-8w7kkmk1

0,1
Comm: tls://10.4.1.59:43655,Total threads: 2
Dashboard: http://10.4.1.59:36173/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-2hs1n_7l,Local directory: /dask-worker-space/worker-2hs1n_7l

0,1
Comm: tls://10.4.1.15:32815,Total threads: 2
Dashboard: http://10.4.1.15:43345/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-e2478sfv,Local directory: /dask-worker-space/worker-e2478sfv

0,1
Comm: tls://10.4.1.133:38687,Total threads: 2
Dashboard: http://10.4.1.133:37635/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-ey_tge29,Local directory: /dask-worker-space/worker-ey_tge29

0,1
Comm: tls://10.4.1.34:45465,Total threads: 2
Dashboard: http://10.4.1.34:40375/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-3fz7cieu,Local directory: /dask-worker-space/worker-3fz7cieu

0,1
Comm: tls://10.4.1.250:39397,Total threads: 2
Dashboard: http://10.4.1.250:42429/status,Memory: 7.68 GiB
Nanny: None,
Local directory: /dask-worker-space/worker-jisscvf_,Local directory: /dask-worker-space/worker-jisscvf_


In [3]:
# Download data from S3
data = dd.read_parquet(
    "s3://coiled-datasets/dea-opioid/arcos_washpost_comp.parquet", 
    compression="lz4",
    storage_options={"anon": True},
    columns=columns+categorical,
)

NameError: name 'columns' is not defined

In [23]:
data_local = data.partitions[0:50]

In [24]:
# inspect the first 5 entries
data_local.head()

Unnamed: 0,QUANTITY,STRENGTH,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,DRUG_NAME,UNIT,Product_Name,Revised_Company_Name
0,1.0,,0.6054,100.0,64,DISTRIBUTOR,ACE SURGICAL SUPPLY CO INC,BROCKTON,MA,2301,PRACTITIONER,"TABRIZI, HAMID R DMD",MALDEN,MA,2148,HYDROCODONE,,HYDROCODONE BIT/ACETA 10MG/500MG USP,Mallinckrodt
1,4.0,,0.12108,40.0,52,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,RETAIL PHARMACY,APOTHECARY SHOP DEER VALLEY,PHOENIX,AZ,85085,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.
2,40.0,,3.6324,1200.0,119,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.
3,20.0,,2.7243,600.0,34,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONEBITARTRATE & ACETA 7.5MG,Apotheca Inc.
4,10.0,,0.9081,300.0,19,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.


This is looking good.

Before we can start training our XGBoost model, however, we'll have to conduct two preprocessing steps:
1. Cast our categorical columns to the correct types (XGBoost only accepts float, integer and boolean dtypes)
2. Create our train and test splits

*Note: we're using the **dask_ml** library for this, which mimics the familiar scikit-learn API*

In [25]:
# Cast categorical columns to the correct type
ce = Categorizer(columns=categorical)
data_local = ce.fit_transform(data_local)
for col in categorical:
    data_local[col] = data_local[col].cat.codes

In [26]:
# Create the train-test split
X, y = data_local.iloc[:, :-1], data_local["DRUG_NAME"]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.1, shuffle=True, random_state=2
)

Great, now we're all set to start training our XGBoost model.

First, we'll create the XGBoost DMatrix and set the model parameters.

In [None]:
# Create the XGBoost DMatrix

dtrain = xgb.dask.DaskDMatrix(client, X_train, y_train)

# Set parameters
params = {
    "max_depth": 8,
    "max_leaves": 2 ** 8,
    "gamma": 0.1,
    "eta": 0.1,
    "min_child_weight": 30,
    "objective": "binary:logistic",
    "grow_policy": "lossguide"
}


Then let's go ahead and train the model.

In [None]:
%%time 
# train the model
output = xgb.dask.train(
    client, params, dtrain, num_boost_round=5,
    evals=[(dtrain, 'train')]
)

And see the results:

In [29]:
# 'booster' is the trained model
booster = output['booster']  

# 'history' is a dictionary containing evaluation metrics
history = output['history']  

## 3. Cloud-Based Distributed XGBoost using Dask and Coiled

Let's now expand this workflow to process the entire dataset (~200GB). We'll run almost exactly the same code except for **2 changes**:
1. We'll connect Dask to a Coiled cluster in the cloud, instead of to our local CPU cores,
2. We'll work with the entire dataset, instead of the first 50 partitions.

In the section below we've copied and pasted the cells from above so that you can run this notebook from top to bottom in one go. Alternatively, you could run the cell below (where we instantiate the Coiled Cluster) and then simply re-run the cells above -- making sure to adjust the cell that downloads the data as well, of course.

### Instantiate Coiled Cluster
Let's create our Coiled cluster in the cloud. We'll specify a cluster of 20 workers, with 4 CPU cores and 16GB of RAM each. That should allow the entire dataset to fit into the cluster's memory comfortably.


In [2]:
cluster = coiled.Cluster(
    name="xgboost",
    software="xgboost-new",
    n_workers=15,
    shutdown_on_close=False,
    scheduler_options={'idle_timeout':'1hour'},
)

Output()

Found software environment build
Created FW rules: coiled-dask-rrpelgr71-47434-firewall
Created scheduler VM: coiled-dask-rrpelgr71-47434-scheduler (type: t3a.medium, ip: ['35.172.109.141'])


In [3]:
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: coiled.Cluster
Dashboard: http://35.172.109.141:8787,

0,1
Dashboard: http://35.172.109.141:8787,Workers: 1
Total threads: 2,Total memory: 7.48 GiB

0,1
Comm: tls://10.4.4.82:8786,Workers: 1
Dashboard: http://10.4.4.82:8787/status,Total threads: 2
Started: Just now,Total memory: 7.48 GiB

0,1
Comm: tls://10.4.18.14:39263,Total threads: 2
Dashboard: http://10.4.18.14:35161/status,Memory: 7.48 GiB
Nanny: tls://10.4.18.14:34255,
Local directory: /dask-worker-space/worker-lze1g4nh,Local directory: /dask-worker-space/worker-lze1g4nh


## Import Data

In [4]:
columns = [
    "QUANTITY",
    "STRENGTH",
    "CALC_BASE_WT_IN_GM",
    "DOSAGE_UNIT",
]

In [5]:
categorical = [
    "TRANSACTION_ID",
    "REPORTER_BUS_ACT",
    "REPORTER_NAME",
    "REPORTER_CITY",
    "REPORTER_STATE",
    "REPORTER_ZIP",
    "BUYER_BUS_ACT",
    "BUYER_NAME",
    "BUYER_CITY",
    "BUYER_STATE",
    "BUYER_ZIP",
    "DRUG_NAME",
    "UNIT",
    "Product_Name",
    "Revised_Company_Name",
]

### Download the Data

In [6]:
# create Dask Dataframe with parquet data from S3
data = dd.read_parquet(
    "s3://coiled-datasets/dea-opioid/arcos_washpost_comp.parquet", 
    compression="lz4",
    storage_options={"anon": True},
    columns=columns+categorical,
)

In [8]:
data.head()

Unnamed: 0,QUANTITY,STRENGTH,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,DRUG_NAME,UNIT,Product_Name,Revised_Company_Name
0,1.0,,0.6054,100.0,64,DISTRIBUTOR,ACE SURGICAL SUPPLY CO INC,BROCKTON,MA,2301,PRACTITIONER,"TABRIZI, HAMID R DMD",MALDEN,MA,2148,HYDROCODONE,,HYDROCODONE BIT/ACETA 10MG/500MG USP,Mallinckrodt
1,4.0,,0.12108,40.0,52,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,RETAIL PHARMACY,APOTHECARY SHOP DEER VALLEY,PHOENIX,AZ,85085,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.
2,40.0,,3.6324,1200.0,119,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.
3,20.0,,2.7243,600.0,34,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONEBITARTRATE & ACETA 7.5MG,Apotheca Inc.
4,10.0,,0.9081,300.0,19,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.


In [9]:
data

Unnamed: 0_level_0,QUANTITY,STRENGTH,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,DRUG_NAME,UNIT,Product_Name,Revised_Company_Name
npartitions=3750,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
,float64,float64,float64,float64,int64,object,object,object,object,int64,object,object,object,object,int64,object,object,object,object
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [17]:
# data = data.repartition(partition_size="100MB")
# data

Unnamed: 0_level_0,QUANTITY,STRENGTH,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,DRUG_NAME,UNIT,Product_Name,Revised_Company_Name
npartitions=1875,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
,float64,float64,float64,float64,int64,object,object,object,object,int64,object,object,object,object,int64,object,object,object,object
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [18]:
# data = data.partitions[:100]
# data

Unnamed: 0_level_0,QUANTITY,STRENGTH,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,DRUG_NAME,UNIT,Product_Name,Revised_Company_Name
npartitions=100,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
,float64,float64,float64,float64,int64,object,object,object,object,int64,object,object,object,object,int64,object,object,object,object
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


### Preprocessing

There's some sort of problem with the dataset. For some reason it's not creating the DMatrix properly.

Let's strip down to bare bones columns and see if we can get it to work, then incrementally add columns.

Let's also:
- drop all NaNs
- remove Strength column which has no meaningful values
- make sure we're only predicting two classes

In [16]:
data.head()

Unnamed: 0,QUANTITY,STRENGTH,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,DRUG_NAME,UNIT,Product_Name,Revised_Company_Name
0,1.0,,0.6054,100.0,0,0,0,0,0,0,0,0,0,0,0,0,-1,0,0
1,4.0,,0.12108,40.0,1,0,1,1,1,1,1,1,1,1,1,0,-1,1,1
2,40.0,,3.6324,1200.0,2,0,1,1,1,1,0,2,2,1,2,0,-1,1,1
3,20.0,,2.7243,600.0,3,0,1,1,1,1,0,2,2,1,2,0,-1,2,1
4,10.0,,0.9081,300.0,4,0,1,1,1,1,0,2,2,1,2,0,-1,1,1


In [19]:
# cat = data[categorical]
# cat = cat.drop('UNIT', axis=1)
# cat.head()

In [20]:
# enc = OneHotEncoder()
# enc.fit(cat)

In [10]:
%%time
# Cast categorical columns to the correct type
ce = Categorizer(columns=categorical)
data = ce.fit_transform(data)
for col in categorical:
    data[col] = data[col].cat.codes

CPU times: user 8.96 s, sys: 2.37 s, total: 11.3 s
Wall time: 1min 37s


In [11]:
data.head(10)

Unnamed: 0,QUANTITY,STRENGTH,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,DRUG_NAME,UNIT,Product_Name,Revised_Company_Name
0,1.0,,0.6054,100.0,0,0,0,0,0,0,0,0,0,0,0,0,-1,0,0
1,4.0,,0.12108,40.0,1,0,1,1,1,1,1,1,1,1,1,0,-1,1,1
2,40.0,,3.6324,1200.0,2,0,1,1,1,1,0,2,2,1,2,0,-1,1,1
3,20.0,,2.7243,600.0,3,0,1,1,1,1,0,2,2,1,2,0,-1,2,1
4,10.0,,0.9081,300.0,4,0,1,1,1,1,0,2,2,1,2,0,-1,1,1
5,20.0,,1.8162,600.0,5,0,1,1,1,1,0,2,2,1,2,0,-1,1,1
6,5.0,,0.227025,75.0,6,0,1,1,1,1,0,3,3,1,3,0,-1,1,1
7,10.0,,0.45405,150.0,7,0,1,1,1,1,0,3,3,1,3,0,-1,1,1
8,10.0,,0.45405,150.0,8,0,1,1,1,1,0,3,3,1,3,0,-1,1,1
9,10.0,,0.45405,150.0,9,0,1,1,1,1,0,3,3,1,3,0,-1,1,1


Categorical columns are now encoded.

NOTE: technically this should be **one-hot encoded** to avoid the ML model treating this as ORDINAL data.

In [12]:
# rearrange column names
cols = data.columns.to_list()
cols_new = cols[:-4] + cols[-3:] + [cols[-4]]
data_new = data[cols_new]

In [13]:
data_new.columns

Index(['QUANTITY', 'STRENGTH', 'CALC_BASE_WT_IN_GM', 'DOSAGE_UNIT',
       'TRANSACTION_ID', 'REPORTER_BUS_ACT', 'REPORTER_NAME', 'REPORTER_CITY',
       'REPORTER_STATE', 'REPORTER_ZIP', 'BUYER_BUS_ACT', 'BUYER_NAME',
       'BUYER_CITY', 'BUYER_STATE', 'BUYER_ZIP', 'UNIT', 'Product_Name',
       'Revised_Company_Name', 'DRUG_NAME'],
      dtype='object')

In [14]:
data_new2 = data_new.drop("STRENGTH", axis=1)

In [19]:
data_dropped = data_new2.dropna(how="any")

In [20]:
data_dropped.head()

Unnamed: 0,QUANTITY,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,UNIT,Product_Name,Revised_Company_Name,DRUG_NAME
0,1.0,0.6054,100.0,0,0,0,0,0,0,0,0,0,0,0,-1,0,0,0
1,4.0,0.12108,40.0,1,0,1,1,1,1,1,1,1,1,1,-1,1,1,0
2,40.0,3.6324,1200.0,2,0,1,1,1,1,0,2,2,1,2,-1,1,1,0
3,20.0,2.7243,600.0,3,0,1,1,1,1,0,2,2,1,2,-1,2,1,0
4,10.0,0.9081,300.0,4,0,1,1,1,1,0,2,2,1,2,-1,1,1,0


In [21]:
# Create the train-test split
X, y = data_dropped.iloc[:, :-1], data_dropped["DRUG_NAME"]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.1, shuffle=True, random_state=2
)

### Training Model

In [23]:
client.restart()

KeyboardInterrupt: 

distributed.protocol.pickle - INFO - Failed to serialize subgraph_callable-25fb69a7-e69c-4e07-84b0-a1bcd7e0490f. Exception: 
distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
asyncio.exceptions.CancelledError


In [None]:
%%time
# Create the XGBoost DMatrix
dtrain = xgb.dask.DaskDMatrix(client, X_train, y_train)

KeyboardInterrupt: 

In [None]:
# Set model parameters
params = {
    "max_depth": 8,
    "max_leaves": 2 ** 8,
    "gamma": 0.1,
    "eta": 0.1,
    "min_child_weight": 30,
    "objective": "binary:logistic",
    "grow_policy": "lossguide"
}

In [None]:
%%time 
# train the model 
output = xgb.dask.train(
    client, params, dtrain, num_boost_round=5,
    evals=[(dtrain, 'train')]
)

In [None]:
# 'booster' is the trained model
booster = output['booster']  

# 'history' is a dictionary containing evaluation metrics
history = output['history']  

### Shutting down the cluster

In [29]:
# Shut down the cluster
client.close()

## 4. Recap

In this notebook, we:
- trained a distributed XGBoost model on a portion of the XXX dataset using all of the cores of our machine in parallel by instantiating a Dask LocalCluster,
- expanded the distributed XGBoost model to train on the entire dataset using a Coiled Cluster of XX machines and XX total memory in the cloud.

We’d love to see you apply distributed XGBoost to a dataset that’s meaningful to you. If you’d like to try, swap your own dataset into this notebook and see how well it does! 

Let us know how you get on in our [Coiled Community Slack channel](https://join.slack.com/t/coiled-users/shared_invite/zt-hx1fnr7k-In~Q8ui3XkQfvQon0yN5WQ) or by tweeting at us.

## Preprocessing Internal Only

In [18]:
data.DRUG_NAME.value_counts().compute()

KeyboardInterrupt: 

In [7]:
# inspect the first 5 entries
data.head()

Unnamed: 0,QUANTITY,STRENGTH,CALC_BASE_WT_IN_GM,DOSAGE_UNIT,TRANSACTION_ID,REPORTER_BUS_ACT,REPORTER_NAME,REPORTER_CITY,REPORTER_STATE,REPORTER_ZIP,BUYER_BUS_ACT,BUYER_NAME,BUYER_CITY,BUYER_STATE,BUYER_ZIP,DRUG_NAME,UNIT,Product_Name,Revised_Company_Name
0,1.0,,0.6054,100.0,64,DISTRIBUTOR,ACE SURGICAL SUPPLY CO INC,BROCKTON,MA,2301,PRACTITIONER,"TABRIZI, HAMID R DMD",MALDEN,MA,2148,HYDROCODONE,,HYDROCODONE BIT/ACETA 10MG/500MG USP,Mallinckrodt
1,4.0,,0.12108,40.0,52,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,RETAIL PHARMACY,APOTHECARY SHOP DEER VALLEY,PHOENIX,AZ,85085,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.
2,40.0,,3.6324,1200.0,119,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.
3,20.0,,2.7243,600.0,34,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONEBITARTRATE & ACETA 7.5MG,Apotheca Inc.
4,10.0,,0.9081,300.0,19,DISTRIBUTOR,APOTHECA INC,PHOENIX,AZ,85006,PRACTITIONER,"HOBBS, DOUGLAS DON, MD",GILBERT,AZ,85233,HYDROCODONE,,HYDROCODONE BITARTRATE & ACETA 5MG/,Apotheca Inc.


In [8]:
data.columns

Index(['QUANTITY', 'STRENGTH', 'CALC_BASE_WT_IN_GM', 'DOSAGE_UNIT',
       'TRANSACTION_ID', 'REPORTER_BUS_ACT', 'REPORTER_NAME', 'REPORTER_CITY',
       'REPORTER_STATE', 'REPORTER_ZIP', 'BUYER_BUS_ACT', 'BUYER_NAME',
       'BUYER_CITY', 'BUYER_STATE', 'BUYER_ZIP', 'DRUG_NAME', 'UNIT',
       'Product_Name', 'Revised_Company_Name'],
      dtype='object')

### Process from here:
1. Get categorical columns
2. Preprocess those

In [7]:
data.REPORTER_BUS_ACT.value_counts().compute()

DISTRIBUTOR              178082880
MANUFACTURER                514415
REVERSE DISTRIB                730
CHEMICAL MANUFACTURER            1
Name: REPORTER_BUS_ACT, dtype: int64

In [31]:
data.REPORTER_NAME.value_counts().compute()

MCKESSON CORPORATION            36970763
CARDINAL HEALTH                 25362775
WALGREEN CO                     23047615
AMERISOURCEBERGEN DRUG CORP     15278813
WAL-MART PHARM WAREHOUSE #1      5682151
                                  ...   
BROOKSHIRE GROCERY COMPANY             1
OCEAN MEDICAL                          1
SAVOY MEDICAL SUPPLY COMPANY           1
MCKESSON SPECIALTY LOGISTICS           1
BOCA PHARMACAL LLC                     1
Name: REPORTER_NAME, Length: 494, dtype: int64

In [12]:
data.REPORTER_CITY.value_counts().compute()

ROGERS                9888365
PERRYSBURG            6102999
WOODLAND              5030005
JUPITER               4730129
KNOXVILLE             4439162
                       ...   
CUMBERLAND                  1
HILTON HEAD ISLAND          1
CORAL SPRINGS               1
CONWAY                      1
LINCOLN                     1
Name: REPORTER_CITY, Length: 418, dtype: int64

In [14]:
data.REPORTER_STATE.value_counts().head(10)

CA    19715990
OH    12204507
FL    12174208
AR    11156357
TX     9234960
IL     7994830
GA     7677366
TN     7501719
NY     7333245
IN     7261726
Name: REPORTER_STATE, dtype: int64

In [15]:
data.BUYER_BUS_ACT.value_counts().compute()

CHAIN PHARMACY         116383479
RETAIL PHARMACY         61053356
PRACTITIONER             1032103
PRACTITIONER-DW/30         59670
PRACTITIONER-DW/100        53702
PRACTITIONER-DW/275        15716
Name: BUYER_BUS_ACT, dtype: int64

In [17]:
data.QUANTITY.value_counts().compute()

1.0        97341738
2.0        35231179
3.0        15903312
4.0         8747752
5.0         5515753
             ...   
1660.0            1
679.0             1
1633.0            1
684.0             1
88880.0           1
Name: QUANTITY, Length: 1159, dtype: int64

In [22]:
data[['QUANTITY', 'UNIT', 'STRENGTH', 'DOSAGE_UNIT']].head()

Unnamed: 0,QUANTITY,UNIT,STRENGTH,DOSAGE_UNIT
0,1.0,,,100.0
1,4.0,,,40.0
2,40.0,,,1200.0
3,20.0,,,600.0
4,10.0,,,300.0


In [19]:
data.UNIT.value_counts().compute()

D    55
2    53
1    48
Name: UNIT, dtype: int64

May be worth including only entries for which UNIT = None. In that case we know for sure we're measuring the 

In [18]:
data.STRENGTH.value_counts().compute()

0.0       42992312
1000.0     1116037
800.0         2481
600.0         2140
100.0         1284
            ...   
896.0            1
910.0            1
915.0            1
943.0            1
8000.0           1
Name: STRENGTH, Length: 350, dtype: int64

STRENGTH is "One of three values: "(1) the purity of a bulk rawmaterial (2) the fractional portion of a standard NDC package size or (3) the percentage by which a package exceeds a standard NDC package size."

In [24]:
data.Measure.value_counts().compute()

TAB    178598026
Name: Measure, dtype: int64

In [25]:
data.DOSAGE_UNIT.value_counts().compute()

100.0        62033577
500.0        38699261
200.0        27531578
300.0        13225804
1000.0       10362267
               ...   
10332.0             1
10192.0             1
10164.0             1
9950.0              1
3115000.0           1
Name: DOSAGE_UNIT, Length: 3114, dtype: int64

DOSAGE_UNIT is DEA-calculated field indicating number of pills for transactions where Measure is Tab (all transactions for this dataset).

In [26]:
data.CALC_BASE_WT_IN_GM.value_counts().compute()

4.540500e-01    16392694
6.054000e-01    16020445
3.027000e+00    13221388
1.513500e+00    12902288
2.270250e+00     9890539
                  ...   
4.209043e+01           1
4.249410e+01           1
4.258989e+01           1
4.290772e+01           1
3.405375e+07           1
Name: CALC_BASE_WT_IN_GM, Length: 5603, dtype: int64

CALC_BASE_WT_IN_GM is a DEA-added field indicating active weight of the drug in the transaction, in grams

In [27]:
data.Product_Name.value_counts().compute()

HYDROCODONE BIT. 10MG/ACETAMINOPHEN     12564663
HYDROCODONE BIT/ACETAMINOPHEN 5MG/50    10722861
OXYCODONE HCL/ACETAMINOPHEN 5MG/325M     7547523
HYDROCODONE BIT 5MG/ACETAMINOPHEN 50     6587821
HYDROCODONE BIT/ACETA 7.5MG/500MG US     6000824
                                          ...   
VICODIN HYDROCODO.BIT.5MG/AC                   1
OXYCOD 5MG HCL/ACET 325MG TAB USP              1
HYDROCODO.BIT 7.5&AC USP TAB                   1
HYDROCODONE BIT/ACETA 7.5MG/750MG US           1
VICODIN ES TAB 7.5MG HYDROCODO.BIT/7           1
Name: Product_Name, Length: 1109, dtype: int64

In [30]:
data.Revised_Company_Name.value_counts().compute()

Mallinckrodt                                            65380874
Allergan, Inc.                                          52458781
Endo Pharmaceuticals, Inc.                              26366516
Purdue Pharma LP                                        13934502
Amneal Pharmaceuticals, Inc.                             6686972
                                                          ...   
Pharmaceutical Manufacturing Research Services, Inc.         110
Coupler Enterprises                                           73
Prepackage Specialists                                        52
Rx Pak Division of McKesson Corporation                        5
United Research Laboratories, Inc.                             2
Name: Revised_Company_Name, Length: 85, dtype: int64

In [28]:
data.Ingredient_Name.value_counts().compute()

HYDROCODONE BITARTRATE HEMIPENTAHYDRATE                109664762
OXYCODONE HYDROCHLORIDE                                 68933034
POLY ETHYLENE GLYCOL OXYCODOL (MPEG 6-ALPHA-OXYCODO          230
Name: Ingredient_Name, dtype: int64

In [29]:
data.dos_str.value_counts().compute()

10.0000     57865691
5.0000      47355052
7.5000      45693133
30.0000      5948201
20.0000      5263221
15.0000      5105884
40.0000      4904995
80.0000      3869498
60.0000       955627
2.5000        818028
5.3500        328988
4.8355        174280
45.0000          560
120.0000         172
100.0000         113
200.0000          95
400.0000          22
50.0000            5
Name: dos_str, dtype: int64

dos_str is the strength of the dose in milligrams