From a474729b6540248efb3535353f8f8f692f8045d1 Mon Sep 17 00:00:00 2001 From: Serhii Khalymon Date: Fri, 23 Aug 2019 15:45:54 +0300 Subject: [PATCH 1/3] Feature long request (#38) * feat: add to_rfc1123_datetime helper function; create datetime_utils * feat: add long_requests with check response * refactor: simplify request handling * test: update server_connector tests * refactor: use requests codes instead of numbers * refactor: simplify _long_request_loop mechanism * refactor: change `location` to capitalcase * test: add test for long_request # Conflicts: # eyes_core/applitools/core/server_connector.py --- .../applitools/common/config/configuration.py | 5 +- .../applitools/common/utils/__init__.py | 5 + .../applitools/common/utils/datetime_utils.py | 63 +++++ .../applitools/common/utils/general_utils.py | 22 -- eyes_core/applitools/core/server_connector.py | 239 +++++++++--------- tests/unit/conftest.py | 12 +- tests/unit/eyes_core/test_server_connector.py | 16 ++ 7 files changed, 206 insertions(+), 156 deletions(-) create mode 100644 eyes_common/applitools/common/utils/datetime_utils.py diff --git a/eyes_common/applitools/common/config/configuration.py b/eyes_common/applitools/common/config/configuration.py index 20073a6c0..3a1dfe58a 100644 --- a/eyes_common/applitools/common/config/configuration.py +++ b/eyes_common/applitools/common/config/configuration.py @@ -9,7 +9,7 @@ from applitools.common.geometry import RectangleSize from applitools.common.match import ImageMatchSettings, MatchLevel from applitools.common.server import FailureReports, SessionType -from applitools.common.utils import argument_guard, general_utils +from applitools.common.utils import UTC, argument_guard from applitools.common.utils.json_utils import JsonInclude __all__ = ("BatchInfo", "Configuration") @@ -30,8 +30,7 @@ class BatchInfo(object): metadata={JsonInclude.THIS: True}, ) # type: Optional[Text] started_at = attr.ib( - factory=lambda: datetime.now(general_utils.UTC), - metadata={JsonInclude.THIS: True}, + factory=lambda: datetime.now(UTC), metadata={JsonInclude.THIS: True} ) # type: Union[datetime, Text] sequence_name = attr.ib( init=False, diff --git a/eyes_common/applitools/common/utils/__init__.py b/eyes_common/applitools/common/utils/__init__.py index 0338593c2..49e563bb8 100644 --- a/eyes_common/applitools/common/utils/__init__.py +++ b/eyes_common/applitools/common/utils/__init__.py @@ -11,6 +11,11 @@ urlsplit, urlunsplit, ) +from .datetime_utils import ( # type: ignore # noqa + UTC, + current_time_in_rfc1123, + to_rfc1123_datetime, +) from .general_utils import cached_property # noqa __all__ = compat.__all__ + ("image_utils", "argument_guard") # noqa diff --git a/eyes_common/applitools/common/utils/datetime_utils.py b/eyes_common/applitools/common/utils/datetime_utils.py new file mode 100644 index 000000000..c74701c68 --- /dev/null +++ b/eyes_common/applitools/common/utils/datetime_utils.py @@ -0,0 +1,63 @@ +from datetime import datetime, timedelta, tzinfo +from typing import Text + +__all__ = ("UTC", "to_rfc1123_datetime", "current_time_in_rfc1123") + + +class _UtcTz(tzinfo): + """ + A UTC timezone class which is tzinfo compliant. + """ + + _ZERO = timedelta(0) + + def utcoffset(self, dt): + return _UtcTz._ZERO + + def tzname(self, dt): + return "UTC" + + def dst(self, dt): + return _UtcTz._ZERO + + +UTC = _UtcTz() + + +def to_rfc1123_datetime(dt): + # type: (datetime) -> Text + """Return a string representation of a date according to RFC 1123 + (HTTP/1.1). + + The supplied date must be in UTC. + + """ + weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"][dt.weekday()] + month = [ + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", + ][dt.month - 1] + return "%s, %02d %s %04d %02d:%02d:%02d GMT" % ( + weekday, + dt.day, + month, + dt.year, + dt.hour, + dt.minute, + dt.second, + ) + + +def current_time_in_rfc1123(): + # type: () -> Text + return to_rfc1123_datetime(datetime.now(UTC)) diff --git a/eyes_common/applitools/common/utils/general_utils.py b/eyes_common/applitools/common/utils/general_utils.py index 74fdf5f9a..27d659dde 100644 --- a/eyes_common/applitools/common/utils/general_utils.py +++ b/eyes_common/applitools/common/utils/general_utils.py @@ -4,7 +4,6 @@ import itertools import time import typing -from datetime import timedelta, tzinfo from applitools.common import logger @@ -21,27 +20,6 @@ T = typing.TypeVar("T") -class _UtcTz(tzinfo): - """ - A UTC timezone class which is tzinfo compliant. - """ - - _ZERO = timedelta(0) - - def utcoffset(self, dt): - return _UtcTz._ZERO - - def tzname(self, dt): - return "UTC" - - def dst(self, dt): - return _UtcTz._ZERO - - -# Constant representing UTC -UTC = _UtcTz() - - def use_default_if_none_factory(default_obj, obj): def default(attr_name): val = getattr(obj, attr_name) diff --git a/eyes_core/applitools/core/server_connector.py b/eyes_core/applitools/core/server_connector.py index 01cf3c097..03bcd3cb9 100644 --- a/eyes_core/applitools/core/server_connector.py +++ b/eyes_core/applitools/core/server_connector.py @@ -1,10 +1,12 @@ from __future__ import absolute_import import json +import math import time import typing from struct import pack +import attr import requests from requests import Response from requests.packages import urllib3 # noqa @@ -17,6 +19,7 @@ from applitools.common.test_results import TestResults from applitools.common.utils import ( argument_guard, + datetime_utils, gzip_compress, image_utils, json_utils, @@ -46,28 +49,27 @@ __all__ = ("ServerConnector",) +@attr.s class _RequestCommunicator(object): - LONG_REQUEST_DELAY_SEC = 2 - MAX_LONG_REQUEST_DELAY_SEC = 10 - LONG_REQUEST_DELAY_MULTIPLICATIVE_INCREASE_FACTOR = 1.5 - - def __init__(self, timeout_sec, headers, api_key, endpoint_uri): - # type: (int, Dict, Text, Text) -> None - self.timeout_sec = timeout_sec - self.headers = headers.copy() - self.api_key = api_key - self.endpoint_uri = endpoint_uri - - def request(self, method, url_resource, **kwargs): + LONG_REQUEST_DELAY_SEC = 2 # type: int + MAX_LONG_REQUEST_DELAY_SEC = 10 # type: int + LONG_REQUEST_DELAY_MULTIPLICATIVE_INCREASE_FACTOR = 1.5 # type: float + + headers = attr.ib() # type: Dict + timeout_sec = attr.ib(default=None) # type: int + api_key = attr.ib(default=None) # type: Text + server_url = attr.ib(default=None) # type: Text + + def request(self, method, url_resource, use_api_key=True, **kwargs): if url_resource is not None: # makes URL relative url_resource = url_resource.lstrip("/") - url_resource = urljoin(self.endpoint_uri, url_resource) + url_resource = urljoin(self.server_url, url_resource) params = {} - if self.api_key: + if use_api_key: params["apiKey"] = self.api_key params.update(kwargs.get("params", {})) - headers = kwargs.get("headers", self.headers) + headers = kwargs.get("headers", self.headers).copy() timeout = kwargs.get("timeout", self.timeout_sec) response = method( url_resource, @@ -84,78 +86,51 @@ def request(self, method, url_resource, **kwargs): return response def long_request(self, method, url_resource, **kwargs): - headers = kwargs["headers"].copy() - headers["Eyes-Expect"] = "202-accepted" - for delay in self.request_delay(): - # Sending the current time of the request (in RFC 1123 format) - headers["Eyes-Date"] = time.strftime( - "%a, %d %b %Y %H:%M:%S GMT", time.gmtime() + headers = kwargs.get("headers", self.headers).copy() + headers["Eyes-Expect"] = "202+location" + headers["Eyes-Date"] = datetime_utils.current_time_in_rfc1123() + kwargs["headers"] = headers + response = self.request(method, url_resource, **kwargs) + return self._long_request_check_status(response) + + def _long_request_check_status(self, response): + if response.status_code == requests.codes.ok: + # request ends successful + return response + elif response.status_code == requests.codes.accepted: + # long request here; calling received url to know that request was processed + url = response.headers["Location"] + response = self._long_request_loop(url) + return self._long_request_check_status(response) + elif response.status_code == requests.codes.created: + # delete url that was used before + url = response.headers["Location"] + return self.request( + requests.delete, + url, + headers={"Eyes-Date": datetime_utils.current_time_in_rfc1123()}, ) - kwargs["headers"] = headers - response = self.request(method, url_resource, **kwargs) - if response.status_code != 202: - return response - logger.debug("Still running... Retrying in {}s".format(delay)) + elif response.status_code == requests.codes.gone: + raise EyesError("The server task has gone.") else: - raise requests.Timeout("Couldn't process request") - - @staticmethod - def request_delay( - first_delay=LONG_REQUEST_DELAY_SEC, - step_factor=LONG_REQUEST_DELAY_MULTIPLICATIVE_INCREASE_FACTOR, - max_delay=MAX_LONG_REQUEST_DELAY_SEC, - ): - delay = _RequestCommunicator.LONG_REQUEST_DELAY_SEC # type: Num - while True: - yield delay - time.sleep(first_delay) - delay = delay * step_factor - if delay > max_delay: - raise StopIteration - - -class _Request(object): - """ - Class for fetching data from - """ - - def __init__(self, com): - self._com = com # type: _RequestCommunicator - - def post(self, url_resource=None, long_query=False, **kwargs): - # type: (str, bool, **Any) -> requests.Response - func = self._com.long_request if long_query else self._com.request - return func(requests.post, url_resource, **kwargs) - - def put(self, url_resource=None, long_query=False, **kwargs): - # type: (str, bool, **Any) -> requests.Response - func = self._com.long_request if long_query else self._com.request - return func(requests.put, url_resource, **kwargs) - - def get(self, url_resource=None, long_query=False, **kwargs): - # type: (str, bool, **Any) -> requests.Response - func = self._com.long_request if long_query else self._com.request - return func(requests.get, url_resource, **kwargs) - - def delete(self, url_resource=None, long_query=False, **kwargs): - # type: (str, bool, **Any) -> requests.Response - func = self._com.long_request if long_query else self._com.request - return func(requests.delete, url_resource, **kwargs) + raise EyesError("Unknown error during long request: {}".format(response)) + def _long_request_loop(self, url, delay=LONG_REQUEST_DELAY_SEC): + delay = min( + self.MAX_LONG_REQUEST_DELAY_SEC, + math.floor(delay * self.LONG_REQUEST_DELAY_MULTIPLICATIVE_INCREASE_FACTOR), + ) + logger.debug("Still running... Retrying in {}s".format(delay)) -def create_request_factory(headers): - class RequestFactory(object): - def __init__(self): - self._com = None - - def create(self, api_key, server_url, timeout_sec): - # server_url could be updated - self._com = _RequestCommunicator( - timeout_sec, headers, api_key, endpoint_uri=server_url - ) - return _Request(self._com) - - return RequestFactory() + time.sleep(delay) + response = self.request( + requests.get, + url, + headers={"Eyes-Date": datetime_utils.current_time_in_rfc1123()}, + ) + if response.status_code != requests.codes.ok: + return response + return self._long_request_loop(url, delay) def prepare_match_data(match_data): @@ -192,12 +167,7 @@ class ServerConnector(object): RENDER_STATUS = "/render-status" RENDER = "/render" - api_key = None - timeout_sec = None - server_url = None # type: Optional[Text] _is_session_started = False - _request = None # type: Optional[_Request] - _render_request = None # type: Optional[_Request] def __init__(self): # type: () -> None @@ -207,31 +177,41 @@ def __init__(self): :param server_url: The url of the Applitools server. """ self._render_info = None # type: Optional[RenderingInfo] - self._request_factory = create_request_factory( - headers=ServerConnector.DEFAULT_HEADERS - ) + self._com = _RequestCommunicator(headers=ServerConnector.DEFAULT_HEADERS) - def _validate_api_key(self): - if self.api_key is None: + def update_config(self, conf): + if conf.api_key is None: raise EyesError( "API key not set! Log in to https://applitools.com to obtain your" " API Key and use 'api_key' to set it." ) + self._com.server_url = conf.server_url + self._com.api_key = conf.api_key + self._com.timeout_sec = conf.timeout / 1000.0 - def update_config(self, conf): - self.server_url = conf.server_url - self.api_key = conf.api_key - self._validate_api_key() - self.timeout_sec = conf.timeout / 1000.0 - - self._request = self._request_factory.create( - server_url=self.server_url, - api_key=self.api_key, - timeout_sec=self.timeout_sec, - ) - self._render_request = self._request_factory.create( - server_url=self.server_url, api_key=None, timeout_sec=self.timeout_sec - ) + @property + def server_url(self): + return self._com.server_url + + @server_url.setter + def server_url(self, value): + self._com.server_url = value + + @property + def api_key(self): + return self._com.api_key + + @api_key.setter + def api_key(self, value): + self._com.api_key = value + + @property + def timeout(self): + return self._com.timeout_sec * 1000 # ms + + @timeout.setter + def timeout(self, value): + self._com.timeout_sec = value / 1000.0 @property def is_session_started(self): @@ -250,7 +230,9 @@ def start_session(self, session_start_info): """ logger.debug("start_session called.") data = json_utils.to_json(session_start_info) - response = self._request.post(url_resource=self.API_SESSIONS_RUNNING, data=data) + response = self._com.request( + requests.post, url_resource=self.API_SESSIONS_RUNNING, data=data + ) running_session = json_utils.attr_from_response(response, RunningSession) running_session.is_new_session = response.status_code == requests.codes.created self._is_session_started = True @@ -272,9 +254,9 @@ def stop_session(self, running_session, is_aborted, save): raise EyesError("Session not started") params = {"aborted": is_aborted, "updateBaseline": save} - response = self._request.delete( + response = self._com.long_request( + requests.delete, url_resource=urljoin(self.API_SESSIONS_RUNNING, running_session.id), - long_query=True, params=params, headers=ServerConnector.DEFAULT_HEADERS, ) @@ -308,7 +290,8 @@ def match_window(self, running_session, match_data): headers = ServerConnector.DEFAULT_HEADERS.copy() headers["Content-Type"] = "application/octet-stream" # TODO: allow to send images as base64 - response = self._request.post( + response = self._com.long_request( + requests.post, url_resource=urljoin(self.API_SESSIONS_RUNNING, running_session.id), data=data, headers=headers, @@ -331,7 +314,8 @@ def post_dom_snapshot(self, dom_json): headers["Content-Type"] = "application/octet-stream" dom_bytes = gzip_compress(dom_json.encode("utf-8")) - response = self._request.post( + response = self._com.request( + requests.post, url_resource=urljoin(self.API_SESSIONS_RUNNING, "data"), data=dom_bytes, headers=headers, @@ -346,7 +330,9 @@ def render_info(self): logger.debug("render_info() called.") headers = ServerConnector.DEFAULT_HEADERS.copy() headers["Content-Type"] = "application/json" - response = self._request.get(self.RENDER_INFO_PATH, headers=headers) + response = self._com.request( + requests.get, self.RENDER_INFO_PATH, headers=headers + ) if not response.ok: raise EyesError( "Cannot get render info: \n Status: {}, Content: {}".format( @@ -369,8 +355,10 @@ def render(self, *render_requests): headers["X-Auth-Token"] = self._render_info.access_token data = json_utils.to_json(render_requests) - response = self._render_request.post(url, headers=headers, data=data) - if response.ok or response.status_code == 404: + response = self._com.request( + requests.post, url, use_api_key=False, headers=headers, data=data + ) + if response.ok or response.status_code == requests.codes.not_found: return json_utils.attr_from_response(response, RunningRender) raise EyesError( "ServerConnector.render - unexpected status ({})\n\tcontent{}".format( @@ -379,7 +367,7 @@ def render(self, *render_requests): ) def render_put_resource(self, running_render, resource): - # type: (RunningRender, VGResource) -> bool + # type: (RunningRender, VGResource) -> Text argument_guard.not_none(running_render) argument_guard.not_none(resource) if self._render_info is None: @@ -398,8 +386,10 @@ def render_put_resource(self, running_render, resource): url = urljoin( self._render_info.service_url, self.RESOURCES_SHA_256 + resource.hash ) - response = self._render_request.put( + response = self._com.request( + requests.put, url, + use_api_key=False, headers=headers, data=content, params={"render-id": running_render.render_id}, @@ -422,14 +412,15 @@ def download_resource(self, url): headers["Accept-Encoding"] = "identity" response = requests.get( - url, headers=headers, timeout=self.timeout_sec, verify=False + url, headers=headers, timeout=self._com.timeout_sec, verify=False ) - if response.status_code == 406: - response = requests.get(url, timeout=self.timeout_sec, verify=False) + if response.status_code == requests.codes.not_acceptable: + response = requests.get(url, timeout=self._com.timeout_sec, verify=False) response.raise_for_status() return response def render_status_by_id(self, render_id): + # type: (Text) -> List[RenderStatusResults] argument_guard.not_none(render_id) if self._render_info is None: raise EyesError("render_info must be fetched first") @@ -438,8 +429,12 @@ def render_status_by_id(self, render_id): headers["Content-Type"] = "application/json" headers["X-Auth-Token"] = self._render_info.access_token url = urljoin(self._render_info.service_url, self.RENDER_STATUS) - response = self._render_request.post( - url, headers=headers, data=json.dumps([render_id]) + response = self._com.request( + requests.post, + url, + use_api_key=False, + headers=headers, + data=json.dumps([render_id]), ) if not response.ok: raise EyesError( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index e4125558c..f85725aee 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,15 +7,15 @@ from selenium.webdriver.remote.webdriver import WebDriver from applitools.common import ( + AppEnvironment, AppOutput, + BatchInfo, Configuration, ImageMatchSettings, + MatchLevel, RunningSession, SessionStartInfo, SessionType, - BatchInfo, - AppEnvironment, - MatchLevel, ) from applitools.common.utils.json_utils import attr_from_json from applitools.core import EyesBase, ServerConnector @@ -72,13 +72,7 @@ def configured_connector(custom_eyes_server): @pytest.fixture(scope="function") def started_connector(configured_connector): - configured_connector._request = configured_connector._request_factory.create( - server_url=configured_connector.server_url, - api_key=configured_connector.api_key, - timeout_sec=configured_connector.timeout_sec, - ) configured_connector._is_session_started = True - return configured_connector diff --git a/tests/unit/eyes_core/test_server_connector.py b/tests/unit/eyes_core/test_server_connector.py index 10e3914b4..dfb1a8ef3 100644 --- a/tests/unit/eyes_core/test_server_connector.py +++ b/tests/unit/eyes_core/test_server_connector.py @@ -3,6 +3,7 @@ from typing import Any import pytest +import requests from mock import patch from applitools.common import ( @@ -36,6 +37,8 @@ RUNNING_SESSION_URL = urljoin(CUSTOM_EYES_SERVER, API_SESSIONS_RUNNING) RUNNING_SESSION_DATA_URL = urljoin(RUNNING_SESSION_URL, "data") RENDER_INFO_PATH_URL = urljoin(CUSTOM_EYES_SERVER, RENDER_INFO_PATH) +LONG_REQUEST_URL = urljoin(CUSTOM_EYES_SERVER, "/one") +LONG_REQUEST_RESPONSE_URL = urljoin(CUSTOM_EYES_SERVER, "/second") RENDER_INFO_URL = "https://render-wus.applitools.com" RENDER_INFO_AT = "Some Token" @@ -95,6 +98,8 @@ def mocked_requests_delete(*args, **kwargs): url = args[0] if url == urljoin(RUNNING_SESSION_URL, RUNNING_SESSION_DATA_RESPONSE_ID): return MockResponse(STOP_SESSION_DATA, 200) + elif url == LONG_REQUEST_RESPONSE_URL: + return MockResponse({}, 200) return MockResponse(None, 404) @@ -103,6 +108,10 @@ def mocked_requests_get(*args, **kwargs): url = args[0] if url == RENDER_INFO_PATH_URL: return MockResponse(RENDERING_INFO_DATA, 200) + if url == LONG_REQUEST_URL: + return MockResponse(None, 202, {"Location": LONG_REQUEST_RESPONSE_URL}) + if url == LONG_REQUEST_RESPONSE_URL: + return MockResponse(None, 201, {"Location": LONG_REQUEST_RESPONSE_URL}) return MockResponse(None, 404) @@ -316,6 +325,13 @@ def test_request_with_changed_values(configured_connector): assert new_server_url in mocked_post.call_args[0][0] +def test_long_request(configured_connector): + with patch("requests.get", side_effect=mocked_requests_get): + with patch("requests.delete", side_effect=mocked_requests_delete): + r = configured_connector._com.long_request(requests.get, LONG_REQUEST_URL) + assert r.status_code == 200 + + def test_get_rendering_info(started_connector): with patch("requests.get", side_effect=mocked_requests_get): render_info = started_connector.render_info() From 353954bf9952e6d67e4718cde368d5acc4ea5020 Mon Sep 17 00:00:00 2001 From: Serhii Khalymon Date: Thu, 29 Aug 2019 15:30:34 +0300 Subject: [PATCH 2/3] test: fix travis config --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1672ccac1..7e2fc0659 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,10 +1,11 @@ language: python +services: + - xvfb install: - pip install -U tox before_script: # Run GUI apps in headless mode - export DISPLAY=:99.0 - - sh -e /etc/init.d/xvfb start - sleep 10 # give webdriver some time to start - export APPLITOOLS_BATCH_ID=`uuidgen -t` script: From 0e28f0cb4550f028a85cd9ef23ec48d0ca3fa2ea Mon Sep 17 00:00:00 2001 From: Serhii Khalymon Date: Thu, 29 Aug 2019 18:28:41 +0300 Subject: [PATCH 3/3] fix-ie-fullscreenshot (#42) * refactor: EyesWebdriverScreenshot class simplification * test: add test_ie_viewport_screenshot * fix: ie always return full page screenshot If FPS cut it to VP size --- .../capture/eyes_webdriver_screenshot.py | 107 +++--------------- .../selenium/capture/screenshot_utils.py | 85 ++++++++++++++ .../selenium/eyes_selenium_utils.py | 64 ++++++++++- .../applitools/selenium/selenium_eyes.py | 27 +++-- .../selenium/test_specific_cases.py | 21 ++++ 5 files changed, 196 insertions(+), 108 deletions(-) create mode 100644 eyes_selenium/applitools/selenium/capture/screenshot_utils.py diff --git a/eyes_selenium/applitools/selenium/capture/eyes_webdriver_screenshot.py b/eyes_selenium/applitools/selenium/capture/eyes_webdriver_screenshot.py index 35b9fcd45..23dea2dbc 100644 --- a/eyes_selenium/applitools/selenium/capture/eyes_webdriver_screenshot.py +++ b/eyes_selenium/applitools/selenium/capture/eyes_webdriver_screenshot.py @@ -1,9 +1,10 @@ from __future__ import absolute_import import typing -from enum import Enum import attr +from selenium.common.exceptions import WebDriverException + from applitools.common import ( CoordinatesType, CoordinatesTypeConversionError, @@ -16,25 +17,21 @@ ) from applitools.common.utils import argument_guard, image_utils from applitools.core.capture import EyesScreenshot, EyesScreenshotFactory -from applitools.selenium import eyes_selenium_utils -from applitools.selenium.positioning import ( - ScrollPositionProvider, - SeleniumPositionProvider, +from applitools.selenium.capture.screenshot_utils import ( + ScreenshotType, + calc_frame_location_in_screenshot, + update_screenshot_type, ) -from selenium.common.exceptions import WebDriverException +from applitools.selenium.eyes_selenium_utils import get_updated_scroll_position +from applitools.selenium.positioning import SeleniumPositionProvider if typing.TYPE_CHECKING: - from typing import Optional, Union + from typing import Union from PIL import Image from applitools.selenium.webdriver import EyesWebDriver from applitools.selenium.frames import FrameChain -class ScreenshotType(Enum): - VIEWPORT = "VIEWPORT" - ENTIRE_FRAME = "ENTIRE_FRAME" - - @attr.s class EyesWebDriverScreenshot(EyesScreenshot): @@ -88,8 +85,8 @@ def from_screenshot(cls, driver, image, screenshot_region): def __attrs_post_init__(self): # type: () -> None self._frame_chain = self._driver.frame_chain.clone() - self._screenshot_type = self.update_screenshot_type( - self._screenshot_type, self._image + self._screenshot_type = update_screenshot_type( + self._screenshot_type, self._image, self._driver ) cur_frame_position_provider = self._driver.eyes.current_frame_position_provider if cur_frame_position_provider: @@ -98,7 +95,7 @@ def __attrs_post_init__(self): position_provider = self._driver.eyes.position_provider if self._current_frame_scroll_position is None: - self._current_frame_scroll_position = self.get_updated_scroll_position( + self._current_frame_scroll_position = get_updated_scroll_position( position_provider ) self._frame_location_in_screenshot = self.get_updated_frame_location_in_screenshot( @@ -123,42 +120,13 @@ def _validate_frame_window(self): def get_updated_frame_location_in_screenshot(self, frame_location_in_screenshot): # type: (Point) -> Point if self.frame_chain.size > 0: - frame_location_in_screenshot = self.calc_frame_location_in_screenshot( + frame_location_in_screenshot = calc_frame_location_in_screenshot( self._driver, self._frame_chain, self._screenshot_type ) elif not frame_location_in_screenshot: frame_location_in_screenshot = Point.zero() return frame_location_in_screenshot - def get_updated_scroll_position(self, position_provider): - # type: (SeleniumPositionProvider) -> Point - try: - sp = position_provider.get_current_position() - if not sp: - sp = Point.zero() - except WebDriverException: - sp = Point.zero() - - return sp - - def update_screenshot_type(self, screenshot_type, image): - # type: ( Optional[ScreenshotType], Image) -> ScreenshotType - if screenshot_type is None: - viewport_size = self._driver.eyes.viewport_size - scale_viewport = self._driver.eyes.stitch_content - - if scale_viewport: - pixel_ratio = self._driver.eyes.device_pixel_ratio - viewport_size = viewport_size.scale(pixel_ratio) - if ( - image.width <= viewport_size["width"] - and image.height <= viewport_size["height"] - ): - screenshot_type = ScreenshotType.VIEWPORT - else: - screenshot_type = ScreenshotType.ENTIRE_FRAME - return screenshot_type - @property def image(self): # type: () -> Union[Image, Image] @@ -180,55 +148,6 @@ def get_frame_size(self, position_provider): frame_size = self._driver.get_default_content_viewport_size() return frame_size - @staticmethod - def _get_default_content_scroll_position(driver): - # type: (EyesWebDriver) -> Point - scroll_root_element = eyes_selenium_utils.current_frame_scroll_root_element( - driver - ) - return ScrollPositionProvider.get_current_position_static( - driver, scroll_root_element - ) - - @staticmethod - def get_default_content_scroll_position(current_frames, driver): - if current_frames.size == 0: - scroll_position = EyesWebDriverScreenshot._get_default_content_scroll_position( - driver - ) - else: - current_fc = driver.eyes._original_frame_chain - with driver.switch_to.frames_and_back(current_fc): - scroll_position = EyesWebDriverScreenshot._get_default_content_scroll_position( - driver - ) - return scroll_position - - @staticmethod - def calc_frame_location_in_screenshot(driver, frame_chain, screenshot_type): - window_scroll = EyesWebDriverScreenshot.get_default_content_scroll_position( - frame_chain, driver - ) - logger.info("Getting first frame...") - first_frame = frame_chain[0] - location_in_screenshot = Point(first_frame.location.x, first_frame.location.y) - # We only need to consider the scroll of the default content if the screenshot is a - # viewport screenshot. If this is a full page screenshot, the frame location will not - # change anyway. - if screenshot_type == ScreenshotType.VIEWPORT: - location_in_screenshot = location_in_screenshot.offset( - -window_scroll.x, -window_scroll.y - ) - - # For inner frames we must calculate the scroll - inner_frames = frame_chain[1:] - for frame in inner_frames: - location_in_screenshot = location_in_screenshot.offset( - frame.location.x - frame.parent_scroll_position.x, - frame.location.y - frame.parent_scroll_position.y, - ) - return location_in_screenshot - @property def frame_chain(self): # type: () -> FrameChain diff --git a/eyes_selenium/applitools/selenium/capture/screenshot_utils.py b/eyes_selenium/applitools/selenium/capture/screenshot_utils.py new file mode 100644 index 000000000..898518787 --- /dev/null +++ b/eyes_selenium/applitools/selenium/capture/screenshot_utils.py @@ -0,0 +1,85 @@ +from enum import Enum +from typing import TYPE_CHECKING, Optional + +from applitools.common import Point, Region, logger +from applitools.common.utils import image_utils +from applitools.selenium import eyes_selenium_utils +from applitools.selenium.eyes_selenium_utils import ( + get_cur_position_provider, + get_updated_scroll_position, +) + +if TYPE_CHECKING: + from PIL import Image + from applitools.selenium.webdriver import EyesWebDriver + + +class ScreenshotType(Enum): + VIEWPORT = "VIEWPORT" + ENTIRE_FRAME = "ENTIRE_FRAME" + + +def update_screenshot_type(screenshot_type, image, driver): + # type: ( Optional[ScreenshotType], Image, EyesWebDriver) -> ScreenshotType + if screenshot_type is None: + viewport_size = driver.eyes.viewport_size + scale_viewport = driver.eyes.stitch_content + + if scale_viewport: + pixel_ratio = driver.eyes.device_pixel_ratio + viewport_size = viewport_size.scale(pixel_ratio) + if ( + image.width <= viewport_size["width"] + and image.height <= viewport_size["height"] + ): + screenshot_type = ScreenshotType.VIEWPORT + else: + screenshot_type = ScreenshotType.ENTIRE_FRAME + return screenshot_type + + +def cut_to_viewport_size_if_required(driver, image): + # type: (EyesWebDriver, Image) -> Image + # Some browsers return always full page screenshot (IE). + # So we cut such images to viewport size + position_provider = get_cur_position_provider(driver) + curr_frame_scroll = get_updated_scroll_position(position_provider) + screenshot_type = update_screenshot_type(None, image, driver) + if screenshot_type != ScreenshotType.VIEWPORT: + viewport_size = driver.eyes.viewport_size + image = image_utils.crop_image( + image, + region_to_crop=Region( + top=curr_frame_scroll.x, + left=0, + height=viewport_size["height"], + width=viewport_size["width"], + ), + ) + return image + + +def calc_frame_location_in_screenshot(driver, frame_chain, screenshot_type): + window_scroll = eyes_selenium_utils.get_default_content_scroll_position( + frame_chain, driver + ) + logger.info("Getting first frame...") + first_frame = frame_chain[0] + location_in_screenshot = Point(first_frame.location.x, first_frame.location.y) + # We only need to consider the scroll of the default content if the screenshot is a + # viewport screenshot. If this is a full page screenshot, the frame location will + # not + # change anyway. + if screenshot_type == ScreenshotType.VIEWPORT: + location_in_screenshot = location_in_screenshot.offset( + -window_scroll.x, -window_scroll.y + ) + + # For inner frames we must calculate the scroll + inner_frames = frame_chain[1:] + for frame in inner_frames: + location_in_screenshot = location_in_screenshot.offset( + frame.location.x - frame.parent_scroll_position.x, + frame.location.y - frame.parent_scroll_position.y, + ) + return location_in_screenshot diff --git a/eyes_selenium/applitools/selenium/eyes_selenium_utils.py b/eyes_selenium/applitools/selenium/eyes_selenium_utils.py index c15988daa..6e12b5d58 100644 --- a/eyes_selenium/applitools/selenium/eyes_selenium_utils.py +++ b/eyes_selenium/applitools/selenium/eyes_selenium_utils.py @@ -12,10 +12,13 @@ from applitools.common import Point, RectangleSize, logger if tp.TYPE_CHECKING: - from typing import Text, Optional, Any, Iterator, Union + from typing import Text, Optional, Any, Union, Iterator + from applitools.selenium.frames import FrameChain + from applitools.selenium.positioning import SeleniumPositionProvider from applitools.selenium.webdriver import EyesWebDriver from applitools.selenium.webelement import EyesWebElement from applitools.selenium.fluent import SeleniumCheckSettings, FrameLocator + from applitools.common.utils.custom_types import ( AnyWebDriver, ViewPort, @@ -35,6 +38,8 @@ "hide_scrollbars", "set_overflow", "parse_location_string", + "get_cur_position_provider", + "get_updated_scroll_position", ) _NATIVE_APP = "NATIVE_APP" @@ -440,7 +445,16 @@ def parse_location_string(position): xy = position.split(";") if len(xy) != 2: raise WebDriverException("Could not get scroll position!") - return Point(float(xy[0]), float(xy[1])) + return Point(round(float(xy[0])), round(float(xy[1]))) + + +def get_current_position(driver, element): + # type: (AnyWebDriver, AnyWebElement) -> Point + element = get_underlying_webelement(element) + xy = driver.execute_script( + "return arguments[0].scrollLeft+';'+arguments[0].scrollTop;", element + ) + return parse_location_string(xy) def scroll_root_element_from(driver, container=None): @@ -469,13 +483,53 @@ def root_html(): return scroll_root_element -def current_frame_scroll_root_element(driver): - # type: (EyesWebDriver) -> EyesWebElement +def current_frame_scroll_root_element(driver, scroll_root_element=None): + # type: (EyesWebDriver, Optional[AnyWebElement]) -> EyesWebElement fc = driver.frame_chain.clone() cur_frame = fc.peek root_element = None if cur_frame: root_element = cur_frame.scroll_root_element if root_element is None and not driver.is_mobile_app: - root_element = driver.find_element_by_tag_name("html") + if scroll_root_element: + root_element = scroll_root_element + else: + root_element = driver.find_element_by_tag_name("html") return root_element + + +def get_cur_position_provider(driver): + # type: (EyesWebDriver) -> SeleniumPositionProvider + cur_frame_position_provider = driver.eyes.current_frame_position_provider + if cur_frame_position_provider: + return cur_frame_position_provider + else: + return driver.eyes.position_provider + + +def get_updated_scroll_position(position_provider): + # type: (SeleniumPositionProvider) -> Point + try: + sp = position_provider.get_current_position() + if not sp: + sp = Point.zero() + except WebDriverException: + sp = Point.zero() + + return sp + + +def get_default_content_scroll_position(current_frames, driver): + # type: (FrameChain, EyesWebDriver) -> Point + if current_frames.size == 0: + + scroll_position = get_current_position( + driver, current_frame_scroll_root_element(driver) + ) + else: + current_fc = driver.eyes.original_frame_chain + with driver.switch_to.frames_and_back(current_fc): + scroll_position = get_current_position( + driver, current_frame_scroll_root_element(driver) + ) + return scroll_position diff --git a/eyes_selenium/applitools/selenium/selenium_eyes.py b/eyes_selenium/applitools/selenium/selenium_eyes.py index 3d2986ce8..1142f12bf 100644 --- a/eyes_selenium/applitools/selenium/selenium_eyes.py +++ b/eyes_selenium/applitools/selenium/selenium_eyes.py @@ -3,7 +3,6 @@ from time import sleep from selenium.common.exceptions import WebDriverException -from selenium.webdriver.remote.webdriver import WebDriver as RemoteWebDriver from applitools.common import ( AppEnvironment, @@ -23,11 +22,11 @@ FixedScaleProvider, ImageProvider, MouseTrigger, + NullCutProvider, NullScaleProvider, PositionProvider, RegionProvider, TextTrigger, - NullCutProvider, ) from applitools.selenium.capture.eyes_webdriver_screenshot import ( EyesWebDriverScreenshotFactory, @@ -36,6 +35,9 @@ FullPageCaptureAlgorithm, ) from applitools.selenium.capture.image_providers import get_image_provider +from applitools.selenium.capture.screenshot_utils import ( + cut_to_viewport_size_if_required, +) from applitools.selenium.region_compensation import ( RegionPositionCompensation, get_region_position_compensation, @@ -95,7 +97,7 @@ class SeleniumEyes(EyesBase): _user_agent = None # type: Optional[UserAgent] _image_provider = None # type: Optional[ImageProvider] _region_position_compensation = None # type: Optional[RegionPositionCompensation] - + _is_check_region = None # type: Optional[bool] current_frame_position_provider = None # type: Optional[PositionProvider] @staticmethod @@ -753,24 +755,28 @@ def _viewport_screenshot(self, scale_provider): # type: (ScaleProvider) -> EyesWebDriverScreenshot logger.info("Viewport screenshot requested") self._ensure_element_visible(self._target_element) - sleep(self.configuration.wait_before_screenshots / 1000.0) + image = self._get_scaled_cropped_image(scale_provider) + if not self._is_check_region and not self._driver.is_mobile_app: + # Some browsers return always full page screenshot (IE). + # So we cut such images to viewport size + image = cut_to_viewport_size_if_required(self.driver, image) + return EyesWebDriverScreenshot.create_viewport(self._driver, image) + + def _get_scaled_cropped_image(self, scale_provider): image = self._image_provider.get_image() self._debug_screenshot_provider.save(image, "original") - scale_provider.update_scale_ratio(image.width) pixel_ratio = 1 / scale_provider.scale_ratio if pixel_ratio != 1.0: - logger.info("Scalling") + logger.info("Scaling") image = image_utils.scale_image(image, 1.0 / pixel_ratio) self._debug_screenshot_provider.save(image, "scaled") - if not isinstance(self.cut_provider, NullCutProvider): logger.info("Cutting") image = self.cut_provider.cut(image) self._debug_screenshot_provider.save(image, "cutted") - - return EyesWebDriverScreenshot.create_viewport(self._driver, image) + return image def _get_viewport_scroll_bounds(self): switch_to = self.driver.switch_to @@ -889,6 +895,8 @@ def _check_element(self, name, check_settings): return result def _check_region(self, name, check_settings): + self._is_check_region = True + def get_region(): location = self._target_element.location size = self._target_element.size @@ -903,4 +911,5 @@ def get_region(): result = self._check_window_base( RegionProvider(get_region), name, False, check_settings ) + self._is_check_region = False return result diff --git a/tests/functional/eyes_selenium/selenium/test_specific_cases.py b/tests/functional/eyes_selenium/selenium/test_specific_cases.py index 55f0f2c56..e7e85ade4 100644 --- a/tests/functional/eyes_selenium/selenium/test_specific_cases.py +++ b/tests/functional/eyes_selenium/selenium/test_specific_cases.py @@ -1,3 +1,5 @@ +import os + import pytest from applitools.selenium import Region, Target @@ -59,3 +61,22 @@ def test_abort_eyes(eyes, driver): eyes.open(driver, "Python VisualGrid", "TestAbortSeleniumEyes") eyes.check_window() eyes.abort() + + +def test_ie_viewport_screenshot(eyes, webdriver_module): + sauce_url = "https://{username}:{password}@ondemand.saucelabs.com:443/wd/hub".format( + username=os.getenv("SAUCE_USERNAME", None), + password=os.getenv("SAUCE_ACCESS_KEY", None), + ) + driver = webdriver_module.Remote( + command_executor=sauce_url, + desired_capabilities={ + "browserName": "internet explorer", + "platform": "Windows 10", + "version": "11.285", + }, + ) + driver.get("http://applitools.github.io/demo/TestPages/FramesTestPage") + eyes.open(driver, "Python SDK", "TestIEViewportScreenshot") + eyes.check_window() + eyes.close()