-
Notifications
You must be signed in to change notification settings - Fork 444
/
session.py
132 lines (102 loc) · 5.22 KB
/
session.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations
import logging
import os
import sys
from asyncio import Event, create_task, gather, get_event_loop
from pathlib import Path
from typing import Dict, List, Optional, Type, TypeVar
from tribler.core.components.component import Component, ComponentError, ComponentStartupException, \
MultipleComponentsFound
from tribler.core.config.tribler_config import TriblerConfig
from tribler.core.utilities.crypto_patcher import patch_crypto_be_discovery
from tribler.core.utilities.install_dir import get_lib_path
from tribler.core.utilities.network_utils import default_network_utils
from tribler.core.utilities.notifier import Notifier
from tribler.core.utilities.simpledefs import STATEDIR_CHANNELS_DIR, STATEDIR_DB_DIR
class SessionError(Exception):
pass
class Session:
_startup_exception: Optional[Exception] = None
def __init__(self, config: TriblerConfig = None, components: List[Component] = (),
shutdown_event: Event = None, notifier: Notifier = None, failfast: bool = True):
# deepcode ignore unguarded~next~call: not necessary to catch StopIteration on infinite iterator
self.exit_code = None
self.failfast = failfast
self.logger = logging.getLogger(self.__class__.__name__)
self.config: TriblerConfig = config or TriblerConfig()
self.shutdown_event: Event = shutdown_event or Event()
self.notifier: Notifier = notifier or Notifier(loop=get_event_loop())
self.components: Dict[Type[Component], Component] = {}
for component in components:
self.register(component.__class__, component)
# Reserve various (possibly) fixed ports to prevent
# components from occupying those accidentally
reserve_ports([config.libtorrent.port,
config.api.http_port,
config.api.https_port,
config.ipv8.port])
async def __aenter__(self):
await self.start_components()
return self
async def __aexit__(self, *_):
await self.shutdown()
def get_instance(self, comp_cls: Type[T]) -> Optional[T]:
# try to find a direct match
if direct_match := self.components.get(comp_cls):
return direct_match
# try to find a subclass match
candidates = {c for c in self.components if issubclass(c, comp_cls)}
if not candidates:
return None
if len(candidates) >= 2:
raise MultipleComponentsFound(comp_cls, candidates)
candidate = candidates.pop()
return self.components[candidate]
def register(self, comp_cls: Type[Component], component: Component):
if comp_cls in self.components:
raise ComponentError(f'Component class {comp_cls.__name__} is already registered in session {self}')
self.components[comp_cls] = component
component.session = self
async def start_components(self):
self.logger.info('Start components...')
self.logger.info(f'State directory: "{self.config.state_dir}"')
create_state_directory_structure(self.config.state_dir)
patch_crypto_be_discovery()
# On Mac, we bundle the root certificate for the SSL validation since Twisted is not using the root
# certificates provided by the system trust store.
if sys.platform == 'darwin':
os.environ['SSL_CERT_FILE'] = str(get_lib_path() / 'root_certs_mac.pem')
coros = [comp.start() for comp in self.components.values()]
await gather(*coros, return_exceptions=not self.failfast)
if self._startup_exception:
self._reraise_startup_exception_in_separate_task()
def _reraise_startup_exception_in_separate_task(self):
self.logger.info('Reraise startup exception in separate task')
async def exception_reraiser():
self.logger.info('Exception reraiser')
e = self._startup_exception
if isinstance(e, ComponentStartupException) and e.component.tribler_should_stop_on_component_error:
self.logger.info('Shutdown with exit code 1')
self.exit_code = 1
self.shutdown_event.set()
# the exception should be intercepted by event loop exception handler
self.logger.info(f'Reraise startup exception: {self._startup_exception}')
raise self._startup_exception
get_event_loop().create_task(exception_reraiser())
def set_startup_exception(self, exc: Exception):
if not self._startup_exception:
self._startup_exception = exc
async def shutdown(self):
self.logger.info("Stopping components")
await gather(*[create_task(component.stop()) for component in self.components.values()])
self.logger.info("All components are stopped")
T = TypeVar('T', bound='Component')
def create_state_directory_structure(state_dir: Path):
"""Create directory structure of the state directory."""
state_dir.mkdir(exist_ok=True, parents=True)
(state_dir / STATEDIR_DB_DIR).mkdir(exist_ok=True)
(state_dir / STATEDIR_CHANNELS_DIR).mkdir(exist_ok=True)
def reserve_ports(ports_list: List[None, int]):
for port in ports_list:
if port is not None:
default_network_utils.remember(port)