Skip to content
3 changes: 2 additions & 1 deletion scratchattach/site/comment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Comment class"""
from __future__ import annotations

from typing import Union, Optional, assert_never, Any
from typing import Union, Optional, Any
from typing_extensions import assert_never # importing from typing caused me errors
from enum import Enum, auto

from . import user, project, studio
Expand Down
109 changes: 80 additions & 29 deletions scratchattach/site/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import re
import time
import warnings
import zlib

from typing import Optional, TypeVar, TYPE_CHECKING, overload, Any, Union
from contextlib import contextmanager
from threading import local
Expand Down Expand Up @@ -102,6 +104,8 @@ def __init__(self, **entries):

# Set attributes that Session object may get
self._user: user.User = None
self.time_created: datetime.datetime = None
self.language = "en" # default

# Update attributes from entries dict:
self.__dict__.update(entries)
Expand All @@ -119,6 +123,9 @@ def __init__(self, **entries):
"Content-Type": "application/json",
}

if self.id:
self._process_session_id()

def _update_from_dict(self, data: dict):
# Note: there are a lot more things you can get from this data dict.
# Maybe it would be a good idea to also store the dict itself?
Expand All @@ -141,13 +148,34 @@ def _update_from_dict(self, data: dict):
self.banned = data["user"]["banned"]

if self.banned:
warnings.warn(f"Warning: The account {self._username} you logged in to is BANNED. "
warnings.warn(f"Warning: The account {self.username} you logged in to is BANNED. "
f"Some features may not work properly.")
if self.has_outstanding_email_confirmation:
warnings.warn(f"Warning: The account {self._username} you logged is not email confirmed. "
warnings.warn(f"Warning: The account {self.username} you logged is not email confirmed. "
f"Some features may not work properly.")
return True

def _process_session_id(self):
assert self.id

data, self.time_created = decode_session_id(self.id)

self.username = data["username"]
self._username = self.username
if self._user:
self._user.username = self.username
else:
self._user = user.User(_session=self, username=self.username)

self._user.id = data["_auth_user_id"]
self.xtoken = data["token"]
self._headers["X-Token"] = self.xtoken

# not saving the login ip because it is a security issue, and is not very helpful

self.language = data["_language"]
# self._cookies["scratchlanguage"] = self.language

def connect_linked_user(self) -> user.User:
"""
Gets the user associated with the login / session.
Expand All @@ -166,7 +194,7 @@ def connect_linked_user(self) -> user.User:
self._user = self.connect_user(self._username)
return self._user

def get_linked_user(self) -> 'user.User':
def get_linked_user(self) -> user.User:
# backwards compatibility with v1

# To avoid inconsistencies with "connect" and "get", this function was renamed
Expand Down Expand Up @@ -1021,10 +1049,43 @@ def get_headers(self) -> dict[str, str]:
def get_cookies(self) -> dict[str, str]:
return self._cookies


# ------ #

def decode_session_id(session_id: str) -> tuple[dict[str, str], datetime.datetime]:
"""
Extract the JSON data from the main part of a session ID string
Session id is in the format:
<p1: long base64 string>:<p2: short base64 string>:<p3: medium base64 string>

p1 contains a base64-zlib compressed JSON string
p2 is a base 62 encoded timestamp
p3 might be a `synchronous signature` for the first 2 parts (might be useless for us)

The dict has these attributes:
- username
- _auth_user_id
- testcookie
- _auth_user_backend
- token
- login-ip
- _language
- django_timezone
- _auth_user_hash
"""
p1, p2, p3 = session_id.split(':')

return (
json.loads(zlib.decompress(base64.urlsafe_b64decode(p1 + "=="))),
datetime.datetime.fromtimestamp(commons.b62_decode(p2))
)


# ------ #

suppressed_login_warning = local()


@contextmanager
def suppress_login_warning():
"""
Expand All @@ -1037,6 +1098,7 @@ def suppress_login_warning():
finally:
suppressed_login_warning.suppressed -= 1


def issue_login_warning() -> None:
"""
Issue a login data warning.
Expand All @@ -1051,6 +1113,7 @@ def issue_login_warning() -> None:
exceptions.LoginDataWarning
)


def login_by_id(session_id: str, *, username: Optional[str] = None, password: Optional[str] = None, xtoken=None) -> Session:
"""
Creates a session / log in to the Scratch website with the specified session id.
Expand All @@ -1067,39 +1130,23 @@ def login_by_id(session_id: str, *, username: Optional[str] = None, password: Op
Returns:
scratchattach.session.Session: An object that represents the created login / session
"""
# Removed this from docstring since it doesn't exist:
# timeout (int): Optional, but recommended.
# Specify this when the Python environment's IP address is blocked by Scratch's API,
# but you still want to use cloud variables.

# Generate session_string (a scratchattach-specific authentication method)
issue_login_warning()
if password is not None:
session_data = dict(id=session_id, username=username, password=password)
session_string = base64.b64encode(json.dumps(session_data).encode()).decode()
else:
session_string = None
_session = Session(id=session_id, username=username, session_string=session_string, xtoken=xtoken)

try:
status = _session.update()
except Exception as e:
status = False
warnings.warn(f"Key error at key {e} when reading scratch.mit.edu/session API response")

if status is not True:
if _session.xtoken is None:
if _session.username is None:
warnings.warn("Warning: Logged in by id, but couldn't fetch XToken. "
"Make sure the provided session id is valid. "
"Setting cloud variables can still work if you provide a "
"`username='username'` keyword argument to the sa.login_by_id function")
else:
warnings.warn("Warning: Logged in by id, but couldn't fetch XToken. "
"Make sure the provided session id is valid.")
else:
warnings.warn("Warning: Logged in by id, but couldn't fetch session info. "
"This won't affect any other features.")
if xtoken is not None:
# todo: consider removing the xtoken parameter?
warnings.warn("xtoken is redundant because it is retrieved by decoding the session id.")

_session = Session(id=session_id, username=username, session_string=session_string)

# xtoken is decoded from sessid, so don't use sess.update
# but this will cause incompatibilities, warranting a change in the 2nd (semver) version number

return _session


Expand Down Expand Up @@ -1136,14 +1183,15 @@ def login(username, password, *, timeout=10) -> Session:
result = re.search('"(.*)"', request.headers["Set-Cookie"])
assert result is not None
session_id = str(result.group())
except (AssertionError, Exception):
except Exception:
raise exceptions.LoginFailure(
"Either the provided authentication data is wrong or your network is banned from Scratch.\n\nIf you're using an online IDE (like replit.com) Scratch possibly banned its IP address. In this case, try logging in with your session id: https://github.com/TimMcCool/scratchattach/wiki#logging-in")

# Create session object:
with suppress_login_warning():
return login_by_id(session_id, username=username, password=password)


def login_by_session_string(session_string: str) -> Session:
"""
Login using a session string.
Expand Down Expand Up @@ -1173,20 +1221,23 @@ def login_by_session_string(session_string: str) -> Session:
pass
raise ValueError("Couldn't log in.")


def login_by_io(file: SupportsRead[str]) -> Session:
"""
Login using a file object.
"""
with suppress_login_warning():
return login_by_session_string(file.read())


def login_by_file(file: FileDescriptorOrPath) -> Session:
"""
Login using a path to a file.
"""
with suppress_login_warning(), open(file, encoding="utf-8") as f:
return login_by_io(f)


def login_from_browser(browser: Browser = ANY):
"""
Login from a browser
Expand Down
14 changes: 14 additions & 0 deletions scratchattach/utils/commons.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""v2 ready: Common functions used by various internal modules"""
from __future__ import annotations

import string

from typing import Optional, Final, Any, TypeVar, Callable, TYPE_CHECKING, Union
from threading import Lock

Expand All @@ -9,6 +11,7 @@

from ..site import _base


headers: Final = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/75.0.3770.142 Safari/537.36",
Expand Down Expand Up @@ -241,3 +244,14 @@ def get_class_sort_mode(mode: str) -> tuple[str, str]:
descsort = "title"

return ascsort, descsort


def b62_decode(s: str):
chars = string.digits + string.ascii_uppercase + string.ascii_lowercase

ret = 0
for char in s:
ret = ret * 62 + chars.index(char)

return ret