Skip to content

Commit

Permalink
Fix python 2.7 and Windows support
Browse files Browse the repository at this point in the history
  • Loading branch information
allegroai committed Oct 10, 2019
1 parent c1bcce9 commit e0e6d91
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 50 deletions.
7 changes: 5 additions & 2 deletions trains/backend_interface/task/task.py
Expand Up @@ -2,6 +2,7 @@
import collections
import itertools
import logging
import os
from enum import Enum
from threading import Thread
from multiprocessing import RLock
Expand Down Expand Up @@ -189,9 +190,11 @@ def check_package_update():
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version:
if not latest_version[1]:
sep = os.linesep
self.get_logger().report_text(
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
latest_version[0]),
'TRAINS new package available: UPGRADE to v{} is recommended! '
'{}'.format(
latest_version[0], sep.join(latest_version[2])),
)
else:
self.get_logger().report_text(
Expand Down
22 changes: 19 additions & 3 deletions trains/backend_interface/util.py
@@ -1,7 +1,23 @@
import getpass
import re
from _socket import gethostname
from datetime import datetime, timezone
from datetime import datetime
try:
from datetime import timezone
utc_timezone = timezone.utc
except ImportError:
from datetime import tzinfo, timedelta

class UTC(tzinfo):
def utcoffset(self, dt):
return timedelta(0)

def tzname(self, dt):
return "UTC"

def dst(self, dt):
return timedelta(0)
utc_timezone = UTC()

from ..backend_api.services import projects
from ..debugging.log import get_logger
Expand All @@ -26,8 +42,8 @@ def get_or_create_project(session, project_name, description=None):


# Hack for supporting windows
def get_epoch_beginning_of_time(tzinfo=None):
return datetime(1970, 1, 1, tzinfo=tzinfo if tzinfo else timezone.utc)
def get_epoch_beginning_of_time(timezone_info=None):
return datetime(1970, 1, 1).replace(tzinfo=timezone_info if timezone_info else utc_timezone)


def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True, sort_by_date=True):
Expand Down
70 changes: 41 additions & 29 deletions trains/task.py
Expand Up @@ -41,8 +41,6 @@
from .utilities.seed import make_deterministic
from .utilities.dicts import ReadOnlyDict

NotSet = object()


class Task(_Task):
"""
Expand All @@ -67,6 +65,8 @@ class Task(_Task):

TaskTypes = _Task.TaskTypes

NotSet = object()

__create_protection = object()
__main_task = None
__exit_hook = None
Expand Down Expand Up @@ -566,37 +566,15 @@ def connect(self, mutable):

raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)

def get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger
def get_logger(self):
# type: () -> Logger
"""
get a logger object for reporting based on the task
:param flush_period: The period of the logger flush.
If None of any other False value, will not flush periodically.
If a logger was created before, this will be the new period and
the old one will be discarded.
get a logger object for reporting, for this task context.
All reports (metrics, text etc.) related to this task are accessible in the web UI
:return: Logger object
"""
if not self._logger:
# force update of base logger to this current task (this is the main logger task)
self._setup_log(replace_existing=self.is_main_task())
# Get a logger object
self._logger = Logger(private_task=self)
# make sure we set our reported to async mode
# we make sure we flush it in self._at_exit
self.reporter.async_enable = True
# if we just created the logger, set default flush period
if not flush_period or flush_period is NotSet:
flush_period = DevWorker.report_period

if isinstance(flush_period, (int, float)):
flush_period = int(abs(flush_period))

if flush_period is None or isinstance(flush_period, int):
self._logger.set_flush_period(flush_period)

return self._logger
return self._get_logger()

def mark_started(self):
"""
Expand Down Expand Up @@ -819,6 +797,40 @@ def set_credentials(cls, host=None, key=None, secret=None):
if secret:
Session.default_secret = secret

def _get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger
"""
get a logger object for reporting based on the task
:param flush_period: The period of the logger flush.
If None of any other False value, will not flush periodically.
If a logger was created before, this will be the new period and
the old one will be discarded.
:return: Logger object
"""
pass

if not self._logger:
# force update of base logger to this current task (this is the main logger task)
self._setup_log(replace_existing=self.is_main_task())
# Get a logger object
self._logger = Logger(private_task=self)
# make sure we set our reported to async mode
# we make sure we flush it in self._at_exit
self.reporter.async_enable = True
# if we just created the logger, set default flush period
if not flush_period or flush_period is self.NotSet:
flush_period = DevWorker.report_period

if isinstance(flush_period, (int, float)):
flush_period = int(abs(flush_period))

if flush_period is None or isinstance(flush_period, int):
self._logger.set_flush_period(flush_period)

return self._logger

def _connect_output_model(self, model):
assert isinstance(model, OutputModel)
model.connect(self)
Expand Down
31 changes: 15 additions & 16 deletions trains/utilities/check_updates.py
Expand Up @@ -316,31 +316,30 @@ def check_new_package_available(cls, only_once=False):
try:
from ..version import __version__
cls._package_version_checked = True
# Sending the request only for statistics
update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server,
args=(__version__,))
update_statistics.daemon = True
update_statistics.start()

releases = requests.get('https://pypi.python.org/pypi/trains/json', timeout=3.0).json()['releases'].keys()

releases = [Version(r) for r in releases]
latest_version = sorted(releases)
cur_version = Version(__version__)
if not cur_version.is_devrelease and not cur_version.is_prerelease:
latest_version = [r for r in latest_version if not r.is_devrelease and not r.is_prerelease]
update_server_releases = requests.get('https://updates.trainsai.io/updates',
data=json.dumps({"versions": {"trains": str(cur_version)}}),
timeout=3.0)
if update_server_releases.ok:
update_server_releases = update_server_releases.json()
else:
return None
trains_answer = update_server_releases.get("trains", {})
latest_version = Version(trains_answer.get("version"))

if cur_version >= latest_version[-1]:
if cur_version >= latest_version:
return None
not_patch_upgrade = latest_version[-1].release[:2] != cur_version.release[:2]
return str(latest_version[-1]), not_patch_upgrade
not_patch_upgrade = latest_version.release[:2] == cur_version.release[:2]
return str(latest_version), not_patch_upgrade, trains_answer.get("description").split("\r\n")
except Exception:
return None

@staticmethod
def get_version_from_updates_server(cur_version):
try:
_ = requests.get('https://updates.trainsai.io/updates',
params=json.dumps({'versions': {'trains': str(cur_version)}}), timeout=1.0)
data=json.dumps({"versions": {"trains": str(cur_version)}}),
timeout=1.0)
return
except Exception:
pass

0 comments on commit e0e6d91

Please sign in to comment.