Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions apps/code/src/main/db/migrations/0003_fair_whiplash.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE TABLE `auth_sessions` (
`id` integer PRIMARY KEY NOT NULL CHECK (`id` = 1),
`refresh_token_encrypted` text NOT NULL,
`cloud_region` text NOT NULL,
`selected_project_id` integer,
`scope_version` integer NOT NULL,
`created_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
`updated_at` text DEFAULT (CURRENT_TIMESTAMP) NOT NULL
);
7 changes: 7 additions & 0 deletions apps/code/src/main/db/migrations/meta/_journal.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
"when": 1773335630838,
"tag": "0002_massive_bishop",
"breakpoints": true
},
{
"idx": 3,
"version": "6",
"when": 1774890000000,
"tag": "0003_fair_whiplash",
"breakpoints": true
}
]
}
42 changes: 42 additions & 0 deletions apps/code/src/main/db/repositories/auth-session-repository.mock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import type {
AuthSession,
IAuthSessionRepository,
PersistAuthSessionInput,
} from "./auth-session-repository";

export interface MockAuthSessionRepository extends IAuthSessionRepository {
_session: AuthSession | null;
}

export function createMockAuthSessionRepository(): MockAuthSessionRepository {
let session: AuthSession | null = null;

const clone = (value: AuthSession | null): AuthSession | null =>
value ? { ...value } : null;

return {
get _session() {
return clone(session);
},
set _session(value) {
session = clone(value);
},
getCurrent: () => clone(session),
saveCurrent: (input: PersistAuthSessionInput) => {
const timestamp = new Date().toISOString();
session = {
id: 1,
refreshTokenEncrypted: input.refreshTokenEncrypted,
cloudRegion: input.cloudRegion,
selectedProjectId: input.selectedProjectId,
scopeVersion: input.scopeVersion,
createdAt: session?.createdAt ?? timestamp,
updatedAt: timestamp,
};
return { ...session };
},
clearCurrent: () => {
session = null;
},
};
}
75 changes: 75 additions & 0 deletions apps/code/src/main/db/repositories/auth-session-repository.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import type { CloudRegion } from "@shared/types/oauth";
import { eq } from "drizzle-orm";
import { inject, injectable } from "inversify";
import { MAIN_TOKENS } from "../../di/tokens";
import { authSessions } from "../schema";
import type { DatabaseService } from "../service";

export type AuthSession = typeof authSessions.$inferSelect;
export type NewAuthSession = typeof authSessions.$inferInsert;

export interface PersistAuthSessionInput {
refreshTokenEncrypted: string;
cloudRegion: CloudRegion;
selectedProjectId: number | null;
scopeVersion: number;
}

export interface IAuthSessionRepository {
getCurrent(): AuthSession | null;
saveCurrent(input: PersistAuthSessionInput): AuthSession;
clearCurrent(): void;
}

const CURRENT_AUTH_SESSION_ID = 1;
const byId = eq(authSessions.id, CURRENT_AUTH_SESSION_ID);
const now = () => new Date().toISOString();

@injectable()
export class AuthSessionRepository implements IAuthSessionRepository {
constructor(
@inject(MAIN_TOKENS.DatabaseService)
private readonly databaseService: DatabaseService,
) {}

private get db() {
return this.databaseService.db;
}

getCurrent(): AuthSession | null {
return (
this.db.select().from(authSessions).where(byId).limit(1).get() ?? null
);
}

saveCurrent(input: PersistAuthSessionInput): AuthSession {
const timestamp = now();
const existing = this.getCurrent();

const row: NewAuthSession = {
id: CURRENT_AUTH_SESSION_ID,
refreshTokenEncrypted: input.refreshTokenEncrypted,
cloudRegion: input.cloudRegion,
selectedProjectId: input.selectedProjectId,
scopeVersion: input.scopeVersion,
createdAt: existing?.createdAt ?? timestamp,
updatedAt: timestamp,
};

if (existing) {
this.db.update(authSessions).set(row).where(byId).run();
} else {
this.db.insert(authSessions).values(row).run();
}

const saved = this.getCurrent();
if (!saved) {
throw new Error("Failed to persist current auth session");
}
return saved;
}

clearCurrent(): void {
this.db.delete(authSessions).where(byId).run();
}
}
12 changes: 11 additions & 1 deletion apps/code/src/main/db/schema.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { sql } from "drizzle-orm";
import { index, sqliteTable, text } from "drizzle-orm/sqlite-core";
import { index, integer, sqliteTable, text } from "drizzle-orm/sqlite-core";

const id = () =>
text()
Expand Down Expand Up @@ -76,3 +76,13 @@ export const suspensions = sqliteTable("suspensions", {
createdAt: createdAt(),
updatedAt: updatedAt(),
});

export const authSessions = sqliteTable("auth_sessions", {
id: integer().primaryKey(),
refreshTokenEncrypted: text().notNull(),
cloudRegion: text({ enum: ["us", "eu", "dev"] }).notNull(),
selectedProjectId: integer(),
scopeVersion: integer().notNull(),
createdAt: createdAt(),
updatedAt: updatedAt(),
});
4 changes: 4 additions & 0 deletions apps/code/src/main/di/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import "reflect-metadata";

import { Container } from "inversify";
import { ArchiveRepository } from "../db/repositories/archive-repository";
import { AuthSessionRepository } from "../db/repositories/auth-session-repository";
import { RepositoryRepository } from "../db/repositories/repository-repository";
import { SuspensionRepositoryImpl } from "../db/repositories/suspension-repository";
import { WorkspaceRepository } from "../db/repositories/workspace-repository";
Expand All @@ -10,6 +11,7 @@ import { DatabaseService } from "../db/service";
import { AgentService } from "../services/agent/service";
import { AppLifecycleService } from "../services/app-lifecycle/service";
import { ArchiveService } from "../services/archive/service";
import { AuthService } from "../services/auth/service";
import { AuthProxyService } from "../services/auth-proxy/service";
import { CloudTaskService } from "../services/cloud-task/service";
import { ConnectivityService } from "../services/connectivity/service";
Expand Down Expand Up @@ -49,12 +51,14 @@ export const container = new Container({
});

container.bind(MAIN_TOKENS.DatabaseService).to(DatabaseService);
container.bind(MAIN_TOKENS.AuthSessionRepository).to(AuthSessionRepository);
container.bind(MAIN_TOKENS.RepositoryRepository).to(RepositoryRepository);
container.bind(MAIN_TOKENS.WorkspaceRepository).to(WorkspaceRepository);
container.bind(MAIN_TOKENS.WorktreeRepository).to(WorktreeRepository);
container.bind(MAIN_TOKENS.ArchiveRepository).to(ArchiveRepository);
container.bind(MAIN_TOKENS.SuspensionRepository).to(SuspensionRepositoryImpl);
container.bind(MAIN_TOKENS.AgentService).to(AgentService);
container.bind(MAIN_TOKENS.AuthService).to(AuthService);
container.bind(MAIN_TOKENS.AuthProxyService).to(AuthProxyService);
container.bind(MAIN_TOKENS.ArchiveService).to(ArchiveService);
container.bind(MAIN_TOKENS.SuspensionService).to(SuspensionService);
Expand Down
2 changes: 2 additions & 0 deletions apps/code/src/main/di/tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export const MAIN_TOKENS = Object.freeze({

// Database
DatabaseService: Symbol.for("Main.DatabaseService"),
AuthSessionRepository: Symbol.for("Main.AuthSessionRepository"),
RepositoryRepository: Symbol.for("Main.RepositoryRepository"),
WorkspaceRepository: Symbol.for("Main.WorkspaceRepository"),
WorktreeRepository: Symbol.for("Main.WorktreeRepository"),
Expand All @@ -18,6 +19,7 @@ export const MAIN_TOKENS = Object.freeze({

// Services
AgentService: Symbol.for("Main.AgentService"),
AuthService: Symbol.for("Main.AuthService"),
AuthProxyService: Symbol.for("Main.AuthProxyService"),
ArchiveService: Symbol.for("Main.ArchiveService"),
SuspensionService: Symbol.for("Main.SuspensionService"),
Expand Down
11 changes: 8 additions & 3 deletions apps/code/src/main/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { container } from "./di/container";
import { MAIN_TOKENS } from "./di/tokens";
import { registerMcpSandboxProtocol } from "./protocols/mcp-sandbox";
import type { AppLifecycleService } from "./services/app-lifecycle/service";
import type { AuthService } from "./services/auth/service";
import type { ExternalAppsService } from "./services/external-apps/service";
import type { NotificationService } from "./services/notification/service";
import type { OAuthService } from "./services/oauth/service";
Expand All @@ -35,15 +36,18 @@ if (!gotTheLock) {
process.exit(0);
}

function initializeServices(): void {
async function initializeServices(): Promise<void> {
container.get<DatabaseService>(MAIN_TOKENS.DatabaseService);
container.get<OAuthService>(MAIN_TOKENS.OAuthService);
const authService = container.get<AuthService>(MAIN_TOKENS.AuthService);
container.get<NotificationService>(MAIN_TOKENS.NotificationService);
container.get<UpdatesService>(MAIN_TOKENS.UpdatesService);
container.get<TaskLinkService>(MAIN_TOKENS.TaskLinkService);
container.get<ExternalAppsService>(MAIN_TOKENS.ExternalAppsService);
container.get<PosthogPluginService>(MAIN_TOKENS.PosthogPluginService);

await authService.initialize();

// Initialize workspace branch watcher for live branch rename detection
const workspaceService = container.get<WorkspaceService>(
MAIN_TOKENS.WorkspaceService,
Expand All @@ -69,7 +73,7 @@ registerDeepLinkHandlers();
// Initialize PostHog analytics
initializePostHog();

app.whenReady().then(() => {
app.whenReady().then(async () => {
const commit = __BUILD_COMMIT__ ?? "dev";
const buildDate = __BUILD_DATE__ ?? "dev";
log.info(
Expand All @@ -87,8 +91,9 @@ app.whenReady().then(() => {
ensureClaudeConfigDir();
registerMcpSandboxProtocol();
createWindow();
initializeServices();
await initializeServices();
initializeDeepLinks();
await initializeServices();
powerMonitor.on("suspend", () => {
log.info("System entering sleep");
});
Expand Down
51 changes: 51 additions & 0 deletions apps/code/src/main/services/auth/schemas.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { z } from "zod";
import { cloudRegion, type oAuthTokenResponse } from "../oauth/schemas";

export const authStatusSchema = z.enum(["anonymous", "authenticated"]);
export type AuthStatus = z.infer<typeof authStatusSchema>;

export const authStateSchema = z.object({
status: authStatusSchema,
bootstrapComplete: z.boolean(),
cloudRegion: cloudRegion.nullable(),
projectId: z.number().nullable(),
availableProjectIds: z.array(z.number()),
availableOrgIds: z.array(z.string()),
hasCodeAccess: z.boolean().nullable(),
needsScopeReauth: z.boolean(),
});
export type AuthState = z.infer<typeof authStateSchema>;

export const loginInput = z.object({
region: cloudRegion,
});
export type LoginInput = z.infer<typeof loginInput>;

export const loginOutput = z.object({
state: authStateSchema,
});
export type LoginOutput = z.infer<typeof loginOutput>;

export const redeemInviteCodeInput = z.object({
code: z.string().min(1),
});

export const selectProjectInput = z.object({
projectId: z.number(),
});

export const validAccessTokenOutput = z.object({
accessToken: z.string(),
apiHost: z.string(),
});
export type ValidAccessTokenOutput = z.infer<typeof validAccessTokenOutput>;

export const AuthServiceEvent = {
StateChanged: "state-changed",
} as const;

export interface AuthServiceEvents {
[AuthServiceEvent.StateChanged]: AuthState;
}

export type AuthTokenResponse = z.infer<typeof oAuthTokenResponse>;
Loading
Loading