diff --git a/fs/io_uring.c b/fs/io_uring.c index 4a5eb9e856f0d3..4ad0d17dc92ddf 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -10323,26 +10323,46 @@ static int io_unregister_iowq_aff(struct io_ring_ctx *ctx) static int io_register_iowq_max_workers(struct io_ring_ctx *ctx, void __user *arg) { - struct io_uring_task *tctx = current->io_uring; + struct io_uring_task *tctx = NULL; + struct io_sq_data *sqd = NULL; __u32 new_count[2]; int i, ret; - if (!tctx || !tctx->io_wq) - return -EINVAL; if (copy_from_user(new_count, arg, sizeof(new_count))) return -EFAULT; for (i = 0; i < ARRAY_SIZE(new_count); i++) if (new_count[i] > INT_MAX) return -EINVAL; + if (ctx->flags & IORING_SETUP_SQPOLL) { + sqd = ctx->sq_data; + if (sqd) { + mutex_lock(&sqd->lock); + tctx = sqd->thread->io_uring; + } + } else { + tctx = current->io_uring; + } + + ret = -EINVAL; + if (!tctx || !tctx->io_wq) + goto err; + ret = io_wq_max_workers(tctx->io_wq, new_count); if (ret) - return ret; + goto err; + + if (sqd) + mutex_unlock(&sqd->lock); if (copy_to_user(arg, new_count, sizeof(new_count))) return -EFAULT; return 0; +err: + if (sqd) + mutex_unlock(&sqd->lock); + return ret; } static bool io_register_op_must_quiesce(int op)