Skip to content

Commit

Permalink
feat: model/base-url settings for AI completion, bring out of experim…
Browse files Browse the repository at this point in the history
…ental (marimo-team#1049)

* feat: model/base-url settings for AI completion, bring out of experimental

* fix

* improve typings for MarimoConfig

- switch to TypedDict, total=True
- use Typing.NotRequired to denote keys that don't need to be included

* fix test

---------

Co-authored-by: Akshay Agrawal <akshay@marimo.io>
  • Loading branch information
2 people authored and Benni-Math committed Apr 16, 2024
1 parent 4710856 commit 45f0307
Show file tree
Hide file tree
Showing 16 changed files with 539 additions and 76 deletions.
21 changes: 13 additions & 8 deletions docs/guides/ai_completion.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ This feature is currently experimental and is not enabled by default. To enable
1. You need add the following to your `~/.marimo.toml`:

```toml
[experimental]
ai = true
```

2. Add your OpenAI API key to your environment:

```bash
export OPENAI_API_KEY=your-api-key
[ai.open_ai]
# Get your API key from https://platform.openai.com/account/api-keys
api_key = "sk-..."
# Choose a model, we recommend "gpt-3.5-turbo"
model = "gpt-3.5-turbo"
# Change the base_url if you are using a different OpenAI-compatible API
base_url = "https://api.openai.com"
```

Once enabled, you can use AI completion by pressing `Ctrl/Cmd-Shift-e` in a cell. This will open an input to modify the cell using AI.
Expand All @@ -44,3 +43,9 @@ Once enabled, you can use AI completion by pressing `Ctrl/Cmd-Shift-e` in a cell
<figcaption>Use AI to modify a cell by pressing `Ctrl/Cmd-Shift-e`.</figcaption>
</figure>
</div>

### Using other AI providers

marimo supports OpenAI's GPT-3.5 API by default. If your provider is compatible with OpenAI's API, you can use it by changing the `base_url` in the configuration.

For other providers not compatible with OpenAI's API, please submit a [feature request](https://github.com/marimo-team/marimo/issues) or "thumbs up" an existing one.
59 changes: 59 additions & 0 deletions frontend/src/components/app-config/user-config-form.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import { SettingTitle, SettingDescription, SettingSubtitle } from "./common";
import { THEMES } from "@/theme/useTheme";
import { isPyodide } from "@/core/pyodide/utils";
import { PackageManagerNames } from "../../core/config/config-schema";
import { Kbd } from "../ui/kbd";

export const UserConfigForm: React.FC = () => {
const [config, setConfig] = useUserConfig();
Expand Down Expand Up @@ -335,6 +336,64 @@ export const UserConfigForm: React.FC = () => {
)}
/>
</div>
<div className="flex flex-col gap-3">
<SettingSubtitle>AI Assist</SettingSubtitle>
<p className="text-sm text-muted-secondary">
You will need to store an API key in your{" "}
<Kbd className="inline">~/.marimo.toml</Kbd> file. See the{" "}
<a
className="text-link hover:underline"
href="https://docs.marimo.io/guides/ai_completion.html"
target="_blank"
rel="noreferrer"
>
documentation
</a>{" "}
for more information.
</p>
<FormField
control={form.control}
disabled={isWasm}
name="ai.open_ai.base_url"
render={({ field }) => (
<FormItem className="mb-2">
<FormLabel>Base URL</FormLabel>
<FormControl>
<Input
data-testid="code-editor-font-size-input"
className="m-0 inline-flex"
{...field}
value={field.value}
placeholder="https://api.openai.com"
onChange={(e) => field.onChange(e.target.value)}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
disabled={isWasm}
name="ai.open_ai.model"
render={({ field: { value, onChange, ...field } }) => (
<FormItem className="mb-2">
<FormLabel>Model</FormLabel>
<FormControl>
<Input
data-testid="code-editor-font-size-input"
className="m-0 inline-flex"
{...field}
defaultValue={value}
placeholder="gpt-3.5-turbo"
onBlur={(e) => onChange(e.target.value)}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</div>
<div className="flex flex-col gap-3">
<SettingSubtitle>GitHub Copilot</SettingSubtitle>
<FormField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import { saveCellConfig } from "@/core/network/requests";
import { EditorView } from "@codemirror/view";
import { useRunCell } from "../cell/useRunCells";
import { NameCellInput } from "./name-cell-input";
import { getFeatureFlag } from "@/core/config/feature-flag";
import { useSetAtom } from "jotai";
import { aiCompletionCellAtom } from "@/core/ai/state";
import { useImperativeModal } from "@/components/modal/ImperativeModal";
Expand All @@ -45,6 +44,7 @@ import {
} from "@/components/ui/dialog";
import { Label } from "@/components/ui/label";
import { MarkdownIcon, PythonIcon } from "../cell/code/icons";
import { useUserConfig } from "@/core/config/config";

export interface CellActionButtonProps
extends Pick<CellData, "name" | "config"> {
Expand Down Expand Up @@ -72,6 +72,7 @@ export function useCellActionButtons({ cell }: Props) {
const runCell = useRunCell(cell?.cellId);
const { openModal } = useImperativeModal();
const setAiCompletionCell = useSetAtom(aiCompletionCellAtom);
const [userConfig] = useUserConfig();
if (!cell) {
return [];
}
Expand Down Expand Up @@ -159,7 +160,7 @@ export function useCellActionButtons({ cell }: Props) {
{
icon: <SparklesIcon size={13} strokeWidth={1.5} />,
label: "AI completion",
hidden: !getFeatureFlag("ai"),
hidden: !userConfig.ai.open_ai?.api_key,
handle: () => {
setAiCompletionCell((current) =>
current === cellId ? null : cellId,
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/core/config/__tests__/config-schema.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ test("default UserConfig - empty", () => {
const defaultConfig = UserConfigSchema.parse({});
expect(defaultConfig).toMatchInlineSnapshot(`
{
"ai": {},
"completion": {
"activate_on_typing": true,
"copilot": false,
Expand Down Expand Up @@ -58,6 +59,7 @@ test("default UserConfig - one level", () => {
});
expect(defaultConfig).toMatchInlineSnapshot(`
{
"ai": {},
"completion": {
"activate_on_typing": true,
"copilot": false,
Expand Down
13 changes: 11 additions & 2 deletions frontend/src/core/config/config-schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,19 @@ export const UserConfigSchema = z
manager: z.enum(PackageManagerNames).default("pip"),
})
.default({ manager: "pip" }),
experimental: z
ai: z
.object({
ai: z.boolean().optional(),
open_ai: z
.object({
api_key: z.string().optional(),
base_url: z.string().optional(),
model: z.string().optional(),
})
.optional(),
})
.default({}),
experimental: z
.object({})
// Pass through so that we don't remove any extra keys that the user has added.
.passthrough()
.default({}),
Expand Down
7 changes: 2 additions & 5 deletions frontend/src/core/config/feature-flag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ import { getUserConfig } from "./config";

// eslint-disable-next-line @typescript-eslint/no-empty-interface
export interface ExperimentalFeatures {
// None yet
ai: boolean;
// Add new feature flags here
}

const defaultValues: ExperimentalFeatures = {
ai: process.env.NODE_ENV === "development",
};
const defaultValues: ExperimentalFeatures = {};

export function getFeatureFlag<T extends keyof ExperimentalFeatures>(
feature: T,
Expand Down
1 change: 1 addition & 0 deletions frontend/src/stories/cell.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ const props: CellProps = {
package_management: {
manager: "pip",
},
ai: {},
experimental: {},
},
};
Expand Down
Loading

0 comments on commit 45f0307

Please sign in to comment.