In [16]:
!wget https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv
import pandas as pd
date_cols = [
    'compas_screening_date', 'c_offense_date',
    'c_arrest_date', 'r_offense_date', 
    'vr_offense_date', 'screening_date',
    'v_screening_date', 'c_jail_in',
    'c_jail_out', 'dob', 'in_custody', 
    'out_custody'
]
data = pd.read_csv(
    'compas-scores-two-years.csv',
    parse_dates=date_cols
)


--2020-08-06 10:20:39--  https://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2546489 (2.4M) [text/plain]
Saving to: ‘compas-scores-two-years.csv’


2020-08-06 10:20:40 (14.4 MB/s) - ‘compas-scores-two-years.csv’ saved [2546489/2546489]



In [17]:
import datetime
indexes = data.compas_screening_date <= pd.Timestamp(datetime.date(2014, 4, 1))
assert indexes.sum() == 6216
data = data[indexes]


In [18]:
!pip install category-encoders



In [13]:
def confusion_metrics(actual, scores, threshold):
    y_predicted = scores.apply(
        lambda x: x >= threshold
    ).values
    y_true = actual.values
    TP = (
        (y_true==y_predicted) & 
        (y_predicted==1)
    ).astype(int)
    FP = (
        (y_true!=y_predicted) &
        (y_predicted==1)
    ).astype(int)
    TN = (
        (y_true==y_predicted) &
        (y_predicted==0)
    ).astype(int)
    FN = (
        (y_true!=y_predicted) &
        (y_predicted==0)
    ).astype(int)
    return TP, FP, TN, FN


In [14]:
def calculate_impacts(data, sensitive_column='race', recid_col='is_recid', score_col='decile_score.1', threshold=5.0):
    if sensitive_column == 'race':
      norm_group = 'Caucasian'
    elif sensitive_column == 'sex':
      norm_group = 'Male'
    else:
      raise ValueError('sensitive column not implemented')
    TP, FP, TN, FN = confusion_metrics(
        actual=data[recid_col],
        scores=data[score_col],
        threshold=threshold
    )
    impact = pd.DataFrame(
        data=np.column_stack([
              FP, TN, FN, TN,
              data[sensitive_column].values, 
              data[recid_col].values,
              data[score_col].values / 10.0
             ]),
        columns=['FP', 'TP', 'FN', 'TN', 'sensitive', 'reoffend', 'score']
    ).groupby(by='sensitive').agg({
        'reoffend': 'sum', 'score': 'sum',
        'sensitive': 'count', 
        'FP': 'sum', 'TP': 'sum', 'FN': 'sum', 'TN': 'sum'
    }).rename(
        columns={'sensitive': 'N'}
    )
    impact['FPR'] = impact['FP'] / (impact['FP'] + impact['TN'])
    impact['FNR'] = impact['FN'] / (impact['FN'] + impact['TP'])
    impact['reoffend'] = impact['reoffend'] / impact['N']
    impact['score'] = impact['score'] / impact['N']
    impact['DFP'] = impact['FPR'] / impact.loc[norm_group, 'FPR']
    impact['DFN'] = impact['FNR'] / impact.loc[norm_group, 'FNR']
    return impact.drop(columns=['FP', 'TP', 'FN', 'TN'])


In [19]:
from sklearn.feature_extraction.text import CountVectorizer
from category_encoders.one_hot import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

charge_desc = data['c_charge_desc'].apply(lambda x: x if isinstance(x, str) else '')
count_vectorizer = CountVectorizer(
    max_df=0.85, stop_words='english',
    max_features=100, decode_error='ignore'
)
charge_desc_features = count_vectorizer.fit_transform(charge_desc)

one_hot_encoder = OneHotEncoder()
charge_degree_features = one_hot_encoder.fit_transform(
    data['c_charge_degree']
)

data['race_black'] = data['race'].apply(lambda x: x == 'African-American').astype(int)
stratification = data['race_black'] + (data['is_recid']).astype(int) * 2


In [20]:
y = data['is_recid']
X = pd.DataFrame(
    data=np.column_stack(
        [data[['juv_fel_count', 'juv_misd_count',
 'juv_other_count', 'priors_count', 'days_b_screening_arrest']], 
          charge_degree_features, 
          charge_desc_features.todense()
        ]
    ),
    columns=['juv_fel_count', 'juv_misd_count', 'juv_other_count', 'priors_count', 'days_b_screening_arrest'] \
    + one_hot_encoder.get_feature_names() \
    + count_vectorizer.get_feature_names(),
    index=data.index
)
X['jailed_days'] = (data['c_jail_out'] - data['c_jail_in']).apply(lambda x: abs(x.days))
X['waiting_jail_days'] = (data['c_jail_in'] - data['c_offense_date']).apply(lambda x: abs(x.days))
X['waiting_arrest_days'] = (data['c_arrest_date'] - data['c_offense_date']).apply(lambda x: abs(x.days))
X.fillna(0, inplace=True)

columns = list(X.columns)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33,
    random_state=42,
    stratify=stratification
)  # we stratify by black and the target


In [25]:
import jax.numpy as jnp
from jax import grad, jit, vmap, ops, lax
import numpy as onp
import numpy.random as npr
import random
from tqdm import trange
from sklearn.base import ClassifierMixin
from sklearn.preprocessing import StandardScaler


class JAXLearner(ClassifierMixin):
  def __init__(self, layer_sizes=[10, 5, 1], epochs=20, batch_size=500, lr=1e-2):
    self.params = self.construct_network(layer_sizes)
    self.perex_grads = jit(grad(self.error))
    self.epochs = epochs
    self.batch_size = batch_size
    self.lr = lr

  @staticmethod
  def construct_network(layer_sizes=[10, 5, 1]):
    '''Please make sure your final layer corresponds to targets in dimensions.
    '''
    def init_layer(n_in, n_out):
      W = npr.randn(n_in, n_out)
      b = npr.randn(n_out,)
      return W, b
      
    return list(map(init_layer, layer_sizes[:-1], layer_sizes[1:]))

  @staticmethod
  def sigmoid(X):  # or tanh
    return 1/(1+jnp.exp(-X))

  def _predict(self, inputs):
    for W, b in self.params:
      outputs = jnp.dot(inputs, W) + b
      inputs = self.sigmoid(outputs)
    return outputs

  def predict(self, inputs):
    inputs = self.standard_scaler.transform(inputs)
    return onp.asarray(self._predict(inputs))

  @staticmethod
  def mse(preds, targets, other=None):
    return jnp.sqrt(jnp.sum((preds - targets)**2))

  @staticmethod
  def penalized_mse(preds, targets, sensitive):
    err = jnp.sum((preds - targets)**2)
    err_s = jnp.sum((preds * sensitive - targets * sensitive)**2)
    penalty = jnp.clip(err_s / err, 1.0, 2.0)
    return err * penalty

  def error(self, params, inputs, targets, sensitive):
      preds = self._predict(inputs)
      return self.penalized_mse(preds, targets, sensitive)

  def fit(self, X, y, sensitive):
    self.standard_scaler = StandardScaler()
    X = self.standard_scaler.fit_transform(X)
    N = X.shape[0]
    indexes = list(range(N))
    steps_per_epoch = N // self.batch_size

    for epoch in trange(self.epochs, desc='training'):
        random.shuffle(indexes)
        index_offset = 0
        for step in trange(steps_per_epoch, desc='iteration'):
            grads = self.perex_grads(
                self.params, 
                X[indexes[index_offset:index_offset+self.batch_size], :], 
                y[indexes[index_offset:index_offset+self.batch_size]],
                sensitive[indexes[index_offset:index_offset+self.batch_size]]
            )
            # print(grads)
            self.params = [(W - self.lr * dW, b - self.lr * db)
                      for (W, b), (dW, db) in zip(self.params, grads)]
            index_offset += self.batch_size


In [26]:
sensitive_train = X_train.join(
    data, rsuffix='_right'
)['race_black']
jax_learner = JAXLearner([X.values.shape[1], 100, 1])
jax_learner.fit(
    X_train.values,
    y_train.values,
    sensitive_train.values
)


training:   0%|          | 0/20 [00:00<?, ?it/s]
iteration: 100%|██████████| 8/8 [00:00<00:00, 460.32it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 576.52it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 627.77it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 650.54it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 737.82it/s]
training:  25%|██▌       | 5/20 [00:00<00:00, 48.03it/s]
iteration: 100%|██████████| 8/8 [00:00<00:00, 638.95it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 782.45it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 601.05it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 807.22it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 628.20it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 640.28it/s]
training:  55%|█████▌    | 11/20 [00:00<00:00, 49.09it/s]
iteration: 100%|██████████| 8/8 [00:00<00:00, 691.06it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 643.61it/s]

iteration: 100%|██████████| 8/8 [00:00<00:00, 808.37it/s]

iteration: 1

In [28]:
X_predicted = pd.DataFrame(
    data=jax_learner.predict(
        X_test.values
    ) * 10,
    columns=['score'], 
    index=X_test.index
).join(
    data[['sex', 'race', 'is_recid']], 
    rsuffix='_right'
)

In [31]:
X_predicted

Unnamed: 0,score,sex,race,is_recid
6553,45.214291,Male,Caucasian,0
1441,-42.396706,Male,African-American,1
2306,-60.853111,Male,Caucasian,0
504,-12.313410,Male,Caucasian,0
5212,27.922726,Male,African-American,0
...,...,...,...,...
6118,46.818214,Male,Caucasian,1
607,9.088397,Female,Caucasian,0
2596,-12.438754,Female,Caucasian,1
3204,-87.487778,Male,African-American,1


In [32]:
calculate_impacts(X_predicted, score_col='score')

Unnamed: 0_level_0,reoffend,score,N,FPR,FNR,DFP,DFN
sensitive,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
African-American,0.471042,-0.948788,1036,0.385036,0.487842,0.868148,1.401687
Asian,0.222222,1.802761,9,0.571429,0.25,1.28841,0.71831
Caucasian,0.335188,-0.308411,719,0.443515,0.348039,1.0,1.0
Hispanic,0.293478,0.121934,184,0.5,0.329897,1.127358,0.947873
Native American,0.5,0.963587,2,0.0,0.0,0.0,0.0
Other,0.294118,-0.626596,102,0.430556,0.267857,0.970781,0.769618
