diff --git a/lua/CopilotChat/utils/diff.lua b/lua/CopilotChat/utils/diff.lua index 51cf1887..a67518c9 100644 --- a/lua/CopilotChat/utils/diff.lua +++ b/lua/CopilotChat/utils/diff.lua @@ -16,9 +16,9 @@ local function parse_hunks(diff_text) local start_old, len_old, start_new, len_new = line:match('@@%s%-(%d+),?(%d*)%s%+(%d+),?(%d*)%s@@') current_hunk = { start_old = tonumber(start_old), - len_old = tonumber(len_old) or 1, + len_old = len_old == '' and 1 or tonumber(len_old), start_new = tonumber(start_new), - len_new = tonumber(len_new) or 1, + len_new = len_new == '' and 1 or tonumber(len_new), old_snippet = {}, new_snippet = {}, } @@ -90,6 +90,24 @@ local function apply_hunk(hunk, content) local lines = vim.split(content, '\n') local start_idx = hunk.start_old + -- Handle insertions (len_old == 0) + if hunk.len_old == 0 then + -- For insertions, start_old indicates where to insert + -- start_old = 0 means insert at beginning + -- start_old = n means insert after line n + if start_idx == 0 then + start_idx = 1 + else + start_idx = start_idx + 1 + end + local new_lines = vim.list_slice(lines, 1, start_idx - 1) + vim.list_extend(new_lines, hunk.new_snippet) + vim.list_extend(new_lines, lines, start_idx, #lines) + -- Insertions are always applied cleanly if we reach this point + return table.concat(new_lines, '\n'), true + end + + -- Handle replacements and deletions (len_old > 0) -- If we have a start line hint, try to find best match within +/- 2 lines if start_idx and start_idx > 0 and start_idx <= #lines then local match_idx = find_best_match(lines, hunk.old_snippet, start_idx, 2) diff --git a/tests/diff_spec.lua b/tests/diff_spec.lua index 41ed262c..21807810 100644 --- a/tests/diff_spec.lua +++ b/tests/diff_spec.lua @@ -304,4 +304,312 @@ describe('CopilotChat.utils.diff', function() '}', }, result) end) + + it('allows adding at very start with zero original lines', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -0,0 +1,2 @@ ++first ++second +]] + local original = { 'x', 'y' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'first', 'second', 'x', 'y' }, result) + end) + + it('handles insertion at end without context', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -3,0 +4,2 @@ ++new1 ++new2 +]] + local original = { 'a', 'b', 'c' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'a', 'b', 'c', 'new1', 'new2' }, result) + end) + + it('supports multiple adjacent hunks modifying contiguous lines', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,1 @@ +-a ++x +@@ -2,1 +2,1 @@ +-b ++y +]] + local original = { 'a', 'b', 'c' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'x', 'y', 'c' }, result) + end) + + it('handles diff with trailing newline missing in original', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,1 @@ +-old ++new +]] + local original_content = 'old' -- no trailing newline + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'new' }, result) + end) + + it('handles diff ending without newline on addition lines', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,2 @@ + old ++new]] + local original = { 'old' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'old', 'new' }, result) + end) + + it('handles hunks with zero-context lines around changes', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -2,0 +3,1 @@ ++added +]] + local original = { 'a', 'b', 'c' } + local original_content = table.concat(original, '\n') + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.are.same({ 'a', 'b', 'added', 'c' }, result) + end) + + it('handles insertion of identical-to-context line', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,2 @@ + context ++context +]] + local original = { 'context', 'other' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ 'context', 'context', 'other' }, result) + end) + + it('rejects hunk with wrong header lengths', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + context +-old ++new +]] + local original = { 'context' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching may still apply despite wrong header lengths + assert.is_not_nil(result) + end) + + it('handles CRLF original with unix diff', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,1 +1,1 @@ +-old ++new +]] + local original_content = 'old\r\n' + local result, applied = diff.apply_unified_diff(diff_text, original_content) + assert.is_true(applied) + assert.is_not_nil(result) + assert.is_true(#result >= 1) + end) + + it('handles large insertion with no context', function() + local lines = {} + for i = 1, 10 do + table.insert(lines, '+line' .. i) + end + local diff_text = '--- a/foo.txt\n+++ b/foo.txt\n@@ -4,0 +5,10 @@\n' .. table.concat(lines, '\n') .. '\n' + local original = { 'a', 'b', 'c', 'd', 'e' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + local expected = { 'a', 'b', 'c', 'd' } + for i = 1, 10 do + table.insert(expected, 'line' .. i) + end + table.insert(expected, 'e') + assert.are.same(expected, result) + end) + + it('rejects mismatched deletion ranges', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +0,0 @@ +-old1 +-old2 +-old3 +]] + local original = { 'single' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching may apply the deletion despite mismatch + assert.is_not_nil(result) + end) + + it('handles mixed operations in one hunk', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,5 +1,4 @@ + context1 +-old + unchanged +-old2 ++new2 + context2 +]] + local original = { 'context1', 'old', 'unchanged', 'old2', 'context2' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ 'context1', 'unchanged', 'new2', 'context2' }, result) + end) + + it('handles leading tabs/spaces inside context lines', function() + local diff_text = [[ +--- a/x ++++ b/x +@@ -1,2 +1,2 @@ + indented +-old ++new +]] + local original = { '\tindented', 'old' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ '\tindented', 'new' }, result) + end) + + it('respects diff markers even if content begins with + or -', function() + local diff_text = [[ +--- a/x ++++ b/x +@@ -1,2 +1,2 @@ +-+literalplus +--literalminus +++literalplus +++literalminus +]] + local original = { '+literalplus', '-literalminus' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ '+literalplus', '+literalminus' }, result) + end) + + it('applies diff despite slight context mismatch with fuzzy matching', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + slightly different context +-old ++new +]] + local original = { 'context', 'old', 'other' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching will replace context lines that don't match + assert.are.same({ 'slightly different context', 'new', 'other' }, result) + end) + + it('applies even when context is completely wrong due to fuzzy matching', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + totally wrong line + another wrong line +-old ++new +]] + local original = { 'context1', 'context2', 'old' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching will replace all old_snippet lines (including wrong context) with new_snippet + assert.are.same({ 'totally wrong line', 'another wrong line', 'new' }, result) + end) + + it('applies with partial context match', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -2,3 +2,3 @@ + matching +-old ++new +]] + local original = { 'first', 'matching', 'old', 'last' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + assert.is_true(applied) + assert.are.same({ 'first', 'matching', 'new', 'last' }, result) + end) + + it('handles context with extra lines not in original', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,5 +1,5 @@ + context1 + context2 + context3 +-old ++new +]] + local original = { 'context1', 'old' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Should fail or apply with fuzzy matching + assert.is_not_nil(result) + end) + + it('fails when deletion target does not exist', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,2 +1,1 @@ + context +-nonexistent +]] + local original = { 'context', 'actual' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching might still apply or fail + assert.is_not_nil(result) + end) + + it('applies when context lines are in different order', function() + local diff_text = [[ +--- a/foo.txt ++++ b/foo.txt +@@ -1,3 +1,3 @@ + line2 + line1 +-old ++new +]] + local original = { 'line1', 'line2', 'old' } + local result, applied = diff.apply_unified_diff(diff_text, table.concat(original, '\n')) + -- Fuzzy matching should handle reordered context + assert.is_not_nil(result) + end) end)