diff --git a/README.md b/README.md index 8cdcf43f..edd3a546 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,15 @@ Default: `10000` (10 seconds) aicommits config set timeout=20000 # 20s ``` +#### max-length +The maximum character length of the generated commit message. + +Default: `50` + +```sh +aicommits config set max-length=100 +``` + ## How it works This CLI tool runs `git diff` to grab all your latest code changes, sends them to OpenAI's GPT-3, then returns the AI generated commit message. diff --git a/src/commands/aicommits.ts b/src/commands/aicommits.ts index 0445be75..ab937caa 100644 --- a/src/commands/aicommits.ts +++ b/src/commands/aicommits.ts @@ -37,9 +37,8 @@ export default async ( throw new KnownError('No staged changes found. Stage your changes manually, or automatically stage all changes with the `--all` flag.'); } - detectingFiles.stop(`${getDetectedMessage(staged.files)}:\n${ - staged.files.map(file => ` ${file}`).join('\n') - }`); + detectingFiles.stop(`${getDetectedMessage(staged.files)}:\n${staged.files.map(file => ` ${file}`).join('\n') + }`); const { env } = process; const config = await getConfig({ @@ -58,6 +57,7 @@ export default async ( config.locale, staged.diff, config.generate, + config['max-length'], config.timeout, config.proxy, ); diff --git a/src/commands/prepare-commit-msg-hook.ts b/src/commands/prepare-commit-msg-hook.ts index 52b7cb2c..dcf9b5b7 100644 --- a/src/commands/prepare-commit-msg-hook.ts +++ b/src/commands/prepare-commit-msg-hook.ts @@ -45,6 +45,7 @@ export default () => (async () => { config.locale, staged!.diff, config.generate, + config['max-length'], config.timeout, config.proxy, ); diff --git a/src/utils/config.ts b/src/utils/config.ts index 192e5849..f09c8229 100644 --- a/src/utils/config.ts +++ b/src/utils/config.ts @@ -36,7 +36,6 @@ const configParsers = { parseAssert('locale', locale, 'Cannot be empty'); parseAssert('locale', /^[a-z-]+$/i.test(locale), 'Must be a valid locale (letters and dashes/underscores). You can consult the list of codes in: https://wikipedia.org/wiki/List_of_ISO_639-1_codes'); - return locale; }, generate(count?: string) { @@ -78,6 +77,18 @@ const configParsers = { const parsed = Number(timeout); parseAssert('timeout', parsed >= 500, 'Must be greater than 500ms'); + return parsed; + }, + 'max-length'(maxLength?: string) { + if (!maxLength) { + return 50; + } + + parseAssert('max-length', /^\d+$/.test(maxLength), 'Must be an integer'); + + const parsed = Number(maxLength); + parseAssert('max-length', parsed >= 20, 'Must be greater than 20 characters'); + return parsed; }, } as const; diff --git a/src/utils/openai.ts b/src/utils/openai.ts index 34a252ad..8c5c4491 100644 --- a/src/utils/openai.ts +++ b/src/utils/openai.ts @@ -1,7 +1,11 @@ import https from 'https'; import type { ClientRequest, IncomingMessage } from 'http'; import type { CreateChatCompletionRequest, CreateChatCompletionResponse } from 'openai'; -import { type TiktokenModel } from '@dqbd/tiktoken'; +import { + TiktokenModel, + // eslint-disable-next-line camelcase + encoding_for_model, +} from '@dqbd/tiktoken'; import createHttpsProxyAgent from 'https-proxy-agent'; import { KnownError } from './error.js'; @@ -100,7 +104,24 @@ const sanitizeMessage = (message: string) => message.trim().replace(/[\n\r]/g, ' const deduplicateMessages = (array: string[]) => Array.from(new Set(array)); -const getPrompt = (locale: string, diff: string) => `Write a git commit message in present tense for the following diff without prefacing it with anything. Do not be needlessly verbose and make sure the answer is concise and to the point. The response must be in the language ${locale}:\n${diff}`; +const getPrompt = (locale: string, diff: string, length: number) => `Write a git commit message in present tense for the following diff without prefacing it with anything. Do not be needlessly verbose and make sure the answer is concise and to the point. The response must be no longer than ${length} characters. The response must be in the language ${locale}:\n${diff}`; + +const generateStringFromLength = (length: number) => { + let result = ''; + const highestTokenChar = 'z'; + for (let i = 0; i < length; i += 1) { + result += highestTokenChar; + } + return result; +}; + +const getTokens = (prompt: string, model: TiktokenModel) => { + const encoder = encoding_for_model(model); + const tokens = encoder.encode(prompt).length; + // Free the encoder to avoid possible memory leaks. + encoder.free(); + return tokens; +}; export const generateCommitMessage = async ( apiKey: string, @@ -108,10 +129,17 @@ export const generateCommitMessage = async ( locale: string, diff: string, completions: number, + length: number, timeout: number, proxy?: string, ) => { - const prompt = getPrompt(locale, diff); + const prompt = getPrompt(locale, diff, length); + + // Padded by 5 for more room for the completion. + const stringFromLength = generateStringFromLength(length + 5); + + // The token limit is shared between the prompt and the completion. + const maxTokens = getTokens(stringFromLength + prompt, model); try { const completion = await createChatCompletion( @@ -126,7 +154,7 @@ export const generateCommitMessage = async ( top_p: 1, frequency_penalty: 0, presence_penalty: 0, - max_tokens: 200, + max_tokens: maxTokens, stream: false, n: completions, }, diff --git a/tests/specs/cli/commits.ts b/tests/specs/cli/commits.ts index 8b2c4f35..78a37101 100644 --- a/tests/specs/cli/commits.ts +++ b/tests/specs/cli/commits.ts @@ -52,6 +52,35 @@ export default testSuite(({ describe }) => { const { stdout: commitMessage } = await git('log', ['--oneline']); console.log('Committed with:', commitMessage); + expect(commitMessage.length <= 50).toBe(true); + + await fixture.rm(); + }); + + test('Generated commit message must be under 20 characters', async () => { + const { fixture, aicommits } = await createFixture({ + ...files, + '.aicommits': `${files['.aicommits']}\nmax-length=20`, + }); + + const git = await createGit(fixture.path); + + await git('add', ['data.json']); + + const committing = aicommits(); + committing.stdout!.on('data', (buffer: Buffer) => { + const stdout = buffer.toString(); + if (stdout.match('└')) { + committing.stdin!.write('y'); + committing.stdin!.end(); + } + }); + + await committing; + + const { stdout: commitMessage } = await git('log', ['--pretty=format:%s']); + console.log('20 Committed with:', commitMessage, commitMessage.length); + expect(commitMessage.length <= 20).toBe(true); await fixture.rm(); }); @@ -84,6 +113,7 @@ export default testSuite(({ describe }) => { const { stdout: commitMessage } = await git('log', ['-n1', '--oneline']); console.log('Committed with:', commitMessage); + expect(commitMessage.length <= 50).toBe(true); await fixture.rm(); }); @@ -123,6 +153,7 @@ export default testSuite(({ describe }) => { const { stdout: commitMessage } = await git('log', ['--oneline']); console.log('Committed with:', commitMessage); + expect(commitMessage.length <= 50).toBe(true); await fixture.rm(); }); @@ -157,6 +188,7 @@ export default testSuite(({ describe }) => { const { stdout: commitMessage } = await git('log', ['--oneline']); console.log('Committed with:', commitMessage); expect(commitMessage).toMatch(japanesePattern); + expect(commitMessage.length <= 50).toBe(true); await fixture.rm(); }); @@ -217,6 +249,7 @@ export default testSuite(({ describe }) => { const { stdout: commitMessage } = await git('log', ['--oneline']); console.log('Committed with:', commitMessage); + expect(commitMessage.length <= 50).toBe(true); await fixture.rm(); }); @@ -248,6 +281,7 @@ export default testSuite(({ describe }) => { const { stdout: commitMessage } = await git('log', ['--oneline']); console.log('Committed with:', commitMessage); + expect(commitMessage.length <= 50).toBe(true); await fixture.rm(); }); diff --git a/tests/specs/config.ts b/tests/specs/config.ts index 5f0abf65..43f566c3 100644 --- a/tests/specs/config.ts +++ b/tests/specs/config.ts @@ -25,6 +25,18 @@ export default testSuite(({ describe }) => { expect(stderr).toMatch('Invalid config property OPENAI_KEY: Must start with "sk-"'); }); + await test('set config file', async () => { + await aicommits(['config', 'set', openAiToken]); + + const configFile = await fs.readFile(configPath, 'utf8'); + expect(configFile).toMatch(openAiToken); + }); + + await test('get config file', async () => { + const { stdout } = await aicommits(['config', 'get', 'OPENAI_KEY']); + expect(stdout).toBe(openAiToken); + }); + await test('reading unknown config', async () => { await fs.appendFile(configPath, 'UNKNOWN=1'); @@ -57,6 +69,38 @@ export default testSuite(({ describe }) => { }); }); + await describe('max-length', ({ test }) => { + test('must be an integer', async () => { + const { stderr } = await aicommits(['config', 'set', 'max-length=abc'], { + reject: false, + }); + + expect(stderr).toMatch('Must be an integer'); + }); + + test('must be at least 20 characters', async () => { + const { stderr } = await aicommits(['config', 'set', 'max-length=10'], { + reject: false, + }); + + expect(stderr).toMatch(/must be greater than 20 characters/i); + }); + + test('updates config', async () => { + const defaultConfig = await aicommits(['config', 'get', 'max-length']); + expect(defaultConfig.stdout).toBe('max-length=50'); + + const maxLength = 'max-length=60'; + await aicommits(['config', 'set', maxLength]); + + const configFile = await fs.readFile(configPath, 'utf8'); + expect(configFile).toMatch(maxLength); + + const get = await aicommits(['config', 'get', 'max-length']); + expect(get.stdout).toBe(maxLength); + }); + }); + await test('set config file', async () => { await aicommits(['config', 'set', openAiToken]); @@ -66,7 +110,6 @@ export default testSuite(({ describe }) => { await test('get config file', async () => { const { stdout } = await aicommits(['config', 'get', 'OPENAI_KEY']); - expect(stdout).toBe(openAiToken); });