# SVM

In [None]:
pip install scikit-survival

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pandas
import seaborn as sns
from sklearn import set_config
from sklearn.model_selection import ShuffleSplit, GridSearchCV

from sksurv.datasets import load_veterans_lung_cancer
from sksurv.column import encode_categorical
from sksurv.metrics import concordance_index_censored
from sksurv.svm import FastSurvivalSVM

set_config(display="text")  # displays text representation of estimators
sns.set_style("whitegrid")

In [None]:
data_x, y = load_veterans_lung_cancer()
x = encode_categorical(data_x)

In [None]:
estimator = FastSurvivalSVM(max_iter=1000, tol=1e-5, random_state=0)

In [None]:
def score_survival_model(model, X, y):
    prediction = model.predict(X)
    result = concordance_index_censored(y["Status"], y["Survival_in_days"], prediction)
    return result[0]

In [None]:
param_grid = {"alpha": 2.0 ** np.arange(-12, 13, 2)}
cv = ShuffleSplit(n_splits=100, test_size=0.5, random_state=0)
gcv = GridSearchCV(estimator, param_grid, scoring=score_survival_model, n_jobs=1, refit=False, cv=cv)

In [None]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
gcv = gcv.fit(x, y)

In [None]:
round(gcv.best_score_, 3), gcv.best_params_

In [None]:
def plot_performance(gcv):
    n_splits = gcv.cv.n_splits
    cv_scores = {"alpha": [], "test_score": [], "split": []}
    order = []
    for i, params in enumerate(gcv.cv_results_["params"]):
        name = f'{params["alpha"]:.5f}'
        order.append(name)
        for j in range(n_splits):
            vs = gcv.cv_results_[f"split{j}_test_score"][i]
            cv_scores["alpha"].append(name)
            cv_scores["test_score"].append(vs)
            cv_scores["split"].append(j)
    df = pandas.DataFrame.from_dict(cv_scores)
    _, ax = plt.subplots(figsize=(11, 6))
    sns.boxplot(x="alpha", y="test_score", data=df, order=order, ax=ax)
    _, xtext = plt.xticks()
    for t in xtext:
        t.set_rotation("vertical")

In [None]:
plot_performance(gcv)

In [None]:
estimator.set_params(**gcv.best_params_)
estimator.fit(x, y)

In [None]:
pred = estimator.predict(x.iloc[:2])
print(np.round(pred, 3))
print(y[:2])

In [None]:
ref_estimator = FastSurvivalSVM(rank_ratio=0.0, max_iter=1000, tol=1e-5, random_state=0)
ref_estimator.fit(x, y)

cindex = concordance_index_censored(
    y["Status"],
    y["Survival_in_days"],
    -ref_estimator.predict(x),  # flip sign to obtain risk scores
)
print(round(cindex[0], 3))

In [None]:
pred = ref_estimator.predict(x.iloc[:2])
print(np.round(pred, 3))