From a3841f630a006519205c42692c4b8234075c67b7 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Mon, 13 Oct 2025 02:24:42 -0700 Subject: [PATCH] fix fp32 --- plugin/sycl/common/linalg_op.h | 2 +- src/objective/multiclass_obj.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugin/sycl/common/linalg_op.h b/plugin/sycl/common/linalg_op.h index 1439408093be..e31ad335005f 100644 --- a/plugin/sycl/common/linalg_op.h +++ b/plugin/sycl/common/linalg_op.h @@ -103,7 +103,7 @@ bool Validate(DeviceOrd device, TensorView t, Fn&& fn) { namespace linalg { template void ElementWiseKernel(Context const* ctx, TensorView t, Fn&& fn) { - if (ctx->IsSycl()) { + if (t.Device().IsSycl()) { sycl::linalg::ElementWiseKernel(t, fn); } else { ElementWiseKernelHost(t, ctx->Threads(), fn); diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 5e3622ac0202..50dea22e8357 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -110,7 +110,7 @@ class SoftmaxMultiClassObj : public ObjFunction { << "Number of weights should be equal to number of data points."; } info.weights_.SetDevice(device); - auto weights = common::MakeOptionalWeights(this->ctx_->Device(), info.weights_); + auto weights = common::MakeOptionalWeights(device, info.weights_); preds.SetDevice(device); auto predt = linalg::MakeTensorView(this->ctx_, &preds, n_samples, n_classes);