# BaselinePipeline tutorial

Минимальный пример полного цикла: обучение, продление, сохранение и инференс бейзлайнов.

In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
import numpy as np
from spark_lifelines_cox.model import BaselinePipeline, BaselinePipelineConfig


In [2]:
spark = SparkSession.builder.master('local[*]').appName('baseline-notebook').getOrCreate()
rng = np.random.default_rng(123)
rows = []
for model_key in ['demo_a', 'demo_b']:
    for _ in range(25):
        duration = int(rng.integers(1, 9))
        event = int(rng.binomial(1, 0.65))
        x = Vectors.dense([float(rng.normal()), float(rng.normal())])
        rows.append((model_key, duration, event, x))
sdf = spark.createDataFrame(rows, ['model_key', 'duration', 'event', 'x'])
sdf.show(5, truncate=False)




Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


25/12/17 14:42:38 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[Stage 0:>                                                                                                          (0 + 0) / 1]

[Stage 0:>                                                                                                          (0 + 1) / 1]



+---------+--------+-----+------------------------------------------+
|model_key|duration|event|x                                         |
+---------+--------+-----+------------------------------------------+
|demo_a   |1       |1    |[1.2879252612892487,0.1939744191326132]   |
|demo_a   |6       |1    |[0.5771037912572513,-0.6364636463709805]  |
|demo_a   |4       |0    |[-0.32238911615896015,0.09716731867045719]|
|demo_a   |3       |1    |[1.1921661041016585,-0.6710896751741096]  |
|demo_a   |4       |1    |[1.5320330796287964,-0.6599694137918207]  |
+---------+--------+-----+------------------------------------------+
only showing top 5 rows



In [3]:
config = BaselinePipelineConfig(max_baseline_length=36, tail_cycle=12, sample_fraction=1.0)
pipeline = BaselinePipeline(config)
pipeline.fit(sdf)
print('Trained models:', list(pipeline.models.keys()))
print('Baseline length:', len(next(iter(pipeline.models.values())).baseline_survival))


[Stage 1:>                                                                                                          (0 + 3) / 3]

                                                                                                                                



                                                                                                                                



                                                                                                                                

Trained models: ['demo_a', 'demo_b']
Baseline length: 37


In [4]:
save_path = '/tmp/baseline_csv_notebook'
pipeline.save(save_path)
restored = BaselinePipeline.load(save_path)
print('Restored max_baseline_length:', restored.config.max_baseline_length)


[Stage 6:>                                                                                                          (0 + 1) / 1]

                                                                                                                                

Restored max_baseline_length: 36


In [5]:
with_baseline = restored.infer_baseline(sdf, output_col='baseline')
with_baseline.select('model_key', 'duration', 'baseline').show(3, truncate=False)




+---------+--------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|model_key|duration|baseline                                                                                                                                                                          

In [6]:
adjusted = restored.adjust_for_lived(with_baseline, duration_col='duration', baseline_col='baseline', output_col='tail')
adjusted.select('model_key', 'duration', 'tail').show(3, truncate=False)


[Stage 10:>                                                                                                         (0 + 1) / 1]

+---------+--------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|model_key|duration|tail                                                                                                                                                                                                  

                                                                                                                                

In [7]:
spark.stop()