diff --git a/packages/api/model.go b/packages/api/model.go index 7e1adf8e..41fb7f71 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -845,8 +845,18 @@ type PAMAccessApprovalRequestResponse struct { } `json:"request"` } +type PAMPolicyRuleConfig struct { + Patterns []string `json:"patterns"` +} + +type PAMPolicyRules struct { + CommandBlocking *PAMPolicyRuleConfig `json:"command-blocking,omitempty"` + SessionLogMasking *PAMPolicyRuleConfig `json:"session-log-masking,omitempty"` +} + type PAMSessionCredentialsResponse struct { Credentials PAMSessionCredentials `json:"credentials"` + PolicyRules *PAMPolicyRules `json:"policyRules,omitempty"` } type PAMSessionCredentials struct { diff --git a/packages/pam/compile_patterns_test.go b/packages/pam/compile_patterns_test.go new file mode 100644 index 00000000..57cc81b0 --- /dev/null +++ b/packages/pam/compile_patterns_test.go @@ -0,0 +1,54 @@ +package pam + +import ( + "testing" + + "github.com/Infisical/infisical-merge/packages/api" +) + +func TestCompilePolicyPatterns(t *testing.T) { + t.Run("nil config returns nil", func(t *testing.T) { + result := compilePolicyPatterns(nil, "sess-1", "test") + if result != nil { + t.Errorf("expected nil, got %v", result) + } + }) + + t.Run("empty patterns returns nil", func(t *testing.T) { + config := &api.PAMPolicyRuleConfig{Patterns: []string{}} + result := compilePolicyPatterns(config, "sess-1", "test") + if result != nil { + t.Errorf("expected nil, got %v", result) + } + }) + + t.Run("valid patterns all compile", func(t *testing.T) { + config := &api.PAMPolicyRuleConfig{ + Patterns: []string{`rm\s+-rf`, `shutdown`, `password\s*=\s*\S+`}, + } + result := compilePolicyPatterns(config, "sess-1", "test") + if len(result) != 3 { + t.Errorf("expected 3 compiled patterns, got %d", len(result)) + } + }) + + t.Run("invalid pattern is skipped", func(t *testing.T) { + config := &api.PAMPolicyRuleConfig{ + Patterns: []string{`rm\s+-rf`, `[invalid`, `shutdown`}, + } + result := compilePolicyPatterns(config, "sess-1", "test") + if len(result) != 2 { + t.Errorf("expected 2 compiled patterns (1 skipped), got %d", len(result)) + } + }) + + t.Run("all invalid patterns returns empty slice", func(t *testing.T) { + config := &api.PAMPolicyRuleConfig{ + Patterns: []string{`[bad`, `(unclosed`}, + } + result := compilePolicyPatterns(config, "sess-1", "test") + if len(result) != 0 { + t.Errorf("expected 0 compiled patterns, got %d", len(result)) + } + }) +} diff --git a/packages/pam/handlers/ssh/command_blocking_test.go b/packages/pam/handlers/ssh/command_blocking_test.go new file mode 100644 index 00000000..3cc7a429 --- /dev/null +++ b/packages/pam/handlers/ssh/command_blocking_test.go @@ -0,0 +1,54 @@ +package ssh + +import ( + "regexp" + "testing" +) + +func TestMatchBlockedCommand(t *testing.T) { + proxy := &SSHProxy{ + config: SSHProxyConfig{ + BlockedCommandPatterns: []*regexp.Regexp{ + regexp.MustCompile(`rm\s+-rf`), + regexp.MustCompile(`shutdown`), + regexp.MustCompile(`reboot`), + }, + }, + } + + tests := []struct { + name string + command string + blocked bool + }{ + {"blocks rm -rf", "rm -rf /", true}, + {"blocks rm -rf with extra space", "rm -rf /home", true}, + {"blocks sudo rm -rf", "sudo rm -rf /", true}, + {"blocks shutdown", "shutdown -h now", true}, + {"blocks reboot", "reboot", true}, + {"allows ls", "ls -la", false}, + {"allows rm without -rf", "rm file.txt", false}, + {"allows empty command", "", false}, + {"allows whitespace only", " ", false}, + {"allows normal commands", "cat /etc/hosts", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := proxy.matchBlockedCommand(tt.command) + if result != tt.blocked { + t.Errorf("matchBlockedCommand(%q) = %v, want %v", tt.command, result, tt.blocked) + } + }) + } +} + +func TestMatchBlockedCommandNoPatterns(t *testing.T) { + proxy := &SSHProxy{ + config: SSHProxyConfig{}, + } + + if proxy.matchBlockedCommand("rm -rf /") { + t.Error("with no patterns, should never block") + } +} diff --git a/packages/pam/handlers/ssh/proxy.go b/packages/pam/handlers/ssh/proxy.go index 552b65fc..0d2657af 100644 --- a/packages/pam/handlers/ssh/proxy.go +++ b/packages/pam/handlers/ssh/proxy.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "regexp" "strings" "sync" "time" @@ -16,23 +17,28 @@ import ( // SSHProxyConfig holds configuration for the SSH proxy type SSHProxyConfig struct { - TargetAddr string // e.g., "target-host:22" - AuthMethod string - InjectUsername string - InjectPassword string - InjectPrivateKey string - InjectCertificate string - SessionID string - SessionLogger session.SessionLogger + TargetAddr string // e.g., "target-host:22" + AuthMethod string + InjectUsername string + InjectPassword string + InjectPrivateKey string + InjectCertificate string + SessionID string + SessionLogger session.SessionLogger + BlockedCommandPatterns []*regexp.Regexp // Regex patterns for command blocking (nil = no blocking) } // SSHProxy handles proxying SSH connections with credential injection type SSHProxy struct { - config SSHProxyConfig - mutex sync.Mutex - sessionData []byte // Store session data for logging - inputBuffer []byte // Buffer for input data to batch keystrokes - inputChannelType session.TerminalChannelType // Channel type for buffered input + config SSHProxyConfig + mutex sync.Mutex + sessionData []byte // Store session data for logging + inputBuffer []byte // Buffer for input data to batch keystrokes + inputChannelType session.TerminalChannelType // Channel type for buffered input + escapeState int // 0=normal, 1=got ESC, 2=in CSI sequence + outputMutex sync.Mutex + outputBuffer []byte // Buffer for output data to enable masking across chunks + outputChannelType session.TerminalChannelType // Channel type for buffered output } // channelState holds per-channel state for tracking session type @@ -246,7 +252,12 @@ func (p *SSHProxy) handleChannel(ctx context.Context, newChannel ssh.NewChannel, // Client to Server go func() { - err := p.proxyData(clientChannel, serverChannel, "client→server", sessionID, true, chState) + var err error + if len(p.config.BlockedCommandPatterns) > 0 { + err = p.proxyClientToServerWithBlocking(clientChannel, serverChannel, clientChannel, sessionID, chState) + } else { + err = p.proxyData(clientChannel, serverChannel, "client→server", sessionID, true, chState) + } // Send EOF so the remote process exits and delivers exit-status. serverChannel.CloseWrite() //nolint:errcheck clientToServerDone <- err @@ -316,6 +327,30 @@ func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetCha if len(req.Payload) >= 4+cmdLen { command := string(req.Payload[4 : 4+cmdLen]) + // Check exec command against blocked patterns + if p.matchBlockedCommand(command) { + log.Warn(). + Str("sessionID", sessionID). + Str("command", command). + Msg("Blocked SSH exec command") + + // Log the blocked exec to session recording + blockedEvent := session.TerminalEvent{ + Timestamp: time.Now(), + EventType: session.TerminalEventInput, + ChannelType: session.TerminalChannelExec, + Data: []byte(fmt.Sprintf("$ %s\n[BLOCKED] Command not permitted\n", command)), + } + if err := p.config.SessionLogger.LogTerminalEvent(blockedEvent); err != nil { + log.Error().Err(err).Str("sessionID", sessionID).Msg("Failed to log blocked exec command") + } + + if req.WantReply { + req.Reply(false, nil) + } + continue + } + // Determine the type of operation isSCP := strings.HasPrefix(command, "scp ") chState.mutex.Lock() @@ -452,11 +487,14 @@ func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetCha func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, sessionID string, logInput bool, chState *channelState) error { buf := make([]byte, 32*1024) // 32KB buffer - // Flush any remaining input buffer on exit + // Flush any remaining buffers on exit defer func() { if logInput && len(p.inputBuffer) > 0 { p.flushInputBuffer(sessionID) } + if !logInput && len(p.outputBuffer) > 0 { + p.flushOutputBuffer(sessionID) + } }() for { @@ -500,21 +538,8 @@ func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, ses if logInput { p.bufferInput(buf[:n], sessionID, channelType) } else { - // For output, log immediately as before - event := session.TerminalEvent{ - Timestamp: time.Now(), - EventType: session.TerminalEventOutput, - ChannelType: channelType, - Data: make([]byte, n), - } - copy(event.Data, buf[:n]) - - if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { - log.Error().Err(err). - Str("sessionID", sessionID). - Str("eventType", string(session.TerminalEventOutput)). - Msg("Failed to log terminal event") - } + // Buffer output until newline so masking patterns can match across echo chunks + p.bufferOutput(buf[:n], sessionID, channelType) } } @@ -547,7 +572,28 @@ func (p *SSHProxy) bufferInput(data []byte, sessionID string, channelType sessio p.inputChannelType = channelType for _, b := range data { + // Skip ANSI escape sequences (e.g., cursor position reports like ESC[11;17R) + // States: 0=normal, 1=got ESC, 2=in CSI sequence + if p.escapeState == 1 { + if b == '[' { + p.escapeState = 2 // ESC[ = CSI sequence start + } else { + p.escapeState = 0 // Two-byte escape sequence (ESC + char), done + } + continue + } + if p.escapeState == 2 { + // In CSI sequence: parameter bytes (0x30-0x3F) and intermediate bytes (0x20-0x2F) + // continue until final byte (0x40-0x7E) + if b >= 0x40 && b <= 0x7E { + p.escapeState = 0 // Final byte, sequence complete + } + continue + } + switch b { + case 0x1B: // ESC - start of escape sequence + p.escapeState = 1 case 0x7F, 0x08: // DEL (backspace on most terminals) or BS if len(p.inputBuffer) > 0 { p.inputBuffer = p.inputBuffer[:len(p.inputBuffer)-1] @@ -613,6 +659,208 @@ func (p *SSHProxy) flushInputBufferUnsafe(sessionID string) { p.inputBuffer = p.inputBuffer[:0] } +// bufferOutput accumulates output data and flushes on newline or size limit. +// This allows session log masking patterns to match across character-by-character echo, +// because the regex sees a full line rather than individual bytes. +func (p *SSHProxy) bufferOutput(data []byte, sessionID string, channelType session.TerminalChannelType) { + p.outputMutex.Lock() + defer p.outputMutex.Unlock() + + p.outputChannelType = channelType + + for _, b := range data { + p.outputBuffer = append(p.outputBuffer, b) + + // Flush on newline (LF) or if buffer gets too large + if b == 0x0A || len(p.outputBuffer) >= 4096 { + p.flushOutputBufferUnsafe(sessionID) + } + } +} + +// flushOutputBuffer flushes the output buffer with locking +func (p *SSHProxy) flushOutputBuffer(sessionID string) { + p.outputMutex.Lock() + defer p.outputMutex.Unlock() + p.flushOutputBufferUnsafe(sessionID) +} + +// flushOutputBufferUnsafe flushes the output buffer without locking (caller must hold lock) +func (p *SSHProxy) flushOutputBufferUnsafe(sessionID string) { + if len(p.outputBuffer) == 0 { + return + } + + event := session.TerminalEvent{ + Timestamp: time.Now(), + EventType: session.TerminalEventOutput, + ChannelType: p.outputChannelType, + Data: make([]byte, len(p.outputBuffer)), + } + copy(event.Data, p.outputBuffer) + + if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + log.Error().Err(err). + Str("sessionID", sessionID). + Str("eventType", string(session.TerminalEventOutput)). + Msg("Failed to log terminal event") + } + + p.outputBuffer = p.outputBuffer[:0] +} + +// matchBlockedCommand checks if a command matches any blocked pattern. +func (p *SSHProxy) matchBlockedCommand(command string) bool { + command = strings.TrimSpace(command) + if command == "" || len(p.config.BlockedCommandPatterns) == 0 { + return false + } + for _, pattern := range p.config.BlockedCommandPatterns { + if pattern.MatchString(command) { + return true + } + } + return false +} + +// proxyClientToServerWithBlocking proxies client→server data with command blocking. +// It forwards all bytes immediately except Enter (CR/LF), which is checked against +// blocked patterns before forwarding. If blocked, Enter is suppressed, a message is +// sent to the client, and Ctrl+U/Ctrl+C clear the server's input. +func (p *SSHProxy) proxyClientToServerWithBlocking(src io.Reader, dst io.Writer, clientWriter io.Writer, sessionID string, chState *channelState) error { + buf := make([]byte, 32*1024) + + log.Debug(). + Str("sessionID", sessionID). + Int("numPatterns", len(p.config.BlockedCommandPatterns)). + Msg("Command blocking active for client→server proxy") + + defer func() { + if len(p.inputBuffer) > 0 { + p.flushInputBuffer(sessionID) + } + }() + + for { + n, err := src.Read(buf) + if n > 0 { + chState.mutex.Lock() + isBinary := chState.isBinarySession + sftpParser := chState.sftpParser + channelType := chState.channelType + chState.mutex.Unlock() + + if isBinary { + // Binary SFTP/SCP session — no command blocking, log file operations + if sftpParser != nil { + operations := sftpParser.Parse(buf[:n]) + for _, op := range operations { + logMsg := FormatOperation(op) + "\n" + event := session.TerminalEvent{ + Timestamp: time.Now(), + EventType: session.TerminalEventInput, + ChannelType: session.TerminalChannelSFTP, + Data: []byte(logMsg), + } + if logErr := p.config.SessionLogger.LogTerminalEvent(event); logErr != nil { + log.Error().Err(logErr). + Str("sessionID", sessionID). + Str("operation", op.Type). + Str("path", op.Path). + Msg("Failed to log SFTP operation") + } + } + } + if written, writeErr := dst.Write(buf[:n]); writeErr != nil { + return writeErr + } else if written != n { + return io.ErrShortWrite + } + } else { + // Interactive/exec session — check for blocked commands at Enter + segStart := 0 + for i := 0; i < n; i++ { + b := buf[i] + if b == 0x0D || b == 0x0A { + // Forward and log everything before this CR/LF + if i > segStart { + segment := buf[segStart:i] + p.bufferInput(segment, sessionID, channelType) + if _, writeErr := dst.Write(segment); writeErr != nil { + return writeErr + } + } + + // Check accumulated command against blocked patterns + p.mutex.Lock() + command := string(p.inputBuffer) + p.mutex.Unlock() + + if p.matchBlockedCommand(command) { + // BLOCKED: flush the typed command to session log, then log the block + p.mutex.Lock() + p.inputBuffer = append(p.inputBuffer, b) + p.flushInputBufferUnsafe(sessionID) + p.mutex.Unlock() + + // Flush pending output buffer so the echoed command appears before the blocked message + p.flushOutputBuffer(sessionID) + + // Send error message to client (red text) + blockedMsg := "\r\n\033[31m[BLOCKED] Command not permitted\033[0m\r\n" + clientWriter.Write([]byte(blockedMsg)) + + // Log the blocked message as output so it appears in session replay + blockedEvent := session.TerminalEvent{ + Timestamp: time.Now(), + EventType: session.TerminalEventOutput, + ChannelType: channelType, + Data: []byte(blockedMsg), + } + if logErr := p.config.SessionLogger.LogTerminalEvent(blockedEvent); logErr != nil { + log.Error().Err(logErr).Str("sessionID", sessionID).Msg("Failed to log blocked command event") + } + + // Clear server's pending input and get a fresh prompt (synthetic, bypass buffers) + dst.Write([]byte{0x15}) // Ctrl+U — clear line + dst.Write([]byte{0x03}) // Ctrl+C — fresh prompt + + log.Warn(). + Str("sessionID", sessionID). + Str("command", command). + Msg("Blocked SSH command") + } else { + // Allowed — forward the CR/LF through normal path + p.bufferInput([]byte{b}, sessionID, channelType) + if _, writeErr := dst.Write([]byte{b}); writeErr != nil { + return writeErr + } + } + + segStart = i + 1 + } + } + + // Forward remaining segment after last CR/LF (or the entire chunk if no CR/LF) + if segStart < n { + segment := buf[segStart:n] + p.bufferInput(segment, sessionID, channelType) + if _, writeErr := dst.Write(segment); writeErr != nil { + return writeErr + } + } + } + } + + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + // extractSCPPath extracts the file path from an SCP command // SCP commands look like: scp -t /path/to/file or scp -f /path/to/file func extractSCPPath(command string) string { diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index b4cf2650..7b6976a6 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -7,8 +7,10 @@ import ( "encoding/json" "fmt" "net/url" + "regexp" "time" + "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/pam/handlers" "github.com/Infisical/infisical-merge/packages/pam/handlers/kubernetes" "github.com/Infisical/infisical-merge/packages/pam/handlers/mongodb" @@ -104,6 +106,28 @@ func HandlePAMCancellation(ctx context.Context, conn *tls.Conn, pamConfig *Gatew return nil } +// compilePolicyPatterns compiles regex pattern strings, logging warnings for any that fail. +func compilePolicyPatterns(config *api.PAMPolicyRuleConfig, sessionID string, ruleType string) []*regexp.Regexp { + if config == nil || len(config.Patterns) == 0 { + return nil + } + var compiled []*regexp.Regexp + for _, pattern := range config.Patterns { + re, err := regexp.Compile(pattern) + if err != nil { + log.Warn(). + Err(err). + Str("sessionID", sessionID). + Str("ruleType", ruleType). + Str("pattern", pattern). + Msg("Failed to compile policy pattern, skipping") + continue + } + compiled = append(compiled, re) + } + return compiled +} + func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMConfig, httpClient *resty.Client) error { credentials, err := pamConfig.CredentialsManager.GetPAMSessionCredentials(pamConfig.SessionId, pamConfig.ExpiryTime) if err != nil { @@ -154,7 +178,14 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo if err != nil { return fmt.Errorf("failed to get PAM session encryption key: %w", err) } - sessionLogger, err := session.NewSessionLogger(pamConfig.SessionId, encryptionKey, pamConfig.ExpiryTime, pamConfig.ResourceType) + + // Compile session log masking patterns from policy rules + var maskingPatterns []*regexp.Regexp + if credentials.PolicyRules != nil { + maskingPatterns = compilePolicyPatterns(credentials.PolicyRules.SessionLogMasking, pamConfig.SessionId, "session-log-masking") + } + + sessionLogger, err := session.NewSessionLogger(pamConfig.SessionId, encryptionKey, pamConfig.ExpiryTime, pamConfig.ResourceType, maskingPatterns) if err != nil { return fmt.Errorf("failed to create session logger: %w", err) } @@ -271,15 +302,22 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Msg("Starting Redis PAM proxy") return proxy.HandleConnection(ctx, conn) case session.ResourceTypeSSH: + // Compile command blocking patterns from policy rules + var blockedCommandPatterns []*regexp.Regexp + if credentials.PolicyRules != nil { + blockedCommandPatterns = compilePolicyPatterns(credentials.PolicyRules.CommandBlocking, pamConfig.SessionId, "command-blocking") + } + sshConfig := ssh.SSHProxyConfig{ - TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), - AuthMethod: credentials.AuthMethod, - InjectUsername: credentials.Username, - InjectPassword: credentials.Password, - InjectPrivateKey: credentials.PrivateKey, - InjectCertificate: credentials.Certificate, - SessionID: pamConfig.SessionId, - SessionLogger: sessionLogger, + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + AuthMethod: credentials.AuthMethod, + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectPrivateKey: credentials.PrivateKey, + InjectCertificate: credentials.Certificate, + SessionID: pamConfig.SessionId, + SessionLogger: sessionLogger, + BlockedCommandPatterns: blockedCommandPatterns, } proxy := ssh.NewSSHProxy(sshConfig) log.Info(). diff --git a/packages/pam/session/credentials.go b/packages/pam/session/credentials.go index 1e2901b1..13ee07a5 100644 --- a/packages/pam/session/credentials.go +++ b/packages/pam/session/credentials.go @@ -25,6 +25,7 @@ type PAMCredentials struct { SSLCertificate string Url string ServiceAccountToken string + PolicyRules *api.PAMPolicyRules } type cachedCredentials struct { @@ -100,6 +101,7 @@ func (cm *CredentialsManager) GetPAMSessionCredentials(sessionId string, expiryT SSLCertificate: response.Credentials.SSLCertificate, Url: response.Credentials.Url, ServiceAccountToken: response.Credentials.ServiceAccountToken, + PolicyRules: response.PolicyRules, } cm.cacheMutex.Lock() diff --git a/packages/pam/session/logger.go b/packages/pam/session/logger.go index 38471aee..77c3c3e3 100644 --- a/packages/pam/session/logger.go +++ b/packages/pam/session/logger.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "path/filepath" + "regexp" "sync" "time" @@ -78,12 +79,13 @@ type SessionLogger interface { } type EncryptedSessionLogger struct { - sessionID string - encryptionKey string - expiresAt time.Time - file *os.File - mutex sync.Mutex - sessionStart time.Time // Track session start time for elapsed time calculation + sessionID string + encryptionKey string + expiresAt time.Time + file *os.File + mutex sync.Mutex + sessionStart time.Time // Track session start time for elapsed time calculation + maskingPatterns []*regexp.Regexp // Patterns for masking sensitive data in session logs } type RequestResponsePair struct { @@ -183,7 +185,7 @@ func CleanupSessionMutex(sessionID string) { } } -func NewSessionLogger(sessionID string, encryptionKey string, expiresAt time.Time, resourceType string) (*EncryptedSessionLogger, error) { +func NewSessionLogger(sessionID string, encryptionKey string, expiresAt time.Time, resourceType string, maskingPatterns []*regexp.Regexp) (*EncryptedSessionLogger, error) { if sessionID == "" { return nil, fmt.Errorf("session ID cannot be empty") } @@ -214,11 +216,12 @@ func NewSessionLogger(sessionID string, encryptionKey string, expiresAt time.Tim } return &EncryptedSessionLogger{ - sessionID: sessionID, - encryptionKey: encryptionKey, - expiresAt: expiresAt, - file: file, - sessionStart: time.Now(), + sessionID: sessionID, + encryptionKey: encryptionKey, + expiresAt: expiresAt, + file: file, + sessionStart: time.Now(), + maskingPatterns: maskingPatterns, }, nil } @@ -265,24 +268,51 @@ func (sl *EncryptedSessionLogger) writeEvent(productEventData func() ([]byte, er return nil } +// applyMasking replaces regex matches in byte data with [MASKED] +func (sl *EncryptedSessionLogger) applyMasking(data []byte) []byte { + if len(sl.maskingPatterns) == 0 || len(data) == 0 { + return data + } + result := data + for _, pattern := range sl.maskingPatterns { + result = pattern.ReplaceAll(result, []byte("[MASKED]")) + } + return result +} + +// applyMaskingString replaces regex matches in string data with [MASKED] +func (sl *EncryptedSessionLogger) applyMaskingString(s string) string { + if len(sl.maskingPatterns) == 0 || s == "" { + return s + } + result := s + for _, pattern := range sl.maskingPatterns { + result = pattern.ReplaceAllString(result, "[MASKED]") + } + return result +} + func (sl *EncryptedSessionLogger) LogEntry(entry SessionLogEntry) error { return sl.writeEvent(func() ([]byte, error) { + entry.Input = sl.applyMaskingString(entry.Input) + entry.Output = sl.applyMaskingString(entry.Output) return json.Marshal(entry) }) } func (sl *EncryptedSessionLogger) LogTerminalEvent(event TerminalEvent) error { return sl.writeEvent(func() ([]byte, error) { - // Calculate elapsed time if not already set if event.ElapsedTime == 0 { event.ElapsedTime = time.Since(sl.sessionStart).Seconds() } + event.Data = sl.applyMasking(event.Data) return json.Marshal(event) }) } func (sl *EncryptedSessionLogger) LogHttpEvent(event HttpEvent) error { return sl.writeEvent(func() ([]byte, error) { + event.Body = sl.applyMasking(event.Body) return json.Marshal(event) }) } diff --git a/packages/pam/session/logger_masking_test.go b/packages/pam/session/logger_masking_test.go new file mode 100644 index 00000000..bfc56ee0 --- /dev/null +++ b/packages/pam/session/logger_masking_test.go @@ -0,0 +1,82 @@ +package session + +import ( + "regexp" + "testing" +) + +func TestApplyMasking(t *testing.T) { + logger := &EncryptedSessionLogger{ + maskingPatterns: []*regexp.Regexp{ + regexp.MustCompile(`password\s*=\s*\S+`), + regexp.MustCompile(`secret_key`), + }, + } + + tests := []struct { + name string + input string + expected string + }{ + { + name: "masks password pattern", + input: "SET password = hunter2", + expected: "SET [MASKED]", + }, + { + name: "masks secret_key", + input: "export secret_key=abc123", + expected: "export [MASKED]=abc123", + }, + { + name: "masks multiple occurrences", + input: "password=foo and password=bar", + expected: "[MASKED] and [MASKED]", + }, + { + name: "no match leaves input unchanged", + input: "SELECT * FROM users", + expected: "SELECT * FROM users", + }, + { + name: "empty input", + input: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := logger.applyMaskingString(tt.input) + if result != tt.expected { + t.Errorf("applyMaskingString(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } + + // Test byte variant + t.Run("applyMasking bytes", func(t *testing.T) { + input := []byte("password = secret123") + result := logger.applyMasking(input) + expected := "[MASKED]" + if string(result) != expected { + t.Errorf("applyMasking(%q) = %q, want %q", input, result, expected) + } + }) +} + +func TestApplyMaskingNoPatterns(t *testing.T) { + logger := &EncryptedSessionLogger{} + + input := "password=secret" + result := logger.applyMaskingString(input) + if result != input { + t.Errorf("with no patterns, expected input unchanged, got %q", result) + } + + byteInput := []byte("password=secret") + byteResult := logger.applyMasking(byteInput) + if string(byteResult) != string(byteInput) { + t.Errorf("with no patterns, expected bytes unchanged") + } +}