Skip to content

Commit

Permalink
Configure read replica connection + move all queries to use it (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
SidSethi committed Jan 21, 2020
1 parent 6718d79 commit d2fa621
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 35 deletions.
1 change: 1 addition & 0 deletions discovery-provider/.env
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ audius_web3_host=docker.for.mac.localhost
audius_web3_port=8545
audius_redis_url=redis://docker.for.mac.localhost:5379/00
audius_db_url=postgresql+psycopg2://postgres:postgres@docker.for.mac.localhost:5432/audius_discovery
audius_db_url_read_replica=postgresql+psycopg2://postgres:postgres@docker.for.mac.localhost:5432/audius_discovery
audius_ipfs_host=docker.for.mac.localhost
audius_ipfs_port=6001
audius_ipfs_gateway_hosts=https://cloudflare-ipfs.com,https://ipfs.io
Expand Down
1 change: 1 addition & 0 deletions discovery-provider/.env2
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ audius_web3_host=docker.for.mac.localhost
audius_web3_port=8545
audius_redis_url=redis://docker.for.mac.localhost:5380/00
audius_db_url=postgresql+psycopg2://postgres:postgres@docker.for.mac.localhost:5433/audius_discovery
audius_db_url_read_replica=postgresql+psycopg2://postgres:postgres@docker.for.mac.localhost:5433/audius_discovery
audius_ipfs_host=docker.for.mac.localhost
audius_ipfs_port=6001
audius_ipfs_gateway_hosts=https://cloudflare-ipfs.com,https://ipfs.io
Expand Down
1 change: 1 addition & 0 deletions discovery-provider/default_config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ url = redis://localhost:5379/0

[db]
url = postgresql+psycopg2://postgres@localhost/audius_discovery
url_read_replica = postgresql+psycopg2://postgres@localhost/audius_discovery
engine_args_literal = {
'pool_size': 10,
'max_overflow': 0,
Expand Down
14 changes: 7 additions & 7 deletions discovery-provider/scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ source ./scripts/utilities.sh

function cd_contracts_repo {
# Navigate to contracts repository
if [ -d "../audius-contracts" ]
if [ -d "../contracts" ]
then
echo "Audius contracts repo is present"
cd ../audius-contracts/
cd ../contracts/
else
echo "INCORRECT REPOSITORY STRUCTURE. PLEASE FOLLOW README"
exit 1
Expand All @@ -18,10 +18,10 @@ function cd_contracts_repo {

function cd_discprov_repo {
# Navigate to discovery provider repository
if [ -d "../audius-discovery-provider" ]
if [ -d "../discovery-provider" ]
then
echo "Audius discprov repo is present"
cd ../audius-discovery-provider/
cd ../discovery-provider/
else
echo "INCORRECT REPOSITORY STRUCTURE. PLEASE FOLLOW README"
exit 1
Expand All @@ -37,9 +37,9 @@ set -e
python3 scripts/lint.py

# initialize virtual environment
rm -r venv
python3 -m venv venv
source venv/bin/activate
# rm -r venv
# python3 -m venv venv
# source venv/bin/activate
pip3 install -r requirements.txt
sleep 5
set +e
Expand Down
6 changes: 3 additions & 3 deletions discovery-provider/src/queries/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from web3 import HTTPProvider, Web3
from src.models import Block
from src.utils import helpers
from src.utils.db_session import get_db
from src.utils.db_session import get_db_read_replica
from src.utils.config import shared_config
from src.utils.redis_constants import latest_block_redis_key, latest_block_hash_redis_key

Expand All @@ -30,7 +30,7 @@

# Returns DB block state & diff
def _get_db_block_state(latest_blocknum, latest_blockhash):
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# Fetch latest block from DB
db_block_query = session.query(Block).filter(Block.is_current == True).all()
Expand All @@ -55,7 +55,7 @@ def _get_db_block_state(latest_blocknum, latest_blockhash):

# Returns number of and info on open db connections
def _get_db_conn_state():
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# Query number of open DB connections
num_connections = session.execute(
Expand Down
6 changes: 3 additions & 3 deletions discovery-provider/src/queries/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from src.queries import response_name_constants as const
from src.queries.query_helpers import get_repost_counts, get_save_counts, get_follower_count_dict
from src.models import Block, Follow, Save, SaveType, Playlist, Track, Repost, RepostType
from src.utils.db_session import get_db
from src.utils.db_session import get_db_read_replica
from sqlalchemy import desc, func

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,7 +47,7 @@ def get_owner_id(session, entity_type, entity_id):

@bp.route("/notifications", methods=("GET",))
def notifications():
db = get_db()
db = get_db_read_replica()
min_block_number = request.args.get("min_block_number", type=int)
max_block_number = request.args.get("max_block_number", type=int)

Expand Down Expand Up @@ -459,7 +459,7 @@ def notifications():

@bp.route("/milestones/followers", methods=("GET",))
def milestones_followers():
db = get_db()
db = get_db_read_replica()
if "user_id" not in request.args:
return api_helpers.error_response({'msg': 'Please provider user ids'}, 500)

Expand Down
30 changes: 15 additions & 15 deletions discovery-provider/src/queries/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from src import api_helpers, exceptions
from src.models import User, Track, Repost, RepostType, Follow, Playlist, Save, SaveType
from src.utils import helpers
from src.utils.db_session import get_db
from src.utils.db_session import get_db_read_replica
from src.queries import response_name_constants
from src.queries.query_helpers import get_current_user_id, parse_sort_param, populate_user_metadata, \
populate_track_metadata, populate_playlist_metadata, get_repost_counts, get_save_counts, \
Expand All @@ -29,7 +29,7 @@
@bp.route("/users", methods=("GET",))
def get_users():
users = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# Create initial query
base_query = session.query(User)
Expand Down Expand Up @@ -84,7 +84,7 @@ def get_users():
@bp.route("/tracks", methods=("GET",))
def get_tracks():
tracks = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# Create initial query
base_query = session.query(Track)
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_tracks_including_unlisted():
for i in identifiers:
helpers.validate_arguments(i, ["handle", "id", "url_title"])

db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
base_query = session.query(Track)
filter_cond = []
Expand Down Expand Up @@ -185,7 +185,7 @@ def get_playlists():
current_user_id = get_current_user_id(required=False)
filter_out_private_playlists = True

db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
try:
playlist_query = (
Expand Down Expand Up @@ -260,7 +260,7 @@ def get_playlists():
@bp.route("/feed", methods=("GET",))
def get_feed():
feed_results = []
db = get_db()
db = get_db_read_replica()

# filter should be one of ["all", "reposts", "original"]
# empty filter value results in "all"
Expand Down Expand Up @@ -500,7 +500,7 @@ def get_feed():
@bp.route("/feed/reposts/<int:user_id>", methods=("GET",))
def get_repost_feed_for_user(user_id):
feed_results = {}
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# query all reposts by user
repost_query = (
Expand Down Expand Up @@ -738,7 +738,7 @@ def get_repost_feed_for_user(user_id):
@bp.route("/users/intersection/follow/<int:followee_user_id>/<int:follower_user_id>", methods=("GET",))
def get_follow_intersection_users(followee_user_id, follower_user_id):
users = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
query = (
session.query(User)
Expand Down Expand Up @@ -787,7 +787,7 @@ def get_follow_intersection_users(followee_user_id, follower_user_id):
@bp.route("/users/intersection/repost/track/<int:repost_track_id>/<int:follower_user_id>", methods=("GET",))
def get_track_repost_intersection_users(repost_track_id, follower_user_id):
users = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# ensure track_id exists
track_entry = session.query(Track).filter(
Expand Down Expand Up @@ -833,7 +833,7 @@ def get_track_repost_intersection_users(repost_track_id, follower_user_id):
@bp.route("/users/intersection/repost/playlist/<int:repost_playlist_id>/<int:follower_user_id>", methods=("GET",))
def get_playlist_repost_intersection_users(repost_playlist_id, follower_user_id):
users = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# ensure playlist_id exists
playlist_entry = session.query(Playlist).filter(
Expand Down Expand Up @@ -876,7 +876,7 @@ def get_playlist_repost_intersection_users(repost_playlist_id, follower_user_id)
@bp.route("/users/followers/<int:followee_user_id>", methods=("GET",))
def get_followers_for_user(followee_user_id):
users = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# correlated subquery sqlalchemy code:
# https://groups.google.com/forum/#!topic/sqlalchemy/WLIy8jxD7qg
Expand Down Expand Up @@ -948,7 +948,7 @@ def get_followers_for_user(followee_user_id):
@bp.route("/users/followees/<int:follower_user_id>", methods=("GET",))
def get_followees_for_user(follower_user_id):
users = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# correlated subquery sqlalchemy code:
# https://groups.google.com/forum/#!topic/sqlalchemy/WLIy8jxD7qg
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def get_followees_for_user(follower_user_id):
@bp.route("/users/reposts/track/<int:repost_track_id>", methods=("GET",))
def get_reposters_for_track(repost_track_id):
user_results = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# Ensure Track exists for provided repost_track_id.
track_entry = session.query(Track).filter(
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def get_reposters_for_track(repost_track_id):
@bp.route("/users/reposts/playlist/<int:repost_playlist_id>", methods=("GET",))
def get_reposters_for_playlist(repost_playlist_id):
user_results = []
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# Ensure Playlist exists for provided repost_playlist_id.
playlist_entry = session.query(Playlist).filter(
Expand Down Expand Up @@ -1284,7 +1284,7 @@ def get_saves(save_type):

save_results = []
current_user_id = get_current_user_id()
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
query = (
session.query(Save)
Expand Down
6 changes: 3 additions & 3 deletions discovery-provider/src/queries/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from src.models import User, Track, RepostType, Playlist, Save, SaveType, Follow
from src.utils import helpers
from src.utils.config import shared_config
from src.utils.db_session import get_db
from src.utils.db_session import get_db_read_replica
from src.queries import response_name_constants

from src.queries.query_helpers import get_current_user_id, populate_user_metadata, \
Expand Down Expand Up @@ -76,7 +76,7 @@ def search_tags():

(limit, offset) = get_pagination_vars()
like_tags_str = str.format('%{}%', search_str)
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
if (searchKind in [SearchKind.all, SearchKind.tracks]):
track_res = sqlalchemy.text(
Expand Down Expand Up @@ -331,7 +331,7 @@ def search(isAutocomplete):

results = {}
if searchStr:
db = get_db()
db = get_db_read_replica()
with db.scoped_session() as session:
# Set similarity threshold to be used by % operator in queries.
session.execute(sqlalchemy.text(f"select set_limit({minSearchSimilarity});"))
Expand Down
4 changes: 2 additions & 2 deletions discovery-provider/src/queries/trending.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from src import api_helpers
from src.models import User, Track, RepostType, Follow, SaveType
from src.utils.db_session import get_db
from src.utils.db_session import get_db_read_replica
from src.utils.config import shared_config
from src.queries.query_helpers import get_pagination_vars
from src.tasks.generate_trending import generate_trending, trending_cache_hits_key, \
Expand Down Expand Up @@ -46,5 +46,5 @@ def trending(time):
# Increment cache miss count
REDIS.incr(trending_cache_miss_key, 1)
# Recalculate trending values if necessary
final_resp = generate_trending(get_db(), time, genre, limit, offset)
final_resp = generate_trending(get_db_read_replica(), time, genre, limit, offset)
return api_helpers.success_response(final_resp)
5 changes: 4 additions & 1 deletion discovery-provider/src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def parse_flask_section(self):
for item in self.items(section_name):
self._load_item(section_name, item[0])

# Set db_read_replica url to same as db url if none provided
if ('url_read_replica' not in current_app.config['db']) or (not current_app.config['db']['url_read_replica']):
current_app.config['db']['url_read_replica'] = current_app.config['db']['url']

def _load_item(self, section_name, key):
"""Load the specified item from the [flask] section. Type is
determined by the type of the equivalent value in app.default_config
Expand All @@ -87,7 +91,6 @@ def _load_item(self, section_name, key):
current_app.config[section_name][key] = str(self.get(section_name, key))
env_config_update(current_app.config, section_name, key)


shared_config = configparser.ConfigParser()
shared_config.read(config_files)

Expand Down
14 changes: 14 additions & 0 deletions discovery-provider/src/utils/db_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ def get_db():
return g.db


def get_db_read_replica():
"""Connect to the configured database. The connection
is unique for each request and will be reused if this is called
again.
"""
if "db_read_replica" not in g:
g.db_read_replica = SessionManager(
current_app.config["db"]["url_read_replica"],
ast.literal_eval(current_app.config["db"]["engine_args_literal"]),
)

return g.db_read_replica


class SessionManager:
def __init__(self, db_url, db_engine_args):
self._engine = create_engine(db_url, **db_engine_args)
Expand Down
2 changes: 1 addition & 1 deletion discovery-provider/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"connect_args": {"options": "-c timezone=utc"},}'

TEST_CONFIG_OVERRIDE = {
"db": {"url": DB_URL, "engine_args_literal": ENGINE_ARGS_LITERAL}
"db": {"url": DB_URL, "url_read_replica": DB_URL, "engine_args_literal": ENGINE_ARGS_LITERAL}
}


Expand Down

0 comments on commit d2fa621

Please sign in to comment.