Skip to content

Commit

Permalink
fix(passportjs): Fix authorization middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
Romakita committed Jan 18, 2020
1 parent 06cd817 commit 2f0f9c8
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 51 deletions.
9 changes: 7 additions & 2 deletions packages/passport/src/decorators/authenticate.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import {UseAuth} from "@tsed/common";
import {AuthenticateMiddleware} from "../middlewares/AuthenticateMiddleware";
import {PassportMiddleware} from "../middlewares/PassportMiddleware";

export function Authenticate(protocol: string | string[], options: any = {}): Function {
return UseAuth(AuthenticateMiddleware, {protocol, security: options.security, options});
return UseAuth(PassportMiddleware, {
protocol,
method: "authorize",
security: options.security,
options
});
}
11 changes: 8 additions & 3 deletions packages/passport/src/decorators/authorize.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import {UseAuth} from "@tsed/common";
import {AuthorizeMiddleware} from "../middlewares/AuthorizeMiddleware";
import {PassportMiddleware} from "../middlewares/PassportMiddleware";

export function Authorize(protocol: string, options: any = {}): Function {
return UseAuth(AuthorizeMiddleware, {protocol, security: options.security, options});
export function Authorize(protocol: string | string[], options: any = {}): Function {
return UseAuth(PassportMiddleware, {
protocol,
method: "authorize",
security: options.security,
options
});
}
3 changes: 1 addition & 2 deletions packages/passport/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ export * from "./services/ProtocolsService";
export * from "./services/PassportSerializerService";

// Middlewares
export * from "./middlewares/AuthenticateMiddleware";
export * from "./middlewares/AuthorizeMiddleware";
export * from "./middlewares/PassportMiddleware";
20 changes: 0 additions & 20 deletions packages/passport/src/middlewares/AuthenticateMiddleware.ts

This file was deleted.

23 changes: 0 additions & 23 deletions packages/passport/src/middlewares/AuthorizeMiddleware.ts

This file was deleted.

23 changes: 23 additions & 0 deletions packages/passport/src/middlewares/PassportMiddleware.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import {EndpointInfo, Inject, Middleware, Req} from "@tsed/common";
import * as Passport from "passport";
import {Unauthorized} from "ts-httpexceptions";
import {ProtocolsService} from "../services/ProtocolsService";
import {getProtocolsFromRequest} from "../utils/getProtocolsFromRequest";

@Middleware()
export class PassportMiddleware {
@Inject(ProtocolsService)
protocolsService: ProtocolsService;

use(@Req() request: Req, @EndpointInfo() endpoint: EndpointInfo) {
const {options, protocol, method} = endpoint.store.get(PassportMiddleware);
const protocols = getProtocolsFromRequest(request, protocol, this.protocolsService.getProtocolsNames());

if (protocols.length === 0) {
throw new Unauthorized("Not authorized");
}

// @ts-ignore
return Passport[method](protocols.length === 1 ? protocol[0] : protocols, options);
}
}
2 changes: 1 addition & 1 deletion packages/passport/src/services/ProtocolsService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export class ProtocolsService {
return Array.from(this.injector.getProviders(PROVIDER_TYPE_PROTOCOL));
}

public getProviderNames(): string[] {
public getProtocolsNames(): string[] {
return this.getProtocols().map(provider => this.getOptions(provider).name);
}

Expand Down
35 changes: 35 additions & 0 deletions packages/passport/src/utils/getProtocolsFromRequest.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
function getProtocolName(req: any) {
const {query = {}, params = {}, body = {}} = req;

return params.protocol || query.protocol || body.protocol;
}

const add = (protocols: string[], protocol: string): string[] => {
if (!protocol || protocols.includes(protocol)) {
return protocols;
}

return protocols.concat(protocol);
};

export function getProtocolsFromRequest(req: any, protocol: string | string[], defaultProtocols: string[]): string[] {
let protocols: string[] = [].concat(protocol as never);

if (protocols.includes("*")) {
return defaultProtocols;
}

protocols = protocols.reduce((protocols: string[], protocol: string) => {
if (protocol === ":protocol") {
return add(protocols, getProtocolName(req));
}

if (protocol === getProtocolName(req)) {
return add(protocols, protocol);
}

return protocols;
}, [] as string[]);

return protocols;
}
83 changes: 83 additions & 0 deletions packages/passport/test/utils/getProtocolsFromRequest.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import {expect} from "chai";
import {getProtocolsFromRequest} from "../../src/utils/getProtocolsFromRequest";

describe("getProtocolsFromRequest", () => {
it("should allow all protocol (from default protocols)", () => {
const defaultProtocols = ["default"];
const req = {};
const result = getProtocolsFromRequest(req, "*", defaultProtocols);

expect(result).to.deep.equal(["default"]);
});

it("should get protocol from request (params)", () => {
const defaultProtocols = ["default"];
const req = {
params: {
protocol: "basic"
}
};
const result = getProtocolsFromRequest(req, ":protocol", defaultProtocols);

expect(result).to.deep.equal(["basic"]);
});

it("should not get protocol from request", () => {
const defaultProtocols = ["default"];
const req = {
params: {}
};
const result = getProtocolsFromRequest(req, ":protocol", defaultProtocols);

expect(result).to.deep.equal([]);
});

it("should get protocol from request (query)", () => {
const defaultProtocols = ["default"];
const req = {
query: {
protocol: "basic"
}
};
const result = getProtocolsFromRequest(req, ":protocol", defaultProtocols);

expect(result).to.deep.equal(["basic"]);
});

it("should get protocol from request (body)", () => {
const defaultProtocols = ["default"];
const req = {
body: {
protocol: "basic"
}
};
const result = getProtocolsFromRequest(req, ":protocol", defaultProtocols);

expect(result).to.deep.equal(["basic"]);
});


it("should return basic protocol", () => {
const defaultProtocols = ["default"];
const req = {
params: {
protocol: "basic"
}
};
const result = getProtocolsFromRequest(req, "basic", defaultProtocols);

expect(result).to.deep.equal(["basic"]);
});

it("should not return protocol when protocol doesn\'t match", () => {
const defaultProtocols = ["default"];
const req = {
params: {
protocol: "basic"
}
};
const result = getProtocolsFromRequest(req, "other", defaultProtocols);

expect(result).to.deep.equal([]);
});
});

0 comments on commit 2f0f9c8

Please sign in to comment.