diff --git a/cmd/security-agent/subcommands/runtime/command.go b/cmd/security-agent/subcommands/runtime/command.go index ae5b752996f7c..afeaf89a312d7 100644 --- a/cmd/security-agent/subcommands/runtime/command.go +++ b/cmd/security-agent/subcommands/runtime/command.go @@ -9,6 +9,8 @@ package runtime import ( + "archive/zip" + "bytes" "context" "encoding/json" "errors" @@ -17,6 +19,7 @@ import ( "os" "path" "runtime" + "strings" "time" "github.com/spf13/cobra" @@ -186,6 +189,7 @@ type downloadPolicyCliParams struct { check bool outputPath string + source string } func downloadPolicyCommands(globalParams *command.GlobalParams) []*cobra.Command { @@ -210,6 +214,7 @@ func downloadPolicyCommands(globalParams *command.GlobalParams) []*cobra.Command downloadPolicyCmd.Flags().BoolVar(&downloadPolicyArgs.check, "check", false, "Check policies after downloading") downloadPolicyCmd.Flags().StringVar(&downloadPolicyArgs.outputPath, "output-path", "", "Output path for downloaded policies") + downloadPolicyCmd.Flags().StringVar(&downloadPolicyArgs.source, "source", "all", `Specify wether should download the custom, default or all policies. allowed: "all", "default", "custom"`) return []*cobra.Command{downloadPolicyCmd} } @@ -765,7 +770,36 @@ func downloadPolicy(log log.Component, config config.Component, _ secrets.Compon if err != nil { return err } + + // Unzip the downloaded file containing both default and custom policies resBytes := []byte(res) + reader, err := zip.NewReader(bytes.NewReader(resBytes), int64(len(resBytes))) + if err != nil { + return err + } + + var defaultPolicy []byte + var customPolicies []string + + for _, file := range reader.File { + if strings.HasSuffix(file.Name, ".policy") { + pf, err := file.Open() + if err != nil { + return err + } + policyData, err := io.ReadAll(pf) + pf.Close() + if err != nil { + return err + } + + if file.Name == "default.policy" { + defaultPolicy = policyData + } else { + customPolicies = append(customPolicies, string(policyData)) + } + } + } tempDir, err := os.MkdirTemp("", "policy_check") if err != nil { @@ -773,10 +807,14 @@ func downloadPolicy(log log.Component, config config.Component, _ secrets.Compon } defer os.RemoveAll(tempDir) - tempOutputPath := path.Join(tempDir, "check.policy") - if err := os.WriteFile(tempOutputPath, resBytes, 0644); err != nil { + if err := os.WriteFile(path.Join(tempDir, "default.policy"), defaultPolicy, 0644); err != nil { return err } + for i, customPolicy := range customPolicies { + if err := os.WriteFile(path.Join(tempDir, fmt.Sprintf("custom%d.policy", i+1)), []byte(customPolicy), 0644); err != nil { + return err + } + } if downloadPolicyArgs.check { if err := checkPolicies(log, config, &checkPoliciesCliParams{dir: tempDir}); err != nil { @@ -784,7 +822,36 @@ func downloadPolicy(log log.Component, config config.Component, _ secrets.Compon } } - _, err = outputWriter.Write(resBytes) + // Extract and merge rules from custom policies + var customRules string + for _, customPolicy := range customPolicies { + customPolicyLines := strings.Split(customPolicy, "\n") + rulesIndex := -1 + for i, line := range customPolicyLines { + if strings.TrimSpace(line) == "rules:" { + rulesIndex = i + break + } + } + if rulesIndex != -1 && rulesIndex+1 < len(customPolicyLines) { + customRules += "\n" + strings.Join(customPolicyLines[rulesIndex+1:], "\n") + } + } + + // Output depending on user's specification + var outputContent string + switch downloadPolicyArgs.source { + case "all": + outputContent = string(defaultPolicy) + customRules + case "default": + outputContent = string(defaultPolicy) + case "custom": + outputContent = string(customRules) + default: + return errors.New("invalid source specified") + } + + _, err = outputWriter.Write([]byte(outputContent)) if err != nil { return err }