In [3]:
from __future__ import print_function
import numpy as np
import pandas as pd

import sys
sys.path.append("../")
from lifelines.estimation import SBGSurvival

import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
# Define parameters of the model!
params = dict(alpha=dict(bias=0.2,
                         categ={'cat_a': 0.09736,
                                'cat_b': -0.368,
                                'cat_c': -1e-2},
                         count=0.065,
                         numer=0.292),
              beta=dict(bias=0.87,
                        categ={'cat_a': -0.12,
                               'cat_b': -0.4425,
                               'cat_c': 0.69},
                        count=-0.148,
                        numer=0.021))


def get_age(alpha, beta, max_age=10):
    """
    A function to simulate the life of a sample given its alpha and beta
    parameters

    :param alpha:
    :param beta:
    :param max_age:
    :return:
    """
    age = 1
    alive = 1

    pchurn = np.random.beta(alpha, beta)

    while age < max_age:
        if np.random.random() <= pchurn:
            alive = 0
            break

        age += 1

    return age, alive

def compute_alpha(row):

    # Get alpha params
    pdict = params['alpha']

    # Start with bias
    alpha = np.exp(pdict['bias'])

    # add categorical contribution
    alpha *= np.exp(pdict['categ'][row['category']])

    # Add count and numerical contributions
    alpha *= np.exp(pdict['count'] * row['counts'] +
                    pdict['numer'] * row['numerical'])
    # add noise
    alpha *= np.exp(2e-2 * np.random.randn())
    return alpha

def compute_beta(row):

    # Get beta params
    pdict = params['beta']

    # Start with bias
    beta = np.exp(pdict['bias'])

    # add categorical contribution
    beta *= np.exp(pdict['categ'][row['category']])

    # Add count and numerical contributions
    beta *= np.exp(pdict['count'] * row['counts'] +
                   pdict['numer'] * row['numerical'])
    # add noise
    beta *= np.exp(1e-2 * np.random.randn())

    return beta

def compute_age(row):
    max_age = np.random.choice([8, 9, 10, 11], 1, p=[0.3, 0.3, 0.2, 0.2])[0]
    age, alive = get_age(row['alpha_true'], row['beta_true'], max_age=max_age)
    return age, alive


def simulate_data(size=10000, max_age=10):

    data = pd.DataFrame()
    data['id'] = np.arange(size)

    # add categories
    data['category'] = np.random.choice(['cat_a', 'cat_b', 'cat_c'],
                                        size,
                                        p=[0.47, 0.36, 0.17])

    # transform cotegory type
    data['category'] = data['category'].astype('category')

    # Add counts feature
    data['counts'] = np.random.poisson(lam=0.25, size=size)

    # add numerical gaussian feature
    data['numerical'] = 0.5 * np.random.randn(size) + 1

    # Add true alpha and beta params
    data['true_alpha'] = data.apply(compute_alpha, axis=1)
    data['true_beta'] = data.apply(compute_beta, axis=1)

    # Simulate age
    sim = data.apply(compute_age, axis=1)
        
    # Update age values    
    data['age'] = [s[0] for s in sim]

    # for simplicity we assume all come from same cohort, so it is easy to set
    # alive value
    data['alive'] = [s[1] for s in sim]

    # split in half
    tr = data.iloc[:size//2]
    te = data.iloc[size//2:].reset_index().drop('index', axis=1)

    return {'train': tr, 'test': te, 'params': params}

In [7]:
simulate_data(size=10000, max_age=10)

{'params': {'alpha': {'bias': 0.2,
   'categ': {'cat_a': 0.09736, 'cat_b': -0.368, 'cat_c': -0.01},
   'count': 0.065,
   'numer': 0.292},
  'beta': {'bias': 0.87,
   'categ': {'cat_a': -0.12, 'cat_b': -0.4425, 'cat_c': 0.69},
   'count': -0.148,
   'numer': 0.021}},
 'test':         id category  counts  numerical  alpha_true  beta_true  age  alive
 0     5000    cat_a       1   1.476608    2.185783   1.903584    1      0
 1     5001    cat_a       1   1.211998    2.011327   1.887865    1      0
 2     5002    cat_a       1   1.717785    2.365148   1.890759    1      0
 3     5003    cat_a       1   0.370316    1.553102   1.858793    2      0
 4     5004    cat_a       0   0.785087    1.694099   2.134720    3      0
 5     5005    cat_b       0   1.419801    1.273844   1.584132    1      0
 6     5006    cat_c       0   1.637797    1.914314   4.888847    5      0
 7     5007    cat_a       0   1.606394    2.171258   2.176195   10      1
 8     5008    cat_b       0   1.368448    1.2408