Skip to content

Commit

Permalink
fix: refresh timer countdown when messages arrive
Browse files Browse the repository at this point in the history
  • Loading branch information
YukunJ committed May 11, 2023
1 parent a034815 commit 6728bd4
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 4 deletions.
12 changes: 11 additions & 1 deletion src/core/acceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ void Acceptor::BaseAcceptCallback(Connection *server_conn) {
reactors_[idx]->AddConnection(std::move(client_connection));
}

void Acceptor::BaseHandleCallback(Connection *client_conn) {
int fd = client_conn->GetFd();
if (client_conn->GetLooper()) {
client_conn->GetLooper()->RefreshConnection(fd);
}
}

void Acceptor::SetCustomAcceptCallback(std::function<void(Connection *)> custom_accept_callback) {
custom_accept_callback_ = std::move(custom_accept_callback);
acceptor_conn->SetCallback([this](auto &&PH1) {
Expand All @@ -67,7 +74,10 @@ void Acceptor::SetCustomAcceptCallback(std::function<void(Connection *)> custom_
}

void Acceptor::SetCustomHandleCallback(std::function<void(Connection *)> custom_handle_callback) {
custom_handle_callback_ = std::move(custom_handle_callback);
custom_handle_callback_ = [this, callback = std::move(custom_handle_callback)](auto &&PH1) {
BaseHandleCallback(std::forward<decltype(PH1)>(PH1));
callback(std::forward<decltype(PH1)>(PH1));
};
}

auto Acceptor::GetCustomAcceptCallback() const noexcept -> std::function<void(Connection *)> {
Expand Down
14 changes: 12 additions & 2 deletions src/core/looper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
#include "log/logger.h"
namespace TURTLE_SERVER {

Looper::Looper(uint64_t timer_expiration) : poller_(std::make_unique<Poller>()), use_timer_(timer_expiration != 0), timer_expiration_(timer_expiration) {
Looper::Looper(uint64_t timer_expiration)
: poller_(std::make_unique<Poller>()), use_timer_(timer_expiration != 0), timer_expiration_(timer_expiration) {
if (use_timer_) {
poller_->AddConnection(timer_.GetTimerConnection());
}
Expand Down Expand Up @@ -52,7 +53,16 @@ void Looper::AddConnection(std::unique_ptr<Connection> new_conn) {
}
}

auto Looper::DeleteConnection(int fd) -> bool {
auto Looper::RefreshConnection(int fd) noexcept -> bool {
std::unique_lock<std::mutex> lock(mtx_);
auto it = timers_mapping_.find(fd);
if (use_timer_ && it != timers_mapping_.end()) {
return timer_.RefreshSingleTimer(it->second, timer_expiration_);
}
return false;
}

auto Looper::DeleteConnection(int fd) noexcept -> bool {
std::unique_lock<std::mutex> lock(mtx_);
auto it = connections_.find(fd);
if (it == connections_.end()) {
Expand Down
18 changes: 18 additions & 0 deletions src/core/timer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void Timer::SingleTimer::Run() noexcept {
}
}

auto Timer::SingleTimer::GetCallback() const noexcept -> std::function<void()> { return callback_; }
/* ------------ Timer --------------- */
Timer::Timer() : timer_fd_(timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC)) {
if (timer_fd_ < 0) {
Expand Down Expand Up @@ -112,6 +113,23 @@ auto Timer::RemoveSingleTimer(Timer::SingleTimer *single_timer) noexcept -> bool
return false;
}

auto Timer::RefreshSingleTimer(Timer::SingleTimer *single_timer, uint64_t expire_from_now) noexcept -> bool {
std::unique_lock<std::mutex> lock(mtx_);
auto it = timer_queue_.find(single_timer);
if (it == timer_queue_.end()) {
return false;
}
auto new_timer = std::make_unique<SingleTimer>(expire_from_now, it->first->GetCallback());
timer_queue_.erase(it);
timer_queue_.emplace(new_timer.get(), std::move(new_timer));
uint64_t new_next_expire = NextExpireTime();
if (new_next_expire != next_expire_) {
next_expire_ = new_next_expire;
ResetTimerFd(timer_fd_, FromNowInTimeSpec(new_next_expire));
}
return true;
}

/* internal call only, no lock */
auto Timer::NextExpireTime() const noexcept -> uint64_t {
if (timer_queue_.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions src/include/core/acceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Acceptor {

void BaseAcceptCallback(Connection *server_conn);

void BaseHandleCallback(Connection *client_conn);

void SetCustomAcceptCallback(std::function<void(Connection *)> custom_accept_callback);

void SetCustomHandleCallback(std::function<void(Connection *)> custom_handle_callback);
Expand Down
4 changes: 3 additions & 1 deletion src/include/core/looper.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ class Looper {

void AddConnection(std::unique_ptr<Connection> new_conn);

auto DeleteConnection(int fd) -> bool;
auto RefreshConnection(int fd) noexcept -> bool;

auto DeleteConnection(int fd) noexcept -> bool;

void SetExit() noexcept;

Expand Down
4 changes: 4 additions & 0 deletions src/include/core/timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class Timer {

void Run() noexcept;

auto GetCallback() const noexcept -> std::function<void()>;

private:
uint64_t expire_time_;
std::function<void()> callback_{nullptr};
Expand All @@ -78,6 +80,8 @@ class Timer {

auto RemoveSingleTimer(SingleTimer *single_timer) noexcept -> bool;

auto RefreshSingleTimer(SingleTimer *single_timer, uint64_t expire_from_now) noexcept -> bool;

auto NextExpireTime() const noexcept -> uint64_t;

auto TimerCount() const noexcept -> size_t;
Expand Down
13 changes: 13 additions & 0 deletions test/core/timer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@ TEST_CASE("[core/timer]") {
REQUIRE(t.TimerCount() == 3);
}

SECTION("timer queue is able to refresh ane existing timer") {
TURTLE_SERVER::Timer t;
REQUIRE(t.NextExpireTime() == 0);
auto now = TURTLE_SERVER::NowSinceEpoch();
auto raw_timer = t.AddSingleTimer(200, nullptr);
auto next_expire = t.NextExpireTime();
REQUIRE((next_expire < (now + 210) && next_expire > (now + 190)));
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // should expired by now
t.RefreshSingleTimer(raw_timer, 200);
next_expire = t.NextExpireTime();
REQUIRE((next_expire < (now + 410) && next_expire > (now + 390)));
}

SECTION("timer queue is able to remove a timer based on raw pointer") {
TURTLE_SERVER::Timer t;
t.AddSingleTimer(200, nullptr);
Expand Down

0 comments on commit 6728bd4

Please sign in to comment.