Skip to content

Commit

Permalink
Feature: redis support as database backend
Browse files Browse the repository at this point in the history
  • Loading branch information
TrueBrain committed Jul 10, 2021
1 parent 6539aea commit d6bc667
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 24 deletions.
4 changes: 3 additions & 1 deletion master_server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from openttd_helpers.sentry_helper import click_sentry

from .database.dynamodb import click_database_dynamodb
from .database.redis import click_database_redis
from .openttd.udp import click_proxy_protocol

log = logging.getLogger(__name__)
Expand All @@ -27,11 +28,12 @@
)
@click.option(
"--db",
type=click.Choice(["dynamodb"], case_sensitive=False),
type=click.Choice(["dynamodb", "redis"], case_sensitive=False),
required=True,
callback=click_helper.import_module("master_server.database", "Database"),
)
@click_database_dynamodb
@click_database_redis
@click_proxy_protocol
def main(bind, msu_port, web_port, app, db):
database = db()
Expand Down
26 changes: 13 additions & 13 deletions master_server/application/master_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def check_stale_servers(self):
log.exception("Exception during check on stale servers")
return

def _get_next_session_key(self):
async def _get_next_session_key(self):
# |63 56 48 40 32 24 16 8 0|
# |--------|--------|--------|--------|--------|--------|--------|--------|
# Version 1 | unused | port | ip |
Expand All @@ -128,24 +128,24 @@ def _get_next_session_key(self):
# and token.
session_key = (int(time.time()) << 24) | (self._session_counter << 8)

self.database.store_session_key_token(session_key, token)
await self.database.store_session_key_token(session_key, token)
return session_key, token

def receive_PACKET_UDP_SERVER_REGISTER(self, source, port, session_key):
async def receive_PACKET_UDP_SERVER_REGISTER(self, source, port, session_key):
# session_key of None means it was version 1.
if session_key is None:
# To be able to use session-keys as an unique ID, also
# generate a session-key for version 1, but based on
# static information of the server.
session_key = int(source.ip) | (port << 32)
self.database.store_session_key_token(session_key, 0)
await self.database.store_session_key_token(session_key, 0)
elif session_key == 0:
# Session-keys were introduced in version 2.
# This session-key tracks the same server over multiple IPs.
# On first contact with the Master Server, a session-key is send
# back to the server. This session-key is reused for any further
# announcement, also on other IPs.
session_key, token = self._get_next_session_key()
session_key, token = await self._get_next_session_key()
source.protocol.send_PACKET_UDP_MASTER_SESSION_KEY(source.addr, session_key | token)

# We don't query the server for now; we first let the server
Expand All @@ -158,15 +158,15 @@ def receive_PACKET_UDP_SERVER_REGISTER(self, source, port, session_key):
token = session_key & 0xFF
session_key = (session_key >> 8) << 8

if not self.database.check_session_key_token(session_key, token):
if not await self.database.check_session_key_token(session_key, token):
log.info("Invalid session-key token from %s:%d; transmitting new session-key", source.ip, source.port)

# TODO -- If an IP has this wrong for more than 3 times, it is
# time to put that IP on a ban-list for a bit of time.

# Send the server a new session-key, as clearly he got a bit
# confused.
session_key, token = self._get_next_session_key()
session_key, token = await self._get_next_session_key()
source.protocol.send_PACKET_UDP_MASTER_SESSION_KEY(source.addr, session_key | token)
return

Expand All @@ -187,7 +187,7 @@ def receive_PACKET_UDP_SERVER_REGISTER(self, source, port, session_key):
# server UDP port).
self.query_server(source.ip, port, source.protocol, user_data=(session_key, source.addr))

def receive_PACKET_UDP_SERVER_RESPONSE(self, source, **info):
async def receive_PACKET_UDP_SERVER_RESPONSE(self, source, **info):
response = self.query_server_response(source.ip, source.port)
if response is None:
return
Expand All @@ -204,19 +204,19 @@ def receive_PACKET_UDP_SERVER_RESPONSE(self, source, **info):
break
else:
# This server can now be marked as online.
if not self.database.server_online(session_key, source.ip, source.port, info):
if not await self.database.server_online(session_key, source.ip, source.port, info):
return

# Inform the server that he is now registered.
source.protocol.send_PACKET_UDP_MASTER_ACK_REGISTER(register_addr)

def receive_PACKET_UDP_SERVER_UNREGISTER(self, source, port):
self.database.server_offline(source.ip, port)
async def receive_PACKET_UDP_SERVER_UNREGISTER(self, source, port):
await self.database.server_offline(source.ip, port)

def receive_PACKET_UDP_CLIENT_GET_LIST(self, source, slt):
async def receive_PACKET_UDP_CLIENT_GET_LIST(self, source, slt):
# Fetching all the servers is pretty expensive, so rate limit how often we do this.
if self._servers_cache[slt] is None or time.time() > self._servers_cache[slt]["expire"]:
servers = self.database.get_server_list_for_client(slt == SLTType.SLT_IPv6)
servers = await self.database.get_server_list_for_client(slt == SLTType.SLT_IPv6)

self._servers_cache[slt] = {
"servers": servers,
Expand Down
4 changes: 2 additions & 2 deletions master_server/application/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def healthz_handler(request):
async def server_list(request):
if request.app.server_list_cache is None or time.time() > request.app.server_list_cache["expire"]:
request.app.server_list_cache = {
"servers": request.app.database.get_server_list_for_web(),
"servers": await request.app.database.get_server_list_for_web(),
"expire": time.time() + TIME_SERVER_LIST_CACHE,
}

Expand All @@ -57,7 +57,7 @@ async def server_entry(request):
or time.time() > request.app.server_entry_cache[server_id]["expire"]
):
request.app.server_entry_cache[server_id] = {
"server": request.app.database.get_server_info_for_web(server_id),
"server": await request.app.database.get_server_info_for_web(server_id),
"expire": time.time() + TIME_SERVER_ENTRY_CACHE,
}

Expand Down
14 changes: 7 additions & 7 deletions master_server/database/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,18 @@ def __init__(self):
if not model.exists():
model.create_table(wait=True)

def check_session_key_token(self, session_key, token):
async def check_session_key_token(self, session_key, token):
try:
server = Server.get(session_key)
except Server.DoesNotExist:
return False
return server.token == token

def store_session_key_token(self, session_key, token):
async def store_session_key_token(self, session_key, token):
server = Server(session_key, token=token, ttl=timedelta(seconds=TTL))
server.save()

def server_online(self, session_key, server_ip, server_port, info):
async def server_online(self, session_key, server_ip, server_port, info):
server_id = _get_server_id(server_ip, server_port)
self._update_ip_port(server_id, session_key)

Expand Down Expand Up @@ -164,7 +164,7 @@ def server_online(self, session_key, server_ip, server_port, info):

return True

def server_offline(self, server_ip, server_port):
async def server_offline(self, server_ip, server_port):
server_id = _get_server_id(server_ip, server_port)

# Lookup the session-key based on the ip/port.
Expand All @@ -189,7 +189,7 @@ def server_offline(self, server_ip, server_port):
]
)

def get_server_list_for_client(self, ipv6_list):
async def get_server_list_for_client(self, ipv6_list):
server_list = []

for ip_port in IpPort.online_view.query(True):
Expand All @@ -205,7 +205,7 @@ def get_server_list_for_client(self, ipv6_list):

return server_list

def get_server_info_for_web(self, server_id):
async def get_server_info_for_web(self, server_id):
try:
ip_port = IpPort.get(server_id)
except IpPort.DoesNotExist:
Expand All @@ -218,7 +218,7 @@ def get_server_info_for_web(self, server_id):

return _convert_server_to_dict(server)

def get_server_list_for_web(self):
async def get_server_list_for_web(self):
return [_convert_server_to_dict(server) for server in Server.online_view.query(True)]

def check_stale_servers(self):
Expand Down
162 changes: 162 additions & 0 deletions master_server/database/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import aioredis
import click
import hashlib
import ipaddress
import json
import logging

from openttd_helpers import click_helper

from .interface import DatabaseInterface

log = logging.getLogger(__name__)

_redis_url = None

# Servers should announce every 15 minutes, so if we haven't seen a server
# after 20 minutes, we can assume it is no longer running.
TTL = 60 * 20


def md5sum(value):
return hashlib.md5(value.encode()).digest().hex()


def _get_server_id(server_ip, server_port):
if isinstance(server_ip, ipaddress.IPv6Address):
return md5sum(f"[{server_ip}]:{server_port}")
else:
return md5sum(f"{server_ip}:{server_port}")


class Database(DatabaseInterface):
def __init__(self):
self._redis = aioredis.from_url(_redis_url, decode_responses=True)

async def check_session_key_token(self, session_key, token):
ms_token = await self._redis.get(f"ms-session-key:{session_key}")
if ms_token is None:
return False

if ms_token != str(token):
return False

await self._redis.expire(f"ms-session-key:{session_key}", TTL)
return True

async def store_session_key_token(self, session_key, token):
await self._redis.set(f"ms-session-key:{session_key}", token, ex=TTL)

async def server_online(self, session_key, server_ip, server_port, info):
# Don't accept servers with empty revision or name.
if info["openttd_version"] == "" or info["name"] == "":
return False

# server_offline() doesn't get the session_key, so we need a reverse lookup.
await self._redis.set(f"ms-session-id:{server_ip}:{server_port}", session_key, ex=TTL)

info["game_type"] = 1 # Public

# Update the information of this server.
info_str = json.dumps(info)
await self._redis.set(f"gc-server:{session_key}", info_str, ex=TTL)
await self._redis.xadd("gc-stream", {"gc-id": -1, "update": session_key, "info": info_str}, approximate=1000)

# Track this IP based on the session_key.
if isinstance(server_ip, ipaddress.IPv6Address):
type = "ipv6"
else:
type = "ipv4"
server_str = json.dumps({"ip": str(server_ip), "port": server_port})
if await self._redis.set(f"gc-direct-{type}:{session_key}", server_str, ex=TTL) > 0:
await self._redis.xadd(
"gc-stream", {"gc-id": -1, f"new-direct-{type}": session_key, "server": server_str}, approximate=1000
)

return True

async def server_offline(self, server_ip, server_port):
# Find the session-key of this ip:port combination.
session_key = await self._redis.get(f"ms-session-id:{server_ip}:{server_port}")
if session_key is None:
return
await self._redis.delete(f"ms-session-id:{server_ip}:{server_port}")

await self._redis.delete(f"gc-direct-ipv4:{session_key}")
await self._redis.delete(f"gc-direct-ipv6:{session_key}")

# Delete this server.
if await self._redis.delete(f"gc-server:{session_key}") > 0:
await self._redis.xadd("gc-stream", {"gc-id": -1, "delete": session_key}, approximate=1000)

async def get_server_list_for_client(self, ipv6_list):
if ipv6_list:
type = "ipv6"
ipcls = ipaddress.IPv6Address
else:
type = "ipv4"
ipcls = ipaddress.IPv4Address

server_list = []
direct_ips = await self._redis.keys(f"gc-direct-{type}:*")
for direct_ip_key in direct_ips:
direct_ip_str = await self._redis.get(direct_ip_key)
direct_ip = json.loads(direct_ip_str)
direct_ip["ip"] = ipcls(direct_ip["ip"])
server_list.append(direct_ip)

return server_list

async def get_server_info_for_web(self, server_id):
info_str = await self._redis.get(f"gc-server:{server_id}")
entry = {
"info": json.loads(info_str),
"online": True,
}

direct_ipv4_str = await self._redis.get(f"gc-direct-ipv4:{server_id}")
if direct_ipv4_str:
direct_ipv4 = json.loads(direct_ipv4_str)
entry["ipv4"] = {
"ip": direct_ipv4["ip"],
"port": direct_ipv4["port"],
"server_id": _get_server_id(direct_ipv4["ip"], direct_ipv4["port"]),
}

direct_ipv6_str = await self._redis.get(f"gc-direct-ipv6:{server_id}")
if direct_ipv6_str:
direct_ipv6 = json.loads(direct_ipv6_str)
entry["ipv6"] = {
"ip": direct_ipv6["ip"],
"port": direct_ipv6["port"],
"server_id": direct_ipv6(direct_ipv4["ip"], direct_ipv6["port"]),
}

return entry

async def get_server_list_for_web(self):
server_list = []

servers = await self._redis.keys("gc-server:*")
for server_key in servers:
_, _, server_id = server_key.partition(":")
entry = self.get_server_info_for_web(server_id)
server_list.append(entry)

return server_list

def check_stale_servers(self):
# Redis takes care of this for us.
pass


@click_helper.extend
@click.option(
"--redis-url",
help="URL of the redis server.",
default="redis://localhost",
)
def click_database_redis(redis_url):
global _redis_url

_redis_url = redis_url
8 changes: 7 additions & 1 deletion master_server/openttd/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def _detect_source_ip_port(self, socket_addr, data):

return source, data

async def guard(self, coro):
try:
await coro
except Exception:
log.exception("Error while processing packet")

def datagram_received(self, data, socket_addr):
try:
source, data = self._detect_source_ip_port(socket_addr, data)
Expand All @@ -90,7 +96,7 @@ def datagram_received(self, data, socket_addr):
log.info("Dropping invalid packet from %r: %r", socket_addr, err)
return

getattr(self._callback, f"receive_{type.name}")(source, **kwargs)
asyncio.create_task(self.guard(getattr(self._callback, f"receive_{type.name}")(source, **kwargs)))

def error_received(self, exc):
print("error on socket: ", exc)
Expand Down
1 change: 1 addition & 0 deletions requirements.base
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
aiohttp
aioredis
click
sentry-sdk
openttd-helpers
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
aiohttp==3.7.4.post0
aioredis==2.0.0a1
async-timeout==3.0.1
attrs==21.2.0
botocore==1.20.103
certifi==2021.5.30
chardet==4.0.0
click==8.0.1
docutils==0.17.1
hiredis==2.0.0
idna==3.2
jmespath==0.10.0
multidict==5.1.0
Expand Down

0 comments on commit d6bc667

Please sign in to comment.