In [1]:
!pip install --upgrade --quiet langchain-community langchain-openai paramiko pydantic

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m47.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.3/55.3 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.3/227.3 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.6/278.6 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m37.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m414.3/414.3 kB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.3/472.3 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m856.7/856.7 kB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependenc

In [3]:
import paramiko
import io
from kaggle_secrets import UserSecretsClient

ssh_comment = UserSecretsClient().get_secret("SSH_COMMENT")

public_key = UserSecretsClient().get_secret("SSH_PUBLIC_KEY")
encrypted_private_key = UserSecretsClient().get_secret("SSH_PRIVATE_KEY_ENCRYPTED")
ssh_hostname = UserSecretsClient().get_secret("SSH_HOSTNAME")
ssh_username = UserSecretsClient().get_secret("SSH_USERNAME")
#ssh_jump_gateway = UserSecretsClient().get_secret("SSH_JUMP_GATEWAY")
#ssh_jump_dest = UserSecretsClient().get_secret("SSH_JUMP_DEST")

In [16]:
from langchain_core.tools import BaseTool
from langchain.agents import AgentExecutor, LLMSingleActionAgent
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
from langchain.prompts import StringPromptTemplate, PromptTemplate
from typing import List, Dict, Optional
import asyncio
import paramiko
import json
from typing import ClassVar
from pydantic import BaseModel, Field, PrivateAttr

class SSHConnection:
    def __init__(self):
        self.client = paramiko.SSHClient()
        self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

        key_string = encrypted_private_key.replace("\\n", "\n")
        private_key = paramiko.RSAKey.from_private_key(
            io.StringIO(key_string),
            password = ssh_comment)
        self.client.connect(
            hostname = ssh_hostname,
            username = ssh_username,
            pkey = private_key
        )

    async def execute_interactive(self, command: str, timeout: int = 30) -> asyncio.Queue:
        """
        Execute an interactive command and return a queue for streaming output
        """
        output_queue = asyncio.Queue()
        
        # Start interactive shell session
        shell = self.client.invoke_shell()
        
        async def read_output():
            while True:
                if shell.recv_ready():
                    output = shell.recv(4096).decode('utf-8')
                    await output_queue.put(output)
                await asyncio.sleep(0.1)
                
        # Start background task to read output
        asyncio.create_task(read_output())
        
        # Send command
        shell.send(command + '\n')
        
        return output_queue

class LinuxCommandTool(BaseTool):
    name: str = "linux_command"
    description: str = "Execute Linux commands and handle interactive output"

    _ssh: SSHConnection = PrivateAttr()
    
    def __init__(self, ssh_connection: SSHConnection):
        super().__init__()
        self._ssh = ssh_connection
    
    async def _run(self, command: str) -> str:
        output_queue = await self.ssh.execute_interactive(command)
        
        # Collect output for a reasonable time
        full_output = []
        try:
            while True:
                output = await asyncio.wait_for(output_queue.get(), timeout=1.0)
                full_output.append(output)
                
                # Check if command has completed (can be customized based on needs)
                if output.strip().endswith('$ '):  # Basic prompt detection
                    break
        except asyncio.TimeoutError:
            pass
            
        return ''.join(full_output)

class CommandState(BaseModel):
    """Track the state of command execution"""
    command: str = Field(description="The command being executed")
    output: str = Field(description="Current accumulated output")
    status: str = Field(description="Current status: running/completed/error")
    next_action: Optional[str] = Field(description="Next action to take based on output")

class InteractiveAgentPrompt(StringPromptTemplate):
    input_variables: List[str] = Field(template="""You are an expert Linux system administrator.
    Current command state: {current_state}
    
    Based on the command output, determine the next action:
    1. If the command is complete, analyze the output and summarize the results
    2. If the command requires interaction, provide the next input
    3. If there's an error, suggest how to resolve it
    
    Previous actions: {memory}
    
    Your response should be in JSON format:
    {{"action": "complete/interact/error",
      "response": "your analysis or next input",
      "reasoning": "your reasoning"}}
    """)
    
    def format(self, **kwargs) -> str:
        return self.template.format(**kwargs)

class InteractiveAgent:
    def __init__(self, llm, tools: List[BaseTool]):
        self.llm = llm
        self.tools = tools
        self.memory = ConversationBufferMemory()
        self.prompt = PromptTemplate.from_template("""You are an expert Linux system administrator.
    Current command state: {current_state}
    
    Based on the command output, determine the next action:
    1. If the command is complete, analyze the output and summarize the results
    2. If the command requires interaction, provide the next input
    3. If there's an error, suggest how to resolve it
    
    Previous actions: {memory}
    
    Your response should be in JSON format:
    {{"action": "complete/interact/error",
      "response": "your analysis or next input",
      "reasoning": "your reasoning"}}
    """)
        
    async def run(self, command: str):
        state = CommandState(
            command=command,
            output="",
            status="running",
            next_action=None
        )
        
        while state.status == "running":
            # Get next action from LLM
            llm_response = await self.llm.apredict(
                self.prompt.format(
                    current_state=state.dict(),
                    memory=self.memory.buffer
                )
            )
            
            action_data = json.loads(llm_response)
            
            if action_data["action"] == "interact":
                # Send next input to command
                for tool in self.tools:
                    if tool.name == "linux_command":
                        new_output = await tool.run(action_data["response"])
                        state.output += new_output
                        
            elif action_data["action"] == "complete":
                state.status = "completed"
                return action_data["response"]
                
            elif action_data["action"] == "error":
                state.status = "error"
                return action_data["response"]
            
            # Update memory
            self.memory.save_context(
                {"input": state.command},
                {"output": action_data["reasoning"]}
            )

# Example usage
async def main():
    # Initialize SSH connection
    ssh = SSHConnection()
    
    # Create tools
    tools = [LinuxCommandTool(ssh)]
    
    # Initialize LLM
    llm = ChatOpenAI(temperature=0, openai_api_key=UserSecretsClient().get_secret("OPENAI_API_KEY"))
    
    # Create agent
    agent = InteractiveAgent(llm, tools)
    
    # Run interactive command
    result = await agent.run("sudo apt update")
    print(f"Final result: {result}")

if __name__ == "__main__":
    await main()

<ipython-input-16-081f5c4253ff>:140: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  current_state=state.dict(),
  llm_response = await self.llm.apredict(


Final result: The 'sudo apt update' command has completed successfully.
