# Cox model with Spark — end-to-end tutorial

Эта тетрадка показывает минимальный рабочий цикл: обучение, продление baseline, инференс, сохранение в единый CSV и создание модели только из конфигурации.

In [1]:
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession

from spark_lifelines_cox.model import SparkCoxPHByType

## 1. Стартуем Spark

In [2]:
spark = SparkSession.builder.master('local[*]').appName('cox_tutorial').getOrCreate()
spark



Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


25/12/17 12:13:19 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## 2. Синтетические данные

In [3]:
rng = np.random.default_rng(7)
pdf = pd.DataFrame({
    'type': np.where(rng.random(120) > 0.5, 'A', 'B'),
    'duration': rng.exponential(scale=6, size=120),
    'event': rng.binomial(1, 0.75, size=120),
    'x': rng.normal(size=120),
})
sdf = spark.createDataFrame(pdf)
sdf.show(5)

  for column, series in pdf.iteritems():


[Stage 0:>                                                                                                          (0 + 1) / 1]

+----+-------------------+-----+--------------------+
|type|           duration|event|                   x|
+----+-------------------+-----+--------------------+
|   A|0.09784204010016541|    1|  0.2967393780454473|
|   A|  4.843488321644426|    1| -0.2987735866784475|
|   A|  7.196560508319127|    1|-0.04017394160294...|
|   B|  6.892499338292188|    1| 0.20659237721603418|
|   B|  26.75843017452848|    1|-0.08396970353317602|
+----+-------------------+-----+--------------------+
only showing top 5 rows



                                                                                                                                

## 3. Обучение и продление baseline

In [4]:
model = SparkCoxPHByType(
    type_col='type',
    duration_col='duration',
    event_col='event',
    feature_cols=['x'],
    seed=42,
)
model.fit(sdf)
model.extend_baselines(max_time=25)
model.artifacts

[SparkCoxPHByType] Starting fit pipeline
[SparkCoxPHByType] Casting columns to numeric types


[SparkCoxPHByType] Applying cap-sampling by type with max_rows_per_type= 500000






                                                                                                                                

[Stage 6:>                                                                                                          (0 + 1) / 1]

[SparkCoxPHByType][A] Start training on 63 rows
[SparkCoxPHByType][A] Trimming trailing hazards: removed 13 points
[SparkCoxPHByType][A] Training finished
[SparkCoxPHByType][B] Start training on 57 rows


[SparkCoxPHByType] Training completed. Fitted types: 2, skipped: 0
[SparkCoxPHByType] Extending baselines to max_time=25


[SparkCoxPHByType][B] Trimming trailing hazards: removed 5 points
[SparkCoxPHByType][B] Training finished
                                                                                                                                

{'A': TypeArtifacts(type_value='A', beta={'x': -0.17203093802722835}, mean_={'x': -0.20381435735001205}, baseline_survival=[1.0, 0.7487158764901778, 0.7007297058019774, 0.6139382135452179, 0.595255199804578, 0.49949025889663584, 0.46061099374514286, 0.40020118113866915, 0.33613803171349604, 0.3133087681088635, 0.26719704610433026, 0.2174604546538099, 0.2174604546538099, 0.1579787054082975, 0.1579787054082975, 0.1579787054082975, 0.1579787054082975, 0.1579787054082975, 0.1579787054082975, 0.1579787054082975, 0.1579787054082975, 0.1579787054082975, 0.1579787054082975, 0.12857223387976638, 0.12857223387976638, 0.09340399426697549], baseline_ratio=[1.0, 0.7487158764901778, 0.9359087042295009, 0.87614126882572, 0.9695685765628532, 0.8391195222832464, 0.9221621153586128, 0.8688485220135701, 0.8399226378020725, 0.9320836637013129, 0.8528233911777692, 0.813858004137141, 1.0, 0.7264709607077506, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.813858004137141, 1.0, 0.7264709607077506], sample_siz

## 4. Прогноз выживаемости

In [5]:
pred = model.predict_survival_at_t(sdf, t=10, output_col='s10')
pred.select('type', 's10').show(5)

[SparkCoxPHByType] Building survival UDF for prediction
[SparkCoxPHByType] Starting prediction DataFrame transformation


+----+-------------------+
|type|                s10|
+----+-------------------+
|   A| 0.2979355068468683|
|   A| 0.2614517304680958|
|   A|0.27716733019341344|
|   B|0.27203315658738986|
|   B|0.24349084121407844|
+----+-------------------+
only showing top 5 rows



## 5. Сохранение, загрузка и создание модели из конфига

In [6]:
csv_path = '/tmp/cox_tutorial.csv'
model.save(csv_path)

loaded = SparkCoxPHByType.load(csv_path)
re_pred = loaded.predict_survival_at_t(sdf, t=10, output_col='s10_loaded')
re_pred.select('type', 's10_loaded').show(3)

config_only = SparkCoxPHByType.from_config(csv_path)
config_only.feature_cols

[SparkCoxPHByType] Saving artifacts to /tmp/cox_tutorial.csv
[SparkCoxPHByType] Loading artifacts from /tmp/cox_tutorial.csv
[SparkCoxPHByType] Building survival UDF for prediction
[SparkCoxPHByType] Starting prediction DataFrame transformation


[Stage 8:>                                                                                                          (0 + 1) / 1]

+----+-------------------+
|type|         s10_loaded|
+----+-------------------+
|   A| 0.2979355068468683|
|   A| 0.2614517304680958|
|   A|0.27716733019341344|
+----+-------------------+
only showing top 3 rows

[SparkCoxPHByType] Loading configuration only from /tmp/cox_tutorial.csv


                                                                                                                                

['x']

## 6. Останавливаем Spark

In [7]:
spark.stop()