Skip to content

Commit

Permalink
Merge pull request #173 from Codium-ai/tr/caching
Browse files Browse the repository at this point in the history
Optimization of PR Diff Processing
  • Loading branch information
mrT23 committed Aug 5, 2023
2 parents bd07a0c + 749ae1b commit bd86266
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 62 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
## 2023-08-03

### Optimized
- Optimized PR diff processing by introducing caching for diff files, reducing the number of API calls.
- Refactored `load_large_diff` function to generate a patch only when necessary.
- Fixed a bug in the GitLab provider where the new file was not retrieved correctly.

## 2023-08-02

### Enhanced
Expand Down
9 changes: 1 addition & 8 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_large_diff
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider

Expand Down Expand Up @@ -46,7 +45,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
PATCH_EXTRA_LINES = 0

try:
diff_files = list(git_provider.get_diff_files())
diff_files = git_provider.get_diff_files()
except RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
raise
Expand Down Expand Up @@ -98,12 +97,7 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
for lang in pr_languages:
for file in lang['files']:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch

# handle the case of large patch, that initially was not loaded
patch = load_large_diff(file, new_file_content_str, original_file_content_str, patch)

if not patch:
continue

Expand Down Expand Up @@ -161,7 +155,6 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
patch = load_large_diff(file, new_file_content_str, original_file_content_str, patch)
if not patch:
continue

Expand Down
28 changes: 10 additions & 18 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,38 +195,30 @@ def convert_str_to_datetime(date_str):
return datetime.strptime(date_str, datetime_format)


def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str:
def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str) -> str:
"""
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
input.
Args:
file: The file object for which the patch needs to be generated.
new_file_content_str: The new content of the file as a string.
original_file_content_str: The original content of the file as a string.
patch: An optional patch string that can be provided as input.
Returns:
The generated or provided patch string.
Raises:
None.
Additional Information:
- If 'patch' is not provided as input, the function generates a patch using the 'difflib' library and returns it
as output.
- If the 'settings.config.verbosity_level' is greater than or equal to 2, a warning message is logged indicating
that the file was modified but no patch was found, and a patch is manually created.
"""
if not patch: # to Do - also add condition for file extension
try:
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True))
if get_settings().config.verbosity_level >= 2:
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.")
patch = ''.join(diff)
except Exception:
pass
patch = ""
try:
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
new_file_content_str.splitlines(keepends=True))
if get_settings().config.verbosity_level >= 2:
logging.warning(f"File was modified, but no patch was found. Manually creating patch: {filename}.")
patch = ''.join(diff)
except Exception:
pass
return patch


Expand Down
57 changes: 37 additions & 20 deletions pr_agent/git_providers/github_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False
self.pr = None
self.github_user_id = None
self.diff_files = None
self.git_files = None
self.incremental = incremental
if pr_url:
self.set_pr(pr_url)
Expand Down Expand Up @@ -81,40 +82,56 @@ def get_previous_review(self):
def get_files(self):
if self.incremental.is_incremental and self.file_set:
return self.file_set.values()
return self.pr.get_files()
if not self.git_files:
# bring files from GitHub only once
self.git_files = self.pr.get_files()
return self.git_files

@retry(exceptions=RateLimitExceeded,
tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
def get_diff_files(self) -> list[FilePatchInfo]:
"""
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub,
along with their content and patch information.
Returns:
diff_files (List[FilePatchInfo]): List of FilePatchInfo objects representing the modified, added, deleted,
or renamed files in the merge request.
"""
try:
if self.diff_files:
return self.diff_files

files = self.get_files()
diff_files = []

for file in files:
if is_valid_file(file.filename):
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
patch = file.patch
if self.incremental.is_incremental and self.file_set:
original_file_content_str = self._get_pr_file_content(file,
self.incremental.last_seen_commit_sha)
patch = load_large_diff(file,
new_file_content_str,
original_file_content_str,
None)
self.file_set[file.filename] = patch
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)

diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename))
if not is_valid_file(file.filename):
continue

new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub
patch = file.patch

if self.incremental.is_incremental and self.file_set:
original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha)
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)
self.file_set[file.filename] = patch
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
if not patch:
patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str)

diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename))

self.diff_files = diff_files
return diff_files

except GithubException.RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for GitHub API. Original message: {e}")
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e

def publish_description(self, pr_title: str, pr_body: str):
self.pr.edit(title=pr_title, body=pr_body)
# self.pr.create_issue_comment(pr_comment)

def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress:
Expand All @@ -132,9 +149,9 @@ def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in
self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)])

def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
diff_files = self.get_diff_files()
position = -1
for file in self.diff_files:
for file in diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
Expand Down
57 changes: 42 additions & 15 deletions pr_agent/git_providers/gitlab_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gitlab import GitlabGetError

from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff
from ..config_loader import get_settings
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider

Expand All @@ -30,6 +31,7 @@ def __init__(self, merge_request_url: Optional[str] = None, incremental: Optiona
self.id_mr = None
self.mr = None
self.diff_files = None
self.git_files = None
self.temp_comments = []
self._set_merge_request(merge_request_url)
self.RE_HUNK_HEADER = re.compile(
Expand Down Expand Up @@ -65,19 +67,27 @@ def _get_pr_file_content(self, file_path: str, branch: str) -> str:
return ''

def get_diff_files(self) -> list[FilePatchInfo]:
"""
Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitLab,
along with their content and patch information.
Returns:
diff_files (List[FilePatchInfo]): List of FilePatchInfo objects representing the modified, added, deleted,
or renamed files in the merge request.
"""

if self.diff_files:
return self.diff_files

diffs = self.mr.changes()['changes']
diff_files = []
for diff in diffs:
if is_valid_file(diff['new_path']):
original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch)
new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch)
edit_type = EDIT_TYPE.MODIFIED
if diff['new_file']:
edit_type = EDIT_TYPE.ADDED
elif diff['deleted_file']:
edit_type = EDIT_TYPE.DELETED
elif diff['renamed_file']:
edit_type = EDIT_TYPE.RENAMED
# original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch)
# new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch)
original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.diff_refs['base_sha'])
new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.diff_refs['head_sha'])

try:
if isinstance(original_file_content_str, bytes):
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
Expand All @@ -86,15 +96,33 @@ def get_diff_files(self) -> list[FilePatchInfo]:
except UnicodeDecodeError:
logging.warning(
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")

edit_type = EDIT_TYPE.MODIFIED
if diff['new_file']:
edit_type = EDIT_TYPE.ADDED
elif diff['deleted_file']:
edit_type = EDIT_TYPE.DELETED
elif diff['renamed_file']:
edit_type = EDIT_TYPE.RENAMED

filename = diff['new_path']
patch = diff['diff']
if not patch:
patch = load_large_diff(filename, new_file_content_str, original_file_content_str)

diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'],
FilePatchInfo(original_file_content_str, new_file_content_str,
patch=patch,
filename=filename,
edit_type=edit_type,
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path']))
self.diff_files = diff_files
return diff_files

def get_files(self):
return [change['new_path'] for change in self.mr.changes()['changes']]
if not self.git_files:
self.git_files = [change['new_path'] for change in self.mr.changes()['changes']]
return self.git_files

def publish_description(self, pr_title: str, pr_body: str):
try:
Expand All @@ -110,7 +138,6 @@ def publish_comment(self, mr_comment: str, is_temporary: bool = False):
self.temp_comments.append(comment)

def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
relevant_line_in_file)
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
Expand Down Expand Up @@ -151,9 +178,9 @@ def publish_code_suggestions(self, code_suggestions: list):
relevant_lines_start = suggestion['relevant_lines_start']
relevant_lines_end = suggestion['relevant_lines_end']

self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
diff_files = self.get_diff_files()
target_file = None
for file in self.diff_files:
for file in diff_files:
if file.filename == relevant_file:
if file.filename == relevant_file:
target_file = file
Expand All @@ -180,7 +207,7 @@ def search_line(self, relevant_file, relevant_line_in_file):
target_file = None

edit_type = self.get_edit_type(relevant_line_in_file)
for file in self.diff_files:
for file in self.get_diff_files():
if file.filename == relevant_file:
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file,
relevant_line_in_file)
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ async def run(self):

async def _prepare_prediction(self, model: str):
logging.info('Getting PR diff...')
# we are using extended hunk with line numbers for code suggestions
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=True)

logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)

Expand Down

0 comments on commit bd86266

Please sign in to comment.