In [0]:
import fifeforspark
from fifeforspark.utils import create_example_data2
from fifeforspark.processors import PanelDataProcessor
from fifeforspark.lgb_modelers import LGBSurvivalModeler


In [0]:
# Import the data
data = fifeforspark.utils.create_example_data2()
test_intervals = 4

# Pass data through Panel Data Processor. We set test intervals to 4 to create the ability to test the model
# We set shuffle parts to 20 to reduce the amount of overhead since the dataset isn't that large.

processor = PanelDataProcessor(data=data, config = {'test_intervals': test_intervals}, shuffle_parts = 20)
processor.build_processed_data()

In [0]:
# Now, we build the model. You can pass parameters into the model that will be passed to lightgbm as well.
modeler = LGBSurvivalModeler(data=processor.data)
modeler.build_model(n_intervals=test_intervals)


In [0]:
# This part is unnecessary as it's equivalent to the default subset; however, we do it to show the ability to pass in a subset. 
min_val = modeler.data.select(modeler.data['_period']).agg({'_period': 'min'}).first()[0]

# We now predict the survival probabilities for our model.
evaluation_subset = modeler.data.select((modeler.data['_period'] == min_val).alias("subset")).select('subset')
predictions = modeler.predict(subset = evaluation_subset)
predictions.show()

In [0]:
# We want to see how we did, so we print some evaluation metrics
evaluation = modeler.evaluate(subset = evaluation_subset)
evaluation

Unnamed: 0_level_0,AUROC,Predicted Share,Actual Share,True Positives,False Negatives,False Positives,True Negatives,Other Metrics:
Lead Length,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1,0.979218,0.90415,0.90762,3362,116,84,270,
2,0.977895,0.81285,0.818372,2991,145,110,586,
3,0.974619,0.719887,0.727557,2616,172,148,896,
4,0.972966,0.634334,0.631002,2246,172,166,1248,


In [0]:
# And finally, we forecast out from the last period of data.
forecasts = modeler.forecast()
forecasts