/
gorm.go
125 lines (106 loc) · 3.45 KB
/
gorm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package logger
import (
"context"
"errors"
"fmt"
"path/filepath"
"runtime"
"strings"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
// GormLogger 操作对象,实现 gormlogger.Interface
type GormLogger struct {
ZapLogger *zap.Logger
SlowThreshold time.Duration
}
// NewGormLogger 外部调用。实例化一个 GormLogger 对象,示例:
//
// DB, err := gorm.Open(dbConfig, &gorm.Config{
// Logger: logger.NewGormLogger(),
// })
func NewGormLogger() GormLogger {
return GormLogger{
ZapLogger: zapLogger, // 使用全局的 logger.Logger 对象
SlowThreshold: 10 * time.Second, // 慢查询阈值,单位为秒
}
}
// LogMode 实现 gormlogger.Interface 的 LogMode 方法
func (l GormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
return GormLogger{
ZapLogger: l.ZapLogger,
SlowThreshold: l.SlowThreshold,
}
}
// Info 实现 gormlogger.Interface 的 Info 方法
func (l GormLogger) Info(ctx context.Context, str string, args ...interface{}) {
l.logger().Sugar().Debugf(str, args...)
}
// Warn 实现 gormlogger.Interface 的 Warn 方法
func (l GormLogger) Warn(ctx context.Context, str string, args ...interface{}) {
l.logger().Sugar().Warnf(str, args...)
}
// Error 实现 gormlogger.Interface 的 Error 方法
func (l GormLogger) Error(ctx context.Context, str string, args ...interface{}) {
l.logger().Sugar().Errorf(str, args...)
}
// Trace 实现 gormlogger.Interface 的 Trace 方法
func (l GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
// 获取运行时间
elapsed := time.Since(begin)
// 获取 SQL 请求和返回条数
sql, rows := fc()
// 通用字段
logFields := []zap.Field{
zap.String("sql", sql),
zap.String("time", MicrosecondsStr(elapsed)),
zap.Int64("rows", rows),
}
// Gorm 错误
if err != nil {
// 记录未找到的错误使用 warning 等级
if errors.Is(err, gorm.ErrRecordNotFound) {
l.logger().Warn("Database ErrRecordNotFound", logFields...)
} else {
// 其他错误使用 error 等级
logFields = append(logFields, zap.Error(err))
l.logger().Error("Database Error", logFields...)
}
}
// 慢查询日志
if l.SlowThreshold != 0 && elapsed > l.SlowThreshold {
l.logger().Warn("Database Slow Log", logFields...)
}
// 记录所有 SQL 请求
l.logger().Debug("Database Query", logFields...)
}
// logger 内用的辅助方法,确保 Zap 内置信息 Caller 的准确性(如 paginator/paginator.go:148)
func (l GormLogger) logger() *zap.Logger {
// 跳过 gorm 内置的调用
var (
gormPackage = filepath.Join("gorm.io", "gorm")
zapgormPackage = filepath.Join("moul.io", "zapgorm2")
)
// 减去一次封装,以及一次在 logger 初始化里添加 zap.AddCallerSkip(1)
clone := l.ZapLogger.WithOptions(zap.AddCallerSkip(-2))
for i := 2; i < 15; i++ {
_, file, _, ok := runtime.Caller(i)
switch {
case !ok:
case strings.HasSuffix(file, "_test.go"):
case strings.Contains(file, gormPackage):
case strings.Contains(file, zapgormPackage):
default:
// 返回一个附带跳过行号的新的 zap logger
return clone.WithOptions(zap.AddCallerSkip(i))
}
}
return l.ZapLogger
}
// MicrosecondsStr 将 time.Duration 类型(nano seconds 为单位)
// 输出为小数点后 3 位的 ms (microsecond 毫秒,千分之一秒)
func MicrosecondsStr(elapsed time.Duration) string {
return fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)
}