forked from qdeconinck/mp-quic
/
olia.go
92 lines (80 loc) · 2.48 KB
/
olia.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package congestion
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const scale uint = 10
// Olia implements the olia algorithm from MPTCP
type Olia struct {
// Total number of bytes acked two losses ago
loss1 protocol.ByteCount
// Total number of bytes acked at the last loss
loss2 protocol.ByteCount
// Current number of bytes acked
loss3 protocol.ByteCount
epsilonNum int
epsilonDen uint32
sndCwndCnt int
// We need to keep a reference to all paths
}
func NewOlia(ackedBytes protocol.ByteCount) *Olia {
o := &Olia{
loss1: ackedBytes,
loss2: ackedBytes,
loss3: ackedBytes,
epsilonNum: 0,
epsilonDen: 1,
sndCwndCnt: 0,
}
return o
}
func oliaScale(val uint64, scale uint) uint64 {
return uint64(val) << scale
}
func (o *Olia) Reset() {
o.loss1 = 0
o.loss2 = 0
o.loss3 = 0
o.epsilonNum = 0
o.epsilonDen = 1
o.sndCwndCnt = 0
}
func (o *Olia) SmoothedBytesBetweenLosses() protocol.ByteCount {
return utils.MaxByteCount(o.loss3 - o.loss2, o.loss2 - o.loss1)
}
func (o *Olia) UpdateAckedSinceLastLoss(ackedBytes protocol.ByteCount) {
o.loss3 += ackedBytes
}
func (o *Olia) OnPacketLost() {
// TODO should we add so many if check? Not done here
o.loss1 = o.loss2
o.loss2 = o.loss3
}
func (o *Olia) CongestionWindowAfterAck(currentCongestionWindow protocol.PacketNumber, rate protocol.ByteCount, cwndScaled uint64) protocol.PacketNumber {
newCongestionWindow := currentCongestionWindow
incDen := uint64(o.epsilonDen) * uint64(currentCongestionWindow) * uint64(rate)
if incDen == 0 {
incDen = 1
}
// calculate the increasing term, scaling is used to reduce the rounding effect
if o.epsilonNum == -1 {
if uint64(o.epsilonDen) * cwndScaled * cwndScaled < uint64(rate) {
incNum := uint64(rate) - uint64(o.epsilonDen) * cwndScaled * cwndScaled
o.sndCwndCnt -= int(oliaScale(incNum, scale) / uint64(incDen))
} else {
incNum := uint64(o.epsilonDen) * cwndScaled * cwndScaled - uint64(rate)
o.sndCwndCnt += int(oliaScale(incNum, scale) / uint64(incDen))
}
} else {
incNum := uint64(o.epsilonNum) * uint64(rate) + uint64(o.epsilonDen) * cwndScaled * cwndScaled
o.sndCwndCnt += int(oliaScale(incNum, scale) / uint64(incDen))
}
if o.sndCwndCnt >= (1 << scale) - 1 {
newCongestionWindow++
o.sndCwndCnt = 0
} else if o.sndCwndCnt <= 0 - (1 << scale) + 1 {
newCongestionWindow = utils.MaxPacketNumber(1, currentCongestionWindow - 1)
o.sndCwndCnt = 0
}
return newCongestionWindow
}