diff --git a/lua/CopilotChat/config.lua b/lua/CopilotChat/config.lua index e43b1838..682923b3 100644 --- a/lua/CopilotChat/config.lua +++ b/lua/CopilotChat/config.lua @@ -14,7 +14,7 @@ ---@field blend number? ---@class CopilotChat.config.Shared ----@field system_prompt nil|string|fun(source: CopilotChat.source):string +---@field system_prompt nil|string ---@field model string? ---@field tools string|table|nil ---@field resources string|table|nil diff --git a/lua/CopilotChat/config/mappings.lua b/lua/CopilotChat/config/mappings.lua index f484df24..a7cd5158 100644 --- a/lua/CopilotChat/config/mappings.lua +++ b/lua/CopilotChat/config/mappings.lua @@ -307,10 +307,10 @@ return { end local lines = {} - local config, prompt = copilot.resolve_prompt(message.content) - local system_prompt = config.system_prompt async.run(function() + local config, prompt = copilot.resolve_prompt(message.content) + local system_prompt = config.system_prompt local resolved_resources = copilot.resolve_functions(prompt, config) local selected_tools = copilot.resolve_tools(prompt, config) local selected_model = copilot.resolve_model(prompt, config) diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index 34f564ae..c738d4a8 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -1,18 +1,13 @@ local async = require('plenary.async') local log = require('plenary.log') -local functions = require('CopilotChat.functions') local client = require('CopilotChat.client') local constants = require('CopilotChat.constants') -local notify = require('CopilotChat.notify') +local prompts = require('CopilotChat.prompt') local select = require('CopilotChat.select') local utils = require('CopilotChat.utils') local curl = require('CopilotChat.utils.curl') local orderedmap = require('CopilotChat.utils.orderedmap') -local WORD = '([^%s:]+)' -local WORD_NO_INPUT = '([^%s]+)' -local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`' -local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)' local BLOCK_OUTPUT_FORMAT = '```%s\n%s\n```' ---@class CopilotChat @@ -234,19 +229,25 @@ end ---@return any local function handle_error(config, cb) return function() - local ok, out = pcall(cb) + local function error_handler(err) + return { + err = utils.make_string(err), + traceback = debug.traceback(), + } + end + + local ok, out = xpcall(cb, error_handler) if ok then return out end + log.error(out.err .. '\n' .. out.traceback) - log.error(out) if config.headless then return end utils.schedule_main() - out = out or 'Unknown error' - out = utils.make_string(out) + out = out.err M.chat:add_message({ role = constants.ROLE.ASSISTANT, @@ -307,291 +308,25 @@ end ---@param config CopilotChat.config.Shared? ---@return table, string function M.resolve_tools(prompt, config) - config, prompt = M.resolve_prompt(prompt, config) - - local tools = {} - for _, tool in ipairs(functions.parse_tools(M.config.functions)) do - tools[tool.name] = tool - end - - local enabled_tools = {} - local tool_matches = utils.to_table(config.tools) - - -- Check for @tool pattern to find enabled tools - prompt = prompt:gsub('@' .. WORD, function(match) - for name, tool in pairs(M.config.functions) do - if name == match or tool.group == match then - table.insert(tool_matches, match) - return '' - end - end - return '@' .. match - end) - for _, match in ipairs(tool_matches) do - for name, tool in pairs(M.config.functions) do - if name == match or tool.group == match then - table.insert(enabled_tools, tools[name]) - end - end - end - - return enabled_tools, prompt + return prompts.resolve_tools(prompt, config) end --- Call and resolve function calls from the prompt. ---@param prompt string? ---@param config CopilotChat.config.Shared? ----@return table, table, string +---@return table, table, table, string ---@async function M.resolve_functions(prompt, config) - config, prompt = M.resolve_prompt(prompt, config) - - local tools = {} - for _, tool in ipairs(functions.parse_tools(M.config.functions)) do - tools[tool.name] = tool - end - - if config.resources then - local resources = utils.to_table(config.resources) - local lines = utils.split_lines(prompt) - for i = #resources, 1, -1 do - local resource = resources[i] - table.insert(lines, 1, '#' .. resource) - end - prompt = table.concat(lines, '\n') - end - - local resolved_resources = {} - local resolved_tools = {} - local tool_calls = {} - for _, message in ipairs(M.chat:get_messages()) do - if message.tool_calls then - for _, tool_call in ipairs(message.tool_calls) do - table.insert(tool_calls, tool_call) - end - end - end - - local resource_matches = {} - - -- Check for #word:`input` pattern - for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_QUOTED) do - local pattern = string.format('#%s:`%s`', word, input) - table.insert(resource_matches, { - pattern = pattern, - word = word, - input = input, - }) - end - - -- Check for #word:input pattern - for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_UNQUOTED) do - local pattern = utils.empty(input) and string.format('#%s', word) or string.format('#%s:%s', word, input) - table.insert(resource_matches, { - pattern = pattern, - word = word, - input = input, - }) - end - - -- Check for ##word:input pattern - for word in prompt:gmatch('##' .. WORD_NO_INPUT) do - local pattern = string.format('##%s', word) - table.insert(resource_matches, { - pattern = pattern, - word = word, - }) - end - - -- Resolve each function reference - local function expand_function(name, input) - notify.publish(notify.STATUS, 'Running function: ' .. name) - - local tool_id = nil - if not utils.empty(tool_calls) then - for _, tool_call in ipairs(tool_calls) do - if tool_call.name == name and vim.trim(tool_call.id) == vim.trim(input) then - input = utils.empty(tool_call.arguments) and {} or utils.json_decode(tool_call.arguments) - tool_id = tool_call.id - break - end - end - end - - local tool = M.config.functions[name] - if not tool then - -- Check if input matches uri - for tool_name, tool_spec in pairs(M.config.functions) do - if tool_spec.uri then - local match = functions.match_uri(name, tool_spec.uri) - if match then - name = tool_name - tool = tool_spec - input = match - break - end - end - end - end - if not tool then - return nil - end - if not tool_id and not tool.uri then - return nil - end - - local schema = tools[name] and tools[name].schema or nil - local ok, output - if config.stop_on_function_failure then - output = tool.resolve(functions.parse_input(input, schema), state.source) - ok = true - else - ok, output = pcall(tool.resolve, functions.parse_input(input, schema), state.source) - end - - local result = '' - if not ok then - result = utils.make_string(output) - else - for _, content in ipairs(output) do - if content then - local content_out = nil - if content.uri then - if - not vim.tbl_contains(resolved_resources, function(resource) - return resource.uri == content.uri - end, { predicate = true }) - then - content_out = '##' .. content.uri - table.insert(resolved_resources, content) - end - - if tool_id then - table.insert(state.sticky, '##' .. content.uri) - end - else - content_out = content.data - end - - if content_out then - if not utils.empty(result) then - result = result .. '\n' - end - result = result .. content_out - end - end - end - end - - if tool_id then - table.insert(resolved_tools, { - id = tool_id, - result = result, - }) - - return '' - end - - return result - end - - -- Resolve and process all tools - for _, match in ipairs(resource_matches) do - if not utils.empty(match.pattern) then - local out = expand_function(match.word, match.input) - if out == nil then - out = match.pattern - end - out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub - prompt = prompt:gsub(vim.pesc(match.pattern), out, 1) - end - end - - return resolved_resources, resolved_tools, prompt + return prompts.resolve_functions(prompt, config) end --- Resolve the final prompt and config from prompt template. ---@param prompt string? ---@param config CopilotChat.config.Shared? ---@return CopilotChat.config.prompts.Prompt, string +---@async function M.resolve_prompt(prompt, config) - if prompt == nil then - local message = M.chat:get_message(constants.ROLE.USER) - if message then - prompt = message.content - end - end - - local prompts_to_use = list_prompts() - local depth = 0 - local MAX_DEPTH = 10 - - local function resolve(inner_config, inner_prompt) - if depth >= MAX_DEPTH then - return inner_config, inner_prompt - end - depth = depth + 1 - - inner_prompt = string.gsub(inner_prompt, '/' .. WORD, function(match) - local p = prompts_to_use[match] - if p then - local resolved_config, resolved_prompt = resolve(p, p.prompt or '') - inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config) - return resolved_prompt - end - - return '/' .. match - end) - - depth = depth - 1 - return inner_config, inner_prompt - end - - local function resolve_system_prompt(system_prompt) - if type(system_prompt) == 'function' then - local ok, result = pcall(system_prompt) - if not ok then - log.warn('Failed to resolve system prompt function: ' .. result) - return nil - end - return result - end - - return system_prompt - end - - config = vim.tbl_deep_extend('force', M.config, config or {}) - config, prompt = resolve(config, prompt or '') - - if config.system_prompt then - config.system_prompt = resolve_system_prompt(config.system_prompt) - - if M.config.prompts[config.system_prompt] then - -- Name references are good for making system prompt auto sticky - config.system_prompt = M.config.prompts[config.system_prompt].system_prompt - end - - config.system_prompt = vim.trim(config.system_prompt) .. '\n' .. M.config.prompts.COPILOT_BASE.system_prompt - config.system_prompt = vim.trim(config.system_prompt) - .. '\n' - .. vim.trim(require('CopilotChat.instructions.tool_use')) - - if config.diff == 'unified' then - config.system_prompt = vim.trim(config.system_prompt) - .. '\n' - .. vim.trim(require('CopilotChat.instructions.edit_file_unified')) - else - config.system_prompt = vim.trim(config.system_prompt) - .. '\n' - .. vim.trim(require('CopilotChat.instructions.edit_file_block')) - end - - config.system_prompt = config.system_prompt:gsub('{OS_NAME}', vim.uv.os_uname().sysname) - config.system_prompt = config.system_prompt:gsub('{LANGUAGE}', config.language) - config.system_prompt = config.system_prompt:gsub('{DIR}', state.source.cwd()) - end - - return config, prompt + return prompts.resolve_prompt(prompt, config) end --- Resolve the model from the prompt. @@ -600,22 +335,7 @@ end ---@return string, string ---@async function M.resolve_model(prompt, config) - config, prompt = M.resolve_prompt(prompt, config) - - local models = vim.tbl_map(function(model) - return model.id - end, list_models()) - - local selected_model = config.model or '' - prompt = prompt:gsub('%$' .. WORD, function(match) - if vim.tbl_contains(models, match) then - selected_model = match - return '' - end - return '$' .. match - end) - - return selected_model, prompt + return prompts.resolve_model(prompt, config) end --- Get the current source buffer and window. @@ -813,30 +533,33 @@ function M.ask(prompt, config) -- After opening window we need to schedule to next cycle so everything properly resolves schedule(function() - -- Prepare chat if not config.headless then + -- Prepare chat store_sticky(prompt) M.chat:start() M.chat:append('\n') end - -- Resolve prompt references - config, prompt = M.resolve_prompt(prompt, config) - local system_prompt = config.system_prompt or '' - - -- Remove sticky prefix - prompt = table.concat( - vim.tbl_map(function(l) - return l:gsub('^>%s+', '') - end, vim.split(prompt, '\n')), - '\n' - ) - async.run(handle_error(config, function() + config, prompt = M.resolve_prompt(prompt, config) + local system_prompt = config.system_prompt or '' local selected_tools, prompt = M.resolve_tools(prompt, config) - local resolved_resources, resolved_tools, prompt = M.resolve_functions(prompt, config) + local resolved_resources, resolved_tools, resolved_stickies, prompt = M.resolve_functions(prompt, config) local selected_model, prompt = M.resolve_model(prompt, config) + -- Remove sticky prefix + prompt = table.concat( + vim.tbl_map(function(l) + return l:gsub('^>%s+', '') + end, vim.split(prompt, '\n')), + '\n' + ) + + -- Add resolved stickies to state + for _, sticky in ipairs(resolved_stickies) do + table.insert(state.sticky, sticky) + end + prompt = vim.trim(prompt) if not config.headless then @@ -916,7 +639,7 @@ function M.ask(prompt, config) end), }) - -- If there was no error and no response, it means job was cancelled + -- If there was no error and no response, it means job was canceled if ask_response == nil then return end diff --git a/lua/CopilotChat/instructions/custom_instructions.lua b/lua/CopilotChat/instructions/custom_instructions.lua new file mode 100644 index 00000000..57b1ba44 --- /dev/null +++ b/lua/CopilotChat/instructions/custom_instructions.lua @@ -0,0 +1,6 @@ +return [[ + +Custom instructions from user's `{FILENAME}`: +{CONTENT} + +]] diff --git a/lua/CopilotChat/prompt.lua b/lua/CopilotChat/prompt.lua new file mode 100644 index 00000000..b35da0f2 --- /dev/null +++ b/lua/CopilotChat/prompt.lua @@ -0,0 +1,390 @@ +local client = require('CopilotChat.client') +local constants = require('CopilotChat.constants') +local functions = require('CopilotChat.functions') +local notify = require('CopilotChat.notify') +local files = require('CopilotChat.utils.files') +local utils = require('CopilotChat.utils') + +local WORD = '([^%s:]+)' +local WORD_NO_INPUT = '([^%s]+)' +local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`' +local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)' + +--- List available models. +--- @return CopilotChat.client.Model[] +local function list_models() + local models = client:models() + local result = vim.tbl_keys(models) + + table.sort(result, function(a, b) + a = models[a] + b = models[b] + if a.provider ~= b.provider then + return a.provider < b.provider + end + return a.id < b.id + end) + + return vim.tbl_map(function(id) + return models[id] + end, result) +end + +--- List available prompts. +---@return table +local function list_prompts() + local config = require('CopilotChat.config') + local prompts_to_use = {} + + for name, prompt in pairs(config.prompts) do + local val = prompt + if type(prompt) == 'string' then + val = { + prompt = prompt, + } + end + + prompts_to_use[name] = val + end + + return prompts_to_use +end + +--- Find custom instructions in the current working directory. +---@param cwd string +---@return table +local function find_custom_instructions(cwd) + local out = {} + local copilot_instructions_path = vim.fs.joinpath(cwd, '.github', 'copilot-instructions.md') + local copilot_instructions = files.read_file(copilot_instructions_path) + if copilot_instructions then + table.insert(out, { + filename = copilot_instructions_path, + content = vim.trim(copilot_instructions), + }) + end + return out +end + +local M = {} + +--- Resolve enabled tools from the prompt. +---@param prompt string? +---@param config CopilotChat.config.Shared? +---@return table, string +function M.resolve_tools(prompt, config) + config, prompt = M.resolve_prompt(prompt, config) + + local tools = {} + for _, tool in ipairs(functions.parse_tools(config.functions)) do + tools[tool.name] = tool + end + + local enabled_tools = {} + local tool_matches = utils.to_table(config.tools) + + -- Check for @tool pattern to find enabled tools + prompt = prompt:gsub('@' .. WORD, function(match) + for name, tool in pairs(config.functions) do + if name == match or tool.group == match then + table.insert(tool_matches, match) + return '' + end + end + return '@' .. match + end) + for _, match in ipairs(tool_matches) do + for name, tool in pairs(config.functions) do + if name == match or tool.group == match then + table.insert(enabled_tools, tools[name]) + end + end + end + + return enabled_tools, prompt +end + +--- Call and resolve function calls from the prompt. +---@param prompt string? +---@param config CopilotChat.config.Shared? +---@return table, table, table, string +---@async +function M.resolve_functions(prompt, config) + config, prompt = M.resolve_prompt(prompt, config) + + local chat = require('CopilotChat').chat + local source = require('CopilotChat').get_source() + + local tools = {} + for _, tool in ipairs(functions.parse_tools(config.functions)) do + tools[tool.name] = tool + end + + if config.resources then + local resources = utils.to_table(config.resources) + local lines = utils.split_lines(prompt) + for i = #resources, 1, -1 do + local resource = resources[i] + table.insert(lines, 1, '#' .. resource) + end + prompt = table.concat(lines, '\n') + end + + local resolved_resources = {} + local resolved_tools = {} + local resolved_stickies = {} + local tool_calls = {} + + utils.schedule_main() + for _, message in ipairs(chat:get_messages()) do + if message.tool_calls then + for _, tool_call in ipairs(message.tool_calls) do + table.insert(tool_calls, tool_call) + end + end + end + + local resource_matches = {} + + -- Check for #word:`input` pattern + for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_QUOTED) do + local pattern = string.format('#%s:`%s`', word, input) + table.insert(resource_matches, { + pattern = pattern, + word = word, + input = input, + }) + end + + -- Check for #word:input pattern + for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_UNQUOTED) do + local pattern = utils.empty(input) and string.format('#%s', word) or string.format('#%s:%s', word, input) + table.insert(resource_matches, { + pattern = pattern, + word = word, + input = input, + }) + end + + -- Check for ##word:input pattern + for word in prompt:gmatch('##' .. WORD_NO_INPUT) do + local pattern = string.format('##%s', word) + table.insert(resource_matches, { + pattern = pattern, + word = word, + }) + end + + -- Resolve each function reference + local function expand_function(name, input) + notify.publish(notify.STATUS, 'Running function: ' .. name) + + local tool_id = nil + if not utils.empty(tool_calls) then + for _, tool_call in ipairs(tool_calls) do + if tool_call.name == name and vim.trim(tool_call.id) == vim.trim(input) then + input = utils.empty(tool_call.arguments) and {} or utils.json_decode(tool_call.arguments) + tool_id = tool_call.id + break + end + end + end + + local tool = config.functions[name] + if not tool then + -- Check if input matches uri + for tool_name, tool_spec in pairs(config.functions) do + if tool_spec.uri then + local match = functions.match_uri(name, tool_spec.uri) + if match then + name = tool_name + tool = tool_spec + input = match + break + end + end + end + end + if not tool then + return nil + end + if not tool_id and not tool.uri then + return nil + end + + local schema = tools[name] and tools[name].schema or nil + local ok, output + if config.stop_on_function_failure then + output = tool.resolve(functions.parse_input(input, schema), source) + ok = true + else + ok, output = pcall(tool.resolve, functions.parse_input(input, schema), source) + end + + local result = '' + if not ok then + result = utils.make_string(output) + else + for _, content in ipairs(output) do + if content then + local content_out = nil + if content.uri then + if + not vim.tbl_contains(resolved_resources, function(resource) + return resource.uri == content.uri + end, { predicate = true }) + then + content_out = '##' .. content.uri + table.insert(resolved_resources, content) + end + + if tool_id then + table.insert(resolved_stickies, '##' .. content.uri) + end + else + content_out = content.data + end + + if content_out then + if not utils.empty(result) then + result = result .. '\n' + end + result = result .. content_out + end + end + end + end + + if tool_id then + table.insert(resolved_tools, { + id = tool_id, + result = result, + }) + + return '' + end + + return result + end + + -- Resolve and process all tools + for _, match in ipairs(resource_matches) do + if not utils.empty(match.pattern) then + local out = expand_function(match.word, match.input) + if out == nil then + out = match.pattern + end + out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub + prompt = prompt:gsub(vim.pesc(match.pattern), out, 1) + end + end + + return resolved_resources, resolved_tools, resolved_stickies, prompt +end + +--- Resolve the final prompt and config from prompt template. +---@param prompt string? +---@param config CopilotChat.config.Shared? +---@return CopilotChat.config.prompts.Prompt, string +---@async +function M.resolve_prompt(prompt, config) + local chat = require('CopilotChat').chat + local source = require('CopilotChat').get_source() + + if prompt == nil then + utils.schedule_main() + local message = chat:get_message(constants.ROLE.USER) + if message then + prompt = message.content + end + end + + local prompts_to_use = list_prompts() + local depth = 0 + local MAX_DEPTH = 10 + + local function resolve(inner_config, inner_prompt) + if depth >= MAX_DEPTH then + return inner_config, inner_prompt + end + depth = depth + 1 + + inner_prompt = string.gsub(inner_prompt, '/' .. WORD, function(match) + local p = prompts_to_use[match] + if p then + local resolved_config, resolved_prompt = resolve(p, p.prompt or '') + inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config) + return resolved_prompt + end + + return '/' .. match + end) + + depth = depth - 1 + return inner_config, inner_prompt + end + + config = vim.tbl_deep_extend('force', require('CopilotChat.config'), config or {}) + config, prompt = resolve(config, prompt or '') + + if config.system_prompt then + if config.prompts[config.system_prompt] then + -- Name references are good for making system prompt auto sticky + config.system_prompt = config.prompts[config.system_prompt].system_prompt + end + + local custom_instructions = vim.trim(require('CopilotChat.instructions.custom_instructions')) + for _, instruction in ipairs(find_custom_instructions(source.cwd())) do + config.system_prompt = vim.trim(config.system_prompt) + .. '\n' + .. custom_instructions:gsub('{FILENAME}', instruction.filename):gsub('{CONTENT}', instruction.content) + end + + config.system_prompt = vim.trim(config.system_prompt) .. '\n' .. config.prompts.COPILOT_BASE.system_prompt + config.system_prompt = vim.trim(config.system_prompt) + .. '\n' + .. vim.trim(require('CopilotChat.instructions.tool_use')) + + if config.diff == 'unified' then + config.system_prompt = vim.trim(config.system_prompt) + .. '\n' + .. vim.trim(require('CopilotChat.instructions.edit_file_unified')) + else + config.system_prompt = vim.trim(config.system_prompt) + .. '\n' + .. vim.trim(require('CopilotChat.instructions.edit_file_block')) + end + + config.system_prompt = config.system_prompt:gsub('{OS_NAME}', vim.uv.os_uname().sysname) + config.system_prompt = config.system_prompt:gsub('{LANGUAGE}', config.language) + config.system_prompt = config.system_prompt:gsub('{DIR}', source.cwd()) + end + + return config, prompt +end + +--- Resolve the model from the prompt. +---@param prompt string? +---@param config CopilotChat.config.Shared? +---@return string, string +---@async +function M.resolve_model(prompt, config) + config, prompt = M.resolve_prompt(prompt, config) + + local models = vim.tbl_map(function(model) + return model.id + end, list_models()) + + local selected_model = config.model or '' + prompt = prompt:gsub('%$' .. WORD, function(match) + if vim.tbl_contains(models, match) then + selected_model = match + return '' + end + return '$' .. match + end) + + return selected_model, prompt +end + +return M diff --git a/lua/CopilotChat/utils/diff.lua b/lua/CopilotChat/utils/diff.lua index 6ff56bac..17eb3091 100644 --- a/lua/CopilotChat/utils/diff.lua +++ b/lua/CopilotChat/utils/diff.lua @@ -54,7 +54,7 @@ local function apply_hunk(hunk, content) return patched, true end - -- Fallback: try smaller context window + -- Fallback: direct replacement local lines = vim.split(content, '\n') local insert_idx = hunk.start_old or 1 if not hunk.start_old then @@ -83,24 +83,8 @@ local function apply_hunk(hunk, content) end end - -- Define context window around insert point - local context_size = 10 local start_idx = insert_idx local end_idx = insert_idx + #hunk.old_snippet - local context_start = math.max(1, start_idx - context_size) - local context_end = math.min(#lines, end_idx + context_size) - local context_window = table.concat(vim.list_slice(lines, context_start, context_end), '\n') - - local patched_window, window_results = dmp.patch_apply(patch, context_window) - if not vim.tbl_contains(window_results, false) then - -- Patch succeeded in window, splice back - local new_lines = vim.list_slice(lines, 1, context_start - 1) - vim.list_extend(new_lines, vim.split(patched_window, '\n')) - vim.list_extend(new_lines, lines, context_end + 1, #lines) - return table.concat(new_lines, '\n'), true - end - - -- Fallback: direct replacement local new_lines = vim.list_slice(lines, 1, start_idx - 1) vim.list_extend(new_lines, hunk.new_snippet) vim.list_extend(new_lines, lines, end_idx + 1, #lines)