Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(#984): manage super user workspaces #1268

Merged
merged 14 commits into from Mar 22, 2022
20 changes: 17 additions & 3 deletions frontend/components/commons/header/user/user.vue
Expand Up @@ -27,9 +27,9 @@
{{ user.username }}<span>Private Workspace</span>
</p>
</a>
<p v-if="user.workspaces">Team workspaces</p>
<p v-if="userWorkspaces">Team workspaces</p>
<a
v-for="workspace in user.workspaces"
v-for="workspace in userWorkspaces"
:key="workspace"
href="#"
class="user__workspace"
Expand All @@ -53,7 +53,11 @@

<script>
import { mapActions } from "vuex";
import { setWorkspace, currentWorkspace } from "@/models/Workspace";
import {
setWorkspace,
currentWorkspace,
NO_WORKSPACE,
} from "@/models/Workspace";
export default {
data: () => {
return {
Expand All @@ -65,6 +69,16 @@ export default {
user() {
return this.$auth.user;
},
userWorkspaces() {
return (this.user.workspaces || [])
.map((ws) => {
if (ws === "") {
return NO_WORKSPACE;
}
return ws;
})
.filter((ws) => ws !== this.user.username);
},
currentWorkspace() {
return currentWorkspace(this.$route);
},
Expand Down
8 changes: 2 additions & 6 deletions frontend/database/modules/datasets.js
Expand Up @@ -17,7 +17,7 @@
import { ObservationDataset, USER_DATA_METADATA_KEY } from "@/models/Dataset";
import { DatasetViewSettings, Pagination } from "@/models/DatasetViewSettings";
import { AnnotationProgress } from "@/models/AnnotationProgress";
import { currentWorkspace, defaultWorkspace } from "@/models/Workspace";
import { currentWorkspace, NO_WORKSPACE } from "@/models/Workspace";
import { Base64 } from "js-base64";

const isObject = (obj) => obj && typeof obj === "object";
Expand Down Expand Up @@ -526,9 +526,6 @@ const actions = {

async deleteDataset(_, { workspace, name }) {
var url = `/datasets/${name}`;
if (workspace !== defaultWorkspace($nuxt.$auth.user)) {
url += `?workspace=${workspace}`;
}
const deleteResults = await ObservationDataset.api().delete(url, {
delete: [workspace, name],
});
Expand All @@ -547,9 +544,8 @@ const actions = {
return await ObservationDataset.api().get("/datasets/", {
persistBy: "create",
dataTransformer: ({ data }) => {
const owner = defaultWorkspace($nuxt.$auth.user);
return data.map((datasource) => {
datasource.owner = datasource.owner || owner;
datasource.owner = datasource.owner || NO_WORKSPACE;
return datasource;
});
},
Expand Down
9 changes: 8 additions & 1 deletion frontend/models/Workspace.js
Expand Up @@ -30,4 +30,11 @@ function setWorkspace(router, workspace) {
router.push(workspaceHome(workspace));
}

export { defaultWorkspace, currentWorkspace, setWorkspace, workspaceHome };
const NO_WORKSPACE = "-";
export {
defaultWorkspace,
currentWorkspace,
setWorkspace,
workspaceHome,
NO_WORKSPACE,
};
8 changes: 5 additions & 3 deletions frontend/plugins/vuex-orm-axios.js
Expand Up @@ -19,7 +19,7 @@ import { Model } from "@vuex-orm/core";
import { ExpiredAuthSessionError } from "@nuxtjs/auth-next/dist/runtime";
import { Notification } from "@/models/Notifications";

import { currentWorkspace, defaultWorkspace } from "@/models/Workspace";
import { currentWorkspace, NO_WORKSPACE } from "@/models/Workspace";

export default ({ $axios, app }) => {
Model.setAxios($axios);
Expand All @@ -31,8 +31,10 @@ export default ({ $axios, app }) => {
return config;
}

const ws = currentWorkspace(app.context.route);
if (ws && ws !== defaultWorkspace(currentUser)) {
let ws = currentWorkspace(app.context.route);
if (ws === NO_WORKSPACE) {
config.headers["X-Rubrix-Workspace"] = "";
} else if (ws) {
config.headers["X-Rubrix-Workspace"] = ws;
}
return config;
Expand Down
7 changes: 3 additions & 4 deletions src/rubrix/client/api.py
Expand Up @@ -111,7 +111,7 @@ def __init__(
)
self._user: User = whoami(client=self._client)

if workspace:
if workspace is not None:
self.set_workspace(workspace)

def set_workspace(self, workspace: str):
Expand All @@ -125,9 +125,8 @@ def set_workspace(self, workspace: str):

if workspace != self.get_workspace():
if workspace == self._user.username:
self._client.headers.pop(RUBRIX_WORKSPACE_HEADER_NAME, None)
return
if (
self._client.headers.pop(RUBRIX_WORKSPACE_HEADER_NAME, workspace)
elif (
self._user.workspaces is not None
and workspace not in self._user.workspaces
):
Expand Down
9 changes: 8 additions & 1 deletion src/rubrix/server/commons/api.py
Expand Up @@ -3,6 +3,7 @@
from fastapi import Header, Query

from rubrix._constants import RUBRIX_WORKSPACE_HEADER_NAME
from rubrix.server.security.model import WORKSPACE_NAME_PATTERN


@dataclass
Expand All @@ -20,4 +21,10 @@ class CommonTaskQueryParams:
@property
def workspace(self) -> str:
"""Return read workspace. Query param prior to header param"""
return self.__workspace_param__ or self.__workspace_header__
workspace = self.__workspace_param__ or self.__workspace_header__
if workspace:
assert WORKSPACE_NAME_PATTERN.match(workspace), (
"Wrong workspace format. "
f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}"
)
return workspace
6 changes: 3 additions & 3 deletions src/rubrix/server/commons/errors/base_errors.py
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Type, Union

import pydantic
from starlette import status
Expand Down Expand Up @@ -151,9 +151,9 @@ class EntityNotFoundError(RubrixServerError):

HTTP_STATUS = status.HTTP_404_NOT_FOUND

def __init__(self, name: str, type: Type):
def __init__(self, name: str, type: Union[Type, str]):
self.name = name
self.type = type.__name__
self.type = type if isinstance(type, str) else type.__name__


class ClosedDatasetError(BadRequestError):
Expand Down
13 changes: 13 additions & 0 deletions src/rubrix/server/commons/es_wrapper.py
Expand Up @@ -23,6 +23,8 @@
from rubrix.logging import LoggingMixin
from rubrix.server.commons.errors import InvalidTextSearchError

from . import es_helpers

try:
import ujson as json
except ModuleNotFoundError:
Expand Down Expand Up @@ -544,6 +546,17 @@ def get_cluster_info(self) -> Dict[str, Any]:
except OpenSearchException as ex:
return {"error": ex}

def aggregate(self, index: str, aggregation: Dict[str, Any]) -> Dict[str, Any]:
"""Apply an aggregation over the index returning ONLY the agg results"""
aggregation_name = "aggregation"
results = self.search(
index=index, size=0, query={"aggs": {aggregation_name: aggregation}}
)

return es_helpers.parse_aggregations(results["aggregations"]).get(
aggregation_name
)


_instance = None # The singleton instance

Expand Down
4 changes: 2 additions & 2 deletions src/rubrix/server/datasets/api.py
Expand Up @@ -53,11 +53,11 @@ def list_datasets(

Returns
-------
A list of datasets visibles by current user
A list of datasets visible by current user
"""
return service.list(
user=current_user,
workspaces=[ds_params.workspace],
workspaces=[ds_params.workspace] if ds_params.workspace is not None else None,
)


Expand Down
35 changes: 33 additions & 2 deletions src/rubrix/server/datasets/dao.py
Expand Up @@ -30,6 +30,8 @@

BaseDatasetDB = TypeVar("BaseDatasetDB", bound=DatasetDB)

NO_WORKSPACE = ""


class DatasetsDAO:
"""Datasets DAO"""
Expand Down Expand Up @@ -97,9 +99,24 @@ def list_datasets(
filters = []
dataset_type = DatasetDB
if owner_list:
filters.append({"terms": {"owner.keyword": owner_list}})
owners_filter = es_helpers.filters.terms_filter("owner.keyword", owner_list)
if NO_WORKSPACE in owner_list:
filters.append(
es_helpers.filters.boolean_filter(
minimum_should_match=1, # OR Condition
should_filters=[
es_helpers.filters.boolean_filter(
must_not_query=es_helpers.filters.exists_field("owner")
),
owners_filter,
],
)
)
else:
filters.append(owners_filter)

if task:
filters.append({"term": {"task.keyword": task}})
filters.append(es_helpers.filters.term_filter("task.keyword", task))
dataset_type = TaskFactory.get_task_dataset(task)

docs = self._es.list_documents(
Expand Down Expand Up @@ -301,3 +318,17 @@ def close(self, dataset: DatasetDB):
def open(self, dataset: DatasetDB):
"""Make available a dataset"""
self._es.open_index(dataset_records_index(dataset.id))

def get_all_workspaces(self) -> List[str]:
"""Get all datasets (Only for super users)"""

workspaces_dict = self._es.aggregate(
index=DATASETS_INDEX_NAME,
aggregation=es_helpers.aggregations.terms_aggregation(
"owner.keyword",
missing=NO_WORKSPACE,
size=500, # TODO: A max number of workspaces env var could be leveraged by this.
),
)

return [k for k in workspaces_dict]
18 changes: 16 additions & 2 deletions src/rubrix/server/datasets/service.py
Expand Up @@ -18,6 +18,7 @@

from fastapi import Depends

from rubrix.server.commons import es_helpers
from rubrix.server.commons.errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
Expand Down Expand Up @@ -108,14 +109,20 @@ def update(
return self.__dao__.update_dataset(updated)

def list(
self, user: User, workspaces: List[str], task: Optional[TaskType] = None
self,
user: User,
workspaces: Optional[List[str]],
task: Optional[TaskType] = None,
) -> List[Dataset]:
owners = user.check_workspaces(workspaces)

datasets = []
for task_config in TaskFactory.get_all_configs():
datasets.extend(
self.__dao__.list_datasets(owner_list=owners, task=task_config.task)
self.__dao__.list_datasets(
owner_list=owners,
task=task_config.task,
)
)
return datasets

Expand Down Expand Up @@ -173,3 +180,10 @@ def copy_dataset(
)

return copy_dataset

def all_workspaces(self) -> List[str]:
"""Retrieve all dataset workspaces"""

workspaces = self.__dao__.get_all_workspaces()
# include the non-workspace workspace?
return workspaces
20 changes: 12 additions & 8 deletions src/rubrix/server/security/auth_provider/local/provider.py
Expand Up @@ -14,21 +14,24 @@
# limitations under the License.

from datetime import datetime, timedelta
from typing import Optional

from fastapi import APIRouter, Depends
from fastapi.security import (
OAuth2PasswordBearer,
OAuth2PasswordRequestForm,
SecurityScopes,
)
from jose import JWTError, jwt

from rubrix.server.commons.errors import InactiveUserError, UnauthorizedError
from rubrix.server.security.auth_provider.base import (
AuthProvider,
api_key_header,
)
from rubrix.server.commons.es_wrapper import create_es_wrapper
from rubrix.server.datasets.dao import DatasetsDAO
from rubrix.server.datasets.service import DatasetsService
from rubrix.server.security.auth_provider.base import AuthProvider, api_key_header
from rubrix.server.security.auth_provider.local.users.service import UsersService
from rubrix.server.security.model import Token, User
from typing import Optional
from rubrix.server.tasks.commons.dao.dao import DatasetRecordsDAO

from .settings import Settings, settings

Expand Down Expand Up @@ -171,10 +174,11 @@ async def _find_user_by_api_key(self, api_key) -> User:

def create_local_auth_provider():
from .users.dao import create_users_dao
from .users.service import create_users_service

settings = Settings()
users_dao = create_users_dao()
users_service = create_users_service(users_dao)

users_service = UsersService.get_instance(
users=create_users_dao(),
)

return LocalAuthProvider(users=users_service, settings=settings)
8 changes: 7 additions & 1 deletion src/rubrix/server/security/auth_provider/local/users/dao.py
Expand Up @@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Iterable, Optional

import yaml

from rubrix.server.security.auth_provider.local.settings import settings
from typing import Dict, Optional

from .model import UserInDB

Expand Down Expand Up @@ -49,6 +51,10 @@ async def get_user_by_api_key(self, api_key: str) -> Optional[UserInDB]:
if api_key == user.api_key:
return user

def all_users(self) -> Iterable[UserInDB]:
for user in self.__users__.values():
yield user


_instance: Optional[UsersDAO] = None

Expand Down