In [3]:
import numpy as np
from sklearn.svm import SVC
import plotly.graph_objects as go
from sklearn.datasets import load_digits
from sklearn.model_selection import validation_curve

In [4]:
X, y = load_digits(return_X_y=True)
subset_mask = np.isin(y, [1, 2])  # binary classification: 1 vs 2
X, y = X[subset_mask], y[subset_mask]

In [6]:
from sklearn.ensemble import RandomForestClassifier

In [8]:
param_range = ['linear', 'poly', 'rbf', 'sigmoid']

In [48]:
param_range = np.logspace(-2, 0, 1001)

In [49]:
train_scores, test_scores = validation_curve(
    SVC(),
    X,
    y,
    param_name="C",
    param_range=param_range,
    scoring="accuracy",
    n_jobs=-1,
)

In [50]:
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)

In [51]:
def plot_validation_curve(x: np.array, ys: list[np.array], yerros: list[np.array], names: list[str], colors: list[str], log_x: bool=True, title: str=""):
    fig = go.Figure()

    for y, yerror, name, color in zip(ys, yerros, names, colors):
        y_upper = y + yerror
        y_lower = y - yerror
        
        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                name=name,
                line_color=color
            )
        )

        fig.add_trace(
            go.Scatter(
                x=x.tolist()+x[::-1].tolist(), # x, then x reversed
                y=y_upper.tolist()+y_lower[::-1].tolist(), # upper, then lower reversed
                fill='toself',
                fillcolor=color,
                line=dict(color=color),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.2
            )
        )

    if log_x:
        fig.update_xaxes(type="log")

    fig.update_layout(title=title, xaxis_title=r"$\alpha$", yaxis_title="Score")

    return fig


In [52]:
plot_validation_curve(param_range, [train_scores_mean, test_scores_mean], [train_scores_std, test_scores_std], ["Training score", "Cross-validation score"], ["orange", "navy"], True)