Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions e2e/oauth/google-auth-callback.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ type ApiMocks = {
type ApiMockOptions = {
beforeConnectGoogleResponse?: Promise<void>;
beforeLoginOrSignupResponse?: Promise<void>;
connectGoogleResponse?: {
body?: unknown;
status: number;
};
};

const createDeferred = () => {
Expand Down Expand Up @@ -158,9 +162,9 @@ const prepareGoogleAuthCallbackPage = async (
});
await options.beforeConnectGoogleResponse;
return route.fulfill({
status: 200,
status: options.connectGoogleResponse?.status ?? 200,
contentType: "application/json",
body: JSON.stringify({}),
body: JSON.stringify(options.connectGoogleResponse?.body ?? {}),
});
}

Expand Down Expand Up @@ -284,6 +288,35 @@ test.describe("Google auth callback", () => {
expectGoogleAuthRequestBody(apiMocks.loginOrSignup[0], state);
});

test("recovers through Google sign-in when Google connect rejects an expired Compass session", async ({
page,
}) => {
const state = "connect-calendar-session-expired-state";
const apiMocks = await prepareGoogleAuthCallbackPage(page, {
connectGoogleResponse: {
status: 401,
body: { message: "unauthorized" },
},
});

await writeGoogleAuthorizationIntent({
intent: "connectCalendar",
page,
returnPath: "/week",
state,
});
await setActiveCompassSession(page);

await page.goto(getCallbackUrl(state));

await expect(page).toHaveURL(/\/week$/);
expect(apiMocks.connectGoogle).toHaveLength(1);
expect(apiMocks.loginOrSignup).toHaveLength(1);
expect(apiMocks.loginOrSignup[0]?.headers.rid).toBe("thirdparty");
expectGoogleAuthRequestBody(apiMocks.connectGoogle[0], state);
expectGoogleAuthRequestBody(apiMocks.loginOrSignup[0], state);
});

test("rejects callbacks that are missing required Google Calendar scopes", async ({
page,
}) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ import {
type GoogleAuthCodeRequest,
GoogleConnectErrorResponseSchema,
} from "@core/types/auth.types";
import {
type ApiError,
type ApiMethodConfig,
} from "@web/common/apis/api.types";
import { type ApiError } from "@web/common/apis/api.types";
import { ROOT_ROUTES } from "@web/common/constants/routes";
import {
GOOGLE_AUTH_SCOPES_REQUIRED,
Expand All @@ -28,10 +25,7 @@ type CompleteAuthentication = (input: {
}) => Promise<void>;

export type GoogleAuthorizationAuthAdapter = {
connectGoogle(
data: GoogleAuthCodeRequest,
config?: ApiMethodConfig,
): Promise<unknown>;
connectGoogle(data: GoogleAuthCodeRequest): Promise<unknown>;
loginOrSignup(data: GoogleAuthCodeRequest): Promise<{
user: { emails?: string[] };
}>;
Expand Down Expand Up @@ -149,7 +143,7 @@ export async function completeGoogleAuthorization({
await completeGoogleSignIn();
} else {
try {
await authApi.connectGoogle(payload, { skipSessionRecovery: true });
await authApi.connectGoogle(payload);
await refreshUserMetadata();
requestEventFetch?.();
} catch (error) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import { Status } from "@core/errors/status.codes";
import { type GoogleAuthCodeRequest } from "@core/types/auth.types";
import { BaseApi } from "@web/common/apis/base/base.api";
import { session } from "@web/common/classes/Session";
import { GoogleAuthCallbackApi } from "./google-auth-callback.api";
import {
afterEach,
beforeEach,
describe,
expect,
it,
mock,
spyOn,
} from "bun:test";

const originalFetch = globalThis.fetch;

const payload: GoogleAuthCodeRequest = {
clientType: "web",
redirectURIInfo: {
redirectURIOnProviderDashboard:
"http://localhost:9080/auth/google/callback",
redirectURIQueryParams: {
code: "auth-code",
scope: "https://www.googleapis.com/auth/calendar",
state: "state-1",
},
},
thirdPartyId: "google",
};

describe("GoogleAuthCallbackApi", () => {
beforeEach(() => {
BaseApi.defaults.adapter = undefined;
});

afterEach(() => {
globalThis.fetch = originalFetch;
BaseApi.defaults.adapter = undefined;
});

it("lets the callback flow handle an expired connect session locally", async () => {
const signOutSpy = spyOn(session, "signOut").mockResolvedValue(undefined);
globalThis.fetch = mock(async () =>
Promise.resolve(
new Response(JSON.stringify({ message: "unauthorised" }), {
status: Status.UNAUTHORIZED,
}),
),
) as unknown as typeof fetch;

await expect(
GoogleAuthCallbackApi.connectGoogle(payload),
).rejects.toMatchObject({
name: "ApiError",
response: {
status: Status.UNAUTHORIZED,
},
});

expect(signOutSpy).not.toHaveBeenCalled();
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.stringContaining("/auth/google/connect"),
expect.objectContaining({
credentials: "include",
method: "POST",
}),
);

signOutSpy.mockRestore();
});

it("uses shared session handling for non-recoverable connect session errors", async () => {
window.history.pushState({}, "", "/day");
const signOutSpy = spyOn(session, "signOut").mockResolvedValue(undefined);
globalThis.fetch = mock(async () =>
Promise.resolve(
new Response(JSON.stringify({ message: "not found" }), {
status: Status.NOT_FOUND,
}),
),
) as unknown as typeof fetch;

await expect(
GoogleAuthCallbackApi.connectGoogle(payload),
).rejects.toMatchObject({
name: "ApiError",
response: {
status: Status.NOT_FOUND,
},
});

expect(signOutSpy).toHaveBeenCalledTimes(1);

signOutSpy.mockRestore();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { GOOGLE_REVOKED } from "@core/constants/sse.constants";
import { Status } from "@core/errors/status.codes";
import {
type GoogleAuthCodeRequest,
type GoogleConnectResponse,
} from "@core/types/auth.types";
import { type ApiError, type ApiResponse } from "@web/common/apis/api.types";
import { AuthApi } from "@web/common/apis/auth.api";
import { sendApiRequestWithoutSharedErrorRecovery } from "@web/common/apis/base/base.api";
import {
getApiErrorCode,
handleErrorResponse,
isApiError,
} from "@web/common/apis/util/api.util";
import { type GoogleAuthorizationAuthAdapter } from "./complete-google-authorization";

const isRecoverableConnectSessionError = (error: ApiError): boolean => {
return (
error.response?.status === Status.UNAUTHORIZED &&
getApiErrorCode(error) !== GOOGLE_REVOKED
);
};

export const GoogleAuthCallbackApi = {
async connectGoogle(
data: GoogleAuthCodeRequest,
): Promise<GoogleConnectResponse> {
try {
const response =
await sendApiRequestWithoutSharedErrorRecovery<GoogleConnectResponse>(
"POST",
"/auth/google/connect",
data,
);

return response.data;
} catch (error) {
if (!isApiError(error)) {
throw error;
}

if (isRecoverableConnectSessionError(error)) {
throw error;
}

await handleErrorResponse<ApiResponse<GoogleConnectResponse>>(error);
throw error;
}
},
loginOrSignup: AuthApi.loginOrSignup,
} satisfies GoogleAuthorizationAuthAdapter;
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ describe("completeGoogleAuthorization", () => {
).resolves.toEqual({ status: "completed", returnPath: "/day" });

expect(deps.authApi.connectGoogle).toHaveBeenCalledTimes(1);
expect(deps.authApi.connectGoogle.mock.calls[0]).toHaveLength(1);
expect(deps.refreshUserMetadata).toHaveBeenCalledTimes(1);
expect(deps.requestEventFetch).toHaveBeenCalledTimes(1);
expect(deps.completeAuthentication).not.toHaveBeenCalled();
Expand Down Expand Up @@ -136,6 +137,7 @@ describe("completeGoogleAuthorization", () => {
).resolves.toEqual({ status: "completed", returnPath: "/day" });

expect(deps.authApi.connectGoogle).toHaveBeenCalledTimes(1);
expect(deps.authApi.connectGoogle.mock.calls[0]).toHaveLength(1);
expect(deps.authApi.loginOrSignup).toHaveBeenCalledTimes(1);
expect(deps.completeAuthentication).toHaveBeenCalledWith({
email: "user@example.com",
Expand Down
6 changes: 1 addition & 5 deletions packages/web/src/common/apis/api.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ export interface ApiError extends Error {
export interface ApiRequestConfig {
headers?: HeadersInit;
method?: string;
skipSessionRecovery?: boolean;
url?: string;
}

Expand All @@ -24,10 +23,7 @@ export interface ApiResponse<T> {
statusText: string;
}

export type ApiMethodConfig = Pick<
ApiRequestConfig,
"headers" | "skipSessionRecovery"
>;
export type ApiMethodConfig = Pick<ApiRequestConfig, "headers">;

export type SignoutStatus =
| Status.UNAUTHORIZED
Expand Down
3 changes: 0 additions & 3 deletions packages/web/src/common/apis/auth.api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import {
type GoogleConnectResponse,
type Result_Auth_Compass,
} from "@core/types/auth.types";
import { type ApiMethodConfig } from "@web/common/apis/api.types";
import { BaseApi } from "@web/common/apis/base/base.api";

const AuthApi = {
Expand All @@ -21,12 +20,10 @@ const AuthApi = {

async connectGoogle(
data: GoogleAuthCodeRequest,
config?: ApiMethodConfig,
): Promise<GoogleConnectResponse> {
const response = await BaseApi.post<GoogleConnectResponse>(
`/auth/google/connect`,
data,
config,
);

return response.data;
Expand Down
25 changes: 23 additions & 2 deletions packages/web/src/common/apis/base/base.api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,32 @@ const request = async <T>(
url: string,
body?: unknown,
config: ApiMethodConfig = {},
): Promise<ApiResponse<T>> => {
try {
return await sendApiRequestWithoutSharedErrorRecovery<T>(
method,
url,
body,
config,
);
} catch (error) {
if (isApiError(error)) {
return handleErrorResponse<ApiResponse<T>>(error);
}

throw error;
}
};

export const sendApiRequestWithoutSharedErrorRecovery = async <T>(
method: string,
url: string,
body?: unknown,
config: ApiMethodConfig = {},
): Promise<ApiResponse<T>> => {
const requestConfig = {
headers: config.headers,
method,
skipSessionRecovery: config.skipSessionRecovery,
url,
} satisfies ApiRequestConfig;

Expand Down Expand Up @@ -74,7 +95,7 @@ const request = async <T>(
}

if (isApiError(error)) {
return handleErrorResponse<ApiResponse<T>>(error);
throw error;
}

throw createApiError(requestConfig);
Expand Down
Loading