Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions cbrain_cli/cli_utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,45 @@
import functools
import json
import re
import sys
import urllib.error

# import importlib.metadata
from cbrain_cli.config import CREDENTIALS_FILE
from cbrain_cli.config import ACTIVE_SESSION_KEY, CREDENTIALS_FILE

# Session name priority: --session flag > _active_session in cbrain.json > "default"
session_name = "default"
session_specified = False
for i, arg in enumerate(sys.argv):
if arg == "--session" and i + 1 < len(sys.argv):
session_name = sys.argv[i + 1]
session_specified = True
elif arg.startswith("--session="):
session_name = arg.split("=", 1)[1]
session_specified = True

try:
# MARK: Credentials.
with open(CREDENTIALS_FILE) as f:
credentials = json.load(f)
try:
with open(CREDENTIALS_FILE) as f:
all_credentials = json.load(f)
except FileNotFoundError:
all_credentials = {}

if not session_specified:
session_name = all_credentials.get(ACTIVE_SESSION_KEY, "default") or "default"

all_credentials.pop(ACTIVE_SESSION_KEY, None)

credentials = all_credentials.get(session_name, {})

# Get credentials.
cbrain_url = credentials.get("cbrain_url")
api_token = credentials.get("api_token")
user_id = credentials.get("user_id")
cbrain_timestamp = credentials.get("timestamp")
except FileNotFoundError:
cbrain_url = None
api_token = None
user_id = None
cbrain_timestamp = None
except Exception:
all_credentials = {}
cbrain_url = api_token = user_id = cbrain_timestamp = None


def is_authenticated():
Expand Down Expand Up @@ -85,7 +104,7 @@ def handle_connection_error(error):
if error.code == 401:
print(f"{status_description}: {error.reason}")
print("Error: Access denied. Please log in using authorized credentials.")
elif error.code == 404 or error.code == 422 or error.code == 500:
elif error.code in (400, 404, 422, 500):
# Try to extract specific error message from response
try:
# Check if the error response has already been read
Expand All @@ -107,6 +126,14 @@ def handle_connection_error(error):
or error_data.get("notice")
or str(error_data)
)
# Check if this looks like a password change redirect
if "change_password" in error_msg:
print(
f"{status_description}: Account requires "
"a password change. "
"Please log into the web portal."
)
return
print(f"{status_description}: {error_msg}")
return
except json.JSONDecodeError:
Expand Down
4 changes: 4 additions & 0 deletions cbrain_cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
SESSION_FILE_DIR.mkdir(parents=True, exist_ok=True)
CREDENTIALS_FILE = SESSION_FILE_DIR / SESSION_FILE_NAME

# Key used inside credentials.json to track the currently active session.
# Prefixed with "_" so it is clearly not a session name.
ACTIVE_SESSION_KEY = "_active_session"

# HTTP headers.
DEFAULT_HEADERS = {
"Content-Type": "application/x-www-form-urlencoded",
Expand Down
36 changes: 18 additions & 18 deletions cbrain_cli/data/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import urllib.request

from cbrain_cli.cli_utils import api_token, cbrain_url
from cbrain_cli.config import CREDENTIALS_FILE, auth_headers
from cbrain_cli.config import auth_headers
from cbrain_cli.sessions import save_credentials


def switch_project(args):
Expand All @@ -20,6 +21,8 @@ def switch_project(args):
dict or None
Dictionary containing project details if successful, None otherwise
"""
from cbrain_cli.cli_utils import all_credentials, session_name

# Get the group ID from the group_id argument
group_id = getattr(args, "group_id", None)
if not group_id:
Expand Down Expand Up @@ -56,15 +59,12 @@ def switch_project(args):
group_data = json.loads(group_data_text)

# Step 3: Update credentials file with current group_id
if CREDENTIALS_FILE.exists():
with open(CREDENTIALS_FILE) as f:
credentials = json.load(f)

credentials["current_group_id"] = group_id
credentials["current_group_name"] = group_data.get("name", "Unknown")

with open(CREDENTIALS_FILE, "w") as f:
json.dump(credentials, f, indent=2)
if session_name in all_credentials:
all_credentials[session_name]["current_group_id"] = group_id
all_credentials[session_name]["current_group_name"] = group_data.get(
"name", "Unknown"
)
save_credentials(all_credentials)

return group_data

Expand All @@ -83,6 +83,8 @@ def show_project(args):
dict or None
Dictionary containing project details if successful, None if no project set
"""
from cbrain_cli.cli_utils import all_credentials, session_name

# Check if a specific project ID was provided
project_id = getattr(args, "project_id", None)

Expand All @@ -105,10 +107,8 @@ def show_project(args):
raise
else:
# Show current project from credentials
with open(CREDENTIALS_FILE) as f:
credentials = json.load(f)

current_group_id = credentials.get("current_group_id")
session_creds = all_credentials.get(session_name, {})
current_group_id = session_creds.get("current_group_id")
if not current_group_id:
return None

Expand All @@ -128,10 +128,9 @@ def show_project(args):
if e.code == 404:
print(f"Error: Current project (ID {current_group_id}) no longer exists")
# Clear the invalid group_id from credentials
credentials.pop("current_group_id", None)
credentials.pop("current_group_name", None)
with open(CREDENTIALS_FILE, "w") as f:
json.dump(credentials, f, indent=2)
session_creds.pop("current_group_id", None)
session_creds.pop("current_group_name", None)
save_credentials(all_credentials)
return None
else:
raise
Expand Down Expand Up @@ -164,3 +163,4 @@ def list_projects(args):
projects_data = json.loads(data)

return projects_data

53 changes: 48 additions & 5 deletions cbrain_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
handle_tool_list,
handle_tool_show,
)
from cbrain_cli.sessions import create_session, logout_session
from cbrain_cli.sessions import create_session, list_sessions, logout_session, switch_session
from cbrain_cli.users import whoami_user


Expand All @@ -60,6 +60,12 @@ def main():
action="store_true",
help="Output in JSONL format (one JSON object per line)",
)
parser.add_argument(
"--session",
type=str,
default="default",
help="Session name to use for multiple configurations (default: default)",
)

subparsers = parser.add_subparsers(dest="command", help="Available commands")

Expand All @@ -70,17 +76,47 @@ def main():
# MARK: Session commands (top-level)
# Create new session.
login_parser = subparsers.add_parser("login", help="Login to CBRAIN")
login_parser.add_argument("--session", type=str, help="Session name to use")
login_parser.add_argument("-u", "--username", type=str, help="CBRAIN username")
login_parser.add_argument("-p", "--password", type=str, help="CBRAIN password")
login_parser.add_argument("-s", "--server", type=str, help="CBRAIN server URL")
login_parser.set_defaults(func=handle_errors(create_session))

# Logout session.
logout_parser = subparsers.add_parser("logout", help="Logout from CBRAIN")
logout_parser.add_argument(
"--session", type=str, help="Session name to logout (default: all sessions)"
)
logout_parser.set_defaults(func=handle_errors(logout_session))

# Show current session.
whoami_parser = subparsers.add_parser("whoami", help="Show current session")
whoami_parser.add_argument("--session", type=str, help="Session name to show")
whoami_parser.add_argument("-v", "--version", action="store_true", help="Show version")
whoami_parser.set_defaults(func=handle_errors(whoami_user))

# Switch active session.
switch_session_parser = subparsers.add_parser(
"switch_session",
help="Switch the default session (e.g. cbrain switch_session prod)",
)
switch_session_parser.add_argument(
"session_target",
type=str,
help="Name of the session to make the default",
)
switch_session_parser.set_defaults(func=handle_errors(switch_session))

# Session management sub-commands.
session_parser = subparsers.add_parser("session", help="Session management")
session_subparsers = session_parser.add_subparsers(
dest="action", help="Session actions"
)
session_list_parser = session_subparsers.add_parser(
"list", help="List all saved sessions"
)
session_list_parser.set_defaults(func=handle_errors(list_sessions))

# MARK: Model-based commands
# File commands
file_parser = subparsers.add_parser("file", help="File operations")
Expand Down Expand Up @@ -404,22 +440,29 @@ def main():
parser.print_help()
return

# Handle session commands (no authentication needed for login, version, and whoami).
# Handle public commands (no authentication needed).
if args.command == "login":
return handle_errors(create_session)(args)
elif args.command == "logout":
return handle_errors(logout_session)(args)
elif args.command == "version":
return handle_errors(version_info)(args)
elif args.command == "whoami":
return handle_errors(whoami_user)(args)
elif args.command == "switch_session":
return handle_errors(switch_session)(args)
elif args.command == "session":
if not getattr(args, "action", None):
session_parser.print_help()
return 1
return args.func(args)

# All other commands require authentication.
if not is_authenticated():
return 1

# Handle authenticated commands.
if args.command == "logout":
return handle_errors(logout_session)(args)
elif args.command in [
if args.command in [
"file",
"dataprovider",
"project",
Expand Down
Loading