In [9]:
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

X, y, centers = make_blobs(n_samples=50, n_features=3, centers=5, cluster_std=0.5, shuffle=True, random_state=42,  return_centers=True)
num_features = X.shape[1]
num_clusters = centers.shape[0]

km = KMeans(n_clusters=num_clusters, init='k-means++', n_init=10, max_iter=500, tol=1e-04, random_state=42)
km.fit(X)

y_km = km.predict(X)

out_onnx = "kmeans.onnx"

initial_type = [('float_input', FloatTensorType([None, num_features]))]
onnx_model = convert_sklearn(km, initial_types=initial_type)

meta =  onnx_model.metadata_props.add()
meta.key = "sklearn_model"
meta.value = "KMeans"

with open(out_onnx, "wb") as f:
    f.write( onnx_model.SerializeToString())

In [23]:
from google.protobuf.json_format import MessageToJson

with open("kmeans.json", "w") as f:
    f.write(MessageToJson(onnx_model))

In [24]:
km.cluster_centers_

array([[ 3.97554154, -9.76482951,  9.55782456],
       [-2.46683784,  9.02993178,  4.4263314 ],
       [ 1.79195845, -6.75723319, -6.78561596],
       [-8.471819  ,  7.12113806,  1.92468434],
       [ 6.84669769, -5.70588812, -6.15262835]])

'{\n  "irVersion": "10",\n  "producerName": "skl2onnx",\n  "producerVersion": "1.18.0",\n  "domain": "ai.onnx",\n  "modelVersion": "0",\n  "docString": "",\n  "graph": {\n    "node": [\n      {\n        "input": [\n          "float_input",\n          "Re_ReduceSumSquarecst"\n        ],\n        "output": [\n          "Re_reduced0"\n        ],\n        "name": "Re_ReduceSumSquare",\n        "opType": "ReduceSumSquare",\n        "attribute": [\n          {\n            "name": "keepdims",\n            "i": "1",\n            "type": "INT"\n          }\n        ],\n        "domain": ""\n      },\n      {\n        "input": [\n          "Re_reduced0",\n          "Mu_Mulcst"\n        ],\n        "output": [\n          "Mu_C0"\n        ],\n        "name": "Mu_Mul",\n        "opType": "Mul",\n        "domain": ""\n      },\n      {\n        "input": [\n          "float_input",\n          "Ge_Gemmcst",\n          "Mu_C0"\n        ],\n        "output": [\n          "Ge_Y0"\n        ],\n        "n