Skip to content

Commit

Permalink
feat: refactor to support module additions
Browse files Browse the repository at this point in the history
  • Loading branch information
ntindle committed May 17, 2024
1 parent 7ef88a2 commit ed5b719
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 51 deletions.
103 changes: 99 additions & 4 deletions codex/api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from prisma.models import Specification
from pydantic import BaseModel, Field

from codex.common.parse_prisma import parse_prisma_schema

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -151,24 +153,65 @@ class SpecificationUpdate(prisma.models.Specification, BaseModel):
apiRouteSpecs: List[APIRouteSpecModel] = []


class DatabaseEnums(BaseModel):
name: str
description: str
values: list[str]
definition: str

def __str__(self):
return f"**Enum: {self.name}**\n\n**Values**:\n{', '.join(self.values)}\n"


class DatabaseTable(BaseModel):
name: str | None = None
description: str
definition: str # prisma model for a table

def __str__(self):
return f"**Table: {self.name}**\n\n\n\n**Definition**:\n```\n{self.definition}\n```\n"


class DatabaseSchema(BaseModel):
name: str # name of the database schema
description: str # context on what the database schema is
tables: List[DatabaseTable] # list of tables in the database schema
enums: List[DatabaseEnums]

def __str__(self):
tables_str = "\n".join(str(table) for table in self.tables)
enum_str = "\n".join(str(enum) for enum in self.enums)
return f"## {self.name}\n**Description**: {self.description}\n**Tables**:\n{tables_str}\n**Enums**:\n{enum_str}\n"


class ModuleWrapper(BaseModel):
id: str
name: str
description: str
interactions: str
apiRouteSpecs: List[APIRouteSpecModel] = []


class SpecificationResponse(BaseModel):
id: str
createdAt: datetime
name: str
context: str
apiRouteSpecs: List[APIRouteSpecModel] = []
modules: List[ModuleWrapper] = []
databaseSchema: Optional[DatabaseSchema] = None

@staticmethod
def from_specification(specification: Specification) -> "SpecificationResponse":
logger.debug(specification.model_dump_json())
routes = []
module_out = []
modules: list[prisma.models.Module] | None = (
specification.Modules if specification.Modules else None
)
if modules is None:
raise ValueError("No routes found for the specification")
for module in modules:
if module.ApiRouteSpecs:
routes = []
for route in module.ApiRouteSpecs:
routes.append(
APIRouteSpecModel(
Expand Down Expand Up @@ -215,13 +258,65 @@ def from_specification(specification: Specification) -> "SpecificationResponse":
else None,
)
)

module_out.append(
ModuleWrapper(
id=module.id,
apiRouteSpecs=routes,
name=module.name,
description=module.description,
interactions=module.interactions,
)
)
else:
module_out.append(
ModuleWrapper(
id=module.id,
name=module.name,
description=module.description,
interactions=module.interactions,
)
)
db_schema = None
if specification.DatabaseSchema:

def convert_to_table(table: prisma.models.DatabaseTable) -> DatabaseTable:
return DatabaseTable(
name=table.name or "ERROR: Unknown Table Name",
description=table.description,
definition=table.definition,
)

def convert_to_enum(table: prisma.models.DatabaseTable) -> DatabaseEnums:
return DatabaseEnums(
name=table.name or "ERROR: Unknown ENUM Name",
description=table.description,
values=parse_prisma_schema(table.definition)
.enums[table.name or "ERROR: Unknown ENUM Name"]
.values,
definition=table.definition,
)

db_schema = DatabaseSchema(
name=specification.DatabaseSchema.name or "Database Schema",
tables=[
convert_to_table(table)
for table in specification.DatabaseSchema.DatabaseTables or []
if not table.isEnum
],
enums=[
convert_to_enum(table)
for table in specification.DatabaseSchema.DatabaseTables or []
if table.isEnum
],
description=specification.DatabaseSchema.description,
)
ret_obj = SpecificationResponse(
id=specification.id,
createdAt=specification.createdAt,
name="",
context="",
apiRouteSpecs=routes,
modules=module_out,
databaseSchema=db_schema,
)

return ret_obj
Expand Down
9 changes: 2 additions & 7 deletions codex/requirements/blocks/ai_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from codex.api_model import DatabaseEnums, DatabaseSchema, DatabaseTable
from codex.common.ai_block import (
AIBlock,
Identifiers,
Expand All @@ -10,13 +11,7 @@
from codex.common.exec_external_tool import OutputType, exec_external_on_contents
from codex.common.logging_config import setup_logging
from codex.common.parse_prisma import parse_prisma_schema
from codex.requirements.model import (
DatabaseEnums,
DatabaseSchema,
DatabaseTable,
DBResponse,
PreAnswer,
)
from codex.requirements.model import DBResponse, PreAnswer

logger = logging.getLogger(__name__)

Expand Down
33 changes: 1 addition & 32 deletions codex/requirements/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from enum import Enum
from typing import List

from pydantic import BaseModel

from codex.api_model import DatabaseSchema
from codex.common.model import ObjectTypeModel as ObjectTypeE

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,37 +99,6 @@ def get_task_description(task):
raise NotImplementedError(f"Example Task {task.value} not implemented")


class DatabaseEnums(BaseModel):
name: str
description: str
values: list[str]
definition: str

def __str__(self):
return f"**Enum: {self.name}**\n\n**Values**:\n{', '.join(self.values)}\n"


class DatabaseTable(BaseModel):
name: str | None = None
description: str
definition: str # prisma model for a table

def __str__(self):
return f"**Table: {self.name}**\n\n\n\n**Definition**:\n```\n{self.definition}\n```\n"


class DatabaseSchema(BaseModel):
name: str # name of the database schema
description: str # context on what the database schema is
tables: List[DatabaseTable] # list of tables in the database schema
enums: List[DatabaseEnums]

def __str__(self):
tables_str = "\n".join(str(table) for table in self.tables)
enum_str = "\n".join(str(enum) for enum in self.enums)
return f"## {self.name}\n**Description**: {self.description}\n**Tables**:\n{tables_str}\n**Enums**:\n{enum_str}\n"


class APIEndpointWrapper(BaseModel):
request_model: ObjectTypeE
response_model: ObjectTypeE
Expand Down
15 changes: 7 additions & 8 deletions codex/tests/gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from prisma.enums import AccessLevel, HTTPVerb

from codex.api import create_app
from codex.api_model import ApplicationCreate
from codex.api_model import (
ApplicationCreate,
DatabaseEnums,
DatabaseSchema,
DatabaseTable,
)
from codex.app import db_client
from codex.common import ai_block
from codex.common.ai_block import LLMFailure
Expand All @@ -19,13 +24,7 @@
from codex.develop.database import get_compiled_code
from codex.requirements.agent import APIRouteSpec, Module, SpecHolder
from codex.requirements.database import create_specification
from codex.requirements.model import (
DatabaseEnums,
DatabaseSchema,
DatabaseTable,
DBResponse,
PreAnswer,
)
from codex.requirements.model import DBResponse, PreAnswer

is_connected = False
setup_logging()
Expand Down

0 comments on commit ed5b719

Please sign in to comment.