diff --git a/docs/middleware.md b/docs/middleware.md index 923379ff9f2..f7c8c0bb4ff 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -67,13 +67,17 @@ Model middleware is supported for the following model functions. Don't confuse model middleware and document middleware: model middleware hooks into *static* functions on a `Model` class, document middleware hooks into *methods* on a `Model` class. In model middleware functions, `this` refers to the model. +* [bulkWrite](api/model.html#model_Model-bulkWrite) +* [createCollection](api/model.html#model_Model-createCollection) * [insertMany](api/model.html#model_Model-insertMany) Here are the possible strings that can be passed to `pre()` * aggregate +* bulkWrite * count * countDocuments +* createCollection * deleteOne * deleteMany * estimatedDocumentCount diff --git a/lib/model.js b/lib/model.js index 0ced344223a..e4754aa1e91 100644 --- a/lib/model.js +++ b/lib/model.js @@ -1415,6 +1415,18 @@ Model.createCollection = async function createCollection(options) { throw new MongooseError('Model.createCollection() no longer accepts a callback'); } + const shouldSkip = await new Promise((resolve, reject) => { + this.hooks.execPre('createCollection', this, [options], (err) => { + if (err != null) { + if (err instanceof Kareem.skipWrappedFunction) { + return resolve(true); + } + return reject(err); + } + resolve(); + }); + }); + const collectionOptions = this && this.schema && this.schema.options && @@ -1468,13 +1480,32 @@ Model.createCollection = async function createCollection(options) { } try { - await this.db.createCollection(this.$__collection.collectionName, options); + if (!shouldSkip) { + await this.db.createCollection(this.$__collection.collectionName, options); + } } catch (err) { - if (err != null && (err.name !== 'MongoServerError' || err.code !== 48)) { - throw err; + await new Promise((resolve, reject) => { + const _opts = { error: err }; + this.hooks.execPost('createCollection', this, [null], _opts, (err) => { + if (err != null) { + return reject(err); + } + resolve(); + }); + }); } } + + await new Promise((resolve, reject) => { + this.hooks.execPost('createCollection', this, [this.$__collection], (err) => { + if (err != null) { + return reject(err); + } + resolve(); + }); + }); + return this.$__collection; }; @@ -3428,44 +3459,62 @@ Model.bulkWrite = async function bulkWrite(ops, options) { throw new MongooseError('Model.bulkWrite() no longer accepts a callback'); } options = options || {}; + + const shouldSkip = await new Promise((resolve, reject) => { + this.hooks.execPre('bulkWrite', this, [ops, options], (err) => { + if (err != null) { + if (err instanceof Kareem.skipWrappedFunction) { + return resolve(err); + } + return reject(err); + } + resolve(); + }); + }); + + if (shouldSkip) { + return shouldSkip.args[0]; + } + const ordered = options.ordered == null ? true : options.ordered; + if (ops.length === 0) { + return getDefaultBulkwriteResult(); + } + const validations = ops.map(op => castBulkWrite(this, op, options)); - return new Promise((resolve, reject) => { - if (ordered) { + let res = null; + if (ordered) { + await new Promise((resolve, reject) => { each(validations, (fn, cb) => fn(cb), error => { if (error) { return reject(error); } - if (ops.length === 0) { - return resolve(getDefaultBulkwriteResult()); - } - - try { - this.$__collection.bulkWrite(ops, options, (error, res) => { - if (error) { - return reject(error); - } - - resolve(res); - }); - } catch (err) { - return reject(err); - } + resolve(); }); + }); - return; + try { + res = await this.$__collection.bulkWrite(ops, options); + } catch (error) { + await new Promise((resolve, reject) => { + const _opts = { error: error }; + this.hooks.execPost('bulkWrite', this, [null], _opts, (err) => { + if (err != null) { + return reject(err); + } + resolve(); + }); + }); } - + } else { let remaining = validations.length; let validOps = []; let validationErrors = []; const results = []; - if (remaining === 0) { - completeUnorderedValidation.call(this); - } else { + await new Promise((resolve) => { for (let i = 0; i < validations.length; ++i) { validations[i]((err) => { if (err == null) { @@ -3475,56 +3524,74 @@ Model.bulkWrite = async function bulkWrite(ops, options) { results[i] = err; } if (--remaining <= 0) { - completeUnorderedValidation.call(this); + resolve(); } }); } - } + }); validationErrors = validationErrors. sort((v1, v2) => v1.index - v2.index). map(v => v.error); - function completeUnorderedValidation() { - const validOpIndexes = validOps; - validOps = validOps.sort().map(index => ops[index]); + const validOpIndexes = validOps; + validOps = validOps.sort().map(index => ops[index]); - if (validOps.length === 0) { - return resolve(getDefaultBulkwriteResult()); - } + if (validOps.length === 0) { + return getDefaultBulkwriteResult(); + } - this.$__collection.bulkWrite(validOps, options, (error, res) => { - if (error) { - if (validationErrors.length > 0) { - error.mongoose = error.mongoose || {}; - error.mongoose.validationErrors = validationErrors; - } + let error; + [res, error] = await this.$__collection.bulkWrite(validOps, options). + then(res => ([res, null])). + catch(err => ([null, err])); - return reject(error); - } + if (error) { + if (validationErrors.length > 0) { + error.mongoose = error.mongoose || {}; + error.mongoose.validationErrors = validationErrors; + } - for (let i = 0; i < validOpIndexes.length; ++i) { - results[validOpIndexes[i]] = null; - } - if (validationErrors.length > 0) { - if (options.throwOnValidationError) { - return reject(new MongooseBulkWriteError( - validationErrors, - results, - res, - 'bulkWrite' - )); - } else { - res.mongoose = res.mongoose || {}; - res.mongoose.validationErrors = validationErrors; - res.mongoose.results = results; + await new Promise((resolve, reject) => { + const _opts = { error: error }; + this.hooks.execPost('bulkWrite', this, [null], _opts, (err) => { + if (err != null) { + return reject(err); } - } - - resolve(res); + resolve(); + }); }); } + + for (let i = 0; i < validOpIndexes.length; ++i) { + results[validOpIndexes[i]] = null; + } + if (validationErrors.length > 0) { + if (options.throwOnValidationError) { + throw new MongooseBulkWriteError( + validationErrors, + results, + res, + 'bulkWrite' + ); + } else { + res.mongoose = res.mongoose || {}; + res.mongoose.validationErrors = validationErrors; + res.mongoose.results = results; + } + } + } + + await new Promise((resolve, reject) => { + this.hooks.execPost('bulkWrite', this, [res], (err) => { + if (err != null) { + return reject(err); + } + resolve(); + }); }); + + return res; }; /** diff --git a/test/model.middleware.test.js b/test/model.middleware.test.js index 7747167bb4b..92ca5224dee 100644 --- a/test/model.middleware.test.js +++ b/test/model.middleware.test.js @@ -457,4 +457,137 @@ describe('model middleware', function() { assert.equal(preCalled, 1); assert.equal(postCalled, 1); }); + + describe('createCollection middleware', function() { + it('calls createCollection hooks', async function() { + const schema = new Schema({ name: String }, { autoCreate: true }); + + const pre = []; + const post = []; + schema.pre('createCollection', function() { + pre.push(this); + }); + schema.post('createCollection', function() { + post.push(this); + }); + + const Test = db.model('Test', schema); + await Test.init(); + assert.equal(pre.length, 1); + assert.equal(pre[0], Test); + assert.equal(post.length, 1); + assert.equal(post[0], Test); + }); + + it('allows skipping createCollection from hooks', async function() { + const schema = new Schema({ name: String }, { autoCreate: true }); + + schema.pre('createCollection', function(next) { + next(mongoose.skipMiddlewareFunction()); + }); + + const Test = db.model('CreateCollectionHookTest', schema); + await Test.init(); + const collections = await db.listCollections(); + assert.equal(collections.length, 0); + }); + }); + + describe('bulkWrite middleware', function() { + it('calls bulkWrite hooks', async function() { + const schema = new Schema({ name: String }); + + const pre = []; + const post = []; + schema.pre('bulkWrite', function(next, ops) { + pre.push(ops); + next(); + }); + schema.post('bulkWrite', function(res) { + post.push(res); + }); + + const Test = db.model('Test', schema); + await Test.bulkWrite([{ + updateOne: { + filter: { name: 'foo' }, + update: { $set: { name: 'bar' } } + } + }]); + assert.equal(pre.length, 1); + assert.deepStrictEqual(pre[0], [{ + updateOne: { + filter: { name: 'foo' }, + update: { $set: { name: 'bar' } } + } + }]); + assert.equal(post.length, 1); + assert.equal(post[0].constructor.name, 'BulkWriteResult'); + }); + + it('allows updating ops', async function() { + const schema = new Schema({ name: String, prop: String }); + + schema.pre('bulkWrite', function(next, ops) { + ops[0].updateOne.filter.name = 'baz'; + next(); + }); + + const Test = db.model('Test', schema); + const { _id } = await Test.create({ name: 'baz' }); + await Test.bulkWrite([{ + updateOne: { + filter: { name: 'foo' }, + update: { $set: { prop: 'test prop value' } } + } + }]); + const { prop } = await Test.findById(_id).orFail(); + assert.equal(prop, 'test prop value'); + }); + + it('supports error handlers', async function() { + const schema = new Schema({ name: String, prop: String }); + + const errors = []; + schema.post('bulkWrite', function(err, res, next) { + errors.push(err); + next(); + }); + + const Test = db.model('Test', schema); + const { _id } = await Test.create({ name: 'baz' }); + await assert.rejects( + Test.bulkWrite([{ + insertOne: { + document: { + _id + } + } + }]), + /duplicate key error/ + ); + assert.equal(errors.length, 1); + assert.equal(errors[0].name, 'MongoBulkWriteError'); + assert.ok(errors[0].message.includes('duplicate key error'), errors[0].message); + }); + + it('supports skipping wrapped function', async function() { + const schema = new Schema({ name: String, prop: String }); + + schema.pre('bulkWrite', function(next) { + next(mongoose.skipMiddlewareFunction('skipMiddlewareFunction test')); + }); + + const Test = db.model('Test', schema); + const { _id } = await Test.create({ name: 'baz' }); + const res = await Test.bulkWrite([{ + insertOne: { + document: { + _id + } + } + }]); + assert.strictEqual(res, 'skipMiddlewareFunction test'); + }); + }); }); diff --git a/test/types/middleware.test.ts b/test/types/middleware.test.ts index d06530e896f..31e210eb26d 100644 --- a/test/types/middleware.test.ts +++ b/test/types/middleware.test.ts @@ -1,5 +1,6 @@ import { Schema, model, Model, Document, SaveOptions, Query, Aggregate, HydratedDocument, PreSaveMiddlewareFunction, ModifyResult } from 'mongoose'; import { expectError, expectType, expectNotType, expectAssignable } from 'tsd'; +import { AnyBulkWriteOperation, CreateCollectionOptions } from 'mongodb'; const preMiddlewareFn: PreSaveMiddlewareFunction = function(next, opts) { this.$markValid('name'); @@ -90,6 +91,14 @@ schema.pre>('insertMany', function(next, docs: Array) { next(); }); +schema.pre>('bulkWrite', function(next, ops: Array>) { + next(); +}); + +schema.pre>('createCollection', function(next, opts?: CreateCollectionOptions) { + next(); +}); + schema.pre>('estimatedDocumentCount', function(next) {}); schema.post>('estimatedDocumentCount', function(count, next) { expectType(count); diff --git a/types/index.d.ts b/types/index.d.ts index 987d35c8117..6a5a926c654 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -373,8 +373,8 @@ declare module 'mongoose' { // method aggregate and insertMany with ErrorHandlingMiddlewareFunction post>(method: 'aggregate' | RegExp, fn: ErrorHandlingMiddlewareFunction>): this; post>(method: 'aggregate' | RegExp, options: SchemaPostOptions, fn: ErrorHandlingMiddlewareFunction>): this; - post(method: 'insertMany' | RegExp, fn: ErrorHandlingMiddlewareFunction): this; - post(method: 'insertMany' | RegExp, options: SchemaPostOptions, fn: ErrorHandlingMiddlewareFunction): this; + post(method: 'bulkWrite' | 'createCollection' | 'insertMany' | RegExp, fn: ErrorHandlingMiddlewareFunction): this; + post(method: 'bulkWrite' | 'createCollection' | 'insertMany' | RegExp, options: SchemaPostOptions, fn: ErrorHandlingMiddlewareFunction): this; /** Defines a pre hook for the model. */ // this = never since it never happens @@ -429,6 +429,44 @@ declare module 'mongoose' { options?: InsertManyOptions & { lean?: boolean } ) => void | Promise ): this; + /* method bulkWrite */ + pre( + method: 'bulkWrite' | RegExp, + fn: ( + this: T, + next: (err?: CallbackError) => void, + ops: Array & MongooseBulkWritePerWriteOptions>, + options?: mongodb.BulkWriteOptions & MongooseBulkWriteOptions + ) => void | Promise + ): this; + pre( + method: 'bulkWrite' | RegExp, + options: SchemaPreOptions, + fn: ( + this: T, + next: (err?: CallbackError) => void, + ops: Array & MongooseBulkWritePerWriteOptions>, + options?: mongodb.BulkWriteOptions & MongooseBulkWriteOptions + ) => void | Promise + ): this; + /* method createCollection */ + pre( + method: 'createCollection' | RegExp, + fn: ( + this: T, + next: (err?: CallbackError) => void, + options?: mongodb.CreateCollectionOptions & Pick + ) => void | Promise + ): this; + pre( + method: 'createCollection' | RegExp, + options: SchemaPreOptions, + fn: ( + this: T, + next: (err?: CallbackError) => void, + options?: mongodb.CreateCollectionOptions & Pick + ) => void | Promise + ): this; /** Object of currently defined query helpers on this schema. */ query: TQueryHelpers;