In [74]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import random, grad, jit
from sklearn.preprocessing import scale
from sklearn import metrics

  return f(*args, **kwds)


In [58]:
train = pd.read_csv('/Users/gabestechschulte/Downloads/customer-churn-prediction-2020/train.csv')
train.head()

Unnamed: 0,state,account_length,area_code,international_plan,voice_mail_plan,number_vmail_messages,total_day_minutes,total_day_calls,total_day_charge,total_eve_minutes,total_eve_calls,total_eve_charge,total_night_minutes,total_night_calls,total_night_charge,total_intl_minutes,total_intl_calls,total_intl_charge,number_customer_service_calls,churn
0,OH,107,area_code_415,no,yes,26,161.6,123,27.47,195.5,103,16.62,254.4,103,11.45,13.7,3,3.7,1,no
1,NJ,137,area_code_415,no,no,0,243.4,114,41.38,121.2,110,10.3,162.6,104,7.32,12.2,5,3.29,0,no
2,OH,84,area_code_408,yes,no,0,299.4,71,50.9,61.9,88,5.26,196.9,89,8.86,6.6,7,1.78,2,no
3,OK,75,area_code_415,yes,no,0,166.7,113,28.34,148.3,122,12.61,186.9,121,8.41,10.1,3,2.73,3,no
4,MA,121,area_code_510,no,yes,24,218.2,88,37.09,348.5,108,29.62,212.6,118,9.57,7.5,7,2.03,3,no


In [59]:
train.churn.value_counts()

no     3652
yes     598
Name: churn, dtype: int64

In [60]:
# Preprocessing and converting to jax array
numFeats = train.select_dtypes(include=['float64', 'int64'])
trainScale = scale(numFeats)
X = jnp.array(trainScale)

train.churn = train.churn.apply(lambda x: 1 if x == 'yes' else 0)
y = jnp.array(train.churn)

X.shape, y.shape

((4250, 15), (4250,))

In [61]:
train.churn.value_counts()

0    3652
1     598
Name: churn, dtype: int64

In [62]:
# Training and testing split
train_size = int(np.round(0.7*float(X.shape[0])))
indices = np.random.permutation(X.shape[0])
X_train, X_test = X[0: train_size], X[train_size: ] 
y_train, y_test = y[0: train_size], y[train_size: ]


In [63]:
# Random weight vector
key = random.PRNGKey(1337)
# Weight for each feature (dimension) of training data
w = random.uniform(key,
                   shape=(X_train.shape[1], ),
                   minval=-0.1,
                   maxval=0.1)

w.shape

(15,)

In [64]:
# Dot product between weights and training data
jnp.dot(X_train, w)

DeviceArray([ 0.19509281, -0.17395341, -0.24718827, ...,  0.10108735,
             -0.13742268, -0.1751465 ], dtype=float32)

### The Model

$y = g(Xw)$ where $g$ is a non-linear function that squishes our values between 0 and 1
 - $g$ = logistic function

$g(x) = \frac{1}{1+e^{-x}}$

In [65]:
def logistic_function(value):
    return 1 / (1 + jnp.exp(-value))

### Maximum Likelihood

Probability of observing $Y = y$ given my $X$ and choosen $w$ parameter

In [66]:
# First model - no learning
y_prob = logistic_function(jnp.dot(X_train, w))

# Likelihood function: P(D | H) where H = weight parameter
p_d_h = jnp.where(y_train == 1, y_prob, 1 - y_prob)
p_d_h ## how likely the outcome we observed would be if our model was correct

DeviceArray([0.4513809 , 0.543379  , 0.56148434, ..., 0.47474968,
             0.53430176, 0.54367507], dtype=float32)

### Negative Log Likelihood

Combine all of the "p_d_h" into a single likelihood of the data we observed given the model we have

$P(y | X, w) = \prod_{i=1}^nP(y_i | X, w)$

In [67]:
jnp.prod(p_d_h) ## 0 as the product of many small probabilities results in an underflow
                ## technically it is the product of ##.##^len(X_train)

# Take the log likelihood
jnp.sum(jnp.log(p_d_h))

DeviceArray(-2064.3867, dtype=float32)

In [68]:
def negative_log_likelihood(X, y, w):
    y_prob = logistic_function(jnp.dot(X, w))
    p_d_h = jnp.where(y == 1, y_prob, 1 - y_prob)
    ll = jnp.sum(jnp.log(p_d_h))
    return -ll


In [69]:
# Negative log-likelihood is now a positive convex function
negative_log_likelihood(X_train, y_train, w) 

DeviceArray(2064.3867, dtype=float32)

### Gradient Descent - Time to Learn

What we want to do now is find values for $w$ that decrease negative log-likelihood and thereby ultimately increase the probability of our labels given the weights

Using JAX to perform automatic differentiation

In [70]:
# Determining derivative of negative_log_likelihood with respect to the argument in the 2 (zero indexed) position 
# which is our $w$ parameter
d_nll_wrt_w = grad(negative_log_likelihood, argnums=2)
d_nll_wrt_w_c = jit(d_nll_wrt_w)

lr = 0.0001

for _ in range(1, 1000):
    w -= lr * d_nll_wrt_w_c(X_train, y_train, w)

print(negative_log_likelihood(X_train, y_train, w))

1963.6995


### Measuring Performance

Per the code cell below, you can see how target class is indeed imbalanced and thus simply using accuracy isn't the most productive predictive performance measure

ROC AUC is a better measure of predictive performance for our classifier

In [73]:
# Imbalanced class - just guessing 0 for everything would result in a good accuracy
1 - sum(y_train) / y_train.shape[0]

0.8568067226890757

In [78]:
# Training predictions
train_pred = logistic_function(np.dot(X_train, w))

# Computing false positive rate, true positive rate, and thresholds based off of training predictions and ground truth
fpr, tpr, thresholds = metrics.roc_curve(y_train, train_pred)
print('Training : ', metrics.auc(fpr, tpr))

# Testing predictions
test_pred = logistic_function(np.dot(X_test, w))
fpr, tpr, thresholds = metrics.roc_curve(y_test, test_pred)
print('Testing  : ', metrics.auc(fpr, tpr))

Training :  0.7671239941282322
Testing  :  0.7822745577600202


## Statistical Inference

**We are much more concerned with the properties of our model itself rather than just its predictions**
 - In statistics (rather than ML) we just don't view our features and weights as inputs to help us predict an output, but an actual model of how the world works


In [81]:
stats_df = pd.DataFrame()
stats_df['feature'] = numFeats.columns
stats_df['coef'] = w
stats_df

Unnamed: 0,feature,coef
0,account_length,-0.024477
1,number_vmail_messages,-0.177489
2,total_day_minutes,0.224164
3,total_day_calls,-0.000509
4,total_day_charge,0.113658
5,total_eve_minutes,0.032819
6,total_eve_calls,0.013183
7,total_eve_charge,0.08462
8,total_night_minutes,0.043243
9,total_night_calls,-0.010402
