-
Notifications
You must be signed in to change notification settings - Fork 1
/
host_repo.go
86 lines (74 loc) · 2.03 KB
/
host_repo.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package orm
import (
"context"
"fmt"
"github.com/ScoreTrak/ScoreTrak/pkg/host"
"github.com/ScoreTrak/ScoreTrak/pkg/host/hostrepo"
"github.com/ScoreTrak/ScoreTrak/pkg/storage/orm/testutil"
"github.com/gofrs/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type hostRepo struct {
db *gorm.DB
}
func NewHostRepo(db *gorm.DB) hostrepo.Repo {
return &hostRepo{db}
}
func (h *hostRepo) Delete(ctx context.Context, id uuid.UUID) error {
result := h.db.WithContext(ctx).Delete(&host.Host{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("error deleting host with id: %s, err: %w", id.String(), result.Error)
}
if result.RowsAffected == 0 {
return &NoRowsAffected{"no model found"}
}
return nil
}
func (h *hostRepo) GetAll(ctx context.Context) ([]*host.Host, error) {
hosts := make([]*host.Host, 0)
err := h.db.WithContext(ctx).Find(&hosts).Error
if err != nil {
return nil, err
}
return hosts, nil
}
func (h *hostRepo) GetByID(ctx context.Context, id uuid.UUID) (*host.Host, error) {
hst := &host.Host{}
err := h.db.WithContext(ctx).Where("id = ?", id).First(hst).Error
if err != nil {
return nil, err
}
return hst, nil
}
func (h *hostRepo) Store(ctx context.Context, hst []*host.Host) error {
err := h.db.WithContext(ctx).Create(hst).Error
if err != nil {
return err
}
return nil
}
func (h *hostRepo) Upsert(ctx context.Context, hst []*host.Host) error {
err := h.db.WithContext(ctx).Clauses(clause.OnConflict{DoNothing: true}).Create(hst).Error
if err != nil {
return err
}
return nil
}
func (h *hostRepo) Update(ctx context.Context, hst *host.Host) error {
err := h.db.WithContext(ctx).Model(hst).Updates(host.Host{Pause: hst.Pause, Hide: hst.Hide,
Address: hst.Address, HostGroupID: hst.HostGroupID,
TeamID: hst.TeamID, EditHost: hst.EditHost, AddressListRange: hst.AddressListRange,
}).Error
if err != nil {
return err
}
return nil
}
func (h *hostRepo) TruncateTable(ctx context.Context) (err error) {
err = testutil.TruncateTable(ctx, &host.Host{}, h.db)
if err != nil {
return err
}
return nil
}