Skip to content

Commit

Permalink
Added users listing and setting user and master passwords (#24)
Browse files Browse the repository at this point in the history
* Services before refactor

* Added services settings

* Disabled create_service function

* Part of changes - faced problem with geoserver - potential issue

* Added setting self and master passwords

* [Pep8] Flake8 refactor

* - Test fixes

Co-authored-by: Alessio Fabiani <alessio.fabiani@geo-solutions.it>
  • Loading branch information
jendrusk and Alessio Fabiani committed Feb 21, 2022
1 parent c2b733c commit ff8b645
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 30 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ lib/
lib64
doc/_build/
.idea/
.env

/src/geoserver_restconfig.egg-info/
99 changes: 74 additions & 25 deletions src/geoserver/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from geoserver.support import prepare_upload_bundle, build_url
from geoserver.layergroup import LayerGroup, UnsavedLayerGroup
from geoserver.workspace import workspace_from_index, Workspace
from geoserver.security import user_from_index
import os
import re
import base64
Expand Down Expand Up @@ -256,7 +257,7 @@ def parse_or_raise(xml):
raw_text = cached_response[1]
return parse_or_raise(raw_text)
else:
resp = self.http_request(rest_url)
resp = self.http_request(rest_url, headers={"Accept": "application/xml"})
if resp.status_code == 200:
content = resp.content
if isinstance(content, bytes):
Expand Down Expand Up @@ -1335,27 +1336,75 @@ def get_services(self, ogc_type="wms"):

return services

# global services are enabled by default, enabling services in workspaces using rest is broken in geoserver for now
# def create_service(self, ogc_type=None, workspace=None):
#
# KNOWN_TYPES = ["wms", "wfs", "wcs", "wmts"]
#
# if ogc_type is None:
# logger.error("You have to specify OGC Service Type ({types})".format(types=",".join(KNOWN_TYPES)))
# return None
#
# if ogc_type.lower() not in KNOWN_TYPES:
# logger.error("Unknown OGC Service Type (known are: {types})".format(types=",".join(KNOWN_TYPES)))
# return None
#
# if workspace is None:
# logger.info("Global services are created by default")
#
# if ogc_type.lower() == "wms":
# raise NotImplementedError()
# elif ogc_type.lower() == "wfs":
# raise NotImplementedError()
# elif ogc_type.lower() == "wcs":
# raise NotImplementedError()
# elif ogc_type.lower() == "wmts":
# raise NotImplementedError()
def get_users(self, names=None):
'''
Returns a list of users in the catalog.
If names is specified, will only return users that match.
names can either be a comma delimited string or an array.
Will return an empty list if no users are found (unlikely).
'''
if names is None:
names = []
elif isinstance(names, string_types):
names = [s.strip() for s in names.split(',') if s.strip()]

data = self.get_xml(f"{self.service_url}/security/usergroup/users/")
users = []
users.extend([user_from_index(self, node) for node in data.findall("user")])

if users and names:
return ([ws for ws in users if ws.name in names])

return users

def get_master_pwd(self):
url = f"{self.service_url}/security/masterpw.xml"
resp = self.http_request(url)
masterpwd = None
if resp.status_code == 200:
content = resp.content
if isinstance(content, bytes):
content = content.decode('UTF-8')
dom = XML(content)
masterpwd = dom.find("oldMasterPassword").text if dom.find("oldMasterPassword") is not None else None
else:
raise FailedRequestError(resp.content)

return masterpwd

def set_master_pwd(self, new_pwd):
old_pwd = self.get_master_pwd()
if old_pwd == new_pwd:
return new_pwd

headers = {"Content-Type": "application/xml"}
url = f"{self.service_url}/security/masterpw.xml"
body = ("<masterPassword>"
"<oldMasterPassword>{old_pwd}</oldMasterPassword>"
"<newMasterPassword>{new_pwd}</newMasterPassword>"
"</masterPassword>").format(old_pwd=old_pwd, new_pwd=new_pwd)
resp = self.http_request(url, method="put", data=body, headers=headers)

if resp.status_code == 200:
res = new_pwd
self.reload()
else:
raise FailedRequestError(resp.content)
return res

def set_my_pwd(self, new_pwd):
headers = {"Content-Type": "application/xml"}
url = f"{self.service_url}/security/self/password.xml"
body = ("<userPassword>"
"<newPassword>{new_pwd}</newPassword>"
"</userPassword>").format(new_pwd=new_pwd)
resp = self.http_request(url, method="put", data=body, headers=headers)

if resp.status_code == 200:
res = new_pwd
self.reload()
self.password = new_pwd
self.reload()
else:
raise FailedRequestError(resp.content)
return res
54 changes: 54 additions & 0 deletions src/geoserver/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
#########################################################################
#
# Copyright 2019, GeoSolutions Sas.
# Jendrusk also was here
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE.txt file in the root directory of this source tree.
#
#########################################################################
try:
from urllib.parse import urljoin
except BaseException:
from urlparse import urljoin

from geoserver.support import ResourceInfo, xml_property, write_bool


def user_from_index(catalog, node):
user_name = node.find("userName").text
return User(catalog, user_name)


class User(ResourceInfo):
resource_type = "user"

def __init__(self, catalog, user_name):
super(User, self).__init__()
self._catalog = catalog
self._user_name = user_name

@property
def catalog(self):
return self._catalog

@property
def user_name(self):
return self._user_name

@property
def href(self):
return urljoin(
f"{self.catalog.service_url}/",
f"security/usergroup/users/{self.user_name}"
)

enabled = xml_property("enabled", lambda x: x.lower() == 'true')
writers = {
'enabled': write_bool("enabled")
}

def __repr__(self):
return f"{self.user_name} @ {self.href}"
4 changes: 4 additions & 0 deletions test/catalogtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def testWorkspaces(self):
self.assertEqual("topp", self.cat.get_workspace("topp").name)
self.assertIsNone(self.cat.get_workspace("blahblah-"))

def testUsers(self):
users = self.cat.get_users()
x = 1

def testStores(self):
self.assertEqual(0, len(self.cat.get_stores(names="nonexistentstore")))
topp = self.cat.get_workspaces("topp")[0]
Expand Down
109 changes: 109 additions & 0 deletions test/securitytests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import unittest
import string
import random
import os
from .utils import DBPARAMS
from .utils import GSPARAMS
import subprocess
import re
import time
from geoserver.catalog import Catalog

if GSPARAMS['GEOSERVER_HOME']:
dest = GSPARAMS['DATA_DIR']
data = os.path.join(GSPARAMS['GEOSERVER_HOME'], 'data/release', '')
if dest:
os.system(f"rsync -v -a --delete {data} {os.path.join(dest, '')}")
else:
os.system(f'git clean -dxf -- {data}')
os.system(f"curl -XPOST --user '{GSPARAMS['GSUSER']}':'{GSPARAMS['GSPASSWORD']}' '{GSPARAMS['GSURL']}/reload'")

if GSPARAMS['GS_VERSION']:
subprocess.Popen(["rm", "-rf", f"{GSPARAMS['GS_BASE_DIR']}/gs"]).communicate()
subprocess.Popen(["mkdir", f"{GSPARAMS['GS_BASE_DIR']}/gs"]).communicate()
subprocess.Popen(
[
"wget",
"http://central.maven.org/maven2/org/eclipse/jetty/jetty-runner/9.4.5.v20170502/jetty-runner-9.4.5.v20170502.jar",
"-P", f"{GSPARAMS['GS_BASE_DIR']}/gs"
]
).communicate()

subprocess.Popen(
[
"wget",
f"https://build.geoserver.org/geoserver/{GSPARAMS['GS_VERSION']}/geoserver-{GSPARAMS['GS_VERSION']}-latest-war.zip",
"-P", f"{GSPARAMS['GS_BASE_DIR']}/gs"
]
).communicate()

subprocess.Popen(
[
"unzip",
"-o",
"-d",
f"{GSPARAMS['GS_BASE_DIR']}/gs",
f"{GSPARAMS['GS_BASE_DIR']}/gs/geoserver-{GSPARAMS['GS_VERSION']}-latest-war.zip"
]
).communicate()

FNULL = open(os.devnull, 'w')

match = re.compile(r'[^\d.]+')
geoserver_short_version = match.sub('', GSPARAMS['GS_VERSION']).strip('.')
if geoserver_short_version >= "2.15" or GSPARAMS['GS_VERSION'].lower() == 'master':
java_executable = "/usr/local/lib/jvm/openjdk11/bin/java"
else:
java_executable = "/usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java"

print(f"geoserver_short_version: {geoserver_short_version}")
print(f"java_executable: {java_executable}")
proc = subprocess.Popen(
[
java_executable,
"-Xmx1024m",
"-Dorg.eclipse.jetty.server.webapp.parentLoaderPriority=true",
"-jar", f"{GSPARAMS['GS_BASE_DIR']}/gs/jetty-runner-9.4.5.v20170502.jar",
"--path", "/geoserver", f"{GSPARAMS['GS_BASE_DIR']}/gs/geoserver.war"
],
stdout=FNULL, stderr=subprocess.STDOUT
)
child_pid = proc.pid
print("Sleep (90)...")
time.sleep(40)


class SecurityTests(unittest.TestCase):
def setUp(self):
self.cat = Catalog(GSPARAMS['GSURL'], username=GSPARAMS['GSUSER'], password=GSPARAMS['GSPASSWORD'])
self.bkp_cat = Catalog(GSPARAMS['GSURL'], username=GSPARAMS['GSUSER'], password=GSPARAMS['GSPASSWORD'])
self.gs_version = self.cat.get_short_version()
self.bkp_masterpwd = self.bkp_cat.get_master_pwd()
self.bkp_my_pwd = self.cat.password

def tearDown(self) -> None:
self.bkp_cat.set_master_pwd(self.bkp_masterpwd)
self.bkp_cat.set_my_pwd(self.bkp_my_pwd)

def test_get_users(self):
users = self.cat.get_users()
self.assertGreater(len(users), 0)

def test_get_master_pwd(self):
master_pwd = self.cat.get_master_pwd()
self.assertIsNotNone(master_pwd)

def test_set_master_pwd(self):
test_pwd = ''.join(random.sample(string.ascii_lowercase, 10))
master_pwd = self.cat.set_master_pwd(new_pwd=test_pwd)
self.assertIsNotNone(master_pwd)
self.assertEqual(master_pwd, test_pwd)
new_master_pwd = self.cat.get_master_pwd()
self.assertEqual(new_master_pwd, test_pwd)

def test_set_my_pwd(self):
test_pwd = ''.join(random.sample(string.ascii_lowercase, 10))
new_pwd = self.cat.set_my_pwd(new_pwd=test_pwd)
self.assertIsNotNone(new_pwd)
self.assertEqual(new_pwd, test_pwd)
self.assertEqual(self.cat.password, test_pwd)
11 changes: 6 additions & 5 deletions test/servicestests.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,16 @@ def test_global_wms(self):
# test enums
attrs = [k for k in self.wms_enums.keys()]
for attr in attrs:
test_str = self.wms_enums[attr][random.randint(0, len(self.wms_enums[attr]))]
attr_values = self.wms_enums[attr]
test_str = attr_values[random.randint(0, len(attr_values) - 1)]
setattr(wms_srv, attr, test_str)
self.cat.save(wms_srv)
wms_srv.refresh()
self.assertIsNone(wms_srv.dirty.get(attr), msg=f"Attribute {attr} still in dirty list")
self.assertEqual(getattr(wms_srv, attr), test_str, msg=f"Invalid value for object {attr}")

# test int
attrs = [k for k in wms_srv.writers.keys() if isinstance(getattr(wms_srv, k), int)]
attrs = [k for k in wms_srv.writers.keys() if isinstance(getattr(wms_srv, k), int) and not isinstance(getattr(wms_srv, k), bool)]
for attr in attrs:
test_int = random.randint(1, 20)
setattr(wms_srv, attr, test_int)
Expand Down Expand Up @@ -213,7 +214,7 @@ def test_global_wfs(self):
self.assertEqual(getattr(wfs_srv, attr), test_str, msg=f"Invalid value for object {attr}")

# test int
attrs = [k for k in wfs_srv.writers.keys() if isinstance(getattr(wfs_srv, k), int)]
attrs = [k for k in wfs_srv.writers.keys() if isinstance(getattr(wfs_srv, k), int) and not isinstance(getattr(wfs_srv, k), bool)]
for attr in attrs:
test_int = random.randint(1, 20)
setattr(wfs_srv, attr, test_int)
Expand Down Expand Up @@ -263,7 +264,7 @@ def test_global_wcs(self):
self.assertEqual(getattr(wcs_srv, attr), test_str, msg=f"Invalid value for object {attr}")

# test int
attrs = [k for k in wcs_srv.writers.keys() if isinstance(getattr(wcs_srv, k), int)]
attrs = [k for k in wcs_srv.writers.keys() if isinstance(getattr(wcs_srv, k), int) and not isinstance(getattr(wcs_srv, k), bool)]
for attr in attrs:
test_int = random.randint(1, 20)
setattr(wcs_srv, attr, test_int)
Expand Down Expand Up @@ -313,7 +314,7 @@ def test_global_wmts(self):
self.assertEqual(getattr(wmts_srv, attr), test_str, msg=f"Invalid value for object {attr}")

# test int
attrs = [k for k in wmts_srv.writers.keys() if isinstance(getattr(wmts_srv, k), int)]
attrs = [k for k in wmts_srv.writers.keys() if isinstance(getattr(wmts_srv, k), int) and not isinstance(getattr(wmts_srv, k), bool)]
for attr in attrs:
test_int = random.randint(1, 20)
setattr(wmts_srv, attr, test_int)
Expand Down

0 comments on commit ff8b645

Please sign in to comment.