diff --git a/src/bthread/key.cpp b/src/bthread/key.cpp index 8f850f7f63..433e1e1409 100644 --- a/src/bthread/key.cpp +++ b/src/bthread/key.cpp @@ -22,6 +22,7 @@ #include #include "butil/macros.h" #include "butil/atomicops.h" +#include "butil/thread_key.h" #include "bvar/passive_status.h" #include "bthread/errno.h" // EAGAIN #include "bthread/task_group.h" // TaskGroup @@ -204,14 +205,56 @@ class BAIDU_CACHELINE_ALIGNMENT KeyTable { SubKeyTable* _subs[KEY_1STLEVEL_SIZE]; }; +struct KeyTableList { + KeyTableList() { + keytable = NULL; + } + ~KeyTableList() { + bthread::TaskGroup* g = bthread::tls_task_group; + bthread::KeyTable* old_kt = bthread::tls_bls.keytable; + while (keytable) { + bthread::KeyTable* kt = keytable; + keytable = kt->next; + bthread::tls_bls.keytable = kt; + if (g) { + g->current_task()->local_storage.keytable = kt; + } + delete kt; + if (old_kt == kt) { + old_kt = NULL; + } + g = bthread::tls_task_group; + } + bthread::tls_bls.keytable = old_kt; + if(g) { + g->current_task()->local_storage.keytable = old_kt; + } + } + KeyTable* keytable; +}; + static KeyTable* borrow_keytable(bthread_keytable_pool_t* pool) { - if (pool != NULL && pool->free_keytables) { - BAIDU_SCOPED_LOCK(pool->mutex); - KeyTable* p = (KeyTable*)pool->free_keytables; - if (p) { - pool->free_keytables = p->next; + if (pool != NULL && (pool->list || pool->free_keytables)) { + KeyTable* p; + pthread_rwlock_rdlock(&pool->rwlock); + auto list = (butil::ThreadLocal*)pool->list; + if (list && list->get()->keytable) { + p = list->get()->keytable; + list->get()->keytable = p->next; + pthread_rwlock_unlock(&pool->rwlock); return p; } + pthread_rwlock_unlock(&pool->rwlock); + if (pool->free_keytables) { + pthread_rwlock_wrlock(&pool->rwlock); + p = (KeyTable*)pool->free_keytables; + if (p) { + pool->free_keytables = p->next; + pthread_rwlock_unlock(&pool->rwlock); + return p; + } + pthread_rwlock_unlock(&pool->rwlock); + } } return NULL; } @@ -226,14 +269,16 @@ void return_keytable(bthread_keytable_pool_t* pool, KeyTable* kt) { delete kt; return; } - std::unique_lock mu(pool->mutex); + pthread_rwlock_rdlock(&pool->rwlock); if (pool->destroyed) { - mu.unlock(); + pthread_rwlock_unlock(&pool->rwlock); delete kt; return; } - kt->next = (KeyTable*)pool->free_keytables; - pool->free_keytables = kt; + auto list = (butil::ThreadLocal*)pool->list; + kt->next = list->get()->keytable; + list->get()->keytable = kt; + pthread_rwlock_unlock(&pool->rwlock); } static void cleanup_pthread(void* arg) { @@ -279,7 +324,8 @@ int bthread_keytable_pool_init(bthread_keytable_pool_t* pool) { LOG(ERROR) << "Param[pool] is NULL"; return EINVAL; } - pthread_mutex_init(&pool->mutex, NULL); + pthread_rwlock_init(&pool->rwlock, NULL); + pool->list = new butil::ThreadLocal(); pool->free_keytables = NULL; pool->destroyed = 0; return 0; @@ -291,16 +337,16 @@ int bthread_keytable_pool_destroy(bthread_keytable_pool_t* pool) { return EINVAL; } bthread::KeyTable* saved_free_keytables = NULL; - { - BAIDU_SCOPED_LOCK(pool->mutex); - if (pool->free_keytables) { - saved_free_keytables = (bthread::KeyTable*)pool->free_keytables; - pool->free_keytables = NULL; - } - pool->destroyed = 1; - } + pthread_rwlock_wrlock(&pool->rwlock); + pool->destroyed = 1; + delete (butil::ThreadLocal*)pool->list; + saved_free_keytables = (bthread::KeyTable*)pool->free_keytables; + pool->list = NULL; + pool->free_keytables = NULL; + pthread_rwlock_unlock(&pool->rwlock); + // Cheat get/setspecific and destroy the keytables. - bthread::TaskGroup* const g = bthread::tls_task_group; + bthread::TaskGroup* g = bthread::tls_task_group; bthread::KeyTable* old_kt = bthread::tls_bls.keytable; while (saved_free_keytables) { bthread::KeyTable* kt = saved_free_keytables; @@ -310,9 +356,7 @@ int bthread_keytable_pool_destroy(bthread_keytable_pool_t* pool) { g->current_task()->local_storage.keytable = kt; } delete kt; - if (old_kt == kt) { - old_kt = NULL; - } + g = bthread::tls_task_group; } bthread::tls_bls.keytable = old_kt; if (g) { @@ -330,11 +374,12 @@ int bthread_keytable_pool_getstat(bthread_keytable_pool_t* pool, LOG(ERROR) << "Param[pool] or Param[stat] is NULL"; return EINVAL; } - std::unique_lock mu(pool->mutex); + pthread_rwlock_rdlock(&pool->rwlock); size_t count = 0; bthread::KeyTable* p = (bthread::KeyTable*)pool->free_keytables; for (; p; p = p->next, ++count) {} stat->nfree = count; + pthread_rwlock_unlock(&pool->rwlock); return 0; } @@ -365,14 +410,15 @@ void bthread_keytable_pool_reserve(bthread_keytable_pool_t* pool, kt->set_data(key, data); } // else append kt w/o data. - std::unique_lock mu(pool->mutex); + pthread_rwlock_wrlock(&pool->rwlock); if (pool->destroyed) { - mu.unlock(); + pthread_rwlock_unlock(&pool->rwlock); delete kt; break; } kt->next = (bthread::KeyTable*)pool->free_keytables; pool->free_keytables = kt; + pthread_rwlock_unlock(&pool->rwlock); if (data == NULL) { break; } diff --git a/src/bthread/types.h b/src/bthread/types.h index d91b85aab3..4b4f0565f5 100644 --- a/src/bthread/types.h +++ b/src/bthread/types.h @@ -84,7 +84,8 @@ inline std::ostream& operator<<(std::ostream& os, bthread_key_t key) { #endif // __cplusplus typedef struct { - pthread_mutex_t mutex; + pthread_rwlock_t rwlock; + void* list; void* free_keytables; int destroyed; } bthread_keytable_pool_t; diff --git a/test/bthread_key_unittest.cpp b/test/bthread_key_unittest.cpp index 7b6bd8d807..c01ae7fe29 100644 --- a/test/bthread_key_unittest.cpp +++ b/test/bthread_key_unittest.cpp @@ -339,15 +339,18 @@ TEST(KeyTest, set_tls_before_creating_any_bthread) { struct PoolData { bthread_key_t key; - PoolData* expected_data; + PoolData* data; int seq; int end_seq; }; +bool use_same_keytable = false; + static void pool_thread_impl(PoolData* data) { - ASSERT_EQ(data->expected_data, (PoolData*)bthread_getspecific(data->key)); if (NULL == bthread_getspecific(data->key)) { ASSERT_EQ(0, bthread_setspecific(data->key, data)); + } else { + use_same_keytable = true; } }; @@ -385,19 +388,21 @@ TEST(KeyTest, using_pool) { ASSERT_EQ(0, bthread_start_urgent(&bth, &attr, pool_thread, &bth_data)); ASSERT_EQ(0, bthread_join(bth, NULL)); ASSERT_EQ(0, bth_data.seq); - ASSERT_EQ(1, bthread_keytable_pool_size(&pool)); - PoolData bth2_data = { key, &bth_data, 0, 3 }; + PoolData bth2_data = { key, NULL, 0, 3 }; bthread_t bth2; ASSERT_EQ(0, bthread_start_urgent(&bth2, &attr2, pool_thread, &bth2_data)); ASSERT_EQ(0, bthread_join(bth2, NULL)); ASSERT_EQ(0, bth2_data.seq); - ASSERT_EQ(1, bthread_keytable_pool_size(&pool)); ASSERT_EQ(0, bthread_keytable_pool_destroy(&pool)); - - EXPECT_EQ(bth_data.end_seq, bth_data.seq); - EXPECT_EQ(0, bth2_data.seq); + if (use_same_keytable) { + EXPECT_EQ(bth_data.end_seq, bth_data.seq); + EXPECT_EQ(0, bth2_data.seq); + } else { + EXPECT_EQ(bth_data.end_seq, bth_data.seq); + EXPECT_EQ(bth_data.end_seq, bth2_data.seq); + } ASSERT_EQ(0, bthread_key_delete(key)); } @@ -415,7 +420,7 @@ static void lid_dtor(void* tls) { static void lid_worker_impl(bthread_key_t key) { ASSERT_EQ(NULL, bthread_getspecific(key)); - ASSERT_EQ(0, bthread_setspecific(key, (void*)seq.fetch_add(1))); + ASSERT_EQ(0, bthread_setspecific(key, (void*)lid_seq.fetch_add(1))); } static void* lid_worker(void* arg) {