Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix run tunnel helper #6292

Merged
merged 3 commits into from Aug 25, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
232 changes: 132 additions & 100 deletions src/tribler-core/run_tunnel_helper.py
Expand Up @@ -2,6 +2,7 @@
This script enables you to start a tunnel helper headless.
"""
import argparse
import asyncio
import logging
import os
import re
Expand All @@ -14,46 +15,87 @@
from ipv8.taskmanager import TaskManager

from tribler_common.simpledefs import NTFY

from tribler_core.components.interfaces.bandwidth_accounting import BandwidthAccountingComponent
from tribler_core.components.interfaces.ipv8 import Ipv8Component
from tribler_core.components.interfaces.masterkey import MasterKeyComponent
from tribler_core.components.interfaces.resource_monitor import ResourceMonitorComponent
from tribler_core.components.interfaces.restapi import RESTComponent
from tribler_core.components.interfaces.socks_configurator import SocksServersComponent
from tribler_core.components.interfaces.tunnels import TunnelsComponent
from tribler_core.components.interfaces.upgrade import UpgradeComponent
from tribler_core.config.tribler_config import TriblerConfig
from tribler_core.start_core import Session
from tribler_core.components.base import Session
from tribler_core.utilities.osutils import get_root_state_directory
from tribler_core.utilities.path_util import Path

logger = logging.getLogger(__name__)


class PortAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if not 0 < values < 2**16:
raise argparse.ArgumentError(self, "Invalid port number")
setattr(namespace, self.dest, values)

interfaces = [
MasterKeyComponent,
UpgradeComponent,
RESTComponent,
Ipv8Component,
ResourceMonitorComponent,
BandwidthAccountingComponent,
SocksServersComponent,
TunnelsComponent,
]

class IPAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
try:
inet_aton(values)
except:
raise argparse.ArgumentError(self, "Invalid IPv4 address")
setattr(namespace, self.dest, values)

def components_gen(config: TriblerConfig):
for interface in interfaces:
implementation = interface.make_implementation(config, True)
yield implementation

class IPPortAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
parsed = re.match(r"^([\d\.]+)\:(\d+)$", values)
if not parsed:
raise argparse.ArgumentError("Invalid address:port")

ip, port = parsed.group(1), int(parsed.group(2))
try:
inet_aton(ip)
except:
raise argparse.ArgumentError("Invalid server address")
def make_config(options) -> TriblerConfig:
# Determine ipv8 port
ipv8_port = options.ipv8_port
if ipv8_port == -1:
if "HELPER_INDEX" in os.environ and "HELPER_BASE" in os.environ:
base_port = int(os.environ["HELPER_BASE"])
ipv8_port = base_port + int(os.environ["HELPER_INDEX"]) * 5
else:
raise ValueError('ipv8_port option is not set, and HELPER_BASE/HELPER_INDEX env vars are not defined')

statedir = Path(os.path.join(get_root_state_directory(), "tunnel-%d") % ipv8_port)
config = TriblerConfig.load(file=statedir / 'triblerd.conf', state_dir=statedir)
config.tunnel_community.random_slots = options.random_slots
config.tunnel_community.competing_slots = options.competing_slots
config.torrent_checking.enabled = False
config.ipv8.enabled = True
config.libtorrent.enabled = False
config.ipv8.port = ipv8_port
config.ipv8.address = options.ipv8_address
config.dht.enabled = True
config.tunnel_community.exitnode_enabled = bool(options.exit)
config.popularity_community.enabled = False
config.tunnel_community.testnet = bool(options.testnet)
config.chant.enabled = False
config.bootstrap.enabled = False

if not options.no_rest_api:
https = bool(options.cert_file)
config.api.https_enabled = https
config.api.http_enabled = not https
config.api.key = options.api_key

api_port = options.restapi
if "HELPER_INDEX" in os.environ and "HELPER_BASE" in os.environ:
api_port = int(os.environ["HELPER_BASE"]) + 10000 + int(os.environ["HELPER_INDEX"])
if https:
config.api.https_port = api_port
config.api.put_path_as_relative('https_certfile', options.cert_file, config.state_dir)
else:
config.api.http_port = api_port
else:
config.api.https_enabled = False
config.api.http_enabled = False

if not (0 < port < 65535):
raise argparse.ArgumentError("Invalid server port")
setattr(namespace, self.dest, values)
if options.ipv8_bootstrap_override is not None:
config.ipv8.bootstrap_override = options.ipv8_bootstrap_override
return config


class TunnelHelperService(TaskManager):
Expand All @@ -73,15 +115,13 @@ def on_circuit_reject(self, reject_time, balance):
def tribler_started(self):
async def signal_handler(sig):
print(f"Received shut down signal {sig}") # noqa: T001
if not self._stopping:
self._stopping = True
await self.session.shutdown()
get_event_loop().stop()
await self.stop()

signal.signal(signal.SIGINT, lambda sig, _: ensure_future(signal_handler(sig)))
signal.signal(signal.SIGTERM, lambda sig, _: ensure_future(signal_handler(sig)))

self.register_task("bootstrap", self.session.tunnel_community.bootstrap, interval=30)
tunnel_community = TunnelsComponent.imp().community
self.register_task("bootstrap", tunnel_community.bootstrap, interval=30)

# Remove all logging handlers
root_logger = logging.getLogger()
Expand All @@ -90,85 +130,85 @@ async def signal_handler(sig):
root_logger.removeHandler(handler)
logging.getLogger().setLevel(logging.ERROR)

ipv8 = Ipv8Component.imp().ipv8
new_strategies = []
with self.session.ipv8.overlay_lock:
for strategy, target_peers in self.session.ipv8.strategies:
if strategy.overlay == self.session.tunnel_community:
with ipv8.overlay_lock:
for strategy, target_peers in ipv8.strategies:
if strategy.overlay == tunnel_community:
new_strategies.append((strategy, -1))
else:
new_strategies.append((strategy, target_peers))
self.session.ipv8.strategies = new_strategies
ipv8.strategies = new_strategies

def circuit_removed(self, circuit, additional_info):
self.session.ipv8.network.remove_by_address(circuit.peer.address)
ipv8 = Ipv8Component.imp().ipv8
ipv8.network.remove_by_address(circuit.peer.address)
if self.log_circuits:
with open(os.path.join(self.session.config.state_dir, "circuits.log"), 'a') as out_file:
duration = time.time() - circuit.creation_time
out_file.write("%d,%f,%d,%d,%s\n" % (circuit.circuit_id, duration, circuit.bytes_up, circuit.bytes_down,
additional_info))

async def start(self, options):
# Determine ipv8 port
ipv8_port = options.ipv8_port
if ipv8_port == -1 and "HELPER_INDEX" in os.environ and "HELPER_BASE" in os.environ:
base_port = int(os.environ["HELPER_BASE"])
ipv8_port = base_port + int(os.environ["HELPER_INDEX"]) * 5

statedir = Path(os.path.join(get_root_state_directory(), "tunnel-%d") % ipv8_port)
config = TriblerConfig.load(file=statedir / 'triblerd.conf', state_dir=statedir)
config.tunnel_community.socks5_listen_ports = []
config.tunnel_community.random_slots = options.random_slots
config.tunnel_community.competing_slots = options.competing_slots
config.torrent_checking.enabled = False
config.ipv8.enabled = True
config.libtorrent.enabled = False
config.ipv8.port = ipv8_port
config.ipv8.address = options.ipv8_address
config.dht.enabled = True
config.tunnel_community.exitnode_enabled = bool(options.exit)
config.popularity_community.enabled = False
config.tunnel_community.testnet = bool(options.testnet)
config.chant.enabled = False
config.bootstrap.enabled = False

if not options.no_rest_api:
https = bool(options.cert_file)
config.api.https_enabled = https
config.api.http_enabled = not https
config.api.key = options.api_key

api_port = options.restapi
if "HELPER_INDEX" in os.environ and "HELPER_BASE" in os.environ:
api_port = int(os.environ["HELPER_BASE"]) + 10000 + int(os.environ["HELPER_INDEX"])
if https:
config.api.https_port = api_port
config.api.put_path_as_relative('https_certfile', options.cert_file, config.state_dir)
else:
config.api.http_port = api_port
else:
config.api.https_enabled = False
config.api.http_enabled = False

if options.ipv8_bootstrap_override is not None:
config.ipv8.bootstrap_override = options.ipv8_bootstrap_override

self.session = Session(config)
config = make_config(options)
components = list(components_gen(config))
session = self.session = Session(config, components)
session.set_as_default()

self.log_circuits = options.log_circuits
self.session.notifier.add_observer(NTFY.TUNNEL_REMOVE, self.circuit_removed)
session.notifier.add_observer(NTFY.TUNNEL_REMOVE, self.circuit_removed)

await self.session.start()
await session.start()

if options.log_rejects:
# We set this after Tribler has started since the tunnel_community won't be available otherwise
self.session.tunnel_community.reject_callback = self.on_circuit_reject
with session:
if options.log_rejects:
tunnels_component = TunnelsComponent.imp()
tunnels_community = tunnels_component.community
# We set this after Tribler has started since the tunnel_community won't be available otherwise
tunnels_community.reject_callback = self.on_circuit_reject

self.tribler_started()

async def stop(self):
await self.shutdown_task_manager()
if self.session:
return self.session.shutdown()
if not self._stopping:
self._stopping = True
self.session.shutdown_event.set()
await self.shutdown_task_manager()
await self.session.shutdown()
get_event_loop().stop()


class PortAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if not 0 < values < 2**16:
raise argparse.ArgumentError(self, "Invalid port number")
setattr(namespace, self.dest, values)


class IPAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
try:
inet_aton(values)
Copy link
Contributor

@qstokkink qstokkink Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider using ipaddress.IPv4Address() instead of inet_aton(). This allows you to only catch AddressValueError exceptions (instead of general exceptions) and gives you a message that shows exactly what is wrong with the given IP address.

EDIT: this should probably not be part of this PR though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will made another PR with this fix

except:
raise argparse.ArgumentError(self, "Invalid IPv4 address")
setattr(namespace, self.dest, values)


class IPPortAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
parsed = re.match(r"^([\d\.]+)\:(\d+)$", values)
if not parsed:
raise argparse.ArgumentError("Invalid address:port")

ip, port = parsed.group(1), int(parsed.group(2))
try:
inet_aton(ip)
except:
raise argparse.ArgumentError("Invalid server address")

if not (0 < port < 65535):
raise argparse.ArgumentError("Invalid server port")
setattr(namespace, self.dest, values)


def main(argv):
Expand Down Expand Up @@ -198,14 +238,6 @@ def main(argv):
coro = service.start(args)
ensure_future(coro)

if sys.platform == 'win32':
# Unfortunately, this is needed on Windows for Ctrl+C to work consistently.
# Should no longer be needed in Python 3.8.
async def wakeup():
while True:
await sleep(1)
ensure_future(wakeup())

loop.run_forever()


Expand Down