In [19]:
import subprocess
import re
import os
import requests
import sys

class GitInteraction:
    def __init__(self, repo_path):
        self.repo_path = repo_path

    def get_file_at_commit(self, commit_hash, file_path):
        """Get the contents of a file at a specific commit."""
        try:
            command = ["git", "show", f"{commit_hash}:{file_path}"]
            result = subprocess.run(
                command,
                cwd=self.repo_path,
                text=True,
                capture_output=True,
                check=True,
                encoding='utf-8',
                errors='ignore'
            )
            return result.stdout
        except subprocess.CalledProcessError as e:
            print(f"Error getting file at commit: {commit_hash} using command: {command}")
            print(e.output)
            return None

    def get_patch_of_commit(self, commit_hash):
        """Fetch the patch of a specific commit from the GitHub URL."""
        url = f"https://github.com/mozilla/gecko-dev/commit/{commit_hash}.patch"
        try:
            response = requests.get(url)
            response.raise_for_status()
            patch_text = response.text
            return patch_text
        except requests.RequestException as e:
            print(f"Error fetching patch from URL: {url}")
            print(e)
            return None

    def fetch_pre_fix_vulnerable_code(self, commit_hash, file_path):
        """Fetch vulnerable code segments from the commit prior to the fixing commit."""
        parent_commit_hash = f"{commit_hash}^"
        return self.get_file_at_commit(parent_commit_hash, file_path)

    def fetch_fixed_code(self, commit_hash, file_path):
        """Fetch patched code segments from the commit."""
        return self.get_file_at_commit(commit_hash, file_path)    
    
    def extract_function_signatures(self, code):
        """Extract function signatures from the code."""
        pattern = r'\b(?:(?:static|struct\s+\w+\s*\*?)\s+)*\w+\s+\**\w+\s*\([^)]*\)\s*\{'
        matches = re.findall(pattern, code, re.MULTILINE)
        function_signatures = [match.strip() for match in matches]
        return function_signatures
    
    def extract_files_and_functions_info(self, patch_text):
        """Extract the file paths and function names that contain added or deleted lines from a diff."""
        function_pattern = re.compile(r'^@@.*?@@\s*(\w[\w\s\*]*)\(')
        file_path_pattern = re.compile(r'^diff --git a/(.*?) b/')

        files_info = {}
        current_function = None
        current_file_path = None
        current_added_block = []
        current_deleted_block = []
        lines = patch_text.split('\n')

        for line in lines:
            file_match = file_path_pattern.search(line)
            if file_match:
                current_file_path = file_match.group(1).strip()
                if current_file_path not in files_info:
                    files_info[current_file_path] = {'functions': {}}
                current_function = None  # reset when encountering a new file path
                continue

            match = function_pattern.search(line)
            if match:
                current_function = match.group(1).strip()
                if current_function not in files_info[current_file_path]['functions']:
                    files_info[current_file_path]['functions'][current_function] = {'added': [], 'deleted': []}
                # Clear the current blocks when a new function is found
                current_added_block = []
                current_deleted_block = []
            else:
                if current_file_path:
                    if line.startswith('+') and not line.startswith('+++'):
                        if current_deleted_block:
                            if current_function:
                                files_info[current_file_path]['functions'][current_function]['deleted'].append('\n'.join(current_deleted_block))
                            else:
                                files_info[current_file_path].setdefault('deleted', []).append('\n'.join(current_deleted_block))
                            current_deleted_block = []
                        current_added_block.append(line[1:].strip())
                    elif line.startswith('-') and not line.startswith('---'):
                        if current_added_block:
                            if current_function:
                                files_info[current_file_path]['functions'][current_function]['added'].append('\n'.join(current_added_block))
                            else:
                                files_info[current_file_path].setdefault('added', []).append('\n'.join(current_added_block))
                            current_added_block = []
                        current_deleted_block.append(line[1:].strip())
                    else:
                        if current_added_block:
                            if current_function:
                                files_info[current_file_path]['functions'][current_function]['added'].append('\n'.join(current_added_block))
                            else:
                                files_info[current_file_path].setdefault('added', []).append('\n'.join(current_added_block))
                            current_added_block = []
                        if current_deleted_block:
                            if current_function:
                                files_info[current_file_path]['functions'][current_function]['deleted'].append('\n'.join(current_deleted_block))
                            else:
                                files_info[current_file_path].setdefault('deleted', []).append('\n'.join(current_deleted_block))
                            current_deleted_block = []

        # Append any remaining blocks
        if current_added_block:
            if current_function:
                files_info[current_file_path]['functions'][current_function]['added'].append('\n'.join(current_added_block))
            else:
                files_info[current_file_path].setdefault('added', []).append('\n'.join(current_added_block))
        if current_deleted_block:
            if current_function:
                files_info[current_file_path]['functions'][current_function]['deleted'].append('\n'.join(current_deleted_block))
            else:
                files_info[current_file_path].setdefault('deleted', []).append('\n'.join(current_deleted_block))

        # Clean up empty strings
        for file_path, changes in files_info.items():
            if 'added' in changes:
                changes['added'] = list(filter(None, changes['added']))
            if 'deleted' in changes:
                changes['deleted'] = list(filter(None, changes['deleted']))
            for function_name, function_changes in list(changes['functions'].items()):
                function_changes['added'] = list(filter(None, function_changes['added']))
                function_changes['deleted'] = list(filter(None, function_changes['deleted']))
                if not function_name:
                    del changes['functions'][function_name]

        return files_info

    def extract_function(self, code, function_name):
        """Extract the entire function (vulnerable or patched) by its name."""
        if not isinstance(code, str):
            return None

        # Pattern to match the function start
        function_start_pattern = re.compile(r'\b{}\b\s*\([^{{}}]*\)\s*{{'.format(re.escape(function_name)), re.DOTALL)
        match = function_start_pattern.search(code)
        if not match:
            return None

        start_index = match.start()
        brace_stack = []
        inside_function = False
        end_index = start_index

        for i in range(start_index, len(code)):
            if code[i] == '{':
                brace_stack.append('{')
                inside_function = True
            elif code[i] == '}':
                if brace_stack:
                    brace_stack.pop()
                    if not brace_stack:
                        end_index = i + 1
                        break

        if not inside_function or brace_stack:
            return None

        return code[start_index:end_index]

    def is_change_within_function(self, function, changes):
        """Check if any change block is present within the given function code."""
        function_lines = function.split('\n')
        change_blocks = changes['added'] + changes['deleted']

        for change in change_blocks:
            change_lines = [line.strip() for line in change.split('\n') if line.strip()]
            if not change_lines:
                continue

            for i in range(len(function_lines) - len(change_lines) + 1):
                if all(change_lines[j] == function_lines[i + j].strip() for j in range(len(change_lines))):
                    return True
        return False

    def parase_patch_header(self, patch_text):
        """Parse the patch header to extract the number of files changed, added, and deleted lines."""
        added_lines = 0
        deleted_lines = 0
        files_changed = set()

        file_pattern = re.compile(r'^diff --git a/(.*?) b/(.*?)$', re.MULTILINE)
        matches = file_pattern.findall(patch_text)
        for match in matches:
            files_changed.add(match[0])

        sections = re.split(r'(?m)^diff --git', patch_text)
        for section in sections[1:]:
            lines = section.split('\n')
            for line in lines:
                if line.startswith('+') and not line.startswith('+++'):
                    added_lines += 1
                elif line.startswith('-') and not line.startswith('---'):
                    deleted_lines += 1

        return len(files_changed), added_lines, deleted_lines

    def extract_commit_description(self, commit_hash):
        """Extract the commit description using git log."""
        try:
            result = subprocess.run(
                ['git', '-C', self.repo_path, 'log', '--format=%B', '-n', '1', commit_hash],
                stdout=subprocess.PIPE,
                text=True,
                encoding='utf-8'
            )
            return result.stdout.strip()
        except subprocess.CalledProcessError as e:
            print(f"Error extracting description for commit {commit_hash}")
            print(e.output)
            return None

    def build_code_blocks(self, files_info, commit_hash):
        """Build the vulnerable and patched code blocks from the patch info."""
        vulnerable_code_block = ""
        patched_code_block = ""

        # Process file-level changes
        for file_path, file_changes in files_info.items():
            file_header_printed_vulnerable = False
            file_header_printed_patched = False

            # Process function-level changes
            functions_to_modify = []
            for function_name, changes in file_changes['functions'].items():
                if not function_name:
                    continue
                vulnerable_code = self.fetch_pre_fix_vulnerable_code(commit_hash, file_path)
                patched_code = self.fetch_fixed_code(commit_hash, file_path)

                vulnerable_function = self.extract_function(vulnerable_code, function_name)
                patched_function = self.extract_function(patched_code, function_name)

                # Check if the change appears within the function body
                if vulnerable_function and patched_function:
                    if (self.is_change_within_function(vulnerable_function, changes) or
                        self.is_change_within_function(patched_function, changes)):
                        if not file_header_printed_vulnerable:
                            vulnerable_code_block += f"// File path: {file_path}\n"
                            file_header_printed_vulnerable = True
                        if not file_header_printed_patched:
                            patched_code_block += f"// File path: {file_path}\n"
                            file_header_printed_patched = True
                        vulnerable_code_block += f"{vulnerable_function}\n"
                        patched_code_block += f"{patched_function}\n"
                    else:
                        # Look for function signatures in the added/deleted lines
                        pattern = r'\b([a-zA-Z_][a-zA-Z0-9_\* ]*\s+[a-zA-Z_][a-zA-Z0-9_]*)\s*\([^)]*\)'
                        added_function_signatures = re.findall(pattern, '\n'.join(changes['added']), re.MULTILINE)
                        deleted_function_signatures = re.findall(pattern, '\n'.join(changes['deleted']), re.MULTILINE)
                        if added_function_signatures or deleted_function_signatures:
                            new_function_name = added_function_signatures[0] if added_function_signatures else deleted_function_signatures[0]
                            functions_to_modify.append((function_name, new_function_name))
                        else:
                            functions_to_modify.append((function_name, ""))
                else:
                    if changes.get('added'):
                        sigs = self.extract_function_signatures('\n'.join(changes['added']))
                        if sigs:
                            new_function_name = sigs[0]
                            functions_to_modify.append((function_name, new_function_name))
                            patched_function = self.extract_function(patched_code, new_function_name)
                        else:
                            patched_function = '\n'.join(changes['added'])
                            functions_to_modify.append((function_name, ""))
                        if not file_header_printed_patched:
                            patched_code_block += f"// File path: {file_path}\n"
                            file_header_printed_patched = True
                        patched_code_block += f"{patched_function}\n"

                    if changes.get('deleted'):
                        sigs = self.extract_function_signatures('\n'.join(changes['deleted']))
                        if sigs:
                            new_function_name = sigs[0]
                            functions_to_modify.append((function_name, new_function_name))
                            vulnerable_function = self.extract_function(vulnerable_code, new_function_name)
                        else:
                            vulnerable_function = '\n'.join(changes['deleted'])
                            functions_to_modify.append((function_name, ""))
                        if not file_header_printed_vulnerable:
                            vulnerable_code_block += f"// File path: {file_path}\n"
                            file_header_printed_vulnerable = True
                        vulnerable_code_block += f"{vulnerable_function}\n"

            # Process any functions that were marked for modification
            functions_to_modify = list(set(functions_to_modify))
            for function_name, new_function_name in functions_to_modify:
                if not function_name:
                    continue
                if new_function_name in file_changes['functions']:
                    combine_add = file_changes['functions'][function_name]['added'] + file_changes['functions'][new_function_name]['added']
                    combine_del = file_changes['functions'][function_name]['deleted'] + file_changes['functions'][new_function_name]['deleted']
                    file_changes['functions'][new_function_name] = {'added': combine_add, 'deleted': combine_del}
                    del file_changes['functions'][function_name]
                else:
                    original_value = file_changes['functions'][function_name]
                    del file_changes['functions'][function_name]
                    file_changes['functions'][new_function_name] = original_value

                if not new_function_name:
                    continue
                if new_function_name in vulnerable_code_block or new_function_name in patched_code_block:
                    continue
                vulnerable_code = self.fetch_pre_fix_vulnerable_code(commit_hash, file_path)
                patched_code = self.fetch_fixed_code(commit_hash, file_path)
                vulnerable_function = self.extract_function(vulnerable_code, new_function_name)
                patched_function = self.extract_function(patched_code, new_function_name)
                if vulnerable_function or patched_function:
                    if not file_header_printed_vulnerable:
                        vulnerable_code_block += f"// File path: {file_path}\n"
                        file_header_printed_vulnerable = True
                    if not file_header_printed_patched:
                        patched_code_block += f"// File path: {file_path}\n"
                        file_header_printed_patched = True
                    vulnerable_code_block += f"{vulnerable_function}\n"
                    patched_code_block += f"{patched_function}\n"

            # Handle file-level added/deleted lines if present
            if file_changes.get('added'):
                if not file_header_printed_patched:
                    patched_code_block += f"// File path: {file_path}\n"
                    file_header_printed_patched = True
                patched_code_block += f"{''.join(file_changes['added'])}\n"

            if file_changes.get('deleted'):
                if not file_header_printed_vulnerable:
                    vulnerable_code_block += f"// File path: {file_path}\n"
                    file_header_printed_vulnerable = True
                vulnerable_code_block += f"{''.join(file_changes['deleted'])}\n"

        return vulnerable_code_block, patched_code_block

    def num_functions_changed(self, vulnerable_code_block, patched_code_block):
        """Calculate the number of functions changed between the two code blocks."""
        vulnerable_functions = self.extract_function_signatures(vulnerable_code_block)
        patched_functions = self.extract_function_signatures(patched_code_block)
        unique_functions = set(vulnerable_functions + patched_functions)
        return len(unique_functions)

def main():

    repo_path = "/home/azibaeir/Research/Benchmarking/gecko-dev"  # e.g., "/home/user/gecko-dev"
    commit_hash = "d3fc632669c98bc8a94c820be75455ca4b446cf7"

    git_interaction = GitInteraction(repo_path)
    patch_text = git_interaction.get_patch_of_commit(commit_hash)
    if not patch_text:
        print("Failed to retrieve patch.")
        sys.exit(1)

    files_info = git_interaction.extract_files_and_functions_info(patch_text)
    vulnerable_code_block, patched_code_block = git_interaction.build_code_blocks(files_info, commit_hash)

    print("Vulnerable code block:")
    print(vulnerable_code_block)
    print("\nPatched code block:")
    print(patched_code_block)

if __name__ == "__main__":
    main()


Vulnerable code block:
// File path: netwerk/base/src/nsBaseChannel.cpp
if (NS_FAILED(mStatus)) {


Patched code block:
// File path: netwerk/base/src/nsBaseChannel.cpp
PRBool doNotify = PR_TRUE;else
doNotify = PR_FALSE;if (doNotify) {



## pre&curr-commit file

In [22]:

# fetch pre-commit and current version of a file
#!/usr/bin/env python3
import os
import subprocess
import sys

def get_changed_files(repo_path, commit_hash):
    """
    Use git diff-tree to list all files changed in the given commit.
    """
    cmd = ["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash]
    try:
        result = subprocess.run(
            cmd, cwd=repo_path, text=True, capture_output=True, check=True
        )
        # Each line is a file path.
        files = result.stdout.splitlines()
        return files
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get changed files: {e}")
        sys.exit(1)

def get_precommit_file(repo_path, commit_hash, file_path):
    """
    Retrieve the content of the file at the parent of the commit (i.e. the precommit version).
    """
    # Note: using f"{commit_hash}^" to indicate the parent commit.
    cmd = ["git", "show", f"{commit_hash}^:{file_path}"]
    try:
        result = subprocess.run(
            cmd, cwd=repo_path, text=True, capture_output=True, check=True
        )
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get precommit version of {file_path}: {e}")
        return None

def get_current_file(repo_path, commit_hash, file_path):
    """
    Retrieve the content of the file at the commit (i.e. the current commit version).
    """
    cmd = ["git", "show", f"{commit_hash}:{file_path}"]
    try:
        result = subprocess.run(
            cmd, cwd=repo_path, text=True, capture_output=True, check=True
        )
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get current version of {file_path}: {e}")
        return None

def main():
    # --- Configuration ---
    # Update this path to your local repository clone.
    repo_path = "/home/azibaeir/Research/Benchmarking/gecko-dev"  # e.g., "/home/user/gecko-dev"
    commit_hash = "b59073dc8fae65cd9dc81c0137b0f7a9911873e2"

    if not os.path.isdir(repo_path):
        print(f"[ERROR] Repository path not found: {repo_path}")
        sys.exit(1)

    # Get list of changed files for the commit.
    changed_files = get_changed_files(repo_path, commit_hash)
    if not changed_files:
        print("No changed files found for this commit.")
        return

    # For each changed file, fetch and print both the precommit and current versions.
    for file_path in changed_files:
        print(f"========== Precommit version of file: {file_path} ==========")
        pre_content = get_precommit_file(repo_path, commit_hash, file_path)
        if pre_content is not None:
            # Save precommit content to a file
            pre_filename = f"{commit_hash}_pre_{os.path.basename(file_path)}"
            with open(pre_filename, 'w', encoding='utf-8') as f:
                f.write(pre_content)
            # print(pre_content)
        else:
            print(f"[WARNING] Could not retrieve precommit content for {file_path}")
        
        print(f"========== Current version of file: {file_path} ==========")
        curr_content = get_current_file(repo_path, commit_hash, file_path)
        if curr_content is not None:
            # Save current commit content to a file
            curr_filename = f"{commit_hash}_curr_{os.path.basename(file_path)}"
            with open(curr_filename, 'w', encoding='utf-8') as f:
                f.write(curr_content)
            # print(curr_content)
        else:
            print(f"[WARNING] Could not retrieve current content for {file_path}")

if __name__ == "__main__":
    main()
#  pre&curr-commit file



## all of the code except code blocks

In [1]:

#!/usr/bin/env python3
import os
import sys
import subprocess
import re
import requests
import platform
import clang.cindex

# -------------------------------------------------------------------
# Configure libclang path based on OS
if platform.system() == "Darwin":
    clang.cindex.Config.set_library_file("/Applications/Xcode.app/Contents/Frameworks/libclang.dylib")
elif platform.system() == "Linux":
    possible_paths = [
        "/usr/lib/llvm-11/lib/libclang.so",
        "/usr/lib/libclang.so",
        "/usr/lib/llvm/lib/libclang.so"
    ]
    for path in possible_paths:
        if os.path.exists(path):
            clang.cindex.Config.set_library_file(path)
            break
# -------------------------------------------------------------------
# Function extraction routines using text-based and Clang-based methods

def find_function(source_code, function_name, class_name=None, filename=None):
    """
    Find a function definition in the given source code using two methods:
      1. Text-based parsing (suitable for many C files)
      2. Clang-based parsing (better for C++ files)
    
    The function first attempts to locate the function using text parsing.
    If that fails, it falls back to a Clang-based approach.
    
    Parameters:
        source_code (str): The source code to search.
        function_name (str): The name of the function to find.
        class_name (str, optional): If searching for a class member, provide the class name.
        filename (str, optional): Filename hint for Clang parsing (default is "temp.cpp").
    
    Returns:
        str or None: The full function definition if found; otherwise, None.
    """
    # --- Text-based parsing ---
    lines = source_code.split('\n')
    for i, line in enumerate(lines):
        if function_name in line and '(' in line and not line.strip().startswith('//'):
            words = line.strip().split()
            if function_name in words or f"{function_name}(" in line:
                print(f"[DEBUG] Found potential function definition (text): {line.strip()}")
                brace_count = 0
                start_line = i
                found_opening = False
                # Try to adjust for multi-line definitions
                while start_line > 0 and not lines[start_line - 1].strip().endswith(';'):
                    start_line -= 1
                    if lines[start_line].strip().startswith('/*') or lines[start_line].strip().startswith('*'):
                        continue
                    if lines[start_line].strip():
                        break
                function_lines = []
                for j in range(start_line, len(lines)):
                    current_line = lines[j]
                    function_lines.append(current_line)
                    for char in current_line:
                        if char == '{':
                            found_opening = True
                            brace_count += 1
                        elif char == '}':
                            brace_count -= 1
                    if found_opening and brace_count == 0:
                        return '\n'.join(function_lines)
    
    # --- Clang-based parsing ---
    import clang.cindex
    index = clang.cindex.Index.create()
    args = [
        "-x", "c++",
        "--std=c++11",
        "-fparse-all-comments",
        "-I/usr/include",
        "-I/usr/local/include",
        "-I.",
        "-DMOZILLA_INTERNAL_API",
        "-DNDEBUG",
        "-DTRIMMED"
    ]
    try:
        tu = index.parse(
            filename or "temp.cpp",
            args=args,
            unsaved_files=[(filename or "temp.cpp", source_code)],
            options=clang.cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES
        )
    except Exception as e:
        print(f"[ERROR] Clang failed to parse {filename}: {e}")
        return None

    if not tu:
        print("[ERROR] Failed to create translation unit")
        return None

    for diag in tu.diagnostics:
        if diag.severity >= clang.cindex.Diagnostic.Warning:
            severity = {2: "Warning", 3: "Error", 4: "Fatal"}.get(diag.severity, "Unknown")
            print(f"[{severity}] {diag.spelling}")

    # Use a simple line search in the source as a fallback after Clang parsing
    for line in source_code.split('\n'):
        if class_name:
            search_pattern = f"{class_name}::{function_name}"
        else:
            words = line.split()
            if function_name in words and '(' in line:
                search_pattern = function_name
            else:
                continue

        if search_pattern in line:
            print(f"[DEBUG] Found potential function definition (clang): {line.strip()}")
            start_idx = source_code.find(line)
            if start_idx != -1:
                brace_count = 0
                end_idx = start_idx
                found_opening = False
                for i in range(start_idx, len(source_code)):
                    if source_code[i] == '{':
                        found_opening = True
                        brace_count += 1
                    elif source_code[i] == '}':
                        brace_count -= 1
                        if found_opening and brace_count == 0:
                            end_idx = i + 1
                            break
                if end_idx > start_idx:
                    return source_code[start_idx:end_idx]
    return None

# -------------------------------------------------------------------
# Extract function names from a patch/diff file.

def extract_functions_from_patch(patch_content):
    """
    Extract all function names from a patch/diff file where changes (+ or -) occurred.
    Returns a list of tuples (function_name, class_name).
    """
    lines = patch_content.split('\n')
    current_function = None
    current_class = None
    in_function = False
    functions = []
    
    for i, line in enumerate(lines):
        stripped_line = line.strip()
        
        # Handle @@ context lines
        if line.startswith('@@'):
            in_function = False
            current_function = None
            current_class = None
            
            if '@@ ' in line:
                context_part = line.split('@@ ')[-1].strip()
                if '::' in context_part:
                    parts = context_part.split('::')
                    current_class = parts[0].strip().replace('@', '').strip()
                    current_function = parts[1].split('(')[0].strip()
                    in_function = True
                else:
                    parts = context_part.split('(')
                    if len(parts) > 1:
                        func_parts = parts[0].split()
                        if func_parts:
                            current_function = func_parts[-1].strip()
                            current_class = None
                            in_function = True

        if not in_function:
            if '::' in stripped_line and '(' in stripped_line and not stripped_line.startswith(('+', '-', '//', '@')):
                parts = stripped_line.split('::')
                if len(parts) == 2:
                    current_class = parts[0].strip()
                    current_function = parts[1].split('(')[0].strip()
                    in_function = True
            elif '(' in stripped_line and not stripped_line.startswith(('+', '-', '//', '@', '}')):
                parts = stripped_line.split('(')[0].strip().split()
                if parts and not parts[0] in ['if', 'while', 'for', 'switch', 'return']:
                    current_function = parts[-1]
                    current_class = None
                    in_function = True
        
        if in_function and line.startswith(('+', '-')) and not line.startswith(('+++ ', '--- ')):
            if current_class and '@' in current_class:
                current_class = current_class.split('@')[-1].strip()
            current = (current_function, current_class)
            if current not in functions:
                functions.append(current)
    print(f"functions: {functions}")
    return functions

# -------------------------------------------------------------------
# Git interaction routines

def get_patch(repo_path, commit_hash):
    """
    Retrieve the full patch (diff) for the given commit.
    """
    cmd = ["git", "show", commit_hash]
    try:
        result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True)
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get patch for commit {commit_hash}: {e}")
        return None

def get_changed_files(repo_path, commit_hash):
    """
    Use git diff-tree to list all files changed in the given commit.
    """
    cmd = ["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash]
    try:
        result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True)
        return result.stdout.splitlines()
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get changed files: {e}")
        sys.exit(1)

def get_precommit_file(repo_path, commit_hash, file_path):
    """
    Retrieve the content of the file at the parent of the commit (pre-commit version).
    """
    cmd = ["git", "show", f"{commit_hash}^:{file_path}"]
    try:
        result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True)
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get precommit version of {file_path}: {e}")
        return None

def get_current_file(repo_path, commit_hash, file_path):
    """
    Retrieve the content of the file at the commit (current version).
    """
    cmd = ["git", "show", f"{commit_hash}:{file_path}"]
    try:
        result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True)
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get current version of {file_path}: {e}")
        return None

# -------------------------------------------------------------------
# Main entry point

def main():
    # --- Configuration ---
    repo_path = "/home/azibaeir/Research/Benchmarking/gecko-dev"  # Update as needed.
    commit_hash = "d3fc632669c98bc8a94c820be75455ca4b446cf7"
    # commit_hash = "b59073dc8fae65cd9dc81c0137b0f7a9911873e2"
    # commit_hash = "75b14a0e97e07f63ad55f41b7d978aeba31d711e"
    # commit_hash = "e29d8ab4e4e47c0f84ecd43c9d100791d265f71c"

    if not os.path.isdir(repo_path):
        print(f"[ERROR] Repository path not found: {repo_path}")
        sys.exit(1)

    # Retrieve patch content and extract function names from the patch.
    patch_content = get_patch(repo_path, commit_hash)
    if patch_content is None:
        print("[ERROR] Could not retrieve patch content.")
        return

    print("========== Extracting functions from patch ==========")
    functions_in_patch = extract_functions_from_patch(patch_content)
    if functions_in_patch:
        for func, cls in functions_in_patch:
            if cls:
                print(f"Function: {func}, Class: {cls}")
            else:
                print(f"Function: {func}")
    else:
        print("No functions extracted from patch.")
    print("=====================================================\n")

    # Get list of changed files for this commit.
    changed_files = get_changed_files(repo_path, commit_hash)
    if not changed_files:
        print("No changed files found for this commit.")
        return

    # Prepare accumulators for combined code blocks.
    combined_vulnerable_code = []
    combined_patched_code = []

    # For each changed file, fetch pre-commit and current versions,
    # then extract the complete function definitions for each changed function.
    for file_path in changed_files:
        print(f"Processing file: {file_path}")
        pre_content = get_precommit_file(repo_path, commit_hash, file_path)
        curr_content = get_current_file(repo_path, commit_hash, file_path)
        
        if pre_content is None or curr_content is None:
            print(f"[WARNING] Skipping file {file_path} due to missing content.")
            continue

        # Optionally, save file contents
        pre_filename = f"{commit_hash}_pre_{os.path.basename(file_path)}"
        with open(pre_filename, 'w', encoding='utf-8') as f:
            f.write(pre_content)
        curr_filename = f"{commit_hash}_curr_{os.path.basename(file_path)}"
        with open(curr_filename, 'w', encoding='utf-8') as f:
            f.write(curr_content)

        file_vulnerable_blocks = []
        file_patched_blocks = []

        # For each function extracted from the patch, check if it appears in the file.
        for func_name, class_name in functions_in_patch:
            # Check if the function name is present in at least one version.
            if func_name not in pre_content and func_name not in curr_content:
                continue

            print(f"  Attempting extraction for function '{func_name}'", end="")
            if class_name:
                print(f" (Class: {class_name})")
            else:
                print()

            vulnerable_func = find_function(pre_content, func_name, class_name, filename=file_path)
            patched_func = find_function(curr_content, func_name, class_name, filename=file_path)

            if vulnerable_func:
                file_vulnerable_blocks.append(vulnerable_func)
            else:
                print(f"    [INFO] Vulnerable version of {func_name} not found in {file_path}")

            if patched_func:
                file_patched_blocks.append(patched_func)
            else:
                print(f"    [INFO] Patched version of {func_name} not found in {file_path}")

        # Combine file-level code: if functions were extracted, prepend a file header.
        if file_vulnerable_blocks:
            combined_vulnerable_code.append(f"// File: {file_path}\n" + "\n\n".join(file_vulnerable_blocks))
        else:
            # If no functions, include the whole file as vulnerable code.
            combined_vulnerable_code.append(f"// File: {file_path}\n" + pre_content)

        if file_patched_blocks:
            combined_patched_code.append(f"// File: {file_path}\n" + "\n\n".join(file_patched_blocks))
        else:
            combined_patched_code.append(f"// File: {file_path}\n" + curr_content)

    # Build the final combined code blocks.
    vulnerable_code_block = "\n\n".join(combined_vulnerable_code)
    patched_code_block = "\n\n".join(combined_patched_code)

    print("========== Combined Vulnerable Code Block ==========")
    print(vulnerable_code_block)
    print("=====================================================\n")
    print("========== Combined Patched Code Block ==========")
    print(patched_code_block)
    print("=====================================================")

if __name__ == "__main__":
    main()


# all of the code except code blocks


functions: [('HandleAsyncRedirect', 'nsBaseChannel')]
Function: HandleAsyncRedirect, Class: nsBaseChannel

Processing file: netwerk/base/src/nsBaseChannel.cpp
  Attempting extraction for function 'HandleAsyncRedirect' (Class: nsBaseChannel)
[DEBUG] Found potential function definition (text): nsBaseChannel::HandleAsyncRedirect(nsIChannel* newChannel)
[DEBUG] Found potential function definition (text): nsBaseChannel::HandleAsyncRedirect(nsIChannel* newChannel)
// File: netwerk/base/src/nsBaseChannel.cpp
void
nsBaseChannel::HandleAsyncRedirect(nsIChannel* newChannel)
{
  NS_ASSERTION(!mPump, "Shouldn't have gotten here");
  if (NS_SUCCEEDED(mStatus)) {
      nsresult rv = Redirect(newChannel, nsIChannelEventSink::REDIRECT_INTERNAL,
                             PR_TRUE);
      if (NS_FAILED(rv))
          Cancel(rv);
  }

  mWaitingOnAsyncRedirect = PR_FALSE;

  if (NS_FAILED(mStatus)) {
    // Notify our consumer ourselves
    mListener->OnStartRequest(this, mListenerContext);
    mListen

In [3]:
#!/usr/bin/env python3
import os
import sys
import subprocess
import re
import sqlite3
import platform
import clang.cindex

# -------------------------------------------------------------------
# Configure libclang path based on OS
if platform.system() == "Darwin":
    clang.cindex.Config.set_library_file("/Applications/Xcode.app/Contents/Frameworks/libclang.dylib")
elif platform.system() == "Linux":
    possible_paths = [
        "/usr/lib/llvm-11/lib/libclang.so",
        "/usr/lib/libclang.so",
        "/usr/lib/llvm/lib/libclang.so"
    ]
    for path in possible_paths:
        if os.path.exists(path):
            clang.cindex.Config.set_library_file(path)
            break
# -------------------------------------------------------------------
# Function extraction routines using text-based and Clang-based methods

def find_function(source_code, function_name, class_name=None, filename=None):
    """
    Find a function definition in the given source code using text- and Clang-based parsing.
    Returns the complete function definition if found; otherwise, None.
    """
    lines = source_code.split('\n')
    for i, line in enumerate(lines):
        if function_name in line and '(' in line and not line.strip().startswith('//'):
            words = line.strip().split()
            if function_name in words or f"{function_name}(" in line:
                brace_count = 0
                start_line = i
                found_opening = False
                while start_line > 0 and not lines[start_line - 1].strip().endswith(';'):
                    start_line -= 1
                    if lines[start_line].strip().startswith('/*') or lines[start_line].strip().startswith('*'):
                        continue
                    if lines[start_line].strip():
                        break
                function_lines = []
                for j in range(start_line, len(lines)):
                    current_line = lines[j]
                    function_lines.append(current_line)
                    for char in current_line:
                        if char == '{':
                            found_opening = True
                            brace_count += 1
                        elif char == '}':
                            brace_count -= 1
                    if found_opening and brace_count == 0:
                        return '\n'.join(function_lines)
    # Fallback: Clang-based parsing
    import clang.cindex
    index = clang.cindex.Index.create()
    args = [
        "-x", "c++",
        "--std=c++11",
        "-fparse-all-comments",
        "-I/usr/include",
        "-I/usr/local/include",
        "-I.",
        "-DMOZILLA_INTERNAL_API",
        "-DNDEBUG",
        "-DTRIMMED"
    ]
    try:
        tu = index.parse(filename or "temp.cpp",
                         args=args,
                         unsaved_files=[(filename or "temp.cpp", source_code)],
                         options=clang.cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES)
    except Exception as e:
        print(f"[ERROR] Clang failed to parse {filename}: {e}")
        return None
    if not tu:
        print("[ERROR] Failed to create translation unit")
        return None
    for diag in tu.diagnostics:
        if diag.severity >= clang.cindex.Diagnostic.Warning:
            severity = {2: "Warning", 3: "Error", 4: "Fatal"}.get(diag.severity, "Unknown")
            print(f"[{severity}] {diag.spelling}")
    for line in source_code.split('\n'):
        if class_name:
            search_pattern = f"{class_name}::{function_name}"
        else:
            words = line.split()
            if function_name in words and '(' in line:
                search_pattern = function_name
            else:
                continue
        if search_pattern in line:
            start_idx = source_code.find(line)
            if start_idx != -1:
                brace_count = 0
                end_idx = start_idx
                found_opening = False
                for i in range(start_idx, len(source_code)):
                    if source_code[i] == '{':
                        found_opening = True
                        brace_count += 1
                    elif source_code[i] == '}':
                        brace_count -= 1
                        if found_opening and brace_count == 0:
                            end_idx = i + 1
                            break
                if end_idx > start_idx:
                    return source_code[start_idx:end_idx]
    return None

def extract_functions_from_patch(patch_content):
    """
    Extract all function names from a patch/diff file where changes occurred.
    Returns a list of tuples (function_name, class_name).
    """
    lines = patch_content.split('\n')
    current_function = None
    current_class = None
    in_function = False
    functions = []
    
    for i, line in enumerate(lines):
        stripped_line = line.strip()
        if line.startswith('@@'):
            in_function = False
            current_function = None
            current_class = None
            if '@@ ' in line:
                context_part = line.split('@@ ')[-1].strip()
                if '::' in context_part:
                    parts = context_part.split('::')
                    current_class = parts[0].strip().replace('@', '').strip()
                    current_function = parts[1].split('(')[0].strip()
                    in_function = True
                else:
                    parts = context_part.split('(')
                    if len(parts) > 1:
                        func_parts = parts[0].split()
                        if func_parts:
                            current_function = func_parts[-1].strip()
                            current_class = None
                            in_function = True
        if not in_function:
            if '::' in stripped_line and '(' in stripped_line and not stripped_line.startswith(('+', '-', '//', '@')):
                parts = stripped_line.split('::')
                if len(parts) == 2:
                    current_class = parts[0].strip()
                    current_function = parts[1].split('(')[0].strip()
                    in_function = True
            elif '(' in stripped_line and not stripped_line.startswith(('+', '-', '//', '@', '}')):
                parts = stripped_line.split('(')[0].strip().split()
                if parts and not parts[0] in ['if', 'while', 'for', 'switch', 'return']:
                    current_function = parts[-1]
                    current_class = None
                    in_function = True
        if in_function and line.startswith(('+', '-')) and not line.startswith(('+++ ', '--- ')):
            if current_class and '@' in current_class:
                current_class = current_class.split('@')[-1].strip()
            current = (current_function, current_class)
            if current not in functions:
                functions.append(current)
    return functions

def get_patch(repo_path, commit_hash):
    """
    Retrieve the full patch (diff) for the given commit.
    """
    cmd = ["git", "show", commit_hash]
    try:
        result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True, errors='replace')
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get patch for commit {commit_hash}: {e}")
        return None

def get_changed_files(repo_path, commit_hash):
    """
    Use git diff-tree to list all files changed in the given commit.
    """
    cmd = ["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash]
    try:
        result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True, errors='replace')
        return result.stdout.splitlines()
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get changed files: {e}")
        sys.exit(1)

def get_precommit_file(repo_path, commit_hash, file_path):
    """
    Retrieve the content of the file at the parent of the commit (pre-commit version).
    """
    cmd = ["git", "show", f"{commit_hash}^:{file_path}"]
    try:
        result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True, errors='replace')
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get precommit version of {file_path}: {e}")
        return None

def get_current_file(repo_path, commit_hash, file_path):
    """
    Retrieve the content of the file at the commit (current version).
    """
    cmd = ["git", "show", f"{commit_hash}:{file_path}"]
    try:
        result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True, errors='replace')
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"[ERROR] Failed to get current version of {file_path}: {e}")
        return None

# -------------------------------------------------------------------
# Main entry point for updating code blocks in the database

def main():
    repo_path = "/home/azibaeir/Research/Benchmarking/gecko-dev"
    commit_db_path = "/home/azibaeir/Research/Benchmarking/project/vulnerability_dataset/database/database.sqlite"
    
    conn = sqlite3.connect(commit_db_path)
    cursor = conn.cursor()
    cursor.execute("UPDATE vulnerabilities SET VULNERABLE_CODE_BLOCK = '', PATCHED_CODE_BLOCK = '' WHERE PROJECT = 'mozilla'")
    conn.commit()
    
    cursor.execute("SELECT COMMIT_HASH FROM vulnerabilities WHERE PROJECT = 'mozilla'")
    commit_hashes = [row[0] for row in cursor.fetchall()]
    print(f"Processing {len(commit_hashes)} commit(s) in mozilla.")
    
    for commit_hash in commit_hashes:
        print(f"\nProcessing commit: {commit_hash}")
        try:
            patch_content = get_patch(repo_path, commit_hash)
            if patch_content is None:
                print(f"[ERROR] Could not retrieve patch for commit {commit_hash}")
                continue

            functions_in_patch = extract_functions_from_patch(patch_content)
            changed_files = get_changed_files(repo_path, commit_hash)
            if not changed_files:
                print(f"[ERROR] No changed files for commit {commit_hash}")
                continue

            combined_vulnerable_blocks = []
            combined_patched_blocks = []

            for file_path in changed_files:
                print(f"  Processing file: {file_path}")
                pre_content = get_precommit_file(repo_path, commit_hash, file_path)
                curr_content = get_current_file(repo_path, commit_hash, file_path)
                if pre_content is None or curr_content is None:
                    print(f"    [ERROR] Missing content for {file_path}, skipping.")
                    continue

                # Save file contents to temporary files.
                pre_filename = f"{commit_hash}_pre_{os.path.basename(file_path)}"
                with open(pre_filename, 'w', encoding='utf-8') as f:
                    f.write(pre_content)
                curr_filename = f"{commit_hash}_curr_{os.path.basename(file_path)}"
                with open(curr_filename, 'w', encoding='utf-8') as f:
                    f.write(curr_content)
                
                file_vulnerable_blocks = []
                file_patched_blocks = []
                source_extensions = (".c", ".cpp", ".h", ".hpp", ".cc", ".cxx")
                if file_path.lower().endswith(source_extensions):
                    for func_name, class_name in functions_in_patch:
                        if func_name not in pre_content and func_name not in curr_content:
                            continue
                        vuln_func = find_function(pre_content, func_name, class_name, filename=file_path)
                        patch_func = find_function(curr_content, func_name, class_name, filename=file_path)
                        if vuln_func:
                            file_vulnerable_blocks.append(vuln_func)
                        else:
                            print(f"      [ERROR] Vulnerable version not found for {func_name} in {file_path}")
                        if patch_func:
                            file_patched_blocks.append(patch_func)
                        else:
                            print(f"      [ERROR] Patched version not found for {func_name} in {file_path}")
                else:
                    file_vulnerable_blocks.append(pre_content)
                    file_patched_blocks.append(curr_content)
                
                combined_vulnerable_blocks.append(f"// File: {file_path}\n" + "\n\n".join(file_vulnerable_blocks))
                combined_patched_blocks.append(f"// File: {file_path}\n" + "\n\n".join(file_patched_blocks))
                
                # Remove temporary files.
                os.remove(pre_filename)
                os.remove(curr_filename)
            
            vulnerable_code_block = "\n\n".join(combined_vulnerable_blocks)
            patched_code_block = "\n\n".join(combined_patched_blocks)
            
            cursor.execute("""
                UPDATE vulnerabilities
                SET VULNERABLE_CODE_BLOCK = ?,
                    PATCHED_CODE_BLOCK = ?
                WHERE COMMIT_HASH = ? AND PROJECT = 'mozilla'
            """, (vulnerable_code_block, patched_code_block, commit_hash))
            conn.commit()
            print(f"  Updated commit {commit_hash} with new code blocks.")
        except Exception as e:
            print(f"[ERROR] Processing commit {commit_hash} failed: {e}")
    
    cursor.close()
    conn.close()
    print("Processing complete.")

if __name__ == "__main__":
    main()


Processing 321 commit(s) in mozilla.

Processing commit: d3fc632669c98bc8a94c820be75455ca4b446cf7
  Processing file: netwerk/base/src/nsBaseChannel.cpp
  Updated commit d3fc632669c98bc8a94c820be75455ca4b446cf7 with new code blocks.

Processing commit: d35623d3e5c126c824408e8b8e5c4f28877792eb
  Processing file: content/base/public/nsContentCID.h
  Processing file: content/html/document/src/Makefile.in
  Processing file: content/html/document/src/nsHTMLFragmentContentSink.cpp
  Processing file: content/xml/document/src/nsXMLFragmentContentSink.cpp
  Processing file: editor/libeditor/html/nsHTMLDataTransfer.cpp
  Processing file: editor/libeditor/html/tests/Makefile.in
  Processing file: editor/libeditor/html/tests/test_bug520189.html
[ERROR] Failed to get precommit version of editor/libeditor/html/tests/test_bug520189.html: Command '['git', 'show', 'd35623d3e5c126c824408e8b8e5c4f28877792eb^:editor/libeditor/html/tests/test_bug520189.html']' returned non-zero exit status 128.
    [ERROR] 

## all of the code

In [18]:
#!/usr/bin/env python3
import subprocess
import re
import os
import requests
import sys
import platform
import clang.cindex

# Configure libclang path based on OS
if platform.system() == "Darwin":
    clang.cindex.Config.set_library_file("/Applications/Xcode.app/Contents/Frameworks/libclang.dylib")
elif platform.system() == "Linux":
    possible_paths = [
        "/usr/lib/llvm-11/lib/libclang.so",
        "/usr/lib/libclang.so",
        "/usr/lib/llvm/lib/libclang.so"
    ]
    for path in possible_paths:
        if os.path.exists(path):
            clang.cindex.Config.set_library_file(path)
            break

class GitInteraction:
    def __init__(self, repo_path):
        self.repo_path = repo_path

    def get_patch(self, commit_hash):
        """
        Retrieve the full patch (diff) for the given commit.
        """
        url = f"https://github.com/mozilla/gecko-dev/commit/{commit_hash}.patch"
        try:
            response = requests.get(url)
            response.raise_for_status()
            return response.text
        except requests.RequestException as e:
            print(f"Error fetching patch from URL: {url}")
            print(e)
            return None

    @staticmethod
    def get_changed_files(repo_path, commit_hash):
        """
        Use git diff-tree to list all files changed in the given commit.
        """
        cmd = ["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash]
        try:
            result = subprocess.run(cmd, cwd=repo_path, text=True, capture_output=True, check=True)
            return result.stdout.splitlines()
        except subprocess.CalledProcessError as e:
            print(f"[ERROR] Failed to get changed files: {e}")
            sys.exit(1)

    def get_current_file(self, commit_hash, file_path):
        """
        Retrieve the content of the file at the commit (current version).
        """
        cmd = ["git", "show", f"{commit_hash}:{file_path}"]
        try:
            result = subprocess.run(cmd, cwd=self.repo_path, text=True, capture_output=True, check=True)
            return result.stdout
        except subprocess.CalledProcessError as e:
            print(f"[ERROR] Failed to get current version of {file_path}: {e}")
            return None

    def get_precommit_file(self, commit_hash, file_path):
        """
        Retrieve the content of the file at the parent of the commit (pre-commit version).
        """
        cmd = ["git", "show", f"{commit_hash}^:{file_path}"]
        try:
            result = subprocess.run(cmd, cwd=self.repo_path, text=True, capture_output=True, check=True)
            return result.stdout
        except subprocess.CalledProcessError as e:
            print(f"[ERROR] Failed to get precommit version of {file_path}: {e}")
            return None

    def find_function(self, source_code, function_name, class_name=None, filename=None):
        """
        Find a function definition in the given source code using two methods:
        1. Text-based parsing (suitable for many C files)
        2. Clang-based parsing (better for C++ files)
        
        The function first attempts to locate the function using text parsing.
        If that fails, it falls back to a Clang-based approach.
        
        Parameters:
            source_code (str): The source code to search.
            function_name (str): The name of the function to find.
            class_name (str, optional): If searching for a class member, provide the class name.
            filename (str, optional): Filename hint for Clang parsing (default is "temp.cpp").
        
        Returns:
            str or None: The full function definition if found; otherwise, None.
        """
        # --- Text-based parsing ---
        lines = source_code.split('\n')
        for i, line in enumerate(lines):
            if function_name in line and '(' in line and not line.strip().startswith('//'):
                words = line.strip().split()
                if function_name in words or f"{function_name}(" in line:
                    print(f"[DEBUG] Found potential function definition (text): {line.strip()}")
                    brace_count = 0
                    start_line = i
                    found_opening = False
                    # Try to adjust for multi-line definitions
                    while start_line > 0 and not lines[start_line - 1].strip().endswith(';'):
                        start_line -= 1
                        if lines[start_line].strip().startswith('/*') or lines[start_line].strip().startswith('*'):
                            continue
                        if lines[start_line].strip():
                            break
                    function_lines = []
                    for j in range(start_line, len(lines)):
                        current_line = lines[j]
                        function_lines.append(current_line)
                        for char in current_line:
                            if char == '{':
                                found_opening = True
                                brace_count += 1
                            elif char == '}':
                                brace_count -= 1
                        if found_opening and brace_count == 0:
                            return '\n'.join(function_lines)
        
        # --- Clang-based parsing ---
        import clang.cindex
        index = clang.cindex.Index.create()
        args = [
            "-x", "c++",
            "--std=c++11",
            "-fparse-all-comments",
            "-I/usr/include",
            "-I/usr/local/include",
            "-I.",
            "-DMOZILLA_INTERNAL_API",
            "-DNDEBUG",
            "-DTRIMMED"
        ]
        try:
            tu = index.parse(
                filename or "temp.cpp",
                args=args,
                unsaved_files=[(filename or "temp.cpp", source_code)],
                options=clang.cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES
            )
        except Exception as e:
            print(f"[ERROR] Clang failed to parse {filename}: {e}")
            return None

        if not tu:
            print("[ERROR] Failed to create translation unit")
            return None

        for diag in tu.diagnostics:
            if diag.severity >= clang.cindex.Diagnostic.Warning:
                severity = {2: "Warning", 3: "Error", 4: "Fatal"}.get(diag.severity, "Unknown")
                print(f"[{severity}] {diag.spelling}")

        # Use a simple line search in the source as a fallback after Clang parsing
        for line in source_code.split('\n'):
            if class_name:
                search_pattern = f"{class_name}::{function_name}"
            else:
                words = line.split()
                if function_name in words and '(' in line:
                    search_pattern = function_name
                else:
                    continue

            if search_pattern in line:
                print(f"[DEBUG] Found potential function definition (clang): {line.strip()}")
                start_idx = source_code.find(line)
                if start_idx != -1:
                    brace_count = 0
                    end_idx = start_idx
                    found_opening = False
                    for i in range(start_idx, len(source_code)):
                        if source_code[i] == '{':
                            found_opening = True
                            brace_count += 1
                        elif source_code[i] == '}':
                            brace_count -= 1
                            if found_opening and brace_count == 0:
                                end_idx = i + 1
                                break
                    if end_idx > start_idx:
                        return source_code[start_idx:end_idx]
        return None

    # def extract_functions(patch_content):
    #     """
    #     Extract all function names from a patch/diff file where changes (+ or -) occurred.
    #     Returns a list of tuples (function_name, class_name).
    #     """
    #     lines = patch_content.split('\n')
    #     current_function = None
    #     current_class = None
    #     in_function = False
    #     functions = []
        
    #     for i, line in enumerate(lines):
    #         stripped_line = line.strip()
            
    #         # Handle @@ context lines
    #         if line.startswith('@@'):
    #             in_function = False
    #             current_function = None
    #             current_class = None
                
    #             if '@@ ' in line:
    #                 context_part = line.split('@@ ')[-1].strip()
    #                 if '::' in context_part:
    #                     parts = context_part.split('::')
    #                     current_class = parts[0].strip().replace('@', '').strip()
    #                     current_function = parts[1].split('(')[0].strip()
    #                     in_function = True
    #                 else:
    #                     parts = context_part.split('(')
    #                     if len(parts) > 1:
    #                         func_parts = parts[0].split()
    #                         if func_parts:
    #                             current_function = func_parts[-1].strip()
    #                             current_class = None
    #                             in_function = True

    #         if not in_function:
    #             if '::' in stripped_line and '(' in stripped_line and not stripped_line.startswith(('+', '-', '//', '@')):
    #                 parts = stripped_line.split('::')
    #                 if len(parts) == 2:
    #                     current_class = parts[0].strip()
    #                     current_function = parts[1].split('(')[0].strip()
    #                     in_function = True
    #             elif '(' in stripped_line and not stripped_line.startswith(('+', '-', '//', '@', '}')):
    #                 parts = stripped_line.split('(')[0].strip().split()
    #                 if parts and not parts[0] in ['if', 'while', 'for', 'switch', 'return']:
    #                     current_function = parts[-1]
    #                     current_class = None
    #                     in_function = True
            
    #         if in_function and line.startswith(('+', '-')) and not line.startswith(('+++ ', '--- ')):
    #             if current_class and '@' in current_class:
    #                 current_class = current_class.split('@')[-1].strip()
    #             current = (current_function, current_class)
    #             if current not in functions:
    #                 functions.append(current)
    #     print(f"functions: {functions}")          
    #     return functions

    def extract_functions_from_patch(self, patch_content):
        """
        Extract file paths and function information from a diff.
        For each file, returns a dictionary with keys:
          - "functions": a dict mapping function names to a dict with keys:
              - "class": the class name (if any)
              - "added": list of added lines within that function
              - "deleted": list of deleted lines within that function
          - "added": file-level added lines (if any)
          - "deleted": file-level deleted lines (if any)
        """
        file_path_pattern = re.compile(r'^diff --git a/(.*?) b/')
        files_info = {}
        current_file_path = None
        current_function = None
        current_class = None
        current_added_block = []
        current_deleted_block = []
        lines = patch_content.split('\n')

        for line in lines:
            # Check for a file header line.
            file_match = file_path_pattern.search(line)
            if file_match:
                # Flush pending blocks before switching files.
                if current_file_path is not None:
                    if current_added_block:
                        if current_function:
                            files_info[current_file_path]["functions"][current_function]["added"].append('\n'.join(current_added_block))
                        else:
                            files_info[current_file_path].setdefault("added", []).append('\n'.join(current_added_block))
                        current_added_block = []
                    if current_deleted_block:
                        if current_function:
                            files_info[current_file_path]["functions"][current_function]["deleted"].append('\n'.join(current_deleted_block))
                        else:
                            files_info[current_file_path].setdefault("deleted", []).append('\n'.join(current_deleted_block))
                        current_deleted_block = []
                current_file_path = file_match.group(1).strip()
                if current_file_path not in files_info:
                    files_info[current_file_path] = {"functions": {}, "added": [], "deleted": []}
                current_function = None  # Reset function context on new file
                current_class = None
                continue

            # Check for hunk header lines (starting with @@).
            if line.startswith('@@'):
                # Flush pending blocks from previous function.
                if current_file_path is not None and current_added_block:
                    if current_function:
                        files_info[current_file_path]["functions"][current_function]["added"].append('\n'.join(current_added_block))
                    else:
                        files_info[current_file_path].setdefault("added", []).append('\n'.join(current_added_block))
                    current_added_block = []
                if current_file_path is not None and current_deleted_block:
                    if current_function:
                        files_info[current_file_path]["functions"][current_function]["deleted"].append('\n'.join(current_deleted_block))
                    else:
                        files_info[current_file_path].setdefault("deleted", []).append('\n'.join(current_deleted_block))
                    current_deleted_block = []
                # --- Use logic from extract_functions to update current_function and current_class ---
                # Reset context:
                current_function = None
                current_class = None
                in_function = False
                if "@@ " in line:
                    context_part = line.split("@@ ")[-1].strip()
                    if "::" in context_part:
                        parts = context_part.split("::")
                        current_class = parts[0].strip().replace("@", "").strip()
                        current_function = parts[1].split("(")[0].strip()
                        in_function = True
                    else:
                        parts = context_part.split("(")
                        if len(parts) > 1:
                            func_parts = parts[0].split()
                            if func_parts:
                                current_function = func_parts[-1].strip()
                                current_class = None
                                in_function = True
                # If a function was extracted, add its entry.
                if in_function and current_file_path:
                    if current_function not in files_info[current_file_path]["functions"]:
                        files_info[current_file_path]["functions"][current_function] = {
                            "class": current_class,
                            "added": [],
                            "deleted": []
                        }
                continue

            # Process added/deleted lines.
            if current_file_path:
                if line.startswith('+') and not line.startswith('+++'):
                    if current_deleted_block:
                        if current_function:
                            files_info[current_file_path]["functions"][current_function]["deleted"].append('\n'.join(current_deleted_block))
                        else:
                            files_info[current_file_path].setdefault("deleted", []).append('\n'.join(current_deleted_block))
                        current_deleted_block = []
                    current_added_block.append(line[1:].strip())
                elif line.startswith('-') and not line.startswith('---'):
                    if current_added_block:
                        if current_function:
                            files_info[current_file_path]["functions"][current_function]["added"].append('\n'.join(current_added_block))
                        else:
                            files_info[current_file_path].setdefault("added", []).append('\n'.join(current_added_block))
                        current_added_block = []
                    current_deleted_block.append(line[1:].strip())
                else:
                    if current_added_block:
                        if current_function:
                            files_info[current_file_path]["functions"][current_function]["added"].append('\n'.join(current_added_block))
                        else:
                            files_info[current_file_path].setdefault("added", []).append('\n'.join(current_added_block))
                        current_added_block = []
                    if current_deleted_block:
                        if current_function:
                            files_info[current_file_path]["functions"][current_function]["deleted"].append('\n'.join(current_deleted_block))
                        else:
                            files_info[current_file_path].setdefault("deleted", []).append('\n'.join(current_deleted_block))
                        current_deleted_block = []

        # Flush any remaining blocks.
        if current_file_path:
            if current_added_block:
                if current_function:
                    files_info[current_file_path]["functions"][current_function]["added"].append('\n'.join(current_added_block))
                else:
                    files_info[current_file_path].setdefault("added", []).append('\n'.join(current_added_block))
            if current_deleted_block:
                if current_function:
                    files_info[current_file_path]["functions"][current_function]["deleted"].append('\n'.join(current_deleted_block))
                else:
                    files_info[current_file_path].setdefault("deleted", []).append('\n'.join(current_deleted_block))

        # Cleanup: remove empty strings.
        for file_path, changes in files_info.items():
            if "added" in changes:
                changes["added"] = list(filter(None, changes["added"]))
            if "deleted" in changes:
                changes["deleted"] = list(filter(None, changes["deleted"]))
            for func_name, func_changes in changes["functions"].items():
                func_changes["added"] = list(filter(None, func_changes["added"]))
                func_changes["deleted"] = list(filter(None, func_changes["deleted"]))
        return files_info



    def is_change_within_function(self, vulnerable_function, patched_function, changes_deleted, changes_added):
        """
        Check if any of the deleted changes appear in the vulnerable function
        or if any of the added changes appear in the patched function.
        """
        vulnerable_lines = vulnerable_function.splitlines()
        patched_lines = patched_function.splitlines()
        
        # Check if deleted changes exist in the vulnerable function
        for change in changes_deleted:
            change = change.strip()  # Ensure no leading/trailing spaces
            for line in vulnerable_lines:
                if change in line.strip():
                    return True

        # Check if added changes exist in the patched function
        for change in changes_added:
            change = change.strip()
            for line in patched_lines:
                if change in line.strip():
                    return True

        return False

    def parase_patch_header(self, patch_text):
        """Parse the patch header to extract the number of files changed, added, and deleted lines."""
        added_lines = 0
        deleted_lines = 0
        files_changed = set()

        file_pattern = re.compile(r'^diff --git a/(.*?) b/(.*?)$', re.MULTILINE)
        matches = file_pattern.findall(patch_text)
        for match in matches:
            files_changed.add(match[0])

        sections = re.split(r'(?m)^diff --git', patch_text)
        for section in sections[1:]:
            lines = section.split('\n')
            for line in lines:
                if line.startswith('+') and not line.startswith('+++'):
                    added_lines += 1
                elif line.startswith('-') and not line.startswith('---'):
                    deleted_lines += 1

        return len(files_changed), added_lines, deleted_lines

    def extract_commit_description(self, commit_hash):
        """Extract the commit description using git log."""
        try:
            result = subprocess.run(
                ['git', '-C', self.repo_path, 'log', '--format=%B', '-n', '1', commit_hash],
                stdout=subprocess.PIPE,
                text=True,
                encoding='utf-8'
            )
            return result.stdout.strip()
        except subprocess.CalledProcessError as e:
            print(f"Error extracting description for commit {commit_hash}")
            print(e.output)
            return None

    def build_code_blocks(self, files_info, commit_hash):
        """Build the vulnerable and patched code blocks from the patch info."""
        vulnerable_code_block = ""
        patched_code_block = ""

        # Process file-level changes
        for file_path, file_changes in files_info.items():
            file_header_printed_vulnerable = False
            file_header_printed_patched = False

            # Process function-level changes
            functions_to_modify = []
            for function_name, changes in file_changes['functions'].items():
                if not function_name:
                    continue

                vulnerable_code = self.get_precommit_file(commit_hash, file_path)
                patched_code = self.get_current_file(commit_hash, file_path)

                vulnerable_function = self.extract_function(vulnerable_code, function_name)
                patched_function = self.extract_function(patched_code, function_name)

                # Use our new find_function method to extract the entire function if possible
                if vulnerable_function and patched_function:
                    if self.is_change_within_function(vulnerable_function, patched_function,
                                                      changes['deleted'], changes['added']):
                        if not file_header_printed_vulnerable:
                            vulnerable_code_block += f"// File path: {file_path}\n"
                            file_header_printed_vulnerable = True
                        if not file_header_printed_patched:
                            patched_code_block += f"// File path: {file_path}\n"
                            file_header_printed_patched = True
                        vulnerable_code_block += f"{vulnerable_function}\n"
                        patched_code_block += f"{patched_function}\n"
                    else:
                        # Look for function signatures in the added/deleted lines
                        pattern = r'\b([a-zA-Z_][a-zA-Z0-9_\* ]*\s+[a-zA-Z_][a-zA-Z0-9_]*)\s*\([^)]*\)'
                        added_function_signatures = re.findall(pattern, '\n'.join(changes['added']), re.MULTILINE)
                        deleted_function_signatures = re.findall(pattern, '\n'.join(changes['deleted']), re.MULTILINE)
                        if added_function_signatures or deleted_function_signatures:
                            new_function_name = (added_function_signatures[0] 
                                                 if added_function_signatures 
                                                 else deleted_function_signatures[0])
                            functions_to_modify.append((function_name, new_function_name))
                        else:
                            functions_to_modify.append((function_name, ""))
                else:
                    if changes.get('added'):
                        sigs = self.extract_function_signatures('\n'.join(changes['added']))
                        if sigs:
                            new_function_name = sigs[0]
                            functions_to_modify.append((function_name, new_function_name))
                            patched_function = self.find_function(patched_code, new_function_name)
                        else:
                            patched_function = '\n'.join(changes['added'])
                            functions_to_modify.append((function_name, ""))
                        if not file_header_printed_patched:
                            patched_code_block += f"// File path: {file_path}\n"
                            file_header_printed_patched = True
                        patched_code_block += f"{patched_function}\n"

                    if changes.get('deleted'):
                        sigs = self.extract_function_signatures('\n'.join(changes['deleted']))
                        if sigs:
                            new_function_name = sigs[0]
                            functions_to_modify.append((function_name, new_function_name))
                            vulnerable_function = self.find_function(vulnerable_code, new_function_name)
                        else:
                            vulnerable_function = '\n'.join(changes['deleted'])
                            functions_to_modify.append((function_name, ""))
                        if not file_header_printed_vulnerable:
                            vulnerable_code_block += f"// File path: {file_path}\n"
                            file_header_printed_vulnerable = True
                        vulnerable_code_block += f"{vulnerable_function}\n"

            # Process any functions that were marked for modification
            functions_to_modify = list(set(functions_to_modify))
            for function_name, new_function_name in functions_to_modify:
                if not function_name:
                    continue
                if new_function_name in file_changes['functions']:
                    combine_add = file_changes['functions'][function_name]['added'] + \
                                  file_changes['functions'][new_function_name]['added']
                    combine_del = file_changes['functions'][function_name]['deleted'] + \
                                  file_changes['functions'][new_function_name]['deleted']
                    file_changes['functions'][new_function_name] = {'added': combine_add, 'deleted': combine_del}
                    del file_changes['functions'][function_name]
                else:
                    original_value = file_changes['functions'][function_name]
                    del file_changes['functions'][function_name]
                    file_changes['functions'][new_function_name] = original_value

                if not new_function_name:
                    continue
                if new_function_name in vulnerable_code_block or new_function_name in patched_code_block:
                    continue
                vulnerable_code = self.get_precommit_file(commit_hash, file_path)
                patched_code = self.get_current_file(commit_hash, file_path)
                vulnerable_function = self.find_function(vulnerable_code, new_function_name)
                patched_function = self.find_function(patched_code, new_function_name)
                if vulnerable_function or patched_function:
                    if not file_header_printed_vulnerable:
                        vulnerable_code_block += f"// File path: {file_path}\n"
                        file_header_printed_vulnerable = True
                    if not file_header_printed_patched:
                        patched_code_block += f"// File path: {file_path}\n"
                        file_header_printed_patched = True
                    vulnerable_code_block += f"{vulnerable_function}\n"
                    patched_code_block += f"{patched_function}\n"

            # Handle file-level added/deleted lines if present
            if file_changes.get('added'):
                if not file_header_printed_patched:
                    patched_code_block += f"// File path: {file_path}\n"
                    file_header_printed_patched = True
                patched_code_block += f"{''.join(file_changes['added'])}\n"

            if file_changes.get('deleted'):
                if not file_header_printed_vulnerable:
                    vulnerable_code_block += f"// File path: {file_path}\n"
                    file_header_printed_vulnerable = True
                vulnerable_code_block += f"{''.join(file_changes['deleted'])}\n"

        return vulnerable_code_block, patched_code_block

def main():
    repo_path = r"/home/azibaeir/Research/Benchmarking/gecko-dev"
    commit_hash = "d3fc632669c98bc8a94c820be75455ca4b446cf7"

    git_interaction = GitInteraction(repo_path)
    
    # Retrieve the patch text from the remote URL.
    patch_text = git_interaction.get_patch(commit_hash)
    if not patch_text:
        print("Failed to retrieve patch.")
        sys.exit(1)
    
    # Get the list of changed files.
    changed_files = GitInteraction.get_changed_files(repo_path, commit_hash)
    print(f"Changed files: {changed_files}")
    
    # Extract functions (and file-level changes) from the patch.
    functions_in_patch = git_interaction.extract_functions_from_patch(patch_text)
    print(f"Functions in patch: {functions_in_patch}")
    # Build the files_info dictionary expected by build_code_blocks.
    files_info = {}
    for file_path in changed_files:
        files_info[file_path] = {
            "functions": {},
            "added": [],
            "deleted": []
        }
        if file_path in functions_in_patch:
            # Populate function-level changes from the patch extraction
            for func_name, func_data in functions_in_patch[file_path]["functions"].items():
                files_info[file_path]["functions"][func_name] = {"added": [], "deleted": []}
            # Also include file-level added/deleted lines if present
            files_info[file_path]["added"] = functions_in_patch[file_path].get("added", [])
            files_info[file_path]["deleted"] = functions_in_patch[file_path].get("deleted", [])
     
    # Build the vulnerable and patched code blocks using the constructed files_info.
    vulnerable_code_block, patched_code_block = git_interaction.build_code_blocks(files_info, commit_hash)

    print("Vulnerable code block:")
    print(vulnerable_code_block)
    print("\nPatched code block:")
    print(patched_code_block)

if __name__ == "__main__":
    main()


Changed files: ['netwerk/base/src/nsBaseChannel.cpp']
Functions in patch: {'netwerk/base/src/nsBaseChannel.cpp': {'functions': {}, 'added': ['PRBool doNotify = PR_TRUE;', 'else\ndoNotify = PR_FALSE;', 'if (doNotify) {'], 'deleted': ['if (NS_FAILED(mStatus)) {']}}
Vulnerable code block:
// File path: netwerk/base/src/nsBaseChannel.cpp
if (NS_FAILED(mStatus)) {


Patched code block:
// File path: netwerk/base/src/nsBaseChannel.cpp
PRBool doNotify = PR_TRUE;else
doNotify = PR_FALSE;if (doNotify) {



## extract function

In [30]:
#!/usr/bin/env python3
import os
import sys
import clang.cindex
import platform

# Configure libclang path based on OS
if platform.system() == "Darwin":
    clang.cindex.Config.set_library_file("/Applications/Xcode.app/Contents/Frameworks/libclang.dylib")
elif platform.system() == "Linux":
    possible_paths = [
        "/usr/lib/llvm-11/lib/libclang.so",
        "/usr/lib/libclang.so",
        "/usr/lib/llvm/lib/libclang.so"
    ]
    for path in possible_paths:
        if os.path.exists(path):
            clang.cindex.Config.set_library_file(path)
            break

def find_function_by_text(source_code, function_name):
    """
    Find a function using text parsing, more suitable for C files.
    """
    lines = source_code.split('\n')
    for i, line in enumerate(lines):
        # Look for the function definition
        if function_name in line and '(' in line and not line.strip().startswith('//'):
            # Verify it's a function definition
            words = line.strip().split()
            if function_name in words or f"{function_name}(" in line:
                print(f"Found potential function definition (text): {line.strip()}")
                
                # Find the opening brace
                brace_count = 0
                start_line = i
                found_opening = False
                
                # Look backwards for any function header lines
                while start_line > 0 and not lines[start_line-1].strip().endswith(';'):
                    start_line -= 1
                    if lines[start_line].strip().startswith('/*') or lines[start_line].strip().startswith('*'):
                        continue
                    if lines[start_line].strip():
                        break
                
                # Collect the function
                function_lines = []
                for j in range(start_line, len(lines)):
                    current_line = lines[j]
                    function_lines.append(current_line)
                    
                    # Count braces
                    for char in current_line:
                        if char == '{':
                            found_opening = True
                            brace_count += 1
                        elif char == '}':
                            brace_count -= 1
                    
                    # Check if we've found the end of the function
                    if found_opening and brace_count == 0:
                        return '\n'.join(function_lines)
    
    return None

def find_function_by_clang(source_code, function_name, class_name=None, filename=None):
    """
    Find a function using Clang parsing, better for C++ files.
    """
    index = clang.cindex.Index.create()
    
    # Base compilation arguments
    args = [
        "-x", "c++",
        "--std=c++11",
        "-fparse-all-comments",
        "-I/usr/include",
        "-I/usr/local/include",
        "-I.",
        "-DMOZILLA_INTERNAL_API",
        "-DNDEBUG",
        "-DTRIMMED"
    ]

    try:
        tu = index.parse(
            filename or "temp.cpp",
            args=args,
            unsaved_files=[(filename or "temp.cpp", source_code)],
            options=clang.cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES
        )
    except Exception as e:
        print(f"[ERROR] Clang failed to parse {filename}: {e}")
        return None

    if not tu:
        print("[ERROR] Failed to create translation unit")
        return None

    # Print diagnostics at severity level Warning and above
    for diag in tu.diagnostics:
        if diag.severity >= clang.cindex.Diagnostic.Warning:
            severity = {
                2: "Warning",
                3: "Error",
                4: "Fatal"
            }.get(diag.severity, "Unknown")
            print(f"[{severity}] {diag.spelling}")

    # First try to find the function with string matching
    for line in source_code.split('\n'):
        # Handle class member function
        if class_name:
            search_pattern = f"{class_name}::{function_name}"
        else:
            # For standalone functions, look for the function name at a word boundary
            words = line.split()
            if function_name in words and '(' in line:
                search_pattern = function_name
            else:
                continue
        
        if search_pattern in line:
            print(f"Found potential function definition (clang): {line.strip()}")
            start_idx = source_code.find(line)
            if start_idx != -1:
                # Try to capture the entire function
                brace_count = 0
                end_idx = start_idx
                found_opening = False
                
                for i in range(start_idx, len(source_code)):
                    if source_code[i] == '{':
                        found_opening = True
                        brace_count += 1
                    elif source_code[i] == '}':
                        brace_count -= 1
                        if found_opening and brace_count == 0:
                            end_idx = i + 1
                            break
                
                if end_idx > start_idx:
                    return source_code[start_idx:end_idx]
    
    return None

def find_function(source_code, function_name, class_name=None, filename=None):
    """
    Try both Clang and text-based parsing to find the function.
    """
    # Try text-based parsing first
    func_def = find_function_by_text(source_code, function_name)
    if func_def:
        print("Found function using text-based parsing")
        return func_def

    # If text-based parsing fails, try Clang
    func_def = find_function_by_clang(source_code, function_name, class_name, filename)
    if func_def:
        print("Found function using Clang parsing")
        return func_def

    return None

def main():
    test_cases = [

        {
            "filename": "b59073dc8fae65cd9dc81c0137b0f7a9911873e2_curr_nsJSEnvironment.cpp",
            "function_name": "ScriptErrorEvent",
            "class_name": "",
        }

    ]
    
    test_case = test_cases[0]
    
    try:
        with open(test_case["filename"], "r", encoding="utf-8") as f:
            source_code = f.read()
    except Exception as e:
        print(f"[ERROR] Could not read file {test_case['filename']}: {e}")
        sys.exit(1)

    func_def = find_function(
        source_code,
        test_case["function_name"],
        test_case.get("class_name"),
        test_case["filename"]
    )
    
    if func_def:
        print("\n========== Found Function ==========\n")
        print(func_def)
        print("\n==================================\n")
    else:
        function_desc = (f"{test_case.get('class_name')}::{test_case['function_name']}" 
                        if test_case.get('class_name') 
                        else test_case['function_name'])
        print(f"[INFO] Could not find function {function_desc} in {test_case['filename']}")
        print("\nFirst few lines of the file:")
        print("\n".join(source_code.split("\n")[:10]))

if __name__ == "__main__":
    main()

Found potential function definition (text): ScriptErrorEvent(nsIScriptGlobalObject* aScriptGlobal,
Found function using text-based parsing


public:
  ScriptErrorEvent(nsIScriptGlobalObject* aScriptGlobal,
                   PRUint32 aLineNr, PRUint32 aColumn, PRUint32 aFlags,
                   const nsAString& aErrorMsg,
                   const nsAString& aFileName,
                   const nsAString& aSourceLine,
                   PRBool aDispatchEvent)
  : mScriptGlobal(aScriptGlobal), mLineNr(aLineNr), mColumn(aColumn),
    mFlags(aFlags), mErrorMsg(aErrorMsg), mFileName(aFileName),
    mSourceLine(aSourceLine), mDispatchEvent(aDispatchEvent) {}




# extract function name

In [29]:
def extract_functions_from_patch(patch_content):
    """
    Extract all function names from a patch/diff file where changes (+ or -) occurred.
    Returns a list of tuples (function_name, class_name).
    """
    lines = patch_content.split('\n')
    current_function = None
    current_class = None
    in_function = False
    functions = []
    
    for i, line in enumerate(lines):
        stripped_line = line.strip()
        
        # Handle @@ context lines
        if line.startswith('@@'):
            # Reset function context at new diff chunk
            in_function = False
            current_function = None
            current_class = None

            if '@@ ' in line:
                context_part = line.split('@@ ')[-1].strip()
                # New: if context starts with "class ", extract the class declaration
                if context_part.startswith("class "):
                    parts = context_part.split()
                    if len(parts) >= 2:
                        current_function = parts[1].strip()  # Extract the class name as function name
                        current_class = None
                        in_function = True
                elif '::' in context_part:
                    parts = context_part.split('::')
                    current_class = parts[0].strip().replace('@', '').strip()
                    current_function = parts[1].split('(')[0].strip()
                    in_function = True
                else:
                    parts = context_part.split('(')
                    if len(parts) > 1:
                        func_parts = parts[0].split()
                        if func_parts:
                            current_function = func_parts[-1].strip()
                            current_class = None
                            in_function = True

        # Look for function declaration in code
        if not in_function:
            if '::' in stripped_line and '(' in stripped_line and not stripped_line.startswith(('+', '-', '//', '@')):
                # C++ member function
                parts = stripped_line.split('::')
                if len(parts) == 2:
                    current_class = parts[0].strip()
                    current_function = parts[1].split('(')[0].strip()
                    in_function = True
            elif '(' in stripped_line and not stripped_line.startswith(('+', '-', '//', '@', '}')):
                # C-style function
                parts = stripped_line.split('(')[0].strip().split()
                if parts and not parts[0] in ['if', 'while', 'for', 'switch', 'return']:
                    current_function = parts[-1]
                    current_class = None
                    in_function = True
        
        # If we're in a function and find a change line, add to our list if not already present
        if in_function and line.startswith(('+', '-')) and not line.startswith(('+++ ', '--- ')):
            # Clean up class name if it still contains @@ markers
            if current_class and '@' in current_class:
                current_class = current_class.split('@')[-1].strip()
            
            # Create tuple of current function context
            current = (current_function, current_class)
            
            # Only add if not already in our list
            if current not in functions:
                functions.append(current)
                
    return functions

# Test cases
test_patches = [
    """"""
]

# Test both cases
for i, patch in enumerate(test_patches, 1):
    functions = extract_functions_from_patch(patch)
    print(f"\nTest case {i}:")
    if functions:
        for func_name, class_name in functions:
            if class_name:
                print(f"Function name: {func_name}, Class name: {class_name}")
            else:
                print(f"Function name: {func_name}")
    else:
        print("No functions found with changes")

# Test cases
test_patches = [
    # Test case 1: Function with direct declaration
    """diff --git a/netwerk/base/src/nsBaseChannel.cpp b/netwerk/base/src/nsBaseChannel.cpp
@@ -253,16 +253,19 @@ void
nsBaseChannel::HandleAsyncRedirect(nsIChannel* newChannel)
{
  NS_ASSERTION(!mPump, "Shouldn't have gotten here");
+  PRBool doNotify = PR_TRUE;""",

    # Test case 2: Function in @@ context line
    """diff --git a/editor/libeditor/base/nsEditor.cpp b/editor/libeditor/base/nsEditor.cpp
@@ -3397,7 +3397,7 @@ nsEditor::FindNode(nsINode *aCurrentNode,
    return nullptr;
  }

-  nsIContent* candidate =
+  nsCOMPtr<nsIContent> candidate =""",

    # Test case 3: Function with multiple changes

"""From 7b1c513be6bdbc0552bea0c1d312507337f4e5cd Mon Sep 17 00:00:00 2001
From: Ehsan Akhgari <ehsan@mozilla.com>
Date: Tue, 20 Jul 2010 09:04:14 -0400
Subject: [PATCH] Bug 580151 - Part 1: Move the increment up in case the call
 to nsIEditor::GetSelection fails and we bail out early; r=roc

--HG--
extra : rebase_source : 249cf74c6a1700b230d946793819ff6611ebbb99
---
 editor/libeditor/text/nsTextEditRules.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/editor/libeditor/text/nsTextEditRules.cpp b/editor/libeditor/text/nsTextEditRules.cpp
index ee8e77d9717bb..56e2b655bc6db 100644
--- a/editor/libeditor/text/nsTextEditRules.cpp
+++ b/editor/libeditor/text/nsTextEditRules.cpp
@@ -221,6 +221,12 @@ nsTextEditRules::BeforeEdit(PRInt32 action, nsIEditor::EDirection aDirection)
   
   nsAutoLockRulesSniffing lockIt(this);
   mDidExplicitlySetInterline = PR_FALSE;
+  if (!mActionNesting)
+  {
+    // let rules remember the top level action
+    mTheAction = action;
+  }
+  mActionNesting++;
   
   // get the selection and cache the position before editing
   nsCOMPtr<nsISelection> selection;
@@ -230,12 +236,6 @@ nsTextEditRules::BeforeEdit(PRInt32 action, nsIEditor::EDirection aDirection)
   selection->GetAnchorNode(getter_AddRefs(mCachedSelectionNode));
   selection->GetAnchorOffset(&mCachedSelectionOffset);
 
-  if (!mActionNesting)
-  {
-    // let rules remember the top level action
-    mTheAction = action;
-  }
-  mActionNesting++;
   return NS_OK;
 }
 """,
 
 #  Test case 4: linux
 """From 8c34e2d63231d4bf4852bac8521883944d770fe3 Mon Sep 17 00:00:00 2001
From: Jens Axboe <jens.axboe@oracle.com>
Date: Tue, 17 Oct 2006 19:43:22 +0200
Subject: [PATCH] [PATCH] Remove SUID when splicing into an inode

Originally from Mark Fasheh <mark.fasheh@oracle.com>

generic_file_splice_write() does not remove S_ISUID or S_ISGID. This is
inconsistent with the way we generally write to files.

Signed-off-by: Mark Fasheh <mark.fasheh@oracle.com>
Signed-off-by: Jens Axboe <jens.axboe@oracle.com>
---
 fs/splice.c | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/fs/splice.c b/fs/splice.c
index 68e20e65c6e114..49fb9f12993884 100644
--- a/fs/splice.c
+++ b/fs/splice.c
@@ -845,6 +845,10 @@ generic_file_splice_write_nolock(struct pipe_inode_info *pipe, struct file *out,
 	ssize_t ret;
 	int err;
 
+	err = remove_suid(out->f_dentry);
+	if (unlikely(err))
+		return err;
+
 	ret = __splice_from_pipe(pipe, out, ppos, len, flags, pipe_to_file);
 	if (ret > 0) {
 		*ppos += ret;
@@ -883,12 +887,21 @@ generic_file_splice_write(struct pipe_inode_info *pipe, struct file *out,
 			  loff_t *ppos, size_t len, unsigned int flags)
 {
 	struct address_space *mapping = out->f_mapping;
+	struct inode *inode = mapping->host;
 	ssize_t ret;
+	int err;
+
+	err = should_remove_suid(out->f_dentry);
+	if (unlikely(err)) {
+		mutex_lock(&inode->i_mutex);
+		err = __remove_suid(out->f_dentry, err);
+		mutex_unlock(&inode->i_mutex);
+		if (err)
+			return err;
+	}
 
 	ret = splice_from_pipe(pipe, out, ppos, len, flags, pipe_to_file);
 	if (ret > 0) {
-		struct inode *inode = mapping->host;
-
 		*ppos += ret;
 
 		/*
@@ -896,8 +909,6 @@ generic_file_splice_write(struct pipe_inode_info *pipe, struct file *out,
 		 * sync it.
 		 */
 		if (unlikely((out->f_flags & O_SYNC) || IS_SYNC(inode))) {
-			int err;
-
 			mutex_lock(&inode->i_mutex);
 			err = generic_osync_inode(inode, mapping,
 						  OSYNC_METADATA|OSYNC_DATA);

""",
# test case 5: linux
"""From 47d439e9fb8a81a90022cfa785bf1c36c4e2aff6 Mon Sep 17 00:00:00 2001
From: Eric Paris <eparis@redhat.com>
Date: Fri, 7 Aug 2009 14:53:57 -0400
Subject: [PATCH] security: define round_hint_to_min in !CONFIG_SECURITY

Fix the header files to define round_hint_to_min() and to define
mmap_min_addr_handler() in the !CONFIG_SECURITY case.

Built and tested with !CONFIG_SECURITY

Signed-off-by: Eric Paris <eparis@redhat.com>
Signed-off-by: James Morris <jmorris@namei.org>
---
 include/linux/security.h | 30 +++++++++++++++---------------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/include/linux/security.h b/include/linux/security.h
index 7b431155e392ee..57ead99d259361 100644
--- a/include/linux/security.h
+++ b/include/linux/security.h
@@ -121,6 +121,21 @@ struct request_sock;
 #define LSM_UNSAFE_PTRACE	2
 #define LSM_UNSAFE_PTRACE_CAP	4
 
+/*
+ * If a hint addr is less than mmap_min_addr change hint to be as
+ * low as possible but still greater than mmap_min_addr
+ */
+static inline unsigned long round_hint_to_min(unsigned long hint)
+{
+	hint &= PAGE_MASK;
+	if (((void *)hint != NULL) &&
+	    (hint < mmap_min_addr))
+		return PAGE_ALIGN(mmap_min_addr);
+	return hint;
+}
+extern int mmap_min_addr_handler(struct ctl_table *table, int write, struct file *filp,
+				 void __user *buffer, size_t *lenp, loff_t *ppos);
+
 #ifdef CONFIG_SECURITY
 
 struct security_mnt_opts {
@@ -149,21 +164,6 @@ static inline void security_free_mnt_opts(struct security_mnt_opts *opts)
 	opts->num_mnt_opts = 0;
 }
 
-/*
- * If a hint addr is less than mmap_min_addr change hint to be as
- * low as possible but still greater than mmap_min_addr
- */
-static inline unsigned long round_hint_to_min(unsigned long hint)
-{
-	hint &= PAGE_MASK;
-	if (((void *)hint != NULL) &&
-	    (hint < mmap_min_addr))
-		return PAGE_ALIGN(mmap_min_addr);
-	return hint;
-}
-
-extern int mmap_min_addr_handler(struct ctl_table *table, int write, struct file *filp,
-				 void __user *buffer, size_t *lenp, loff_t *ppos);
 /**
  * struct security_operations - main security structure
  *
""",
"""From b59073dc8fae65cd9dc81c0137b0f7a9911873e2 Mon Sep 17 00:00:00 2001
From: Boris Zbarsky <bzbarsky@mit.edu>
Date: Tue, 8 Jun 2010 15:58:26 -0400
Subject: [PATCH] Bug 568564.  Suppress the script filename for cross-origin
 onerror events.  r=jst

---
 content/base/test/test_bug461735.html | 2 +-
 dom/base/nsJSEnvironment.cpp          | 4 ++++
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/content/base/test/test_bug461735.html b/content/base/test/test_bug461735.html
index 61cb4cd647663..f1956998f9120 100644
--- a/content/base/test/test_bug461735.html
+++ b/content/base/test/test_bug461735.html
@@ -19,7 +19,7 @@
 <script type="application/javascript">
 window.onerror = function(message, uri, line) {
   is(message, "Script error.", "Should have empty error message");
-  is(uri, "http://example.com/tests/content/base/test/bug461735-post-redirect.js", "Unexpected error location URI");
+  is(uri, "", "Should have empty error location URI");
   is(line, 0, "Shouldn't have a line here");
 }
 </script>
diff --git a/dom/base/nsJSEnvironment.cpp b/dom/base/nsJSEnvironment.cpp
index 42537280dee3e..54180a3f076d5 100644
--- a/dom/base/nsJSEnvironment.cpp
+++ b/dom/base/nsJSEnvironment.cpp
@@ -475,6 +475,10 @@ class ScriptErrorEvent : public nsRunnable
             NS_WARNING("Not same origin error!");
             errorevent.errorMsg = xoriginMsg.get();
             errorevent.lineNr = 0;
+            // FIXME: once the principal of the script is not tied to
+            // the filename, we can stop using the post-redirect
+            // filename if we want and remove this line.
+            errorevent.fileName = nsnull;
           }
 
           nsEventDispatcher::Dispatch(win, presContext, &errorevent, nsnull,"""
 
]

# Test both cases
for i, patch in enumerate(test_patches, 1):
    functions = extract_functions_from_patch(patch)
    print(f"\nTest case {i}:")
    if functions:
        for func_name, class_name in functions:
            if class_name:
                print(f"Function name: {func_name}, Class name: {class_name}")
            else:
                print(f"Function name: {func_name}")
    else:
        print("No functions found with changes")


Test case 1:
No functions found with changes

Test case 1:
Function name: HandleAsyncRedirect, Class name: nsBaseChannel

Test case 2:
Function name: FindNode, Class name: nsEditor

Test case 3:
Function name: BeforeEdit, Class name: nsTextEditRules

Test case 4:
Function name: generic_file_splice_write
Function name: generic_file_splice_write_nolock

Test case 5:
Function name: round_hint_to_min
Function name: security_free_mnt_opts

Test case 6:
Function name: function
Function name: ScriptErrorEvent


test line inside function

In [None]:
def is_change_within_function(vulnerable_function, patched_function, changes_deleted, changes_added):
    vulnerable_lines = vulnerable_function.splitlines()
    patched_lines = patched_function.splitlines()
    
    # Check if deleted changes exist in the vulnerable function
    for change in changes_deleted:
        change = change.strip()  # Ensure no leading/trailing spaces
        for line in vulnerable_lines:
            if change in line.strip():  # Check if change exists in vulnerable function
                return True

    # Check if added changes exist in the patched function
    for change in changes_added:
        change = change.strip()
        for line in patched_lines:
            if change in line.strip():  # Check if change exists in patched function
                return True

    return False

# Example usage
vulnerable_function = """
void
nsBaseChannel::HandleAsyncRedirect(nsIChannel* newChannel)
{
  NS_ASSERTION(!mPump, "Shouldn't have gotten here");
  if (NS_SUCCEEDED(mStatus)) {
      nsresult rv = Redirect(newChannel, nsIChannelEventSink::REDIRECT_INTERNAL,
                             PR_TRUE);
      if (NS_FAILED(rv))
          Cancel(rv);
  }

  mWaitingOnAsyncRedirect = PR_FALSE;

  if (NS_FAILED(mStatus)) {
    // Notify our consumer ourselves
    mListener->OnStartRequest(this, mListenerContext);
    mListener->OnStopRequest(this, mListenerContext, mStatus);
    mListener = nsnull;
    mListenerContext = nsnull;
  }

  if (mLoadGroup)
    mLoadGroup->RemoveRequest(this, nsnull, mStatus);

  // Drop notification callbacks to prevent cycles.
  mCallbacks = nsnull;
  CallbacksChanged();
}
"""

patched_function = """void
nsBaseChannel::HandleAsyncRedirect(nsIChannel* newChannel)
{
  NS_ASSERTION(!mPump, "Shouldn't have gotten here");
  PRBool doNotify = PR_TRUE;
  if (NS_SUCCEEDED(mStatus)) {
      nsresult rv = Redirect(newChannel, nsIChannelEventSink::REDIRECT_INTERNAL,
                             PR_TRUE);
      if (NS_FAILED(rv))
          Cancel(rv);
      else
          doNotify = PR_FALSE;
  }

  mWaitingOnAsyncRedirect = PR_FALSE;

  if (doNotify) {
    // Notify our consumer ourselves
    mListener->OnStartRequest(this, mListenerContext);
    mListener->OnStopRequest(this, mListenerContext, mStatus);
    mListener = nsnull;
    mListenerContext = nsnull;
  }

  if (mLoadGroup)
    mLoadGroup->RemoveRequest(this, nsnull, mStatus);

  // Drop notification callbacks to prevent cycles.
  mCallbacks = nsnull;
  CallbacksChanged();
}"""

changes_deleted = ["if (NS_FAILED(mStatus)) {"]
changes_added = ["PRBool doNotify = PR_TRUE;", "else \n doNotify = PR_FALSE;", "if (doNotify) {"]

print(is_change_within_function(vulnerable_function, patched_function, changes_deleted, changes_added))


True
