Skip to content

Add predict method for stats::kmeans() #154

@kadyb

Description

@kadyb

Is there any reason why there is no predict method for stats::kmeans()? The prediction is currently available in, e.g. clue::cl_predict() or ClusterR::predict_KMeans() packages. I think this will be a significant convenience for users.

predict.kmeans = function(x, newdata) {
  if (ncol(newdata) != ncol(x$centers)) {
    stop("newdata must have the same number of columns as the original data")
  }

  vec = integer(nrow(newdata))
  if (inherits(newdata, "data.frame")) newdata = as.matrix(newdata)
  x = x$centers
  for (i in seq_len(nrow(newdata))) {
    vec[i] = which.min(colSums((t(x) - newdata[i, ])^2))
  }
  return(vec)
}

data = iris[, 1:4]
mdl = kmeans(data, centers = 3)
pr = predict(mdl, data)
identical(mdl$cluster, pr)
#> [1] TRUE

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions