In [2]:
## Imports

In [2]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
import lightgbm as lgb
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.multioutput import MultiOutputRegressor

In [4]:
## Load data

In [3]:
df = pd.read_csv("../data/datasets/2021-12-03 04:24:21.588262.csv", header=None)

In [4]:
df.fillna(0, inplace=True)

In [5]:
df.shape

(400, 1044)

In [6]:
## Extract data

In [7]:
ANTIBIOTIC_LIST = ['Amikacin', 'Ampicillin', 'Ampicillin/Sulbactam', 'Aztreonam', 'Cefazolin', 'Cefepime', 'Cefoxitin', 'Ceftazidime', 'Ceftriaxone', 'Cefuroxime sodium',
                   'Ciprofloxacin', 'Gentamicin', 'Imipenem', 'Levofloxacin', 'Meropenem', 'Nitrofurantoin', 'Piperacillin/Tazobactam', 'Tetracycline', 'Tobramycin', 'Trimethoprim/Sulfamethoxazole']


In [8]:
params = {'colsample_bytree': 0.7, 'max_depth': 15, 'min_split_gain': 0.4, 'n_estimators': 400, 'num_leaves': 50, 'reg_alpha': 1.3, 'reg_lambda': 1.1, 'subsample': 0.9, 'subsample_freq': 20}

In [9]:
param_grid = {
    'n_estimators': [400, 700, 1000],
    'colsample_bytree': [0.7, 0.8],
    'max_depth': [15,20,25],
    'num_leaves': [50, 100, 200],
    'reg_alpha': [1.1, 1.2, 1.3],
    'reg_lambda': [1.1, 1.2, 1.3],
    'min_split_gain': [0.3, 0.4],
    'subsample': [0.7, 0.8, 0.9],
    'subsample_freq': [20]
}


In [15]:
results

[{'name': 'Trimethoprim/Sulfamethoxazole',
  'mse': 1.868773725365026,
  'mae': 1.057192978100121,
  'r2': -0.002060757596588081},
 {'name': 'Tobramycin',
  'mse': 15.215157400642825,
  'mae': 3.007496818438285,
  'r2': 0.5842064246987198},
 {'name': 'Tetracycline',
  'mse': 26.03370034548165,
  'mae': 4.1748590022730045,
  'r2': 0.27839511203709644},
 {'name': 'Piperacillin/Tazobactam',
  'mse': 1379.42821654932,
  'mae': 25.761185860026043,
  'r2': 0.3987277021032969},
 {'name': 'Nitrofurantoin',
  'mse': 4567.236728463197,
  'mae': 59.16233814006517,
  'r2': -0.21643907070235668},
 {'name': 'Meropenem',
  'mse': 7.928262027605476,
  'mae': 1.6625317433362445,
  'r2': 0.8107763179918606},
 {'name': 'Levofloxacin',
  'mse': 1.684862474602032,
  'mae': 0.7655098104592416,
  'r2': 0.686034219754267},
 {'name': 'Imipenem',
  'mse': 8.827109845722472,
  'mae': 1.9295695015915981,
  'r2': 0.7702731704621748},
 {'name': 'Gentamicin',
  'mse': 27.045976420154844,
  'mae': 4.129745419912086,


In [27]:
results

[{'name': 'Trimethoprim/Sulfamethoxazole',
  'mse': 1.649661309474234,
  'mae': 0.9704840049518665,
  'r2': 0.25550617463064706},
 {'name': 'Tobramycin',
  'mse': 14.995343593655518,
  'mae': 3.033213251237812,
  'r2': 0.5966101974447044},
 {'name': 'Tetracycline',
  'mse': 27.608943733859793,
  'mae': 4.469331724307332,
  'r2': 0.23661320408071296},
 {'name': 'Piperacillin/Tazobactam',
  'mse': 1601.9756286357558,
  'mae': 28.033284746935845,
  'r2': 0.354134031424433},
 {'name': 'Nitrofurantoin',
  'mse': 4241.319527372664,
  'mae': 54.857381076878966,
  'r2': -0.13323679394572596},
 {'name': 'Meropenem',
  'mse': 164.56558366089928,
  'mae': 4.275839356255748,
  'r2': 0.42860150573305933},
 {'name': 'Levofloxacin',
  'mse': 3.1130366149806292,
  'mae': 0.9550121192998478,
  'r2': 0.5294961243574345},
 {'name': 'Imipenem',
  'mse': 57.72420401667318,
  'mae': 3.5438028001574526,
  'r2': 0.6450761177817961},
 {'name': 'Gentamicin',
  'mse': 148.1299285821507,
  'mae': 6.18383684181730

In [16]:
results = []

for index, value in enumerate(ANTIBIOTIC_LIST[::-1]):
    index_name = int((len(df.columns) - 1) - index)

    y = df[index_name]
    x = df.drop(index_name, axis=1)
    
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.3, random_state = 42, shuffle=False)

    model = lgb.LGBMRegressor(**params)
    model.fit(x_train, y_train)
    
    y_pred=model.predict(x_test)
    
    mse, mae, r2 = get_metrics(y_test, y_pred)
    results.append({
        "name": value, 
         "mse": mse, 
         "mae": mae, 
         "r2": r2
    })

In [24]:
results = []

for index, value in enumerate(ANTIBIOTIC_LIST[::-1]):
    index_name = int((len(df.columns) - 1) - index)

    y = df[index_name]
    x = df.drop(index_name, axis=1)
    
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.3, random_state = 42, shuffle=False)
    gs, pred = algorithm_pipeline(x_train, x_test, y_train, y_test, model, 
                                 param_grid, cv=5)

    params = gs.best_params_
    
    model = lgb.LGBMRegressor(**params)
    model.fit(x_train, y_train)
    
    y_pred=model.predict(x_test)
    
    mse, mae, r2 = get_metrics(y_test, y_pred)
    results.append({
        "name": value, 
         "mse": mse, 
         "mae": mae, 
         "r2": r2
    })

Fitting 5 folds for each of 2916 candidates, totalling 14580 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   13.6s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:  1.1min
[Parallel(n_jobs=-1)]: Done 341 tasks      | elapsed:  2.6min
[Parallel(n_jobs=-1)]: Done 624 tasks      | elapsed:  5.1min
[Parallel(n_jobs=-1)]: Done 989 tasks      | elapsed:  9.1min
[Parallel(n_jobs=-1)]: Done 1434 tasks      | elapsed: 13.2min
[Parallel(n_jobs=-1)]: Done 1961 tasks      | elapsed: 17.7min
[Parallel(n_jobs=-1)]: Done 2568 tasks      | elapsed: 23.9min
[Parallel(n_jobs=-1)]: Done 3257 tasks      | elapsed: 30.0min
[Parallel(n_jobs=-1)]: Done 4026 tasks      | elapsed: 37.2min
[Parallel(n_jobs=-1)]: Done 4877 tasks      | elapsed: 46.1min
[Parallel(n_jobs=-1)]: Done 5808 tasks      | elapsed: 54.4min
[Parallel(n_jobs=-1)]: Done 6821 tasks      | elapsed: 63.4min
[Parallel(n_jobs=-1)]: Done 7914 tasks      | elapsed: 74.2min
[Parallel(n_jobs=-1)]: Done 9089 tasks      | 

Fitting 5 folds for each of 2916 candidates, totalling 14580 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   28.5s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:  2.8min
[Parallel(n_jobs=-1)]: Done 341 tasks      | elapsed:  6.8min
[Parallel(n_jobs=-1)]: Done 624 tasks      | elapsed: 13.2min
[Parallel(n_jobs=-1)]: Done 989 tasks      | elapsed: 22.2min
[Parallel(n_jobs=-1)]: Done 1434 tasks      | elapsed: 31.9min
[Parallel(n_jobs=-1)]: Done 1961 tasks      | elapsed: 42.6min
[Parallel(n_jobs=-1)]: Done 2568 tasks      | elapsed: 56.2min
[Parallel(n_jobs=-1)]: Done 3257 tasks      | elapsed: 71.6min
[Parallel(n_jobs=-1)]: Done 4026 tasks      | elapsed: 88.8min
[Parallel(n_jobs=-1)]: Done 4877 tasks      | elapsed: 108.0min
[Parallel(n_jobs=-1)]: Done 5808 tasks      | elapsed: 128.8min
[Parallel(n_jobs=-1)]: Done 6821 tasks      | elapsed: 150.4min
[Parallel(n_jobs=-1)]: Done 7914 tasks      | elapsed: 175.6min
[Parallel(n_jobs=-1)]: Done 9089 tasks    

Fitting 5 folds for each of 2916 candidates, totalling 14580 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   28.0s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:  2.8min
[Parallel(n_jobs=-1)]: Done 341 tasks      | elapsed:  6.7min
[Parallel(n_jobs=-1)]: Done 624 tasks      | elapsed: 13.1min
[Parallel(n_jobs=-1)]: Done 989 tasks      | elapsed: 22.6min
[Parallel(n_jobs=-1)]: Done 1434 tasks      | elapsed: 32.3min
[Parallel(n_jobs=-1)]: Done 1961 tasks      | elapsed: 43.1min
[Parallel(n_jobs=-1)]: Done 2568 tasks      | elapsed: 56.7min
[Parallel(n_jobs=-1)]: Done 3257 tasks      | elapsed: 72.0min
[Parallel(n_jobs=-1)]: Done 4026 tasks      | elapsed: 89.0min
[Parallel(n_jobs=-1)]: Done 4877 tasks      | elapsed: 108.3min
[Parallel(n_jobs=-1)]: Done 5808 tasks      | elapsed: 129.1min
[Parallel(n_jobs=-1)]: Done 6821 tasks      | elapsed: 150.9min
[Parallel(n_jobs=-1)]: Done 7914 tasks      | elapsed: 176.5min
[Parallel(n_jobs=-1)]: Done 9089 tasks    

Fitting 5 folds for each of 2916 candidates, totalling 14580 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   35.5s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:  3.7min
[Parallel(n_jobs=-1)]: Done 341 tasks      | elapsed:  9.3min
[Parallel(n_jobs=-1)]: Done 624 tasks      | elapsed: 21.0min
[Parallel(n_jobs=-1)]: Done 989 tasks      | elapsed: 40.4min
[Parallel(n_jobs=-1)]: Done 1434 tasks      | elapsed: 59.9min
[Parallel(n_jobs=-1)]: Done 1961 tasks      | elapsed: 80.3min
[Parallel(n_jobs=-1)]: Done 2568 tasks      | elapsed: 110.8min
[Parallel(n_jobs=-1)]: Done 3257 tasks      | elapsed: 137.2min
[Parallel(n_jobs=-1)]: Done 4026 tasks      | elapsed: 170.8min
[Parallel(n_jobs=-1)]: Done 4877 tasks      | elapsed: 214.1min
[Parallel(n_jobs=-1)]: Done 5808 tasks      | elapsed: 250.9min
[Parallel(n_jobs=-1)]: Done 6821 tasks      | elapsed: 293.1min
[Parallel(n_jobs=-1)]: Done 7914 tasks      | elapsed: 343.2min
[Parallel(n_jobs=-1)]: Done 9089 tasks 

Fitting 5 folds for each of 2916 candidates, totalling 14580 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   34.4s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:  3.8min
[Parallel(n_jobs=-1)]: Done 341 tasks      | elapsed:  9.3min
[Parallel(n_jobs=-1)]: Done 624 tasks      | elapsed: 21.3min
[Parallel(n_jobs=-1)]: Done 989 tasks      | elapsed: 41.4min
[Parallel(n_jobs=-1)]: Done 1434 tasks      | elapsed: 61.5min
[Parallel(n_jobs=-1)]: Done 1961 tasks      | elapsed: 82.4min
[Parallel(n_jobs=-1)]: Done 2568 tasks      | elapsed: 114.4min
[Parallel(n_jobs=-1)]: Done 3257 tasks      | elapsed: 141.8min
[Parallel(n_jobs=-1)]: Done 4026 tasks      | elapsed: 176.7min
[Parallel(n_jobs=-1)]: Done 4877 tasks      | elapsed: 222.4min
[Parallel(n_jobs=-1)]: Done 5808 tasks      | elapsed: 260.9min
[Parallel(n_jobs=-1)]: Done 6821 tasks      | elapsed: 305.0min
[Parallel(n_jobs=-1)]: Done 7914 tasks      | elapsed: 357.4min
[Parallel(n_jobs=-1)]: Done 9089 tasks 

Fitting 5 folds for each of 2916 candidates, totalling 14580 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   33.1s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:  3.4min
[Parallel(n_jobs=-1)]: Done 341 tasks      | elapsed:  8.4min
[Parallel(n_jobs=-1)]: Done 624 tasks      | elapsed: 18.3min
[Parallel(n_jobs=-1)]: Done 989 tasks      | elapsed: 34.0min
[Parallel(n_jobs=-1)]: Done 1434 tasks      | elapsed: 49.9min
[Parallel(n_jobs=-1)]: Done 1961 tasks      | elapsed: 66.8min
[Parallel(n_jobs=-1)]: Done 2568 tasks      | elapsed: 91.0min
[Parallel(n_jobs=-1)]: Done 3257 tasks      | elapsed: 114.1min
[Parallel(n_jobs=-1)]: Done 4026 tasks      | elapsed: 142.0min
[Parallel(n_jobs=-1)]: Done 4877 tasks      | elapsed: 176.8min
[Parallel(n_jobs=-1)]: Done 5808 tasks      | elapsed: 208.5min
[Parallel(n_jobs=-1)]: Done 6821 tasks      | elapsed: 244.0min
[Parallel(n_jobs=-1)]: Done 7914 tasks      | elapsed: 285.8min
[Parallel(n_jobs=-1)]: Done 9089 tasks  

Fitting 5 folds for each of 2916 candidates, totalling 14580 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   17.8s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done 341 tasks      | elapsed:  4.4min
[Parallel(n_jobs=-1)]: Done 624 tasks      | elapsed:  8.5min
[Parallel(n_jobs=-1)]: Done 989 tasks      | elapsed: 14.5min
[Parallel(n_jobs=-1)]: Done 1434 tasks      | elapsed: 21.0min
[Parallel(n_jobs=-1)]: Done 1961 tasks      | elapsed: 27.9min
[Parallel(n_jobs=-1)]: Done 2568 tasks      | elapsed: 37.3min
[Parallel(n_jobs=-1)]: Done 3257 tasks      | elapsed: 47.2min
[Parallel(n_jobs=-1)]: Done 4026 tasks      | elapsed: 58.3min
[Parallel(n_jobs=-1)]: Done 4877 tasks      | elapsed: 71.1min
[Parallel(n_jobs=-1)]: Done 5808 tasks      | elapsed: 84.6min
[Parallel(n_jobs=-1)]: Done 6821 tasks      | elapsed: 98.8min
[Parallel(n_jobs=-1)]: Done 7914 tasks      | elapsed: 115.5min
[Parallel(n_jobs=-1)]: Done 9089 tasks      |

Fitting 5 folds for each of 2916 candidates, totalling 14580 fits


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed:   37.9s
[Parallel(n_jobs=-1)]: Done 138 tasks      | elapsed:  3.8min
[Parallel(n_jobs=-1)]: Done 341 tasks      | elapsed:  9.3min
[Parallel(n_jobs=-1)]: Done 624 tasks      | elapsed: 19.9min
[Parallel(n_jobs=-1)]: Done 989 tasks      | elapsed: 36.6min
[Parallel(n_jobs=-1)]: Done 1434 tasks      | elapsed: 53.4min
[Parallel(n_jobs=-1)]: Done 1961 tasks      | elapsed: 70.7min
[Parallel(n_jobs=-1)]: Done 2568 tasks      | elapsed: 96.1min


KeyboardInterrupt: 

In [None]:
results

In [25]:
results

[{'name': 'Trimethoprim/Sulfamethoxazole',
  'mse': 1.6035968628585067,
  'mae': 0.9688595267177702,
  'r2': 0.2762951061995128},
 {'name': 'Tobramycin',
  'mse': 16.216575600400137,
  'mae': 3.241377727233006,
  'r2': 0.5637578299748887},
 {'name': 'Tetracycline',
  'mse': 26.804366338221712,
  'mae': 4.350389215033006,
  'r2': 0.258859754548052},
 {'name': 'Piperacillin/Tazobactam',
  'mse': 1655.284933169597,
  'mae': 28.730939099968463,
  'r2': 0.3326414038267458},
 {'name': 'Nitrofurantoin',
  'mse': 4421.230745678358,
  'mae': 55.62925679769783,
  'r2': -0.1813072142269141},
 {'name': 'Meropenem',
  'mse': 113.88934476907873,
  'mae': 4.704567288579886,
  'r2': 0.6045576561853002},
 {'name': 'Levofloxacin',
  'mse': 3.7115763161360467,
  'mae': 1.2231464221190094,
  'r2': 0.4390329258957233}]

In [17]:
results

[{'name': 'Trimethoprim/Sulfamethoxazole',
  'mse': 1.7584694678021393,
  'mae': 0.9879802338533317,
  'r2': 0.057085818791628884},
 {'name': 'Tobramycin',
  'mse': 16.377358839917637,
  'mae': 3.276315632765871,
  'r2': 0.5524462608744565},
 {'name': 'Tetracycline',
  'mse': 27.184940718828734,
  'mae': 4.298610717606818,
  'r2': 0.24648490835482673},
 {'name': 'Piperacillin/Tazobactam',
  'mse': 1694.9050853334707,
  'mae': 30.102860811654285,
  'r2': 0.26121601461468524},
 {'name': 'Nitrofurantoin',
  'mse': 4643.972060507573,
  'mae': 58.52501738409949,
  'r2': -0.23687677988006883},
 {'name': 'Meropenem',
  'mse': 9.0197628780036,
  'mae': 1.8484351266623074,
  'r2': 0.7847254875440017},
 {'name': 'Levofloxacin',
  'mse': 1.8648682458297026,
  'mae': 0.8330624626218273,
  'r2': 0.6524910355097608},
 {'name': 'Imipenem',
  'mse': 8.31666080514956,
  'mae': 1.9535202083396532,
  'r2': 0.7835576816760308},
 {'name': 'Gentamicin',
  'mse': 27.219312560268932,
  'mae': 4.18953975541019

In [19]:
# Partial
params = {'colsample_bytree': 0.7, 'max_depth': 15, 'min_split_gain': 0.4, 'n_estimators': 400, 'num_leaves': 50, 'reg_alpha': 1.3, 'reg_lambda': 1.1, 'subsample': 0.9, 'subsample_freq': 20}
results

[{'name': 'Trimethoprim/Sulfamethoxazole',
  'mse': 1.4888562477280876,
  'mae': 0.9222309309940206,
  'r2': 0.32807766240852265},
 {'name': 'Tobramycin',
  'mse': 15.637690174701797,
  'mae': 3.185663422368499,
  'r2': 0.5793304293032157},
 {'name': 'Tetracycline',
  'mse': 27.049059370147248,
  'mae': 4.401220306684193,
  'r2': 0.2520939966318474},
 {'name': 'Piperacillin/Tazobactam',
  'mse': 1563.2156703887251,
  'mae': 27.46008719005732,
  'r2': 0.36976082219932505},
 {'name': 'Nitrofurantoin',
  'mse': 4277.002862218243,
  'mae': 54.45024852990884,
  'r2': -0.14277101265212577},
 {'name': 'Meropenem',
  'mse': 119.82624968033356,
  'mae': 4.687416408132143,
  'r2': 0.5839437559308752},
 {'name': 'Levofloxacin',
  'mse': 3.9427071107942444,
  'mae': 1.157667570098017,
  'r2': 0.404099853106374},
 {'name': 'Imipenem',
  'mse': 72.77335843999312,
  'mae': 4.359093897333441,
  'r2': 0.5525446675346315},
 {'name': 'Gentamicin',
  'mse': 139.35617400725576,
  'mae': 6.025936350578807,


In [8]:
## Trainning model

In [7]:
params = {'colsample_bytree': 0.7, 'max_depth': 15, 'min_split_gain': 0.4, 'n_estimators': 400, 'num_leaves': 50, 'reg_alpha': 1.3, 'reg_lambda': 1.1, 'subsample': 0.9, 'subsample_freq': 20}

In [19]:
multi_regressor = MultiOutputRegressor(lgb.LGBMRegressor(**params))

In [35]:
multi_regressor.fit(x_train, y_train)

MultiOutputRegressor(estimator=LGBMRegressor(colsample_bytree=0.7, max_depth=15,
                                             min_split_gain=0.4,
                                             n_estimators=400, num_leaves=50,
                                             reg_alpha=1.3, reg_lambda=1.1,
                                             subsample=0.9, subsample_freq=20))

In [36]:
multi_regressor.score(x_test, y_test)

-0.1748435943601686

In [44]:
y_pred=multi_regressor.predict(x_test)

In [45]:
mse, mae, r2 = get_metrics(y_test, y_pred)

MSE:  465.1994537234406
MAE:  9.968363863916919
R2:  -0.1748435943601686


In [14]:
## Utils

In [13]:
def get_y_index(antibiotic_list: list, data):
    y_list = []
    index_names = []


    for index, value in enumerate(antibiotic_list):
        index_name = int((len(df.columns) - 1) - index)

        y = df[index_name]

        if not len(y_list):
            y_list = np.zeros((len(y),len(antibiotic_list)))

        y_list[:,index] = y

        index_names.append(index_name)
        
    return y_list, index_names


In [12]:
def get_metrics(expected, predicted):
    mse = mean_squared_error(expected, predicted, squared=True)
    mae = mean_absolute_error(expected, predicted)
    r2 = r2_score(expected, predicted)

    return mse, mae, r2

In [11]:
def algorithm_pipeline(X_train_data, X_test_data, y_train_data, y_test_data, 
                       model, param_grid, cv=10, scoring_fit='neg_mean_squared_error',
                       do_probabilities = False):
    gs = GridSearchCV(
        estimator=model,
        param_grid=param_grid, 
        cv=cv, 
        n_jobs=-1, 
        scoring=scoring_fit,
        verbose=2
    )
    
    fitted_model = gs.fit(X_train_data, y_train_data)
    
    if do_probabilities:
      pred = fitted_model.predict_proba(X_test_data)
    else:
      pred = fitted_model.predict(X_test_data)
    
    return fitted_model, pred