diff --git a/lib/log/api.go b/lib/log/api.go index 9b9a827a..68e7588b 100644 --- a/lib/log/api.go +++ b/lib/log/api.go @@ -26,3 +26,9 @@ type DebugLogLevelGetter interface { type DebugLogLevelSetter interface { SetLevel(maxLevel int16) } + +type FullDebugLogger interface { + DebugLogger + DebugLogLevelGetter + DebugLogLevelSetter +} diff --git a/lib/log/serverlogger/api.go b/lib/log/serverlogger/api.go index 615bc022..208d8e8d 100644 --- a/lib/log/serverlogger/api.go +++ b/lib/log/serverlogger/api.go @@ -5,10 +5,10 @@ import ( "fmt" "io" "log" - "os" + "sync" - "github.com/Symantec/Dominator/lib/log/debuglogger" "github.com/Symantec/Dominator/lib/logbuf" + "github.com/Symantec/Dominator/lib/srpc" ) var ( @@ -19,8 +19,18 @@ var ( ) type Logger struct { - *debuglogger.Logger + accessChecker func(authInfo *srpc.AuthInformation) bool circularBuffer *logbuf.LogBuffer + flags int + level int16 + maxLevel int16 + mutex sync.Mutex // Lock everything below. + streamers map[*streamerType]struct{} +} + +type streamerType struct { + debugLevel int16 + output chan<- []byte } // New will create a Logger which has an internal log buffer (see the @@ -28,7 +38,7 @@ type Logger struct { // By default, the max debug level is -1, meaning all debug logs are dropped // (ignored). // The name of the new logger is given by name. This name is used to remotely -// identify the logger for RPC methods such as Logger.SetDebugLevel. The first +// identify the logger for SRPC methods such as Logger.SetDebugLevel. The first // or primary logger should be created with name "" (the empty string). func New(name string) *Logger { flags := log.LstdFlags @@ -60,19 +70,42 @@ func NewWithOptions(name string, options logbuf.Options, flags int) *Logger { return newLogger(name, options, flags) } +// Debug will call the Print method if level is less than or equal to the max +// debug level for the Logger. +func (l *Logger) Debug(level uint8, v ...interface{}) { + l.debug(int16(level), v...) +} + +// Debugf will call the Printf method if level is less than or equal to the max +// debug level for the Logger. +func (l *Logger) Debugf(level uint8, format string, v ...interface{}) { + l.debugf(int16(level), format, v...) +} + +// Debugln will call the Println method if level is less than or equal to the +// max debug level for the Logger. +func (l *Logger) Debugln(level uint8, v ...interface{}) { + l.debugln(int16(level), v...) +} + +// GetLevel gets the current maximum debug level. +func (l *Logger) GetLevel() int16 { + return l.level +} + +// Fatal is equivalent to Print() followed by a call to os.Exit(1). func (l *Logger) Fatal(v ...interface{}) { - msg := fmt.Sprint(v...) - l.Print(msg) - l.circularBuffer.Flush() - os.Exit(1) + l.fatals(fmt.Sprint(v...)) } +// Fatalf is equivalent to Printf() followed by a call to os.Exit(1). func (l *Logger) Fatalf(format string, v ...interface{}) { - l.Fatal(fmt.Sprintf(format, v...)) + l.fatals(fmt.Sprintf(format, v...)) } +// Fatalln is equivalent to Println() followed by a call to os.Exit(1). func (l *Logger) Fatalln(v ...interface{}) { - l.Fatal(fmt.Sprintln(v...)) + l.fatals(fmt.Sprintln(v...)) } // Flush flushes the open log file (if one is open). This should only be called @@ -82,6 +115,52 @@ func (l *Logger) Flush() error { return l.circularBuffer.Flush() } +// Panic is equivalent to Print() followed by a call to panic(). +func (l *Logger) Panic(v ...interface{}) { + l.panics(fmt.Sprint(v...)) +} + +// Panicf is equivalent to Printf() followed by a call to panic(). +func (l *Logger) Panicf(format string, v ...interface{}) { + l.panics(fmt.Sprintf(format, v...)) +} + +// Panicln is equivalent to Println() followed by a call to panic(). +func (l *Logger) Panicln(v ...interface{}) { + l.panics(fmt.Sprintln(v...)) +} + +// Print prints to the logger. Arguments are handled in the manner of fmt.Print. +func (l *Logger) Print(v ...interface{}) { + l.prints(fmt.Sprint(v...)) +} + +// Printf prints to the logger. Arguments are handled in the manner of +// fmt.Printf. +func (l *Logger) Printf(format string, v ...interface{}) { + l.prints(fmt.Sprintf(format, v...)) +} + +// Println prints to the logger. Arguments are handled in the manner of +// fmt.Println. +func (l *Logger) Println(v ...interface{}) { + l.prints(fmt.Sprintln(v...)) +} + +// SetAccessChecker sets the function that is called when SRPC methods are +// called for the Logger. This allows the application to control which users or +// groups are permitted to remotely control the Logger. +func (l *Logger) SetAccessChecker( + accessChecker func(authInfo *srpc.AuthInformation) bool) { + l.accessChecker = accessChecker +} + +// SetLevel sets the maximum debug level. A negative level will cause all debug +// messages to be dropped. +func (l *Logger) SetLevel(maxLevel int16) { + l.setLevel(maxLevel) +} + // WriteHtml will write the contents of the internal log buffer to writer, with // appropriate HTML markups. func (l *Logger) WriteHtml(writer io.Writer) { diff --git a/lib/log/serverlogger/impl.go b/lib/log/serverlogger/impl.go index 51e89cc1..3f79dca9 100644 --- a/lib/log/serverlogger/impl.go +++ b/lib/log/serverlogger/impl.go @@ -2,27 +2,50 @@ package serverlogger import ( "errors" + "fmt" "log" + "os" "strings" "sync" + "time" - "github.com/Symantec/Dominator/lib/log/debuglogger" + liblog "github.com/Symantec/Dominator/lib/log" "github.com/Symantec/Dominator/lib/logbuf" "github.com/Symantec/Dominator/lib/srpc" + "github.com/Symantec/Dominator/lib/srpc/serverutil" proto "github.com/Symantec/Dominator/proto/logger" ) type loggerMapT struct { + *serverutil.PerUserMethodLimiter sync.Mutex loggerMap map[string]*Logger } -var loggerMap *loggerMapT = &loggerMapT{loggerMap: make(map[string]*Logger)} +type grabWriter struct { + data []byte +} + +var loggerMap *loggerMapT = &loggerMapT{ + loggerMap: make(map[string]*Logger), + PerUserMethodLimiter: serverutil.NewPerUserMethodLimiter( + map[string]uint{ + "Debug": 1, + "Print": 1, + "SetDebugLevel": 1, + "Watch": 1, + }), +} func init() { srpc.RegisterName("Logger", loggerMap) } +func (w *grabWriter) Write(p []byte) (int, error) { + w.data = p + return len(p), nil +} + func newLogger(name string, options logbuf.Options, flags int) *Logger { loggerMap.Lock() defer loggerMap.Unlock() @@ -30,25 +53,205 @@ func newLogger(name string, options logbuf.Options, flags int) *Logger { panic("logger already exists: " + name) } circularBuffer := logbuf.NewWithOptions(options) - debugLogger := debuglogger.New(log.New(circularBuffer, "", flags)) - if *initialLogDebugLevel >= 0 { - debugLogger.SetLevel(int16(*initialLogDebugLevel)) - } logger := &Logger{ - Logger: debugLogger, circularBuffer: circularBuffer, + flags: flags, + level: int16(*initialLogDebugLevel), + streamers: make(map[*streamerType]struct{}), } + if logger.level < -1 { + logger.level = -1 + } + logger.maxLevel = logger.level + // Ensure this satisfies the published interface. + var debugLogger liblog.FullDebugLogger + debugLogger = logger + _ = debugLogger loggerMap.loggerMap[name] = logger return logger } -func (t *loggerMapT) Debug(conn *srpc.Conn, - request proto.DebugRequest, - reply *proto.DebugResponse) error { +func (l *Logger) checkAuth(authInfo *srpc.AuthInformation) error { + if authInfo.HaveMethodAccess { + return nil + } + if accessChecker := l.accessChecker; accessChecker == nil { + return errors.New("no access to resource") + } else if accessChecker(authInfo) { + return nil + } else { + return errors.New("no access to resource") + } +} + +func (l *Logger) debug(level int16, v ...interface{}) { + if l.maxLevel >= level { + l.log(level, fmt.Sprint(v...), false) + } +} + +func (l *Logger) debugf(level int16, format string, v ...interface{}) { + if l.maxLevel >= level { + l.log(level, fmt.Sprintf(format, v...), false) + } +} + +func (l *Logger) debugln(level int16, v ...interface{}) { + if l.maxLevel >= level { + l.log(level, fmt.Sprintln(v...), false) + } +} + +func (l *Logger) fatals(msg string) { + l.log(-1, msg, true) + os.Exit(1) +} + +func (l *Logger) log(level int16, msg string, dying bool) { + buffer := &grabWriter{} + rawLogger := log.New(buffer, "", l.flags) + rawLogger.Output(4, msg) + if l.level >= level { + l.circularBuffer.Write(buffer.data) + } + recalculateLevels := false + l.mutex.Lock() + defer l.mutex.Unlock() + for streamer := range l.streamers { + if streamer.debugLevel >= level { + select { + case streamer.output <- buffer.data: + default: + delete(l.streamers, streamer) + close(streamer.output) + recalculateLevels = true + } + } + } + if dying { + for streamer := range l.streamers { + delete(l.streamers, streamer) + close(streamer.output) + } + l.circularBuffer.Flush() + time.Sleep(time.Millisecond * 10) + } else if recalculateLevels { + l.updateMaxLevel() + } +} + +func (l *Logger) panics(msg string) { + l.log(-1, msg, true) + panic(msg) +} + +func (l *Logger) prints(msg string) { + l.log(-1, msg, false) +} + +func (l *Logger) setLevel(maxLevel int16) { + if maxLevel < -1 { + maxLevel = -1 + } + l.level = maxLevel + l.mutex.Lock() + l.updateMaxLevel() + l.mutex.Unlock() +} + +func (l *Logger) updateMaxLevel() { + maxLevel := l.level + for streamer := range l.streamers { + if streamer.debugLevel > maxLevel { + maxLevel = streamer.debugLevel + } + } + l.maxLevel = maxLevel +} + +func (l *Logger) watch(conn *srpc.Conn, request proto.WatchRequest) { + channel := make(chan []byte, 256) + if request.DebugLevel < -1 { + request.DebugLevel = -1 + } + streamer := &streamerType{ + debugLevel: request.DebugLevel, + output: channel, + } + l.mutex.Lock() + l.streamers[streamer] = struct{}{} + l.updateMaxLevel() + l.mutex.Unlock() + keepGoing := true + if request.DumpBuffer { + if err := l.circularBuffer.Dump(conn, "", "", false); err != nil { + keepGoing = false + } + } + timer := time.NewTimer(time.Millisecond * 100) + flushPending := true + closeNotifier := conn.GetCloseNotifier() + for keepGoing { + select { + case <-closeNotifier: + keepGoing = false + break + case data, ok := <-channel: + if !ok { + keepGoing = false + break + } + if _, err := conn.Write(data); err != nil { + keepGoing = false + break + } + if !flushPending { + timer.Reset(time.Millisecond * 100) + flushPending = true + } + case <-timer.C: + if conn.Flush() != nil { + keepGoing = false + break + } + flushPending = false + } + } + if flushPending { + conn.Flush() + } + l.mutex.Lock() + delete(l.streamers, streamer) + l.updateMaxLevel() + l.mutex.Unlock() + // Drain the channel. + for { + select { + case <-channel: + default: + return + } + } +} + +func (t *loggerMapT) getLogger(name string, + authInfo *srpc.AuthInformation) (*Logger, error) { loggerMap.Lock() defer loggerMap.Unlock() - if logger, ok := loggerMap.loggerMap[request.Name]; !ok { - return errors.New("unknown logger: " + request.Name) + if logger, ok := loggerMap.loggerMap[name]; !ok { + return nil, errors.New("unknown logger: " + name) + } else if err := logger.checkAuth(authInfo); err != nil { + return nil, err + } else { + return logger, nil + } +} + +func (t *loggerMapT) Debug(conn *srpc.Conn, + request proto.DebugRequest, reply *proto.DebugResponse) error { + authInfo := conn.GetAuthInformation() + if logger, err := t.getLogger(request.Name, authInfo); err != nil { + return err } else { logger.Debugf(request.Level, "Logger.Debug(%d): %s\n", request.Level, strings.Join(request.Args, " ")) @@ -59,10 +262,9 @@ func (t *loggerMapT) Debug(conn *srpc.Conn, func (t *loggerMapT) Print(conn *srpc.Conn, request proto.PrintRequest, reply *proto.PrintResponse) error { - loggerMap.Lock() - defer loggerMap.Unlock() - if logger, ok := loggerMap.loggerMap[request.Name]; !ok { - return errors.New("unknown logger: " + request.Name) + authInfo := conn.GetAuthInformation() + if logger, err := t.getLogger(request.Name, authInfo); err != nil { + return err } else { logger.Println("Logger.Print():", strings.Join(request.Args, " ")) return nil @@ -72,13 +274,33 @@ func (t *loggerMapT) Print(conn *srpc.Conn, func (t *loggerMapT) SetDebugLevel(conn *srpc.Conn, request proto.SetDebugLevelRequest, reply *proto.SetDebugLevelResponse) error { - loggerMap.Lock() - defer loggerMap.Unlock() - if logger, ok := loggerMap.loggerMap[request.Name]; !ok { - return errors.New("unknown logger: " + request.Name) + authInfo := conn.GetAuthInformation() + if logger, err := t.getLogger(request.Name, authInfo); err != nil { + return err } else { logger.Printf("Logger.SetDebugLevel(%d)\n", request.Level) logger.SetLevel(request.Level) return nil } } + +func (t *loggerMapT) Watch(conn *srpc.Conn, decoder srpc.Decoder, + encoder srpc.Encoder) error { + var request proto.WatchRequest + if err := decoder.Decode(&request); err != nil { + return err + } + authInfo := conn.GetAuthInformation() + if logger, err := t.getLogger(request.Name, authInfo); err != nil { + return encoder.Encode(proto.WatchResponse{Error: err.Error()}) + } else { + if err := encoder.Encode(proto.WatchResponse{}); err != nil { + return err + } + if err := conn.Flush(); err != nil { + return err + } + logger.watch(conn, request) + return srpc.ErrorCloseClient + } +} diff --git a/proto/logger/messages.go b/proto/logger/messages.go index c8e56966..07a2e84e 100644 --- a/proto/logger/messages.go +++ b/proto/logger/messages.go @@ -21,3 +21,15 @@ type SetDebugLevelRequest struct { } type SetDebugLevelResponse struct{} + +type WatchRequest struct { + DebugLevel int16 + DumpBuffer bool + ExcludeRegex string // Empty: nothing excluded. Processed after includes. + IncludeRegex string // Empty: everything included. + Name string +} + +type WatchResponse struct { + Error string +} // Log data are streamed afterwards.