Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frame access controls #61

Merged
merged 5 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/app/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .repositories import *
from .settings import *
from .templates import *
from .misc import *
13 changes: 10 additions & 3 deletions backend/app/api/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def api_frame_get_image(id: int):
frame = Frame.query.get_or_404(id)
cache_key = f'frame:{frame.frame_host}:{frame.frame_port}:image'
url = f'http://{frame.frame_host}:{frame.frame_port}/image'
if frame.frame_access == "private":
url += "?k=" + frame.frame_access_key

try:
if request.args.get('t') == '-1':
Expand Down Expand Up @@ -74,6 +76,8 @@ def api_frame_get_state(id: int):
frame = Frame.query.get_or_404(id)
cache_key = f'frame:{frame.frame_host}:{frame.frame_port}:state'
url = f'http://{frame.frame_host}:{frame.frame_port}/state'
if (frame.frame_access == "private" or frame.frame_access == "protected" )and frame.frame_access_key is not None:
url += "?k=" + frame.frame_access_key

try:
last_state = redis.get(cache_key)
Expand All @@ -99,11 +103,14 @@ def api_frame_get_state(id: int):
def api_frame_event(id: int, event: str):
frame = Frame.query.get_or_404(id)
try:
headers = {}
if (frame.frame_access == "protected" or frame.frame_access == "private") and frame.frame_access_key is not None:
headers["Authorization"] = f'Bearer {frame.frame_access_key}'
if request.is_json:
headers = {"Content-Type": "application/json"}
headers["Content-Type"] = "application/json"
response = requests.post(f'http://{frame.frame_host}:{frame.frame_port}/event/{event}', json=request.json, headers=headers)
else:
response = requests.post(f'http://{frame.frame_host}:{frame.frame_port}/event/{event}')
response = requests.post(f'http://{frame.frame_host}:{frame.frame_port}/event/{event}', headers=headers)
if response.status_code == 200:
return "OK", 200
else:
Expand Down Expand Up @@ -164,7 +171,7 @@ def api_frame_deploy_event(id: int):
@login_required
def api_frame_update(id: int):
frame = Frame.query.get_or_404(id)
fields = ['scenes', 'name', 'frame_host', 'frame_port', 'ssh_user', 'ssh_pass', 'ssh_port', 'server_host',
fields = ['scenes', 'name', 'frame_host', 'frame_port', 'frame_access_key', 'frame_access', 'ssh_user', 'ssh_pass', 'ssh_port', 'server_host',
'server_port', 'server_api_key', 'width', 'height', 'rotate', 'color', 'interval', 'metrics_interval',
'scaling_mode', 'background_color', 'device', 'debug']
defaults = {'frame_port': 8787, 'ssh_port': 22}
Expand Down
31 changes: 31 additions & 0 deletions backend/app/api/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from flask import jsonify
from flask_login import login_required
from . import api


@api.route("/generate_ssh_keys", methods=["POST"])
@login_required
def generate_ssh_keys():
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization

try:
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=3072,
)
except:
return jsonify(error="Key generation error"), 500

public_key = private_key.public_key()
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH
)

return jsonify({"private": private_key_bytes.decode('utf-8'), "public": public_key_bytes.decode('utf-8')})
27 changes: 0 additions & 27 deletions backend/app/api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,3 @@ def set_settings():
return jsonify(error="Database error"), 500

return jsonify(get_settings_dict())

@api.route("/generate_ssh_keys", methods=["POST"])
@login_required
def generate_ssh_keys():
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization

try:
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=3072,
)
except:
return jsonify(error="Key generation error"), 500

public_key = private_key.public_key()
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
public_key_bytes = public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH
)

return jsonify({"private": private_key_bytes.decode('utf-8'), "public": public_key_bytes.decode('utf-8')})
32 changes: 32 additions & 0 deletions backend/app/api/tests/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
from unittest.mock import patch
from app.tests.base import BaseTestCase

class TestMiscAPI(BaseTestCase):
def test_generate_ssh_keys(self):
# Test the POST /generate_ssh_keys endpoint
response = self.client.post('/api/generate_ssh_keys')
self.assertEqual(response.status_code, 200)
keys = json.loads(response.data)
self.assertIn('private', keys)
self.assertIn('public', keys)

def test_unauthorized_access(self):
self.logout()

endpoints = [
('/api/settings', 'GET', None),
('/api/settings', 'POST', {'some_setting': 'value'}),
('/api/generate_ssh_keys', 'POST', None)
]
for endpoint, method, data in endpoints:
response = self.client.open(endpoint, method=method, json=data)
self.assertEqual(response.status_code, 401)

def test_generate_ssh_keys_error_handling(self):
# Simulate an error during key generation
with patch('cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key') as mock_generate:
mock_generate.side_effect = Exception("Key generation error")
response = self.client.post('/api/generate_ssh_keys')
self.assertEqual(response.status_code, 500)

18 changes: 0 additions & 18 deletions backend/app/api/tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from unittest.mock import patch
from app.tests.base import BaseTestCase

class TestSettingsAPI(BaseTestCase):
Expand All @@ -23,30 +22,13 @@ def test_set_settings_no_payload(self):
response = self.client.post('/api/settings', json={})
self.assertEqual(response.status_code, 400)

def test_generate_ssh_keys(self):
# Test the POST /generate_ssh_keys endpoint
response = self.client.post('/api/generate_ssh_keys')
self.assertEqual(response.status_code, 200)
keys = json.loads(response.data)
self.assertIn('private', keys)
self.assertIn('public', keys)

def test_unauthorized_access(self):
self.logout()

endpoints = [
('/api/settings', 'GET', None),
('/api/settings', 'POST', {'some_setting': 'value'}),
('/api/generate_ssh_keys', 'POST', None)
]
for endpoint, method, data in endpoints:
response = self.client.open(endpoint, method=method, json=data)
self.assertEqual(response.status_code, 401)

def test_generate_ssh_keys_error_handling(self):
# Simulate an error during key generation
with patch('cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key') as mock_generate:
mock_generate.side_effect = Exception("Key generation error")
response = self.client.post('/api/generate_ssh_keys')
self.assertEqual(response.status_code, 500)

15 changes: 11 additions & 4 deletions backend/app/models/frame.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import uuid
import secrets
from datetime import timezone
from app import db, socketio
from typing import Optional
from sqlalchemy.dialects.sqlite import JSON

from app.models.apps import get_app_configs
from app.models.settings import get_settings_dict
from app.utils.token import secure_token


# NB! Update frontend/src/types.tsx if you change this
Expand All @@ -17,6 +17,8 @@ class Frame(db.Model):
# sending commands to frame
frame_host = db.Column(db.String(256), nullable=False)
frame_port = db.Column(db.Integer, default=8787)
frame_access_key = db.Column(db.String(256), nullable=True)
frame_access = db.Column(db.String(50), nullable=True)
ssh_user = db.Column(db.String(50), nullable=True)
ssh_pass = db.Column(db.String(50), nullable=True)
ssh_port = db.Column(db.Integer, default=22)
Expand Down Expand Up @@ -51,6 +53,8 @@ def to_dict(self):
'name': self.name,
'frame_host': self.frame_host,
'frame_port': self.frame_port,
'frame_access_key': self.frame_access_key,
'frame_access': self.frame_access,
'ssh_user': self.ssh_user,
'ssh_pass': self.ssh_pass,
'ssh_port': self.ssh_port,
Expand All @@ -73,7 +77,6 @@ def to_dict(self):
'last_log_at': self.last_log_at.replace(tzinfo=timezone.utc).isoformat() if self.last_log_at else None,
}


def new_frame(name: str, frame_host: str, server_host: str, device: Optional[str] = None, interval: Optional[float] = None) -> Frame:
if '@' in frame_host:
user_pass, frame_host = frame_host.split('@')
Expand Down Expand Up @@ -102,11 +105,13 @@ def new_frame(name: str, frame_host: str, server_host: str, device: Optional[str
name=name,
ssh_user=user,
ssh_pass=password,
frame_host=frame_host,
ssh_port=ssh_port,
frame_host=frame_host,
frame_access_key=secure_token(20),
frame_access="private",
server_host=server_host,
server_port=int(server_port),
server_api_key=secrets.token_hex(32),
server_api_key=secure_token(32),
interval=interval or 60,
status="uninitialized",
apps=[],
Expand Down Expand Up @@ -195,6 +200,8 @@ def get_frame_json(frame: Frame) -> dict:
"name": frame.name,
"frameHost": frame.frame_host or "localhost",
"framePort": frame.frame_port or 8787,
"frameAccessKey": frame.frame_access_key,
"frameAccess": frame.frame_access,
"serverHost": frame.server_host or "localhost",
"serverPort": frame.server_port or 8989,
"serverApiKey": frame.server_api_key,
Expand Down
2 changes: 1 addition & 1 deletion backend/app/models/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def process_log(frame: Frame, log: dict):
if event == 'bootup':
if frame.status != 'ready':
changes['status'] = 'ready'
for key in ['frame_port', 'width', 'height', 'color', 'interval', 'metrics_interval', 'scaling_mode', 'rotate', 'background_color']:
for key in ['width', 'height', 'color']:
if key in log and log[key] is not None and log[key] != getattr(frame, key):
changes[key] = log[key]
if 'config' in log and key in log['config'] and log['config'][key] is not None and log['config'][key] != getattr(frame, key):
Expand Down
7 changes: 7 additions & 0 deletions backend/app/utils/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import secrets
from base64 import urlsafe_b64encode

def secure_token(bytes: int) -> str:
token_bytes = secrets.token_bytes(bytes)
token = urlsafe_b64encode(token_bytes).decode('utf-8').replace('=', '')
return token
32 changes: 32 additions & 0 deletions backend/migrations/versions/7f2a8719a009_frame_api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""frame api key

Revision ID: 7f2a8719a009
Revises: 3145a02fc973
Create Date: 2024-02-04 23:31:37.112323

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '7f2a8719a009'
down_revision = '3145a02fc973'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('frame', schema=None) as batch_op:
batch_op.add_column(sa.Column('frame_access_key', sa.String(length=256), nullable=True))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('frame', schema=None) as batch_op:
batch_op.drop_column('frame_access_key')

# ### end Alembic commands ###
42 changes: 42 additions & 0 deletions backend/migrations/versions/f8db11069084_frame_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""frame access

Revision ID: f8db11069084
Revises: 7f2a8719a009
Create Date: 2024-02-06 09:35:11.908314

"""
from alembic import op
import sqlalchemy as sa
from app.utils.token import secure_token


# revision identifiers, used by Alembic.
revision = 'f8db11069084'
down_revision = '7f2a8719a009'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('frame', schema=None) as batch_op:
batch_op.add_column(sa.Column('frame_access', sa.String(length=50), nullable=True))

from app.models import Frame
from app import db
frames = Frame.query.all()
for frame in frames:
frame.frame_access_key = secure_token(20)
frame.frame_access = "private"
db.session.add(frame)
db.session.commit()
db.session.flush()
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('frame', schema=None) as batch_op:
batch_op.drop_column('frame_access')

# ### end Alembic commands ###
32 changes: 27 additions & 5 deletions frameos/assets/web/control.html
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,38 @@
position: relative;
}
</style>
<script>
const headers = {'Content-Type': 'application/json'}
if (window.location.search) {
for (const [key, value] of window.location.search.substring(1).split("&").map(p => p.split("="))) {
if (key === 'k') {
headers['Authorization'] = 'Bearer ' + value;
}
}
}
function postRender() {
fetch('/event/render', {
method: 'POST',
headers: headers,
body: JSON.stringify({})
})
}
function postSetSceneState() {
var data={render:true,state:{/*$$fieldsSubmitHtml$$*/}};
fetch('/event/setSceneState', {
method: 'POST',
headers: headers,
body: JSON.stringify(data)
});
document.getElementById('setSceneState').value = 'Now wait a while...';
}
</script>
</head>
<body>
<h1>Frame Control</h1>
<h2>Actions:</h2>
<script>function postRender() { fetch('/event/render', {method:'POST',headers:{'Content-Type': 'application/json'},body:JSON.stringify({})}) }</script>
<form onSubmit='postRender(); return false'><input type='submit' value='Render'></form>
<h2>State</h2>
<script>function postSetSceneState() { var data={render:true,state:{/*$$fieldsSubmitHtml$$*/}};fetch('/event/setSceneState', {method:'POST',headers:{'Content-Type': 'application/json'},body:JSON.stringify(data)}); document.getElementById('setSceneState').value = 'Now wait a while...'; }</script>
<form onSubmit='postSetSceneState(); return false'>
/*$$fieldsHtml$$*/
</form>
<form onSubmit='postSetSceneState(); return false'>/*$$fieldsHtml$$*/</form>
</body>
</html>
Loading
Loading