Skip to content

Commit

Permalink
Type safe deinstrumenting
Browse files Browse the repository at this point in the history
  • Loading branch information
DimitarPetrov committed Sep 6, 2020
1 parent 94bcdf2 commit b60eec9
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 78 deletions.
141 changes: 137 additions & 4 deletions tracing/deinstrument.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"go/token"
"io"
"os"
"reflect"
)

type codeDeinstrumenter struct {
Expand Down Expand Up @@ -62,10 +63,13 @@ func (cd *codeDeinstrumenter) DeinstrumentFile(fset *token.FileSet, file *ast.Fi
if len(t.Body.List) >= instrumentationStmtsCount {
firstStmntDecorations := t.Body.List[0].Decorations().Start.All()
secondStmntDecorations := t.Body.List[instrumentationStmtsCount-1].Decorations().End.All()
if len(firstStmntDecorations) > 0 && firstStmntDecorations[0] == "/* prinTracer */" &&
len(secondStmntDecorations) > 0 && secondStmntDecorations[0] == "/* prinTracer */" {
t.Body.List = t.Body.List[instrumentationStmtsCount:]
t.Body.List[0].Decorations().Before = dst.None
if len(firstStmntDecorations) > 0 && firstStmntDecorations[0] == printracerCommentWatermark &&
len(secondStmntDecorations) > 0 && secondStmntDecorations[0] == printracerCommentWatermark {

if checkInstrumentationStatementsIntegrity(t) {
t.Body.List = t.Body.List[instrumentationStmtsCount:]
t.Body.List[0].Decorations().Before = dst.None
}
}
}
}
Expand All @@ -74,3 +78,132 @@ func (cd *codeDeinstrumenter) DeinstrumentFile(fset *token.FileSet, file *ast.Fi

return decorator.Fprint(out, f)
}

func checkInstrumentationStatementsIntegrity(f *dst.FuncDecl) bool {
stmts := f.Body.List
instrumentationStmts := buildInstrumentationStmts(f)

for i := 0; i < instrumentationStmtsCount; i++ {
if !equalStmt(stmts[i], instrumentationStmts[i]) {
return false
}
}
return true
}

func equalStmt(stmt1, stmt2 dst.Stmt) bool {
switch t := stmt1.(type) {
case *dst.AssignStmt:
instStmt, ok := stmt2.(*dst.AssignStmt)
if !ok {
return false
}
if !(equalExprSlice(t.Lhs, instStmt.Lhs) && equalExprSlice(t.Rhs, instStmt.Rhs) && reflect.DeepEqual(t.Tok, instStmt.Tok)) {
return false
}
return true
case *dst.IfStmt:
instStmt, ok := stmt2.(*dst.IfStmt)
if !ok {
return false
}
if !(equalStmt(t.Init, instStmt.Init) && equalExpr(t.Cond, instStmt.Cond) && equalStmt(t.Body, instStmt.Body) && equalStmt(t.Else, instStmt.Else)) {
return false
}
return true
case *dst.ExprStmt:
instStmt, ok := stmt2.(*dst.ExprStmt)
if !ok {
return false
}
if !(equalExpr(t.X, instStmt.X)) {
return false
}
return true
case *dst.DeferStmt:
instStmt, ok := stmt2.(*dst.DeferStmt)
if !ok {
return false
}
if !(equalExpr(t.Call, instStmt.Call)) {
return false
}
return true
case *dst.BlockStmt:
instStmt, ok := stmt2.(*dst.BlockStmt)
if !ok {
return false
}
if len(t.List) != len(instStmt.List) || t.RbraceHasNoPos != instStmt.RbraceHasNoPos {
return false
}
for i, stmt1 := range t.List {
if !equalStmt(stmt1, instStmt.List[i]) {
return false
}
}
return true
}
return reflect.DeepEqual(stmt1, stmt2)
}

func equalExprSlice(exprSlice1, exprSlice2 []dst.Expr) bool {
if len(exprSlice1) != len(exprSlice2) {
return false
}
for i, expr1 := range exprSlice1 {
if !equalExpr(expr1, exprSlice2[i]) {
return false
}
}
return true
}

func equalExpr(expr1, expr2 dst.Expr) bool {
switch t := expr1.(type) {
case *dst.Ident:
instExpr, ok := expr2.(*dst.Ident)
if !ok {
instExpr, ok := expr2.(*dst.BasicLit)
if !ok {
return false
}
return t.Name == instExpr.Value
}
return t.Name == instExpr.Name && t.Path == instExpr.Path
case *dst.CallExpr:
instExpr, ok := expr2.(*dst.CallExpr)
if !ok {
return false
}
if !(equalExprSlice(t.Args, instExpr.Args) && equalExpr(t.Fun, instExpr.Fun)) {
return false
}
return true
case *dst.SelectorExpr:
instExpr, ok := expr2.(*dst.SelectorExpr)
if !ok {
return false
}
if !(equalExpr(t.X, instExpr.X) && equalExpr(t.Sel, instExpr.Sel)) {
return false
}
return true
case *dst.SliceExpr:
instExpr, ok := expr2.(*dst.SliceExpr)
if !ok {
return false
}
if !(t.Slice3 == instExpr.Slice3 && equalExpr(t.X, instExpr.X) && equalExpr(t.High, instExpr.High) && equalExpr(t.Low, instExpr.Low) && equalExpr(t.Max, instExpr.Max)) {
return false
}
return true
case *dst.BasicLit:
instExpr, ok := expr2.(*dst.BasicLit)
if !ok {
return false
}
return t.Value == instExpr.Value && t.Kind == instExpr.Kind
}
return reflect.DeepEqual(expr1, expr2)
}
3 changes: 2 additions & 1 deletion tracing/deinstrument_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func TestDeinstrumentFile(t *testing.T) {
{Name: "DeinstrumentFileWithoutFmtImport", InputCode: resultCodeWithImportsWithoutFmt, OutputCode: codeWithImportsWithoutFmt},
{Name: "DeinstrumentFileWithoutFunctions", InputCode: resultCodeWithoutFunction, OutputCode: codeWithoutFunction},
{Name: "DeinstrumentFileWithoutPreviousInstrumentation", InputCode: codeWithMultipleImports, OutputCode: codeWithMultipleImports},
{Name: "DeinstrumentFileDoesNotChangeManuallyEditedFunctions", InputCode: editedResultCodeWithoutImports, OutputCode: editedResultCodeWithoutImports},
}

for _, test := range tests {
Expand All @@ -47,7 +48,7 @@ func TestDeinstrumentFile(t *testing.T) {
}

if buff2.String() != test.OutputCode {
t.Error("Assertion failed!")
t.Errorf("Assertion failed! Expected %s god %s", test.OutputCode, buff2.String())
}
})
}
Expand Down
102 changes: 30 additions & 72 deletions tracing/instrument.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,35 @@ import (
)

const funcNameVarName = "funcName"
const funcPCVarName = "funcPC"

const callerFuncNameVarName = "caller"
const defaultCallerName = "unknown"
const callerFuncPCVarName = "callerPC"

const callIDVarName = "callID"
const instrumentationStmtsCount = 9

const printracerCommentWatermark = "/* prinTracer */"

const instrumentationStmtsCount = 9 // Acts like a contract of how many statements instrumentation adds and deinstrumentation removes.

func buildInstrumentationStmts(f *dst.FuncDecl) [instrumentationStmtsCount]dst.Stmt {
return [instrumentationStmtsCount]dst.Stmt{
newAssignStmt(funcNameVarName, f.Name.Name),
newAssignStmt(callerFuncNameVarName, defaultCallerName),
newGetFuncNameIfStatement("0", funcPCVarName, funcNameVarName),
newGetFuncNameIfStatement("1", callerFuncPCVarName, callerFuncNameVarName),
newMakeByteSliceStmt(),
newRandReadStmt(),
newParseUUIDFromByteSliceStmt(callIDVarName),
&dst.ExprStmt{
X: newPrintExprWithArgs(buildEnteringFunctionArgs(f)),
},
&dst.DeferStmt{
Call: newPrintExprWithArgs(buildExitFunctionArgs()),
},
}
}

type codeInstrumenter struct {
}
Expand Down Expand Up @@ -70,81 +96,13 @@ func (ci *codeInstrumenter) InstrumentFile(fset *token.FileSet, file *ast.File,
dst.Inspect(f, func(n dst.Node) bool {
switch t := n.(type) {
case *dst.FuncDecl:
var enteringStringFormat = "Function %s called by %s"
var exitingStringFormat = "Exiting function %s called by %s; callID=%s"

args := []dst.Expr{
&dst.BasicLit{
Kind: token.STRING,
Value: funcNameVarName,
},
&dst.BasicLit{
Kind: token.STRING,
Value: callerFuncNameVarName,
},
}

if len(t.Type.Params.List) > 0 {
enteringStringFormat += " with args"

for _, param := range t.Type.Params.List {
enteringStringFormat += " (%v)"
args = append(args, &dst.BasicLit{
Kind: token.STRING,
Value: param.Names[0].Name,
})
}
}
args = append(args, &dst.BasicLit{
Kind: token.STRING,
Value: callIDVarName,
})
args = append([]dst.Expr{
&dst.BasicLit{
Kind: token.STRING,
Value: `"` + enteringStringFormat + `; callID=%s\n"`,
},
}, args...)

instrumentationStmts := [instrumentationStmtsCount]dst.Stmt{
newAssignStmt(funcNameVarName, t.Name.Name),
newAssignStmt(callerFuncNameVarName, "unknown"),
newGetFuncNameIfStatement("0", "funcPC", funcNameVarName),
newGetFuncNameIfStatement("1", "callerPC", callerFuncNameVarName),
newMakeByteSliceStmt(),
newRandReadStmt(),
newParseUUIDFromByteSliceStmt(callIDVarName),
&dst.ExprStmt{
X: newPrintExprWithArgs(args),
},
&dst.DeferStmt{
Call: newPrintExprWithArgs([]dst.Expr{
&dst.BasicLit{
Kind: token.STRING,
Value: `"` + exitingStringFormat + `\n"`,
},
&dst.BasicLit{
Kind: token.STRING,
Value: funcNameVarName,
},
&dst.BasicLit{
Kind: token.STRING,
Value: callerFuncNameVarName,
},
&dst.BasicLit{
Kind: token.STRING,
Value: callIDVarName,
},
}),
},
}

instrumentationStmts := buildInstrumentationStmts(t)
t.Body.List = append(instrumentationStmts[:], t.Body.List...)

t.Body.List[0].Decorations().Before = dst.EmptyLine
t.Body.List[0].Decorations().Start.Append("/* prinTracer */")
t.Body.List[0].Decorations().Start.Append(printracerCommentWatermark)
t.Body.List[instrumentationStmtsCount-1].Decorations().After = dst.EmptyLine
t.Body.List[instrumentationStmtsCount-1].Decorations().End.Append("/* prinTracer */")
t.Body.List[instrumentationStmtsCount-1].Decorations().End.Append(printracerCommentWatermark)
}
return true
})
Expand Down
54 changes: 53 additions & 1 deletion tracing/instrument_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,58 @@ func main() {
}
`

const editedResultCodeWithoutImports = `package a
import (
"fmt"
"rand"
"runtime"
)
func test(i int, b bool) int {
/* prinTracer */
funcName := "test2"
caller := "unknown2"
if funcPC, _, _, ok := runtime.Caller(0); ok {
funcName = runtime.FuncForPC(funcPC).Name()
}
if callerPC, _, _, ok := runtime.Caller(1); ok {
caller = runtime.FuncForPC(callerPC).Name()
}
fmt.Println("test")
idBytes := make([]byte, 16)
_, _ = rand.Read(idBytes)
callID := fmt.Sprintf("%x-%x-%x-%x-%x", idBytes[0:4], idBytes[4:6], idBytes[6:8], idBytes[8:10], idBytes[10:])
fmt.Printf("Function %s called by %s with args (%v) (%v); callID=%s\n", funcName, caller, i, b, callID)
defer fmt.Printf("Exiting function %s called by %s; callID=%s\n", funcName, caller, callID) /* prinTracer */
if b {
return i
}
return 0
}
func main() {
funcName := "main"
caller := "unknown"
if funcPC, _, _, ok := runtime.Caller(0); ok {
funcName = runtime.FuncForPC(funcPC).Name()
}
if callerPC, _, _, ok := runtime.Caller(1); ok {
caller = runtime.FuncForPC(callerPC).Name()
}
idBytes := make([]byte, 16)
_, _ = rand.Read(idBytes)
callID := fmt.Sprintf("%x-%x-%x-%x-%x", idBytes[0:4], idBytes[4:6], idBytes[6:8], idBytes[8:10], idBytes[10:])
fmt.Printf("Function %s called by %s; callID=%s\n", funcName, caller, callID)
defer fmt.Printf("Exiting function %s called by %s; callID=%s\n", funcName, caller, callID)
i := test(2, false)
}
`

const codeWithFmtImport = `package a
import (
Expand Down Expand Up @@ -340,7 +392,7 @@ func TestInstrumentFile(t *testing.T) {
}

if buff.String() != test.OutputCode {
t.Error("Assertion failed!")
t.Errorf("Assertion failed! Expected %s got %s", test.OutputCode, buff.String())
}
})
}
Expand Down
Loading

0 comments on commit b60eec9

Please sign in to comment.