In [2]:
# | default_exp tools.fs_write

In [None]:
# | export

from pydantic import BaseModel, Field, field_validator, ValidationInfo
from typing import Optional, List, Any
from enum import Enum
from pathlib import Path
import os
import re
import logging
import difflib
from typing import Dict, Any
from agentic.tools.base import BaseTool, ToolMetadata, ToolCategory

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class WriteCommand(str, Enum):
    CREATE = "create"
    EDIT = "edit"

class EditOperationType(str, Enum):
    REPLACE = "replace"
    INSERT = "insert"
    APPEND = "append"
    PREPEND = "prepend"
    DELETE_LINES = "delete_lines"

class FsWriteOperation(BaseModel):
    command: WriteCommand = Field(..., description="Operation mode: 'create' to create/overwrite a file, 'edit' for modifying a file (replace, insert, append, prepend, delete_lines).")
    path: str = Field(..., description="File path, e.g., '/project/src/app.py'. Must be a file, not a directory.")
    file_text: Optional[str] = Field(None, description="Content for create, insert, append, prepend operations.")
    operation_type: Optional[EditOperationType] = Field(None, description="Sub-operation for edit mode: replace, insert, append, prepend, delete_lines.")
    old_str: Optional[str] = Field(None, description="Text to replace in edit/replace (regex or substring).")
    new_str: Optional[str] = Field(None, description="Replacement text in edit/replace.")
    insert_line: Optional[int] = Field(None, description="Line number (1-based) for edit/insert.")
    start_line: Optional[int] = Field(None, description="Start line (1-based) for edit/delete_lines.")
    end_line: Optional[int] = Field(None, description="End line (1-based) for edit/delete_lines.")
    regex_mode: Optional[bool] = Field(True, description="True (default) treats old_str as regex in edit/replace; False for substring matching.")
    respect_gitignore: Optional[bool] = Field(True, description="Check if file is ignored by .gitignore (default True; warns but allows edit if False).")
    blocklist_pattern: str = Field("*__pycache__/*|*.lock|*.o|*.pyc|*.class", description="Patterns to block, e.g., cache/lock files.")
    show_diff: Optional[bool] = Field(True, description="Show unified diff in Git style before writing.")
    trusted: Optional[bool] = Field(False, description="Bypass confirmation prompts if True.")
    summary: Optional[str] = Field(None, description="Brief description of the change.")

    @field_validator("path")
    @classmethod
    def validate_path(cls, value: str) -> str:
        try:
            if not value:
                value = os.getcwd()
            path_obj = Path(value)
            if path_obj.is_dir():
                raise ValueError(f"Path {value} is a directory; must be a file")
            if not path_obj.parent.exists():
                raise ValueError(f"Parent directory of {value} does not exist")
            return str(path_obj)
        except Exception as e:
            raise ValueError(f"Invalid path '{value}': {str(e)}")

    @field_validator("file_text", "operation_type", "old_str", "new_str", "insert_line", "start_line", "end_line")
    @classmethod
    def validate_fields(cls, value: Any, info: ValidationInfo) -> Any:
        command = info.data.get("command")
        operation_type = info.data.get("operation_type")
        field_name = info.field_name
        if command == WriteCommand.CREATE and field_name == "file_text" and value is None:
            raise ValueError("file_text is required for create command")
        if command == WriteCommand.EDIT:
            if field_name == "operation_type" and value is None:
                raise ValueError("operation_type is required for edit command")
            if operation_type == EditOperationType.REPLACE and field_name in ["old_str", "new_str"] and value is None:
                raise ValueError(f"{field_name} is required for edit/replace")
            if operation_type == EditOperationType.INSERT and field_name in ["file_text", "insert_line"] and value is None:
                raise ValueError(f"{field_name} is required for edit/insert")
            if operation_type in [EditOperationType.APPEND, EditOperationType.PREPEND] and field_name == "file_text" and value is None:
                raise ValueError("file_text is required for edit/append or edit/prepend")
            if operation_type == EditOperationType.DELETE_LINES and field_name in ["start_line", "end_line"] and value is None:
                raise ValueError(f"{field_name} is required for edit/delete_lines")
            if operation_type == EditOperationType.INSERT and field_name == "insert_line" and value is not None and value < 1:
                raise ValueError("insert_line must be >= 1")
            if operation_type == EditOperationType.DELETE_LINES:
                if field_name == "start_line" and value is not None and value < 1:
                    raise ValueError("start_line must be >= 1")
                if field_name == "end_line" and value is not None and value < 1:
                    raise ValueError("end_line must be >= 1")
                if "start_line" in info.data and "end_line" in info.data and info.data["end_line"] < info.data["start_line"]:
                    raise ValueError("end_line must be >= start_line")
        return value

class FsWriteParams(BaseModel):
    operations: List[FsWriteOperation] = Field(..., description="List of write operations to perform")



class FsWriteTool(BaseTool):
    def __init__(self):
        metadata = ToolMetadata(
            name="fs_write",
            description="Advanced filesystem writing with Git integration and safety checks",
            category=ToolCategory.FILESYSTEM
        )
        super().__init__(metadata)

    def get_parameters_schema(self) -> Dict[str, Any]:
        try:
            return {
                "type": "object",
                "properties": {
                    "command": {
                        "type": "string",
                        "enum": ["create", "edit"],
                        "description": "Operation: create (new file) or edit (modify existing file)"
                    },
                    "path": {"type": "string", "description": "File path"},
                    "file_text": {"type": "string", "description": "Content for create operation"},
                    "operation_type": {
                        "type": "string", 
                        "enum": ["replace", "insert", "append", "prepend", "delete_lines"],
                        "description": "Edit operation type (required when command=edit)"
                    },
                    "old_str": {"type": "string", "description": "Text to replace (for replace operation)"},
                    "new_str": {"type": "string", "description": "Replacement text (for replace operation)"},
                    "insert_line": {"type": "integer", "description": "Line number to insert at (1-based)"},
                    "start_line": {"type": "integer", "description": "Start line for delete_lines"},
                    "end_line": {"type": "integer", "description": "End line for delete_lines"},
                    "auto_approve": {"type": "boolean", "description": "Skip confirmation prompt (default: false)"}
                },
                "required": ["command", "path"]
            }
        except Exception as e:
            logger.error(f"Schema generation failed: {e}")
            return {}

    def _generate_diff(self, original: str, new: str, filepath: str) -> str:
        try:
            original_lines = original.splitlines(keepends=True)
            new_lines = new.splitlines(keepends=True)
            diff_lines = list(difflib.unified_diff(
                original_lines,
                new_lines,
                fromfile=f"a/{filepath}",
                tofile=f"b/{filepath}",
                lineterm=""
            ))
            if not diff_lines:
                return f"No changes in {filepath}"
            
            # Add colorful IDE-style diff formatting
            colored_diff = []
            line_num_old = 0
            line_num_new = 0
            
            for line in diff_lines:
                if line.startswith('---') or line.startswith('+++'):
                    colored_diff.append(f"\033[1m{line}\033[0m")  # Bold
                elif line.startswith('@@'):
                    colored_diff.append(f"\033[36m{line}\033[0m")  # Cyan
                elif line.startswith('-'):
                    if not line.startswith('---'):
                        line_num_old += 1
                        colored_diff.append(f"\033[31m- {line_num_old:6}: {line[1:]}\033[0m")  # Red
                elif line.startswith('+'):
                    if not line.startswith('+++'):
                        line_num_new += 1
                        colored_diff.append(f"\033[32m+ {line_num_new:6}: {line[1:]}\033[0m")  # Green
                else:
                    line_num_old += 1
                    line_num_new += 1
                    colored_diff.append(f"  {line_num_old:6}: {line}")
            
            added = sum(1 for line in diff_lines if line.startswith('+') and not line.startswith('+++'))
            removed = sum(1 for line in diff_lines if line.startswith('-') and not line.startswith('---'))
            colored_diff.append(f"\n\033[32mSummary: +{added} -{removed} lines\033[0m")
            return "\n".join(colored_diff)
        except Exception as e:
            logger.error(f"Diff generation failed: {e}")
            return f"Error generating diff for {filepath}: {e}"
            return '\n'.join(diff_output)
        except Exception as e:
            logger.error(f"Diff generation failed for {filepath}: {e}")
            return "Error generating diff"

    def _apply_operation(self, file_path: str, params: Dict[str, Any]) -> Dict:
        try:
            path_obj = Path(file_path)
            if path_obj.is_dir():
                return {"error": f"Path {file_path} is a directory; must be a file"}
            if not path_obj.parent.exists():
                return {"error": f"Parent directory of {file_path} does not exist"}
            with open(file_path, 'rb') as f:
                chunk = f.read(1024)
                if b'\0' in chunk:
                    return {"error": f"Cannot edit {file_path}: Binary file"}
        except (OSError, UnicodeDecodeError) as e:
            if params["command"] != "create":
                return {"error": f"Failed to read {file_path}: {str(e)}"}
        except Exception as e:
            return {"error": f"Unexpected error checking file: {str(e)}"}
        command = params.get("command")
        if command not in ["create", "edit"]:
            return {"error": f"Invalid command: {command}"}
        if command == "create" and not params.get("file_text"):
            return {"error": "file_text is required for create command"}
        if command == "edit":
            operation_type = params.get("operation_type")
            if not operation_type or operation_type not in ["replace", "insert", "append", "prepend", "delete_lines"]:
                return {"error": f"Invalid operation_type: {operation_type}"}
            # Additional validation as in validator
        original_content = ""
        if command != "create":
            try:
                with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                    original_content = f.read()
            except (OSError, UnicodeDecodeError) as e:
                return {"error": f"Failed to read {file_path}: {str(e)}"}
        current_content = original_content
        line_count = len(current_content.splitlines())
        result = {
            "operation": command,
            "status": None,
            "error": None,
            "summary": params.get("summary")
        }
        try:
            if command == "create":
                current_content = params.get("file_text", "")
                result["status"] = "created"
            elif operation_type == "replace":
                logger.debug(f"Applying replace: pattern='{params['old_str']}', file={file_path}")
                if params.get("regex_mode", True):
                    current_content = re.sub(params["old_str"], params["new_str"], current_content, flags=re.MULTILINE)
                else:
                    current_content = current_content.replace(params["old_str"], params["new_str"])
                if current_content == original_content:
                    result["status"] = "no changes"
                    result["error"] = f"No matches found for pattern '{params['old_str']}'"
                else:
                    result["status"] = "replace"
            elif operation_type == "insert":
                lines = current_content.splitlines(keepends=True)
                insert_idx = max(0, min(params["insert_line"] - 1, len(lines)))
                lines.insert(insert_idx, params["file_text"])
                current_content = ''.join(lines)
                result["status"] = "insert"
            elif operation_type == "append":
                current_content += params["file_text"]
                result["status"] = "append"
            elif operation_type == "prepend":
                current_content = params["file_text"] + current_content
                result["status"] = "prepend"
            elif operation_type == "delete_lines":
                lines = current_content.splitlines(keepends=True)
                start_idx = max(0, params["start_line"] - 1)
                end_idx = min(len(lines), params["end_line"])
                if start_idx >= len(lines) or end_idx < start_idx:
                    raise ValueError("Invalid line range for delete")
                current_content = ''.join(lines[:start_idx] + lines[end_idx:])
                result["status"] = "delete_lines"
            return {"content": current_content, "result": result, "original_content": original_content}
        except re.error as e:
            return {"error": f"Invalid regex pattern '{params.get('old_str', '')}': {str(e)}"}
        except ValueError as e:
            return {"error": f"Validation error during operation: {str(e)}"}
        except Exception as e:
            logger.error(f"Operation failed for {file_path}: {e}")
            return {"error": f"Unexpected error applying operation: {str(e)}"}

    def execute(self, **kwargs) -> Dict[str, Any]:
        try:
            if isinstance(kwargs.get('params'), dict):
                return self._execute_internal(kwargs['params'])
            else:
                return self._execute_internal(kwargs)
        except Exception as e:
            logger.error(f"Execute failure: {e}")
            return{"success": False, "data": [], "error": {"type": "ExecuteError", "message": str(e)}}

    def _execute_internal(self, params: Dict[str, Any]) -> Dict:
        if not isinstance(params, dict):
            return {
                "success": False,
                "data": [],
                "error": {"type": "InvalidInput", "message": "Params must be a dictionary"}
            }
        if "command" not in params or "path" not in params:
            return {
                "success": False,
                "data": [],
                "error": {"type": "InvalidInput", "message": "command and path are required"}
            }
        file_path = params["path"]
        result = {"command": params["command"], "path": file_path, "data": None, "error": None}
        apply_result = self._apply_operation(file_path, params)
        if "error" in apply_result:
            result["error"] = {"type": "EditError", "message": apply_result["error"]}
            return {"success": False, "data": [result], "error": result["error"]}
        new_content = apply_result["content"]
        original_content = apply_result["original_content"]
        if params.get("show_diff", True):
            diff = self._generate_diff(original_content, new_content, file_path)
            logger.info(diff)
        if apply_result["result"]["status"] == "no changes" or apply_result["result"].get("error"):
            result["data"] = {
                "path": file_path,
                "status": apply_result["result"]["status"],
                "size": len(original_content.encode('utf-8')) if original_content else 0,
                "operation": apply_result["result"]
            }
            return {
                "success": True,
                "data": [result],
                "error": None,
                "metadata": {
                    "processed_files": 0,
                    "description": "No changes applied; no files written."
                }
            }
        if not params.get("trusted", False):
            logger.warning("Trusted is False; in production, this would prompt for confirmation. Assuming approval for refactoring demo.")
        try:
            path_obj = Path(file_path)
            path_obj.parent.mkdir(parents=True, exist_ok=True)
            with open(file_path, 'w', encoding='utf-8') as f:
                f.write(new_content)
            result["data"] = {
                "path": file_path,
                "status": "edited",
                "size": len(new_content.encode('utf-8')),
                "operation": apply_result["result"]
            }
        except (OSError, PermissionError) as e:
            result["error"] = {"type": "WriteError", "message": f"Failed to write {file_path}: {str(e)}"}
            logger.error(result["error"]["message"])
        except Exception as e:
            result["error"] = {"type": "UnexpectedError", "message": str(e)}
            logger.error(str(e))
        return {
            "success": result["error"] is None,
            "data": [result],
            "error": result["error"],
            "metadata": {
                "processed_files": 1 if result["error"] is None else 0,
                "description": "Results ready for chaining: Use 'data[].path' for execute_bash (e.g., git add)."
            }
        }
