Skip to content

Commit

Permalink
[db] Enbancement add traffic fully transactional
Browse files Browse the repository at this point in the history
Co-Authored-By: Alireza Ahmadi <alireza7@gmail.com>
  • Loading branch information
MHSanaei and alireza0 committed Aug 26, 2023
1 parent 75df8a0 commit 1277285
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 116 deletions.
37 changes: 0 additions & 37 deletions web/job/check_inbound_job.go

This file was deleted.

8 changes: 3 additions & 5 deletions web/job/xray_traffic_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@ func (j *XrayTrafficJob) Run() {
logger.Warning("get xray traffic failed:", err)
return
}
err = j.inboundService.AddTraffic(traffics)
err, needRestart := j.inboundService.AddTraffic(traffics, clientTraffics)
if err != nil {
logger.Warning("add traffic failed:", err)
}

err = j.inboundService.AddClientTraffic(clientTraffics)
if err != nil {
logger.Warning("add client traffic failed:", err)
if needRestart {
j.xrayService.SetToNeedRestart()
}

}
130 changes: 60 additions & 70 deletions web/service/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,38 +194,6 @@ func (s *InboundService) AddInbound(inbound *model.Inbound) (*model.Inbound, boo
return inbound, needRestart, err
}

func (s *InboundService) AddInbounds(inbounds []*model.Inbound) error {
for _, inbound := range inbounds {
exist, err := s.checkPortExist(inbound.Port, 0)
if err != nil {
return err
}
if exist {
return common.NewError("Port already exists:", inbound.Port)
}
}

db := database.GetDB()
tx := db.Begin()
var err error
defer func() {
if err == nil {
tx.Commit()
} else {
tx.Rollback()
}
}()

for _, inbound := range inbounds {
err = tx.Save(inbound).Error
if err != nil {
return err
}
}

return nil
}

func (s *InboundService) DelInbound(id int) (bool, error) {
db := database.GetDB()

Expand Down Expand Up @@ -687,35 +655,8 @@ func (s *InboundService) UpdateInboundClient(data *model.Inbound, clientId strin
return needRestart, tx.Save(oldInbound).Error
}

func (s *InboundService) AddTraffic(traffics []*xray.Traffic) error {
if len(traffics) == 0 {
return nil
}
// Update traffics in a single transaction
err := database.GetDB().Transaction(func(tx *gorm.DB) error {
for _, traffic := range traffics {
if traffic.IsInbound {
update := tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag).
Updates(map[string]interface{}{
"up": gorm.Expr("up + ?", traffic.Up),
"down": gorm.Expr("down + ?", traffic.Down),
})
if update.Error != nil {
return update.Error
}
}
}
return nil
})

return err
}

func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err error) {
if len(traffics) == 0 {
return nil
}

func (s *InboundService) AddTraffic(inboundTraffics []*xray.Traffic, clientTraffics []*xray.ClientTraffic) (error, bool) {
var err error
db := database.GetDB()
tx := db.Begin()

Expand All @@ -726,13 +667,64 @@ func (s *InboundService) AddClientTraffic(traffics []*xray.ClientTraffic) (err e
tx.Commit()
}
}()
err = s.addInboundTraffic(tx, inboundTraffics)
if err != nil {
return err, false
}
err = s.addClientTraffic(tx, clientTraffics)
if err != nil {
return err, false
}

needRestart1, count, err := s.disableInvalidClients(tx)
if err != nil {
logger.Warning("Error in disabling invalid clients:", err)
} else if count > 0 {
logger.Debugf("%v clients disabled", count)
}

needRestart2, count, err := s.disableInvalidInbounds(tx)
if err != nil {
logger.Warning("Error in disabling invalid inbounds:", err)
} else if count > 0 {
logger.Debugf("%v inbounds disabled", count)
}
return nil, (needRestart1 || needRestart2)
}

func (s *InboundService) addInboundTraffic(tx *gorm.DB, traffics []*xray.Traffic) error {
if len(traffics) == 0 {
return nil
}

var err error

for _, traffic := range traffics {
if traffic.IsInbound {
err = tx.Model(&model.Inbound{}).Where("tag = ?", traffic.Tag).
Updates(map[string]interface{}{
"up": gorm.Expr("up + ?", traffic.Up),
"down": gorm.Expr("down + ?", traffic.Down),
}).Error
if err != nil {
return err
}
}
}
return nil
}

func (s *InboundService) addClientTraffic(tx *gorm.DB, traffics []*xray.ClientTraffic) (err error) {
if len(traffics) == 0 {
return nil
}

emails := make([]string, 0, len(traffics))
for _, traffic := range traffics {
emails = append(emails, traffic.Email)
}
dbClientTraffics := make([]*xray.ClientTraffic, 0, len(traffics))
err = db.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error
err = tx.Model(xray.ClientTraffic{}).Where("email IN (?)", emails).Find(&dbClientTraffics).Error
if err != nil {
return err
}
Expand Down Expand Up @@ -817,14 +809,13 @@ func (s *InboundService) adjustTraffics(tx *gorm.DB, dbClientTraffics []*xray.Cl
return dbClientTraffics, nil
}

func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) {
db := database.GetDB()
func (s *InboundService) disableInvalidInbounds(tx *gorm.DB) (bool, int64, error) {
now := time.Now().Unix() * 1000
needRestart := false

if p != nil {
var tags []string
err := db.Table("inbounds").
err := tx.Table("inbounds").
Select("inbounds.tag").
Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
Scan(&tags).Error
Expand All @@ -844,16 +835,15 @@ func (s *InboundService) DisableInvalidInbounds() (bool, int64, error) {
s.xrayApi.Close()
}

result := db.Model(model.Inbound{}).
result := tx.Model(model.Inbound{}).
Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
Update("enable", false)
err := result.Error
count := result.RowsAffected
return needRestart, count, err
}

func (s *InboundService) DisableInvalidClients() (bool, int64, error) {
db := database.GetDB()
func (s *InboundService) disableInvalidClients(tx *gorm.DB) (bool, int64, error) {
now := time.Now().Unix() * 1000
needRestart := false

Expand All @@ -863,7 +853,7 @@ func (s *InboundService) DisableInvalidClients() (bool, int64, error) {
Email string
}

err := db.Table("inbounds").
err := tx.Table("inbounds").
Select("inbounds.tag, client_traffics.email").
Joins("JOIN client_traffics ON inbounds.id = client_traffics.inbound_id").
Where("((client_traffics.total > 0 AND client_traffics.up + client_traffics.down >= client_traffics.total) OR (client_traffics.expiry_time > 0 AND client_traffics.expiry_time <= ?)) AND client_traffics.enable = ?", now, true).
Expand All @@ -883,7 +873,7 @@ func (s *InboundService) DisableInvalidClients() (bool, int64, error) {
}
s.xrayApi.Close()
}
result := db.Model(xray.ClientTraffic{}).
result := tx.Model(xray.ClientTraffic{}).
Where("((total > 0 and up + down >= total) or (expiry_time > 0 and expiry_time <= ?)) and enable = ?", now, true).
Update("enable", false)
err := result.Error
Expand Down
2 changes: 1 addition & 1 deletion web/service/xray.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (s *XrayService) GetXrayConfig() (*xray.Config, error) {
return nil, err
}

s.inboundService.DisableInvalidClients()
s.inboundService.AddTraffic(nil, nil)

inbounds, err := s.inboundService.GetAllInbounds()
if err != nil {
Expand Down
3 changes: 0 additions & 3 deletions web/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,6 @@ func (s *Server) startTask() {
s.cron.AddJob("@every 10s", job.NewXrayTrafficJob())
}()

// Check the inbound traffic every 30 seconds that the traffic exceeds and expires
s.cron.AddJob("@every 30s", job.NewCheckInboundJob())

// check client ips from log file every 10 sec
s.cron.AddJob("@every 10s", job.NewCheckClientIpJob())

Expand Down

0 comments on commit 1277285

Please sign in to comment.