In [7]:
!pip install --upgrade --quiet langchain-community paramiko pydantic

In [8]:
import paramiko
import io
from kaggle_secrets import UserSecretsClient

ssh_comment = UserSecretsClient().get_secret("SSH_COMMENT")

def generate_rsa_key(bits=4096, comment=ssh_comment, passphrase=None):
    """Generate a new RSA key pair and return it as a string"""
    # Create new private key
    key = paramiko.RSAKey.generate(bits)
    
    # Add comment
    key.comment = comment
    
    # Return as string (unencrypted or encrypted based on passphrase)
    key_file = io.StringIO()
    key.write_private_key(key_file, password=passphrase)
    private_key_str = key_file.getvalue()
    
    # Get the public key as well
    public_key_str = f"ssh-rsa {key.get_base64()} {comment}"
    
    return private_key_str, public_key_str
    
try:
    public_key = UserSecretsClient().get_secret("SSH_PUBLIC_KEY")
    encrypted_private_key = UserSecretsClient().get_secret("SSH_PRIVATE_KEY_ENCRYPTED")
except:
    # Generate an unencrypted key
    #private_key, public_key = generate_rsa_key()
    #print("Private key:\n", private_key)
    #print("\nPublic key:\n", public_key)
    
    # Generate an encrypted key
    encrypted_private_key, public_key = generate_rsa_key(passphrase=ssh_comment)
    encrypted_private_key = encrypted_private_key.replace("\n", "\\n")
    print("\nEncrypted private key:\n", encrypted_private_key)
    print("\nPublic key:\n", public_key)

ssh_hostname = UserSecretsClient().get_secret("SSH_HOSTNAME")
ssh_username = UserSecretsClient().get_secret("SSH_USERNAME")

try:
    ssh_jump_gateway = UserSecretsClient().get_secret("SSH_JUMP_GATEWAY")
    ssh_jump_dest = UserSecretsClient().get_secret("SSH_JUMP_DEST")
except:
    pass

In [9]:
from langchain.agents import Tool, AgentExecutor
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, AIMessage
from typing import List, Dict, Optional, AsyncGenerator
import asyncio
import paramiko
import json
from pydantic import BaseModel, Field
from datetime import datetime
import re

class OutputChunk(BaseModel):
    """Represents a chunk of command output"""
    content: str
    timestamp: datetime
    type: str = "stdout"  # stdout, stderr, or system
    requires_attention: bool = False
    pattern_matches: Dict[str, str] = Field(default_factory=dict)

class PatternMatcher:
    """Matches important patterns in command output"""
    def __init__(self):
        self.patterns = {
            'error': r'error|exception|failed|fatal',
            'prompt': r'\[y/N\]|\[Y/n\]|password:|continue\?',
            'progress': r'\d+%|\d+/\d+',
            'completion': r'(done|completed|finished|ready).*$',
        }
        self.compiled_patterns = {
            k: re.compile(v, re.IGNORECASE) for k, v in self.patterns.items()
        }

    def analyze_chunk(self, text: str) -> Dict[str, str]:
        matches = {}
        for pattern_name, pattern in self.compiled_patterns.items():
            if found := pattern.search(text):
                matches[pattern_name] = found.group(0)
        return matches

class StreamProcessor:
    """Processes command output streams and chunks them intelligently"""
    def __init__(self, pattern_matcher: PatternMatcher):
        self.pattern_matcher = pattern_matcher
        self.buffer = ""
        self.chunk_size = 1024
        self.min_chunk_size = 100  # Minimum size to process

    def should_chunk(self, text: str) -> bool:
        """Determine if we should create a new chunk based on content"""
        if len(text) >= self.chunk_size:
            return True
        
        patterns = self.pattern_matcher.analyze_chunk(text)
        return bool(patterns)  # Chunk if we find any important patterns

    async def process_stream(self, stream: AsyncGenerator[str, None]) -> AsyncGenerator[OutputChunk, None]:
        async for data in stream:
            self.buffer += data
            
            while self.buffer:
                if len(self.buffer) < self.min_chunk_size and not self.should_chunk(self.buffer):
                    break
                    
                chunk_size = min(len(self.buffer), self.chunk_size)
                chunk_text = self.buffer[:chunk_size]
                self.buffer = self.buffer[chunk_size:]
                
                patterns = self.pattern_matcher.analyze_chunk(chunk_text)
                requires_attention = bool({'error', 'prompt'} & patterns.keys())
                
                yield OutputChunk(
                    content=chunk_text,
                    timestamp=datetime.now(),
                    requires_attention=requires_attention,
                    pattern_matches=patterns
                )

class SSHConnection:
    def __init__(self):
        self.jump_client = paramiko.SSHClient()
        self.jump_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

        key_string = encrypted_private_key.replace("\\n", "\n")
        private_key = paramiko.RSAKey.from_private_key(
            io.StringIO(key_string),
            password = ssh_comment)
        self.jump_client.connect(
            hostname = ssh_hostname,
            username = ssh_username,
            pkey = private_key
        )

        transport = self.jump_client.get_transport()
        dest_addr = (ssh_jump_dest, 22)
        local_addr = (ssh_jump_gateway, 22)
        channel = transport.open_channel("direct-tcpip", dest_addr, local_addr)
        
        self.client = paramiko.SSHClient()
        self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        self.client.connect(ssh_jump_dest, username = ssh_username, pkey = private_key, sock = channel)

    async def stream_output(self, shell) -> AsyncGenerator[str, None]:
        """Stream output from shell with backpressure control"""
        while True:
            if shell.recv_ready():
                data = shell.recv(4096).decode('utf-8')
                if data:
                    yield data
            await asyncio.sleep(0.1)

    async def execute_interactive(self, command: str) -> AsyncGenerator[OutputChunk, None]:
        """Execute command and stream output chunks"""
        shell = self.client.invoke_shell()
        processor = StreamProcessor(PatternMatcher())
        
        # Send command
        shell.send(command + '\n')
        
        # Process output stream
        async for chunk in processor.process_stream(self.stream_output(shell)):
            yield chunk

class LinuxCommandTool(Tool):
    def __init__(self, ssh_connection: SSHConnection):
        self.ssh = ssh_connection
        super().__init__(
            name="linux_command",
            description="Execute Linux commands and handle interactive output",
            func=self.run
        )
    
    async def run(self, command: str) -> AsyncGenerator[OutputChunk, None]:
        async for chunk in self.ssh.execute_interactive(command):
            yield chunk

class AgentResponse(BaseModel):
    """Structured response from the agent"""
    action: str  # "continue", "interact", "alert", "complete"
    response: str
    reasoning: str
    priority: int = 0

class StreamingAgent:
    def __init__(self, llm, tools: List[Tool]):
        self.llm = llm
        self.tools = tools
        self.memory = ConversationBufferMemory(max_history=10)
        self.current_context = []
        
    def create_prompt(self, chunks: List[OutputChunk]) -> str:
        """Create prompt for LLM based on recent chunks"""
        return f"""Analyze this command output stream and determine appropriate action:

Recent output:
{chunks[-5:]}  # Show last 5 chunks

Patterns detected:
{[chunk.pattern_matches for chunk in chunks[-5:]]}

Based on this output:
1. If you see a prompt/question, provide the appropriate response
2. If you detect an error, provide guidance
3. If the command is progressing normally, return "continue"
4. If the command has completed, provide a summary

Respond in JSON format:
{{
    "action": "continue/interact/alert/complete",
    "response": "your response or next action",
    "reasoning": "your analysis",
    "priority": 0-10  # Urgency of response
}}"""

    async def process_chunks(self, chunks: List[OutputChunk]) -> Optional[AgentResponse]:
        """Process chunks and determine if LLM analysis is needed"""
        # Quick pattern-based analysis first
        if any(chunk.requires_attention for chunk in chunks):
            # Important patterns detected, consult LLM
            prompt = self.create_prompt(chunks)
            response = await self.llm.apredict(prompt)
            return AgentResponse(**json.loads(response))
        
        # For normal output, accumulate more before consulting LLM
        if len(self.current_context) >= 5:  # Batch size
            prompt = self.create_prompt(self.current_context)
            response = await self.llm.apredict(prompt)
            self.current_context.clear()
            return AgentResponse(**json.loads(response))
        
        return None

    async def run(self, command: str):
        """Run command and process output stream"""
        tool = self.tools[0]  # Assume Linux command tool
        
        async for chunk in tool.run(command):
            self.current_context.append(chunk)
            
            if response := await self.process_chunks([chunk]):
                if response.action != "continue":
                    # Handle interactive needs or alerts
                    if response.action == "interact":
                        # Send response back to command
                        await tool.run(response.response)
                    
                    # Update memory
                    self.memory.save_context(
                        {"input": command},
                        {"output": response.reasoning}
                    )
                    
                    yield response

# Example usage
async def main():
    ssh = SSHConnection()
    
    tools = [LinuxCommandTool(ssh)]
    llm = ChatOpenAI(temperature=0, openai_api_key=UserSecretsClient().get_secret("OPENAI_API_KEY"))
    agent = StreamingAgent(llm, tools)
    
    async for response in agent.run("sudo apt upgrade"):
        if response.action != "continue":
            print(f"Agent response: {response.dict()}")

if __name__ == "__main__":
    await main()

AuthenticationException: Authentication failed.