Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func start() {
}

var (
handlerFuncs = map[string]func(*handlers.Request, *webhook.Webhook) error{}
handlerFuncs = map[string]func(*handlers.Request) error{}
handlerMu sync.RWMutex
)

Expand Down Expand Up @@ -115,17 +115,17 @@ func Handler(c echo.Context) error {
return
}

// if err := command.LoadInfoAndConfig(hook.GetProjectID(), hook.GetID()); err != nil {
// logger.Error("can't load repo config", "provider", providerName, "command", command, "err", err)
// return
// }
if err := command.LoadInfoAndConfig(hook.GetProjectID(), hook.GetID()); err != nil {
logger.Error("can't load repo config", "provider", providerName, "command", command, "err", err)
return
}

if !command.ValidateSecret(hook.GetProjectID(), hook.GetSecret()) {
if !command.ValidateSecret(hook.GetSecret()) {
logger.Info("webhook secret is not valid", "projectId", hook.GetProjectID(), "provider", providerName)
return
}

if err := f(command, hook); err != nil {
if err := f(command); err != nil {
logger.Error("handlerFunc returns err", "provider", providerName, "event", hook.Event, "err", err)
return
}
Expand All @@ -135,11 +135,11 @@ func Handler(c echo.Context) error {
return nil
}

func handle(onEvent string, funcHandler func(*handlers.Request, *webhook.Webhook) error) {
func handle(onEvent string, funcHandler func(*handlers.Request) error) {
handlerMu.Lock()
defer handlerMu.Unlock()

handlerFuncs[onEvent] = func(command *handlers.Request, hook *webhook.Webhook) error {
return funcHandler(command, hook)
handlerFuncs[onEvent] = func(command *handlers.Request) error {
return funcHandler(command)
}
}
28 changes: 14 additions & 14 deletions commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,50 @@ func init() {
handle(webhook.OnMerge, MergeEvent)
}

func UpdateBranchCmd(command *handlers.Request, hook *webhook.Webhook) error {
if err := command.UpdateFromMaster(hook.GetProjectID(), hook.GetID()); err != nil {
func UpdateBranchCmd(command *handlers.Request) error {
if err := command.UpdateFromMaster(); err != nil {
logger.Error("command.UpdateFromMaster failed", "error", err)
return command.LeaveComment(hook.GetProjectID(), hook.GetID(), "❌ i couldn't update branch from master")
return command.LeaveComment("❌ i couldn't update branch from master")
}

return nil
}

func MergeCmd(command *handlers.Request, hook *webhook.Webhook) error {
ok, text, err := command.Merge(hook.GetProjectID(), hook.GetID())
func MergeCmd(command *handlers.Request) error {
ok, text, err := command.Merge()
if err != nil {
return fmt.Errorf("command.Merge returns err: %w", err)
}

if !ok && len(text) > 0 {
return command.LeaveComment(hook.GetProjectID(), hook.GetID(), text)
return command.LeaveComment(text)
}
return err
}

func CheckCmd(command *handlers.Request, hook *webhook.Webhook) error {
ok, text, err := command.IsValid(hook.GetProjectID(), hook.GetID())
func CheckCmd(command *handlers.Request) error {
ok, text, err := command.IsValid()
if err != nil {
return fmt.Errorf("command.IsValid returns err: %w", err)
}

if !ok && len(text) > 0 {
return command.LeaveComment(hook.GetProjectID(), hook.GetID(), text)
return command.LeaveComment(text)
} else {
return command.LeaveComment(hook.GetProjectID(), hook.GetID(), "You can merge, LGTM :D")
return command.LeaveComment("You can merge, LGTM :D")
}
}

func NewMR(command *handlers.Request, hook *webhook.Webhook) error {
if err := command.Greetings(hook.GetProjectID(), hook.GetID()); err != nil {
func NewMR(command *handlers.Request) error {
if err := command.Greetings(); err != nil {
return fmt.Errorf("command.Greetings returns err: %w", err)
}

return nil
}

func MergeEvent(command *handlers.Request, hook *webhook.Webhook) error {
if err := command.DeleteStaleBranches(hook.GetProjectID(), hook.GetID()); err != nil {
func MergeEvent(command *handlers.Request) error {
if err := command.DeleteStaleBranches(); err != nil {
return fmt.Errorf("command.MergeEvent returns err: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
func TestHandle(t *testing.T) {
// Save original state
handlerMu.Lock()
originalHandlers := make(map[string]func(*handlers.Request, *webhook.Webhook) error)
originalHandlers := make(map[string]func(*handlers.Request) error)
for k, v := range handlerFuncs {
originalHandlers[k] = v
}
Expand All @@ -24,7 +24,7 @@ func TestHandle(t *testing.T) {
handlerMu.Unlock()
}()

testFunc := func(command *handlers.Request, hook *webhook.Webhook) error {
testFunc := func(command *handlers.Request) error {
return nil
}

Expand Down
6 changes: 5 additions & 1 deletion handlers/gitlab/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ func (g *GitlabProvider) GetFile(projectId int, path string) (string, error) {

func (g *GitlabProvider) GetMRInfo(projectId, mergeId int, configPath string) (*handlers.MrInfo, error) {
var err error
info := handlers.MrInfo{}
info := handlers.MrInfo{
ProjectId: projectId,
Id: mergeId,
}

info.IsValid, err = g.IsValid(projectId, mergeId)
if err != nil {
return nil, err
Expand Down
2 changes: 2 additions & 0 deletions handlers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func Register(name string, constructor func() RequestProvider) {
}

type MrInfo struct {
ProjectId int
Id int
Approvals map[string]struct{}
FailedPipelines int
FailedTests int
Expand Down
55 changes: 17 additions & 38 deletions handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@ func (r *Request) LoadInfoAndConfig(projectId, id int) error {
return nil
}

func (r *Request) IsValid(projectId, id int) (bool, string, error) {
if err := r.LoadInfoAndConfig(projectId, id); err != nil {
return false, "", err
}

func (r *Request) IsValid() (bool, string, error) {
if !r.info.IsValid {
return false, ValidError.Error(), nil
}
Expand Down Expand Up @@ -95,15 +91,11 @@ func (r *Request) ParseConfig(content string) (*Config, error) {
return mrConfig, nil
}

func (r *Request) LeaveComment(projectId, id int, message string) error {
return r.provider.LeaveComment(projectId, id, message)
func (r *Request) LeaveComment(message string) error {
return r.provider.LeaveComment(r.info.ProjectId, r.info.Id, message)
}

func (r *Request) Greetings(projectId, id int) error {
if err := r.LoadInfoAndConfig(projectId, id); err != nil {
return err
}

func (r *Request) Greetings() error {
if !r.config.Greetings.Enabled {
return nil
}
Expand All @@ -118,41 +110,33 @@ func (r *Request) Greetings(projectId, id int) error {
return err
}

return r.LeaveComment(projectId, id, buf.String())
return r.LeaveComment(buf.String())
}

func (r *Request) DeleteStaleBranches(projectId, id int) error {
if err := r.LoadInfoAndConfig(projectId, id); err != nil {
return err
}

func (r *Request) DeleteStaleBranches() error {
if r.config.StaleBranchesDeletion.Enabled {
if err := r.cleanStaleMergeRequests(projectId); err != nil {
if err := r.cleanStaleMergeRequests(); err != nil {
return err
}

if err := r.cleanStaleBranches(projectId); err != nil {
if err := r.cleanStaleBranches(); err != nil {
return err
}
}

return nil
}

func (r *Request) Merge(projectId, id int) (bool, string, error) {
if err := r.LoadInfoAndConfig(projectId, id); err != nil {
return false, "", err
}

func (r *Request) Merge() (bool, string, error) {
if r.config.AutoMasterMerge {
err := r.provider.UpdateFromMaster(projectId, id)
err := r.provider.UpdateFromMaster(r.info.ProjectId, r.info.Id)
if err != nil {
return false, "", err
}
}

if ok, text, err := r.IsValid(projectId, id); ok {
if err := r.provider.Merge(projectId, id, fmt.Sprintf("%s\nMerged by MergeApproveBot", r.info.Title)); err != nil {
if ok, text, err := r.IsValid(); ok {
if err := r.provider.Merge(r.info.ProjectId, r.info.Id, fmt.Sprintf("%s\nMerged by MergeApproveBot", r.info.Title)); err != nil {
return false, "", err
}
return true, "", nil
Expand All @@ -161,24 +145,19 @@ func (r *Request) Merge(projectId, id int) (bool, string, error) {
}
}

func (r *Request) UpdateFromMaster(projectId, id int) error {
if err := r.LoadInfoAndConfig(projectId, id); err != nil {
return err
}

if err := r.provider.UpdateFromMaster(projectId, id); err != nil {
func (r *Request) UpdateFromMaster() error {
if err := r.provider.UpdateFromMaster(r.info.ProjectId, r.info.Id); err != nil {
return err
}
return nil
}

func (r Request) ValidateSecret(projectId int, secret string) bool {
func (r Request) ValidateSecret(secret string) bool {
const mergeBotSecret = "MERGE_BOT_SECRET"

secretVar, err := r.provider.GetVar(projectId, mergeBotSecret)
secretVar, err := r.provider.GetVar(r.info.ProjectId, mergeBotSecret)
if err != nil {
logger.Error("cound't validate secret", "err", err)

logger.Info("cound't validate secret", "err", err)
return false
}

Expand Down
21 changes: 19 additions & 2 deletions handlers/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ func (p *testProvider) GetVar(projectId int, varName string) (string, error) {

func (p *testProvider) GetMRInfo(projectId, id int, path string) (*MrInfo, error) {
return &MrInfo{
ProjectId: projectId,
Id: id,
Title: p.title,
ConfigContent: p.config,
Approvals: p.approvals,
Expand Down Expand Up @@ -123,7 +125,13 @@ func Test_Merge(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ok, s, _ := tt.args.pr.Merge(1, 2)
// Load info and config first
err := tt.args.pr.LoadInfoAndConfig(1, 2)
if err != nil {
t.Fatalf("LoadInfoAndConfig failed: %v", err)
}

ok, s, _ := tt.args.pr.Merge()
if tt.wantErr {
assert.NotEmpty(t, s)
assert.Equal(t, false, ok)
Expand Down Expand Up @@ -245,7 +253,16 @@ func TestRequest_Greetings(t *testing.T) {
r := &Request{
provider: tt.fields.provider,
}
err := r.Greetings(tt.args.projectId, tt.args.id)

// Load info and config first (this is required for the current implementation)
err := r.LoadInfoAndConfig(tt.args.projectId, tt.args.id)
if err != nil && !tt.wantErr {
t.Fatalf("LoadInfoAndConfig failed: %v", err)
}

if err == nil {
err = r.Greetings()
}

if tt.wantErr {
assert.Error(t, err)
Expand Down
6 changes: 3 additions & 3 deletions handlers/stalebranches.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ type StaleBranch struct {
LastUpdated time.Time
}

func (r *Request) cleanStaleBranches(projectId int) error {
func (r Request) cleanStaleBranches() error {
cleanStaleBranchesLock.Lock()
defer cleanStaleBranchesLock.Unlock()

logger.Debug("deletion of stale branches has been run")

candidates, err := r.provider.ListBranches(projectId, r.config.StaleBranchesDeletion.BatchSize)
candidates, err := r.provider.ListBranches(r.info.ProjectId, r.config.StaleBranchesDeletion.BatchSize)
if err != nil {
return fmt.Errorf("ListBranches returns error: %w", err)
}
Expand All @@ -36,7 +36,7 @@ func (r *Request) cleanStaleBranches(projectId int) error {
// branch is stale
// delete branch
logger.Debug("branch info", "name", b.Name, "createdAt", b.LastUpdated.String())
if err := r.provider.DeleteBranch(projectId, b.Name); err != nil {
if err := r.provider.DeleteBranch(r.info.ProjectId, b.Name); err != nil {
return fmt.Errorf("DeleteBranch returns error: %w", err)
}
}
Expand Down
10 changes: 5 additions & 5 deletions handlers/stalemergerequests.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ type StaleMergeRequest struct {
LastUpdated time.Time
}

func (r *Request) cleanStaleMergeRequests(projectId int) error {
func (r Request) cleanStaleMergeRequests() error {
cleanStaleMergeRquestsLock.Lock()
defer cleanStaleMergeRquestsLock.Unlock()

days := r.config.StaleBranchesDeletion.Days
coolDays := r.config.StaleBranchesDeletion.WaitDays
now := time.Now()

candidates, err := r.provider.ListMergeRequests(projectId, r.config.StaleBranchesDeletion.BatchSize)
candidates, err := r.provider.ListMergeRequests(r.info.ProjectId, r.config.StaleBranchesDeletion.BatchSize)
if err != nil {
return fmt.Errorf("ListMergeRequests returns error: %w", err)
}
Expand All @@ -40,20 +40,20 @@ func (r *Request) cleanStaleMergeRequests(projectId int) error {
span := now.Sub(mr.LastUpdated)
if slices.Contains(mr.Labels, staleLabel) {
if span > time.Duration(time.Duration(coolDays)*24*time.Hour) {
if err := r.provider.DeleteBranch(projectId, mr.Branch); err != nil {
if err := r.provider.DeleteBranch(r.info.ProjectId, mr.Branch); err != nil {
return fmt.Errorf("DeleteBranch returns error: %w", err)
}
}
}

if span > time.Duration(time.Duration(days)*24*time.Hour) {
// mr is stale
if err := r.provider.AssignLabel(projectId, mr.Id, staleLabel, staleLabelColor); err != nil {
if err := r.provider.AssignLabel(r.info.ProjectId, mr.Id, staleLabel, staleLabelColor); err != nil {
return fmt.Errorf("AssignLabel returns error: %w", err)
}

message := fmt.Sprintf("This MR is stale because it has been open %d days with no activity. Remove stale label othewise this will be closed in %d days.", days, coolDays)
if err := r.provider.LeaveComment(projectId, mr.Id, message); err != nil {
if err := r.provider.LeaveComment(r.info.ProjectId, mr.Id, message); err != nil {
return fmt.Errorf("LeaveComment returns error: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestIntegrationWebhookFlow(t *testing.T) {

// Save original handlers
handlerMu.Lock()
originalHandlers := make(map[string]func(*handlers.Request, *webhook.Webhook) error)
originalHandlers := make(map[string]func(*handlers.Request) error)
for k, v := range handlerFuncs {
originalHandlers[k] = v
}
Expand All @@ -72,7 +72,7 @@ func TestIntegrationWebhookFlow(t *testing.T) {
}()

// Register a test handler
handle("!merge", func(command *handlers.Request, hook *webhook.Webhook) error {
handle("!merge", func(command *handlers.Request) error {
return nil
})

Expand Down
Loading