Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add configurable shield rule #609

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions arcjet/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
ArcjetTokenBucketRateLimitRule,
ArcjetFixedWindowRateLimitRule,
ArcjetSlidingWindowRateLimitRule,
ArcjetShieldRule,
} from "@arcjet/protocol";
import {
ArcjetBotTypeToProtocol,
Expand Down Expand Up @@ -304,7 +305,6 @@ export function createRemoteClient(
const decideRequest = new DecideRequest({
sdkStack,
sdkVersion,
fingerprint: context.fingerprint,
details: {
ip: details.ip,
method: details.method,
Expand Down Expand Up @@ -355,7 +355,6 @@ export function createRemoteClient(
const reportRequest = new ReportRequest({
sdkStack,
sdkVersion,
fingerprint: context.fingerprint,
details: {
ip: details.ip,
method: details.method,
Expand Down Expand Up @@ -568,9 +567,10 @@ export class ArcjetHeaders extends Headers {
}

const Priority = {
RateLimit: 1,
BotDetection: 2,
EmailValidation: 3,
Shield: 1,
RateLimit: 2,
BotDetection: 3,
EmailValidation: 4,
};

type PlainObject = { [key: string]: unknown };
Expand Down Expand Up @@ -986,6 +986,29 @@ export function detectBot(
return rules;
}

export type ShieldOptions = {
mode?: ArcjetMode;
};

export function shield(
options?: ShieldOptions,
...additionalOptions: ShieldOptions[]
): Primitive {
const rules: ArcjetShieldRule<{}>[] = [];

// Always create at least one Shield rule
for (const opt of [options ?? {}, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
rules.push({
type: "SHIELD",
priority: Priority.Shield,
mode,
});
}

return rules;
}

export type ProtectSignupOptions<Characteristics extends string[]> = {
rateLimit?:
| SlidingWindowRateLimitOptions<Characteristics>
Expand Down
37 changes: 26 additions & 11 deletions arcjet/test/index.node.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import arcjet, {
slidingWindow,
Primitive,
Arcjet,
shield,
} from "../index";

// Type helpers from https://github.com/sindresorhus/type-fest but adjusted for
Expand Down Expand Up @@ -283,7 +284,6 @@ describe("createRemoteClient", () => {
...details,
headers: { "user-agent": "curl/8.1.2" },
},
fingerprint,
rules: [],
sdkStack: SDKStack.SDK_STACK_NEXTJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
Expand Down Expand Up @@ -338,7 +338,6 @@ describe("createRemoteClient", () => {
...details,
headers: { "user-agent": "curl/8.1.2" },
},
fingerprint,
rules: [],
sdkStack: SDKStack.SDK_STACK_UNSPECIFIED,
sdkVersion: "__ARCJET_SDK_VERSION__",
Expand Down Expand Up @@ -391,7 +390,6 @@ describe("createRemoteClient", () => {
...details,
headers: { "user-agent": "curl/8.1.2" },
},
fingerprint,
rules: [],
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
Expand Down Expand Up @@ -445,7 +443,6 @@ describe("createRemoteClient", () => {
...details,
headers: { "user-agent": "curl/8.1.2" },
},
fingerprint,
rules: [],
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
Expand Down Expand Up @@ -504,7 +501,6 @@ describe("createRemoteClient", () => {
...details,
headers: { "user-agent": "curl/8.1.2" },
},
fingerprint,
rules: [new Rule()],
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
Expand Down Expand Up @@ -817,7 +813,6 @@ describe("createRemoteClient", () => {
new ReportRequest({
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
fingerprint,
details: {
...details,
headers: { "user-agent": "curl/8.1.2" },
Expand Down Expand Up @@ -881,7 +876,6 @@ describe("createRemoteClient", () => {
expect(router.report).toHaveBeenCalledTimes(1);
expect(router.report).toHaveBeenCalledWith(
new ReportRequest({
fingerprint,
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
details: {
Expand Down Expand Up @@ -949,7 +943,6 @@ describe("createRemoteClient", () => {
new ReportRequest({
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
fingerprint,
details: {
...details,
headers: { "user-agent": "curl/8.1.2" },
Expand Down Expand Up @@ -1022,7 +1015,6 @@ describe("createRemoteClient", () => {
new ReportRequest({
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
fingerprint,
details: {
...details,
headers: { "user-agent": "curl/8.1.2" },
Expand Down Expand Up @@ -1084,7 +1076,6 @@ describe("createRemoteClient", () => {
new ReportRequest({
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
fingerprint,
details: {
...details,
headers: { "user-agent": "curl/8.1.2" },
Expand Down Expand Up @@ -1164,7 +1155,6 @@ describe("createRemoteClient", () => {
new ReportRequest({
sdkStack: SDKStack.SDK_STACK_NODEJS,
sdkVersion: "__ARCJET_SDK_VERSION__",
fingerprint,
details: {
...details,
headers: { "user-agent": "curl/8.1.2" },
Expand Down Expand Up @@ -3004,6 +2994,31 @@ describe("Primitive > validateEmail", () => {
});
});

describe("Primitive > shield", () => {
test("provides a default rule with no options specified", async () => {
const [rule] = shield();
expect(rule.type).toEqual("SHIELD");
expect(rule).toHaveProperty("mode", "DRY_RUN");
});

test("sets mode as 'DRY_RUN' if not 'LIVE' or 'DRY_RUN'", async () => {
const [rule] = shield({
// @ts-expect-error
mode: "INVALID",
});
expect(rule.type).toEqual("SHIELD");
expect(rule).toHaveProperty("mode", "DRY_RUN");
});

test("sets mode as `LIVE` if specified", async () => {
const [rule] = shield({
mode: "LIVE",
});
expect(rule.type).toEqual("SHIELD");
expect(rule).toHaveProperty("mode", "LIVE");
});
});

describe("Products > protectSignup", () => {
test("allows configuration of rateLimit, bot, and email", () => {
const rules = protectSignup({
Expand Down
6 changes: 5 additions & 1 deletion examples/nextjs-13-pages-wrap/pages/api/arcjet-edge.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import arcjet, { fixedWindow, withArcjet } from "@arcjet/next";
import arcjet, { fixedWindow, shield, withArcjet } from "@arcjet/next";
import { NextRequest, NextResponse } from "next/server";

export const config = {
Expand All @@ -12,6 +12,10 @@ const aj = arcjet({
// See: https://nextjs.org/docs/pages/building-your-application/configuring/environment-variables
key: process.env.ARCJET_KEY,
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
// Fixed window rate limit. Arcjet also supports sliding window and token
// bucket.
fixedWindow({
Expand Down
6 changes: 5 additions & 1 deletion examples/nextjs-13-pages-wrap/pages/api/arcjet.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import arcjet, { fixedWindow, withArcjet } from "@arcjet/next";
import arcjet, { fixedWindow, shield, withArcjet } from "@arcjet/next";
import type { NextApiRequest, NextApiResponse } from "next";

const aj = arcjet({
Expand All @@ -8,6 +8,10 @@ const aj = arcjet({
// See: https://nextjs.org/docs/pages/building-your-application/configuring/environment-variables
key: process.env.ARCJET_KEY,
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
// Fixed window rate limit. Arcjet also supports sliding window and token
// bucket.
fixedWindow({
Expand Down
9 changes: 7 additions & 2 deletions examples/nextjs-14-app-dir-rl/middleware.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import arcjet, { createMiddleware } from "@arcjet/next";
import arcjet, { createMiddleware, shield } from "@arcjet/next";

export const config = {
// matcher tells Next.js which routes to run the middleware on
Expand All @@ -10,7 +10,12 @@ const aj = arcjet({
// and set it as an environment variable rather than hard coding.
// See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables
key: process.env.ARCJET_KEY!,
rules: [],
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
],
});

export default createMiddleware(aj);
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import arcjet, { validateEmail } from "@arcjet/next";
import arcjet, { shield, validateEmail } from "@arcjet/next";
import { NextResponse } from "next/server";

const aj = arcjet({
Expand All @@ -7,6 +7,10 @@ const aj = arcjet({
// See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables
key: process.env.ARCJET_KEY,
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
validateEmail({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
block: ["NO_MX_RECORDS"], // block email addresses with no MX records
Expand Down
6 changes: 5 additions & 1 deletion examples/nextjs-14-authjs-5/app/api/protected/route.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import arcjet, { tokenBucket } from "@arcjet/next";
import arcjet, { shield, tokenBucket } from "@arcjet/next";
import { auth } from "auth";

// The arcjet instance is created outside of the handler
const aj = arcjet({
key: process.env.ARCJET_KEY!, // Get your site key from https://app.arcjet.com
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
// Create a token bucket rate limit. Other algorithms are supported.
tokenBucket({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import arcjet, { detectBot, slidingWindow } from "@arcjet/next";
import arcjet, { detectBot, shield, slidingWindow } from "@arcjet/next";
import { handlers } from "auth";
import { NextRequest, NextResponse } from "next/server";

const aj = arcjet({
key: process.env.ARCJET_KEY,
rules: [
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
slidingWindow({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
interval: 60, // tracks requests across a 60 second sliding window
Expand Down
6 changes: 5 additions & 1 deletion examples/nextjs-14-authjs-5/middleware.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import arcjet, { tokenBucket } from "@arcjet/next";
import arcjet, { shield, tokenBucket } from "@arcjet/next";
import { auth } from "auth";

const aj = arcjet({
key: process.env.ARCJET_KEY!, // Get your site key from https://app.arcjet.com
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
// Create a token bucket rate limit. Other algorithms are supported.
tokenBucket({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
Expand Down
5 changes: 4 additions & 1 deletion examples/nextjs-14-clerk-rl/app/api/arcjet/route.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import arcjet, { ArcjetDecision, tokenBucket, detectBot } from "@arcjet/next";
import arcjet, { ArcjetDecision, tokenBucket, detectBot, shield } from "@arcjet/next";
import { NextResponse } from "next/server";
import { currentUser } from "@clerk/nextjs";

// The root Arcjet client is created outside of the handler.
const aj = arcjet({
key: process.env.ARCJET_KEY!, // Get your site key from https://app.arcjet.com
rules: [
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
detectBot({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
block: ["AUTOMATED"], // blocks all automated clients
Expand Down
11 changes: 7 additions & 4 deletions examples/nextjs-14-clerk-shield/middleware.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { authMiddleware } from "@clerk/nextjs";
import arcjet, { createMiddleware } from "@arcjet/next";
import arcjet, { createMiddleware, shield } from "@arcjet/next";

export const config = {
// Protects all routes, including api/trpc.
Expand All @@ -19,9 +19,12 @@ const aj = arcjet({
// and set it as an environment variable rather than hard coding.
// See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables
key: process.env.ARCJET_KEY!,
// No rules are required for Arcjet Shield - it runs on every request.
// You can also add other rules, such as bot protection, here.
rules: [],
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
],
});

// Clerk middleware will run after the Arcjet middleware. You could also use
Expand Down
6 changes: 5 additions & 1 deletion examples/nextjs-14-decorate/app/api-app/arcjet/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import arcjet, { fixedWindow } from "@arcjet/next";
import arcjet, { fixedWindow, shield } from "@arcjet/next";
import { setRateLimitHeaders } from "@arcjet/decorate";
import { NextResponse } from "next/server";

Expand All @@ -8,6 +8,10 @@ const aj = arcjet({
// See: https://nextjs.org/docs/app/building-your-application/configuring/environment-variables
key: process.env.ARCJET_KEY!,
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
// Fixed window rate limit. Arcjet also supports sliding window and token
// bucket.
fixedWindow({
Expand Down
9 changes: 7 additions & 2 deletions examples/nextjs-14-ip-details/pages/api/arcjet-edge.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Next.js API route support: https://nextjs.org/docs/api-routes/introduction
import arcjet from "@arcjet/next";
import arcjet, { shield } from "@arcjet/next";
import { NextRequest, NextResponse } from "next/server";

export const config = {
Expand All @@ -11,7 +11,12 @@ const aj = arcjet({
// and set it as an environment variable rather than hard coding.
// See: https://nextjs.org/docs/pages/building-your-application/configuring/environment-variables
key: process.env.ARCJET_KEY!,
rules: [],
rules: [
// Protect against common attacks with Arcjet Shield
shield({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
}),
],
});

export default async function handler(req: NextRequest) {
Expand Down
Loading
Loading