/
users.go
265 lines (219 loc) · 7.04 KB
/
users.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
// Copyright 2022 E99p1ant. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"context"
"github.com/pkg/errors"
"github.com/wuhan005/gadget"
"gorm.io/gorm"
"github.com/NekoWheel/NekoBox/internal/conf"
)
var Users UsersStore
var _ UsersStore = (*users)(nil)
type UsersStore interface {
Create(ctx context.Context, opts CreateUserOptions) error
GetByID(ctx context.Context, id uint) (*User, error)
GetByEmail(ctx context.Context, email string) (*User, error)
GetByDomain(ctx context.Context, domain string) (*User, error)
Update(ctx context.Context, id uint, opts UpdateUserOptions) error
UpdateHarassmentSetting(ctx context.Context, id uint, typ HarassmentSettingType) error
Authenticate(ctx context.Context, email, password string) (*User, error)
ChangePassword(ctx context.Context, id uint, oldPassword, newPassword string) error
UpdatePassword(ctx context.Context, id uint, newPassword string) error
Deactivate(ctx context.Context, id uint) error
}
func NewUsersStore(db *gorm.DB) UsersStore {
return &users{db}
}
type users struct {
*gorm.DB
}
type User struct {
gorm.Model `json:"-"`
Name string `json:"name"`
Password string `json:"-"`
Email string `json:"email"`
Avatar string `json:"avatar"`
Domain string `json:"domain"`
Background string `json:"background"`
Intro string `json:"intro"`
Notify NotifyType `json:"notify"`
HarassmentSetting HarassmentSettingType `json:"harassment_setting"`
}
type NotifyType string
const (
NotifyTypeEmail NotifyType = "email"
NotifyTypeNone NotifyType = "none"
)
type HarassmentSettingType string
const (
HarassmentSettingNone HarassmentSettingType = "none"
HarassmentSettingTypeRegisterOnly HarassmentSettingType = "register_only"
)
func (u *User) EncodePassword() {
u.Password = gadget.HmacSha1(u.Password, conf.Server.Salt)
}
func (u *User) Authenticate(password string) bool {
password = gadget.HmacSha1(password, conf.Server.Salt)
return u.Password == password
}
type CreateUserOptions struct {
Name string
Password string
Email string
Avatar string
Domain string
Background string
Intro string
}
var (
ErrUserNotExists = errors.New("账号不存在")
ErrBadCredential = errors.New("邮箱或密码错误")
ErrDuplicateEmail = errors.New("这个邮箱已经注册过账号了!")
ErrDuplicateDomain = errors.New("个性域名重复了,换一个吧~")
)
func (db *users) Create(ctx context.Context, opts CreateUserOptions) error {
if err := db.validate(ctx, opts); err != nil {
return err
}
newUser := &User{
Name: opts.Name,
Password: opts.Password,
Email: opts.Email,
Avatar: opts.Avatar,
Domain: opts.Domain,
Background: opts.Background,
Intro: opts.Intro,
Notify: NotifyTypeEmail,
}
newUser.EncodePassword()
if err := db.WithContext(ctx).Create(newUser).Error; err != nil {
return errors.Wrap(err, "create user")
}
return nil
}
func (db *users) getBy(ctx context.Context, where string, args ...interface{}) (*User, error) {
var user User
if err := db.WithContext(ctx).Where(where, args...).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotExists
}
return nil, errors.Wrap(err, "get user")
}
return &user, nil
}
func (db *users) GetByID(ctx context.Context, id uint) (*User, error) {
return db.getBy(ctx, "id = ?", id)
}
func (db *users) GetByEmail(ctx context.Context, email string) (*User, error) {
return db.getBy(ctx, "email = ?", email)
}
func (db *users) GetByDomain(ctx context.Context, domain string) (*User, error) {
return db.getBy(ctx, "domain = ?", domain)
}
type UpdateUserOptions struct {
Name string
Avatar string
Background string
Intro string
Notify NotifyType
}
func (db *users) Update(ctx context.Context, id uint, opts UpdateUserOptions) error {
_, err := db.GetByID(ctx, id)
if err != nil {
return errors.Wrap(err, "get user by id")
}
switch opts.Notify {
case NotifyTypeEmail, NotifyTypeNone:
default:
return errors.Errorf("unexpected notify type: %q", opts.Notify)
}
if err := db.WithContext(ctx).Where("id = ?", id).Updates(&User{
Name: opts.Name,
Avatar: opts.Avatar,
Background: opts.Background,
Intro: opts.Intro,
Notify: opts.Notify,
}).Error; err != nil {
return errors.Wrap(err, "update user")
}
return nil
}
func (db *users) UpdateHarassmentSetting(ctx context.Context, id uint, typ HarassmentSettingType) error {
switch typ {
case HarassmentSettingNone, HarassmentSettingTypeRegisterOnly:
default:
return errors.Errorf("unexpected harassment setting type: %q", typ)
}
if err := db.WithContext(ctx).Where("id = ?", id).Updates(&User{
HarassmentSetting: typ,
}).Error; err != nil {
return errors.Wrap(err, "update user")
}
return nil
}
func (db *users) Authenticate(ctx context.Context, email, password string) (*User, error) {
u, err := db.GetByEmail(ctx, email)
if err != nil {
return nil, ErrBadCredential
}
if !u.Authenticate(password) {
return nil, ErrBadCredential
}
return u, nil
}
func (db *users) ChangePassword(ctx context.Context, id uint, oldPassword, newPassword string) error {
u, err := db.GetByID(ctx, id)
if err != nil {
return errors.Wrap(err, "get user by id")
}
if !u.Authenticate(oldPassword) {
return ErrBadCredential
}
u.Password = newPassword
u.EncodePassword()
if err := db.WithContext(ctx).Model(&User{}).Where("id = ?", u.ID).Update("password", u.Password).Error; err != nil {
return errors.Wrap(err, "change password")
}
return nil
}
func (db *users) UpdatePassword(ctx context.Context, id uint, newPassword string) error {
u, err := db.GetByID(ctx, id)
if err != nil {
return errors.Wrap(err, "get user by id")
}
u.Password = newPassword
u.EncodePassword()
if err := db.WithContext(ctx).Model(&User{}).Where("id = ?", u.ID).Update("password", u.Password).Error; err != nil {
return errors.Wrap(err, "change password")
}
return nil
}
func (db *users) Deactivate(ctx context.Context, id uint) error {
u, err := db.GetByID(ctx, id)
if err != nil {
return errors.Wrap(err, "get user by id")
}
if err := db.WithContext(ctx).Model(&User{}).Delete("id = ?", u.ID).Error; err != nil {
return errors.Wrap(err, "delete user")
}
return nil
}
func (db *users) validate(ctx context.Context, opts CreateUserOptions) error {
if err := db.WithContext(ctx).Model(&User{}).Where("email = ?", opts.Email).First(&User{}).Error; err != nil {
if err != gorm.ErrRecordNotFound {
return errors.Wrap(err, "validate email")
}
} else {
return ErrDuplicateEmail
}
if err := db.WithContext(ctx).Model(&User{}).Where("domain = ?", opts.Domain).First(&User{}).Error; err != nil {
if err != gorm.ErrRecordNotFound {
return errors.Wrap(err, "validate name")
}
} else {
return ErrDuplicateDomain
}
return nil
}