# Cox model with Spark — end-to-end tutorial

Эта тетрадка показывает полный цикл: обучение модели, немедленное сохранение артефактов в CSV, загрузку новой инстанции из этого файла и инференс по тренировочному датафрейму с фокусом на хвостах выживаемости для живых объектов.

Докстринги и комментарии в коде написаны на русском и поясняют, зачем нужны ключевые шаги: продление baseline, отбраковка типовых хвостов и broadcast артефактов для инференса. Если вы переходите к чтению исходников, ориентируйтесь на эти примечания.

In [1]:
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

from spark_lifelines_cox.model import SparkCoxPHByType

## 1. Стартуем Spark

In [2]:
spark = SparkSession.builder.master('local[*]').appName('cox_tutorial').getOrCreate()
spark



Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


25/12/17 13:51:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## 2. Синтетические данные

In [3]:
rng = np.random.default_rng(7)
pdf = pd.DataFrame({
    'type': np.where(rng.random(120) > 0.5, 'A', 'B'),
    'duration': rng.exponential(scale=6, size=120),
    'event': rng.binomial(1, 0.75, size=120),
    'x': rng.normal(size=120),
})
sdf = spark.createDataFrame(pdf)
print(f"Spark df shape: {sdf.count()} rows, schema: {sdf.dtypes}")
sdf.show(5)


  for column, series in pdf.iteritems():


[Stage 0:>                                                                                                          (0 + 3) / 3]



                                                                                                                                

Spark df shape: 120 rows, schema: [('type', 'string'), ('duration', 'double'), ('event', 'bigint'), ('x', 'double')]


+----+-------------------+-----+--------------------+
|type|           duration|event|                   x|
+----+-------------------+-----+--------------------+
|   A|0.09784204010016541|    1|  0.2967393780454473|
|   A|  4.843488321644426|    1| -0.2987735866784475|
|   A|  7.196560508319127|    1|-0.04017394160294...|
|   B|  6.892499338292188|    1| 0.20659237721603418|
|   B|  26.75843017452848|    1|-0.08396970353317602|
+----+-------------------+-----+--------------------+
only showing top 5 rows



## 3. Обучение, продление baseline и немедленное сохранение

In [4]:
model = SparkCoxPHByType(
    type_col='type',
    duration_col='duration',
    event_col='event',
    feature_cols=['x'],
    seed=42,
)

model.fit(sdf)
model.extend_baselines(max_time=25)

csv_path = '/tmp/cox_tutorial.csv'
model.save(csv_path)
print(f"Artifacts saved to: {csv_path}")

spark.read.csv(csv_path, header=True).show(4, truncate=False)
print("Trained types:", sorted(model.artifacts.keys()))
first_type = sorted(model.artifacts.keys())[0]
print(
    "Tail survival sample (last 3 points) for",
    first_type,
    model.artifacts[first_type].baseline_survival[-3:],
)


[SparkCoxPHByType] Starting fit pipeline
[SparkCoxPHByType] Casting columns to numeric types


[SparkCoxPHByType] Applying cap-sampling by type with max_rows_per_type= 500000




[Stage 5:>                                                                                                          (0 + 3) / 3]

                                                                                                                                

[Stage 9:>                                                                                                          (0 + 1) / 1]

[SparkCoxPHByType][A] Start training on 63 rows
[SparkCoxPHByType][A] Trimming trailing hazards: removed 13 points
[SparkCoxPHByType][A] Training finished


[SparkCoxPHByType][B] Start training on 57 rows
[SparkCoxPHByType][B] Trimming trailing hazards: removed 5 points
[SparkCoxPHByType][B] Training finished


                                                                                                                                

[SparkCoxPHByType] Training completed. Fitted types: 2, skipped: 0
[SparkCoxPHByType] Extending baselines to max_time=25
[SparkCoxPHByType] Saving artifacts to /tmp/cox_tutorial.csv
Artifacts saved to: /tmp/cox_tutorial.csv


+----------+-----------------------+-------------------------------------+
|type      |payload                |status                               |
+----------+-----------------------+-------------------------------------+
|__config__|"{""type_col"":""type""|""duration_col"":""duration""        |
|A         |"{""type_value"":""A"" |""beta"":{""x"":-0.17203093802722835}|
|B         |"{""type_value"":""B"" |""beta"":{""x"":-0.2812265147202562} |
+----------+-----------------------+-------------------------------------+

Trained types: ['A', 'B']
Tail survival sample (last 3 points) for A [0.12857223387976638, 0.12857223387976638, 0.09340399426697549]


## 4. Создание новой модели из CSV

In [5]:
loaded = SparkCoxPHByType.load(csv_path)
print("Loaded feature_cols:", loaded.feature_cols)
print("Loaded types:", sorted(loaded.artifacts.keys()))

config_only = SparkCoxPHByType.from_config(csv_path)
print("Config-only model seed:", config_only.seed)
print("Config-only feature cols:", config_only.feature_cols)


[SparkCoxPHByType] Loading artifacts from /tmp/cox_tutorial.csv
Loaded feature_cols: ['x']
Loaded types: ['A', 'B']
[SparkCoxPHByType] Loading configuration only from /tmp/cox_tutorial.csv
Config-only model seed: 42
Config-only feature cols: ['x']


## 5. Прогноз по тренировочному датасету с фильтрацией живых объектов

In [6]:
full_pred = loaded.predict_survival_at_t(sdf, t=10, output_col='s10_loaded')
full_pred.select('type', 'duration', 'event', 's10_loaded').show(5)

alive_sdf = sdf.filter(col('event') == 0)
print(f"Alive rows for tail prediction: {alive_sdf.count()}")

tail_pred = loaded.predict_survival_at_t(alive_sdf, t=25, output_col='s25_alive')
# хвостовая вероятность выживания после продления baseline
(tail_pred
    .select('type', 'duration', 'event', 's25_alive')
    .orderBy(col('s25_alive').desc())
    .show(5)
)


[SparkCoxPHByType] Building survival UDF for prediction
[SparkCoxPHByType] Starting prediction DataFrame transformation


+----+-------------------+-----+-------------------+
|type|           duration|event|         s10_loaded|
+----+-------------------+-----+-------------------+
|   A|0.09784204010016541|    1| 0.2979355068468683|
|   A|  4.843488321644426|    1| 0.2614517304680958|
|   A|  7.196560508319127|    1|0.27716733019341344|
|   B|  6.892499338292188|    1|0.27203315658738986|
|   B|  26.75843017452848|    1|0.24349084121407844|
+----+-------------------+-----+-------------------+
only showing top 5 rows



Alive rows for tail prediction: 32
[SparkCoxPHByType] Building survival UDF for prediction
[SparkCoxPHByType] Starting prediction DataFrame transformation


[Stage 16:>                                                                                                         (0 + 3) / 3]



+----+------------------+-----+-------------------+
|type|          duration|event|          s25_alive|
+----+------------------+-----+-------------------+
|   B|1.1441997185120298|    0|0.24921195901995766|
|   A| 8.159188001130543|    0|0.22773824666021145|
|   B|2.0535088772531407|    0|0.19995123263260595|
|   B| 21.66735633435271|    0|0.17671227532023756|
|   A|3.8066704037950903|    0|0.14561671854066843|
+----+------------------+-----+-------------------+
only showing top 5 rows





## 6. Останавливаем Spark

In [7]:
spark.stop()