Skip to content

Commit 8fd8cf1

Browse files
authored
backend/assistants_web: Move Default Agent usage to backend (#842)
* Add router for default agent * Working app with default agent * Lint * Working app and tests * add back global exc handler * prettify coral * PR review
1 parent ccf41ea commit 8fd8cf1

File tree

34 files changed

+235
-223
lines changed

34 files changed

+235
-223
lines changed

src/backend/config/default_agent.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import datetime
2+
3+
from backend.config.deployments import ModelDeploymentName
4+
from backend.config.tools import Tool
5+
from backend.schemas.agent import AgentPublic
6+
7+
DEFAULT_AGENT_ID = "default"
8+
DEFAULT_DEPLOYMENT = ModelDeploymentName.CoherePlatform
9+
DEFAULT_MODEL = "command-r-plus"
10+
11+
def get_default_agent() -> AgentPublic:
12+
return AgentPublic(
13+
id=DEFAULT_AGENT_ID,
14+
name='Command R+',
15+
description='Ask questions and get answers based on your files.',
16+
created_at=datetime.datetime.now(),
17+
updated_at=datetime.datetime.now(),
18+
preamble="",
19+
version=1,
20+
temperature=0.3,
21+
tools=[
22+
Tool.Read_File.value.ID,
23+
Tool.Search_File.value.ID,
24+
Tool.Python_Interpreter.value.ID,
25+
Tool.Hybrid_Web_Search.value.ID,
26+
],
27+
tools_metadata=[],
28+
deployment=DEFAULT_DEPLOYMENT,
29+
model=DEFAULT_MODEL,
30+
user_id='',
31+
organization_id=None,
32+
is_private=False,
33+
)

src/backend/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def create_app():
9191

9292
app = create_app()
9393

94-
9594
@app.exception_handler(Exception)
9695
async def validation_exception_handler(request: Request, exc: Exception):
9796
ctx = get_context(request)
@@ -115,7 +114,6 @@ async def validation_exception_handler(request: Request, exc: Exception):
115114
},
116115
)
117116

118-
119117
@app.on_event("startup")
120118
async def startup_event():
121119
"""

src/backend/routers/agent.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fastapi import File as RequestFile
66
from fastapi import UploadFile as FastAPIUploadFile
77

8+
from backend.config.default_agent import DEFAULT_AGENT_ID, get_default_agent
89
from backend.config.routers import RouterName
910
from backend.crud import agent as agent_crud
1011
from backend.crud import agent_tool_metadata as agent_tool_metadata_crud
@@ -71,9 +72,9 @@
7172
],
7273
)
7374
async def create_agent(
74-
session: DBSessionDep,
75-
agent: CreateAgentRequest,
76-
ctx: Context = Depends(get_context),
75+
session: DBSessionDep,
76+
agent: CreateAgentRequest,
77+
ctx: Context = Depends(get_context),
7778
) -> AgentPublic:
7879
"""
7980
Create an agent.
@@ -127,13 +128,13 @@ async def create_agent(
127128

128129
@router.get("", response_model=list[AgentPublic])
129130
async def list_agents(
130-
*,
131-
offset: int = 0,
132-
limit: int = 100,
133-
session: DBSessionDep,
134-
visibility: AgentVisibility = AgentVisibility.ALL,
135-
organization_id: Optional[str] = None,
136-
ctx: Context = Depends(get_context),
131+
*,
132+
offset: int = 0,
133+
limit: int = 100,
134+
session: DBSessionDep,
135+
visibility: AgentVisibility = AgentVisibility.ALL,
136+
organization_id: Optional[str] = None,
137+
ctx: Context = Depends(get_context),
137138
) -> list[AgentPublic]:
138139
"""
139140
List all agents.
@@ -163,6 +164,8 @@ async def list_agents(
163164
visibility=visibility,
164165
organization_id=organization_id,
165166
)
167+
# Tradeoff: This appends the default Agent regardless of pagination
168+
agents.append(get_default_agent())
166169
return agents
167170
except Exception as e:
168171
logger.exception(event=e)
@@ -171,8 +174,8 @@ async def list_agents(
171174

172175
@router.get("/{agent_id}", response_model=AgentPublic)
173176
async def get_agent_by_id(
174-
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
175-
) -> Agent:
177+
agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
178+
) -> AgentPublic:
176179
"""
177180
Args:
178181
agent_id (str): Agent ID.
@@ -189,7 +192,11 @@ async def get_agent_by_id(
189192
agent = None
190193

191194
try:
192-
agent = agent_crud.get_agent_by_id(session, agent_id, user_id)
195+
# Intentionally not adding Default Agent to DB so it's more flexible
196+
if agent_id == DEFAULT_AGENT_ID:
197+
agent = get_default_agent()
198+
else:
199+
agent = agent_crud.get_agent_by_id(session, agent_id, user_id)
193200
except Exception as e:
194201
raise HTTPException(status_code=500, detail=str(e))
195202

src/backend/tests/unit/routers/test_agent.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastapi.testclient import TestClient
55
from sqlalchemy.orm import Session
66

7+
from backend.config.default_agent import DEFAULT_AGENT_ID
78
from backend.config.deployments import ModelDeploymentName
89
from backend.config.tools import Tool
910
from backend.crud import deployment as deployment_crud
@@ -17,6 +18,9 @@
1718
and os.environ.get("COHERE_API_KEY") != ""
1819
)
1920

21+
def filter_default_agent(agents: list) -> list:
22+
return [agent for agent in agents if agent.get("id") != DEFAULT_AGENT_ID]
23+
2024
def test_create_agent_missing_name(
2125
session_client: TestClient, session: Session, user
2226
) -> None:
@@ -159,35 +163,34 @@ def test_create_existing_agent(
159163
assert response.json() == {"detail": "Agent test agent already exists."}
160164

161165

162-
def test_list_agents_empty(session_client: TestClient, session: Session) -> None:
163-
# Delete default agent
164-
session.query(Agent).delete()
166+
def test_list_agents_empty_returns_default_agent(session_client: TestClient, session: Session) -> None:
165167
response = session_client.get("/v1/agents", headers={"User-Id": "123"})
166168
assert response.status_code == 200
167169
response_agents = response.json()
168-
assert len(response_agents) == 0
170+
# Returns default agent
171+
assert len(response_agents) == 1
169172

170173

171174
def test_list_agents(session_client: TestClient, session: Session, user) -> None:
172-
session.query(Agent).delete()
173-
for _ in range(3):
175+
num_agents = 3
176+
for _ in range(num_agents):
174177
_ = get_factory("Agent", session).create(user=user)
175178

176179
response = session_client.get("/v1/agents", headers={"User-Id": user.id})
177180
assert response.status_code == 200
178-
response_agents = response.json()
179-
assert len(response_agents) == 3
181+
response_agents = filter_default_agent(response.json())
182+
assert len(response_agents) == num_agents
180183

181184

182185
def test_list_organization_agents(
183186
session_client: TestClient,
184187
session: Session,
185188
user,
186189
) -> None:
187-
session.query(Agent).delete()
190+
num_agents = 3
188191
organization = get_factory("Organization", session).create()
189192
organization1 = get_factory("Organization", session).create()
190-
for i in range(3):
193+
for i in range(num_agents):
191194
_ = get_factory("Agent", session).create(
192195
user=user,
193196
organization_id=organization.id,
@@ -201,9 +204,9 @@ def test_list_organization_agents(
201204
"/v1/agents", headers={"User-Id": user.id, "Organization-Id": organization.id}
202205
)
203206
assert response.status_code == 200
204-
response_agents = response.json()
207+
response_agents = filter_default_agent(response.json())
205208
agents = sorted(response_agents, key=lambda x: x["name"])
206-
for i in range(3):
209+
for i in range(num_agents):
207210
assert agents[i]["name"] == f"agent-{i}-{organization.id}"
208211

209212

@@ -212,10 +215,10 @@ def test_list_organization_agents_query_param(
212215
session: Session,
213216
user,
214217
) -> None:
215-
session.query(Agent).delete()
218+
num_agents = 3
216219
organization = get_factory("Organization", session).create()
217220
organization1 = get_factory("Organization", session).create()
218-
for i in range(3):
221+
for i in range(num_agents):
219222
_ = get_factory("Agent", session).create(
220223
user=user, organization_id=organization.id
221224
)
@@ -230,9 +233,9 @@ def test_list_organization_agents_query_param(
230233
headers={"User-Id": user.id, "Organization-Id": organization.id},
231234
)
232235
assert response.status_code == 200
233-
response_agents = response.json()
236+
response_agents = filter_default_agent(response.json())
234237
agents = sorted(response_agents, key=lambda x: x["name"])
235-
for i in range(3):
238+
for i in range(num_agents):
236239
assert agents[i]["name"] == f"agent-{i}-{organization1.id}"
237240

238241

@@ -263,7 +266,7 @@ def test_list_private_agents(
263266
)
264267

265268
assert response.status_code == 200
266-
response_agents = response.json()
269+
response_agents = filter_default_agent(response.json())
267270

268271
# Only the agents created by user should be returned
269272
assert len(response_agents) == 3
@@ -282,7 +285,7 @@ def test_list_public_agents(session_client: TestClient, session: Session, user)
282285
)
283286

284287
assert response.status_code == 200
285-
response_agents = response.json()
288+
response_agents = filter_default_agent(response.json())
286289

287290
# Only the agents created by user should be returned
288291
assert len(response_agents) == 2
@@ -319,14 +322,14 @@ def test_list_agents_with_pagination(
319322
"/v1/agents?limit=3&offset=2", headers={"User-Id": user.id}
320323
)
321324
assert response.status_code == 200
322-
response_agents = response.json()
325+
response_agents = filter_default_agent(response.json())
323326
assert len(response_agents) == 3
324327

325328
response = session_client.get(
326329
"/v1/agents?limit=2&offset=4", headers={"User-Id": user.id}
327330
)
328331
assert response.status_code == 200
329-
response_agents = response.json()
332+
response_agents = filter_default_agent(response.json())
330333
assert len(response_agents) == 1
331334

332335

src/interfaces/assistants_web/src/app/(main)/(chat)/c/[conversationId]/page.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { HydrationBoundary, QueryClient, dehydrate } from '@tanstack/react-query
22
import { NextPage } from 'next';
33

44
import Chat from '@/app/(main)/(chat)/Chat';
5-
import { BASE_AGENT } from '@/constants';
5+
import { DEFAULT_AGENT_ID } from '@/constants';
66
import { getCohereServerClient } from '@/server/cohereServerClient';
77

88
type Props = {
@@ -23,8 +23,8 @@ const Page: NextPage<Props> = async ({ params }) => {
2323
cohereServerClient.getConversation({ conversationId: params.conversationId }),
2424
}),
2525
queryClient.prefetchQuery({
26-
queryKey: ['agent', null],
27-
queryFn: () => BASE_AGENT,
26+
queryKey: ['agent', DEFAULT_AGENT_ID],
27+
queryFn: () => cohereServerClient.getDefaultAgent(),
2828
}),
2929
]);
3030

src/interfaces/assistants_web/src/app/(main)/(chat)/page.tsx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@ import { HydrationBoundary, QueryClient, dehydrate } from '@tanstack/react-query
22
import { NextPage } from 'next';
33

44
import Chat from '@/app/(main)/(chat)/Chat';
5-
import { BASE_AGENT } from '@/constants';
5+
import { DEFAULT_AGENT_ID } from '@/constants';
6+
import { getCohereServerClient } from '@/server/cohereServerClient';
67

78
const Page: NextPage = async () => {
89
const queryClient = new QueryClient();
10+
const cohereServerClient = getCohereServerClient();
911

1012
await queryClient.prefetchQuery({
11-
queryKey: ['agent', null],
12-
queryFn: () => BASE_AGENT,
13+
queryKey: ['agent', DEFAULT_AGENT_ID],
14+
queryFn: () => cohereServerClient.getDefaultAgent(),
1315
});
1416

1517
return (

src/interfaces/assistants_web/src/app/(main)/discover/DiscoverAgentCard.tsx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { DeleteAgent } from '@/components/Modals/DeleteAgent';
77
import { CoralLogo, KebabMenu, Text } from '@/components/UI';
88
import { useContextStore } from '@/context';
99
import { useBrandedColors, useSession } from '@/hooks';
10-
import { checkIsBaseAgent, cn } from '@/utils';
10+
import { checkIsDefaultAgent, cn } from '@/utils';
1111

1212
type Props = {
1313
agent?: AgentPublic;
@@ -17,11 +17,11 @@ type Props = {
1717
* @description renders a card for an agent with the agent's name, description
1818
*/
1919
export const DiscoverAgentCard: React.FC<Props> = ({ agent }) => {
20-
const isBaseAgent = checkIsBaseAgent(agent);
20+
const isDefaultAgent = checkIsDefaultAgent(agent);
2121
const { bg, contrastText, contrastFill } = useBrandedColors(agent?.id);
2222
const session = useSession();
2323
const isCreator = agent?.user_id === session.userId;
24-
const createdBy = isBaseAgent ? 'COHERE' : isCreator ? 'YOU' : 'TEAM';
24+
const createdBy = isDefaultAgent ? 'COHERE' : isCreator ? 'YOU' : 'TEAM';
2525

2626
const { open, close } = useContextStore();
2727

@@ -36,7 +36,7 @@ export const DiscoverAgentCard: React.FC<Props> = ({ agent }) => {
3636
return (
3737
<Link
3838
className="flex overflow-x-hidden rounded-lg border border-volcanic-800 bg-volcanic-950 p-4 transition-colors duration-300 hover:bg-marble-950 dark:border-volcanic-300 dark:bg-volcanic-150 dark:hover:bg-volcanic-100"
39-
href={isBaseAgent ? '/' : `/a/${agent?.id}`}
39+
href={isDefaultAgent ? '/' : `/a/${agent?.id}`}
4040
>
4141
<div className="flex h-full flex-grow flex-col items-start gap-y-2 overflow-x-hidden">
4242
<div className="flex w-full items-center gap-x-2">
@@ -46,7 +46,7 @@ export const DiscoverAgentCard: React.FC<Props> = ({ agent }) => {
4646
bg
4747
)}
4848
>
49-
{isBaseAgent ? (
49+
{isDefaultAgent ? (
5050
<CoralLogo className={contrastFill} />
5151
) : (
5252
<Text className={cn('uppercase', contrastText)} styleAs="p-lg">

src/interfaces/assistants_web/src/app/(main)/layout.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { redirect } from 'next/navigation';
66
import { Swipeable } from '@/components/Global';
77
import { HotKeys } from '@/components/HotKeys';
88
import { SideNavPanel } from '@/components/SideNavPanel';
9-
import { COOKIE_KEYS, DEFAULT_AGENT_TOOLS } from '@/constants';
9+
import { BACKGROUND_TOOLS, COOKIE_KEYS } from '@/constants';
1010
import { getCohereServerClient } from '@/server/cohereServerClient';
1111

1212
const MainLayout: NextPage<React.PropsWithChildren> = async ({ children }) => {
@@ -35,7 +35,7 @@ const MainLayout: NextPage<React.PropsWithChildren> = async ({ children }) => {
3535
queryKey: ['tools'],
3636
queryFn: async () => {
3737
const tools = await cohereServerClient.listTools({});
38-
return tools.filter((tool) => !DEFAULT_AGENT_TOOLS.includes(tool.name ?? ''));
38+
return tools.filter((tool) => !BACKGROUND_TOOLS.includes(tool.name ?? ''));
3939
},
4040
}),
4141
queryClient.prefetchQuery({

src/interfaces/assistants_web/src/app/(main)/new/CreateAgent.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import { AgentSettingsFields, AgentSettingsForm } from '@/components/AgentSettin
99
import { MobileHeader } from '@/components/Global';
1010
import { Button, Icon, Text } from '@/components/UI';
1111
import {
12+
BACKGROUND_TOOLS,
1213
DEFAULT_AGENT_MODEL,
13-
DEFAULT_AGENT_TOOLS,
1414
DEFAULT_PREAMBLE,
1515
DEPLOYMENT_COHERE_PLATFORM,
1616
} from '@/constants';
@@ -23,7 +23,7 @@ const DEFAULT_FIELD_VALUES = {
2323
preamble: DEFAULT_PREAMBLE,
2424
deployment: DEPLOYMENT_COHERE_PLATFORM,
2525
model: DEFAULT_AGENT_MODEL,
26-
tools: DEFAULT_AGENT_TOOLS,
26+
tools: BACKGROUND_TOOLS,
2727
is_private: false,
2828
};
2929
/**

src/interfaces/assistants_web/src/cohere-client/client.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import {
1717
UpdateConversationRequest,
1818
UpdateDeploymentEnv,
1919
} from '@/cohere-client';
20+
import { DEFAULT_AGENT_ID } from '@/constants';
2021

2122
import { mapToChatRequest } from './mappings';
2223

@@ -285,6 +286,10 @@ export class CohereClient {
285286
// this.cohereService.default.oidcAuthorizeV1OidcAuthGet();
286287
}
287288

289+
public getDefaultAgent() {
290+
return this.cohereService.default.getAgentByIdV1AgentsAgentIdGet({ agentId: DEFAULT_AGENT_ID });
291+
}
292+
288293
public getAgent(agentId: string) {
289294
return this.cohereService.default.getAgentByIdV1AgentsAgentIdGet({ agentId });
290295
}

0 commit comments

Comments
 (0)