diff --git a/CHANGELOG.md b/CHANGELOG.md index d08048f..bc845c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,25 @@ All notable changes to ExaFS will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.3.0] - 2026-04-20 + +### Changed +- Migrated entire codebase to SQLAlchemy 2.0 Query API — replaced all legacy `session.query()` / `Model.query` patterns with `session.execute(select())`, `session.scalars()`, and `db.paginate()` across models, services, and views +- Moved DB queries out of views into model classmethods following separation of concerns (`ApiKey`, `Flowspec4/6`, `RTBH`, `Community`, `Action`, `Log`, `User`, `Role`, `Organization`, `ASPath`, etc.) +- New `get_org_rule_stats()` utility replaces inline admin view queries + +### Fixed +- API response discrepancy: rules returned via API now apply the same data transformation as the UI, fixing inconsistency between `/api/v3/` responses and dashboard display + +### Added +- Extensive new test coverage for previously untested DB-touching code paths (`test_auth.py`, `test_model_utils.py`, `test_services_base.py`, `test_admin_models.py`, additions to `test_flowapp.py`, `test_messages.py`, `test_models.py`, `test_whitelist_service.py`) +- Tests for all new model classmethods (added before implementation, TDD-style) +- ExaBGP 5.x support via `EXABGP_MAJOR_VERSION` config option (default: `4`) + - TCP flags formatted as `tcp-flags [ syn ack ];` (lowercase, bracketed) when version is 5 + - Fragment conditions use updated `!is-fragment` syntax for version 5 + - `IPV4_FRAGMENT_V5` constants dict added for version 5 fragment mappings +- Unit tests for ExaBGP message formatting helpers (`tests/test_messages.py`) + ## [1.2.2] - 2026-02-19 ### Changed @@ -304,6 +323,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Route Distinguisher for VRF now supported - See config example and update your `config.py` +[1.3.0]: https://github.com/CESNET/exafs/compare/v1.2.2...v1.3.0 [1.2.2]: https://github.com/CESNET/exafs/compare/v1.2.1...v1.2.2 [1.2.1]: https://github.com/CESNET/exafs/compare/v1.2.0...v1.2.1 [1.2.0]: https://github.com/CESNET/exafs/compare/v1.1.9...v1.2.0 diff --git a/CLAUDE.md b/CLAUDE.md index 5e4549a..24a988b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,837 +1,6 @@ # CLAUDE.md - AI Assistant Guide for ExaFS -## Project Overview +## Plan Mode -**ExaFS** is a Flask-based web application for managing ExaBGP (Border Gateway Protocol) rules to prevent DDoS and other malicious cyber attacks. It provides a user interface and REST API for creating, validating, and executing BGP FlowSpec and RTBH (Remotely Triggered Black Hole) rules. - -- **Current Version**: 1.1.9 -- **License**: MIT -- **Primary Language**: Python (3.9+) -- **Framework**: Flask -- **Organization**: CESNET (Czech national e-infrastructure for science, research and education) -- **PyPI Package**: `exafs` -- **Total Lines of Code**: ~11,500 lines - -## Key Features - -1. **User Authorization**: Role-based access control for BGP rule management -2. **Validation System**: Syntax and access rights validation before rule storage -3. **Rule Types**: IPv4 FlowSpec, IPv6 FlowSpec, and RTBH rules -4. **Authentication Methods**: SSO (Shibboleth), HTTP Header Auth, or Local Auth -5. **REST API**: Swagger-documented API with token-based authentication -6. **Rule Persistence**: Database storage with automatic restoration after system reboot -7. **Whitelist System**: Automated rule creation from whitelists - -## Repository Structure - -``` -exafs/ -├── flowapp/ # Main application package -│ ├── __init__.py # Flask app factory -│ ├── __about__.py # Version and metadata -│ ├── auth.py # Authentication decorators -│ ├── constants.py # Application constants -│ ├── flowspec.py # FlowSpec rule translation logic -│ ├── validators.py # Form and data validators -│ ├── messages.py # User-facing messages -│ ├── output.py # ExaBGP output formatting -│ ├── instance_config.py # Default dashboard configuration -│ │ -│ ├── models/ # SQLAlchemy database models -│ │ ├── __init__.py # Model exports -│ │ ├── base.py # Base models and relationships -│ │ ├── user.py # User and Role models -│ │ ├── organization.py # Organization model -│ │ ├── api.py # API key models -│ │ ├── community.py # BGP Community models -│ │ ├── log.py # Audit log model -│ │ ├── utils.py # Model utility functions -│ │ └── rules/ # Rule-specific models -│ │ ├── flowspec.py # IPv4/IPv6 FlowSpec models -│ │ ├── rtbh.py # RTBH model -│ │ ├── whitelist.py # Whitelist models -│ │ └── base.py # Base rule models -│ │ -│ ├── forms/ # WTForms form definitions -│ │ ├── base.py # Base form classes -│ │ ├── user.py # User management forms -│ │ ├── organization.py # Organization forms -│ │ ├── api.py # API key forms -│ │ ├── choices.py # Form choice generators -│ │ └── rules/ # Rule-specific forms -│ │ ├── ipv4.py # IPv4 FlowSpec forms -│ │ ├── ipv6.py # IPv6 FlowSpec forms -│ │ ├── rtbh.py # RTBH forms -│ │ └── whitelist.py # Whitelist forms -│ │ -│ ├── views/ # Flask Blueprints (routes) -│ │ ├── __init__.py # Blueprint registration -│ │ ├── dashboard.py # Main dashboard views -│ │ ├── rules.py # Rule CRUD operations -│ │ ├── whitelist.py # Whitelist management -│ │ ├── admin.py # Admin operations -│ │ ├── api_keys.py # API key management -│ │ ├── api_v1.py # API v1 endpoints (deprecated) -│ │ ├── api_v2.py # API v2 endpoints -│ │ ├── api_v3.py # API v3 endpoints (current) -│ │ └── api_common.py # Common API utilities -│ │ -│ ├── services/ # Business logic layer -│ │ ├── base.py # Base service classes -│ │ ├── rule_service.py # Rule creation/modification logic -│ │ ├── whitelist_service.py # Whitelist rule generation -│ │ └── whitelist_common.py # Whitelist utilities -│ │ -│ ├── utils/ # Utility functions -│ │ └── [various utilities] -│ │ -│ ├── templates/ # Jinja2 templates -│ │ ├── layouts/ # Base layouts -│ │ ├── pages/ # Page templates -│ │ ├── forms/ # Form templates -│ │ └── errors/ # Error pages -│ │ -│ ├── static/ # Static assets -│ │ ├── js/ # JavaScript files -│ │ └── swagger.yml # Swagger API specification -│ │ -│ └── tests/ # Test suite -│ ├── test_flowapp.py # Basic app tests -│ ├── test_models.py # Model tests -│ ├── test_forms.py # Form validation tests -│ ├── test_validators.py # Validator tests -│ ├── test_flowspec.py # FlowSpec translation tests -│ ├── test_api_*.py # API endpoint tests -│ ├── test_rule_service.py # Service layer tests -│ └── test_whitelist_*.py # Whitelist tests -│ -├── docs/ # Documentation -│ ├── INSTALL.md # Installation guide -│ ├── API.md # API documentation -│ ├── AUTH.md # Authentication setup -│ ├── DB_*.md # Database guides -│ └── guarda-service/ # Rule restoration service docs -│ -├── config.example.py # Configuration template -├── instance_config_override.example.py # Dashboard override template -├── run.example.py # Application run script template -├── scripts/ -│ ├── db-init.py # Database initialization (runs flask db upgrade) -│ ├── create-admin.py # Interactive first admin user setup -│ └── migrate_v0x_to_v1.py # Optional v0.x to v1.0+ migration helper -├── pyproject.toml # Project metadata and dependencies -├── setup.cfg # Setup configuration -├── CHANGELOG.md # Version history -└── README.md # Project documentation -``` - -## Technology Stack - -### Core Dependencies -- **Flask** (>=2.0.2) - Web framework -- **Flask-SQLAlchemy** (>=2.2) - ORM -- **Flask-Migrate** (>=3.0.0) - Database migrations -- **Flask-WTF** (>=1.0.0) - Form handling with CSRF protection -- **Flask-SSO** (>=0.4.0) - Shibboleth authentication -- **Flask-Session** - Server-side sessions -- **PyJWT** (>=2.4.0) - JWT token authentication for API -- **PyMySQL** (>=1.0.0) - MySQL database driver -- **Flasgger** - Swagger API documentation -- **Pika** (>=1.3.0) - RabbitMQ client -- **Loguru** - Logging -- **Babel** (>=2.7.0) - Internationalization - -### Development Dependencies -- **pytest** (>=7.0.0) - Testing framework -- **flake8** - Code linting - -### Supported Python Versions -- Python 3.9, 3.10, 3.11, 3.12, 3.13 - -### Database -- **Primary**: MariaDB/MySQL -- **Supported**: Any SQLAlchemy-compatible database - -## Application Architecture - -### MVC Pattern -The application follows a structured MVC pattern: - -1. **Models** (`flowapp/models/`) - SQLAlchemy ORM models -2. **Views** (`flowapp/views/`) - Flask Blueprints handling routes -3. **Forms** (`flowapp/forms/`) - WTForms for validation -4. **Services** (`flowapp/services/`) - Business logic layer - -### Key Design Patterns - -1. **Factory Pattern**: App creation via `create_app()` function -2. **Blueprint Pattern**: Modular route organization -3. **Service Layer**: Business logic separated from views -4. **Validator Pattern**: Custom validators for BGP rule syntax -5. **Repository Pattern**: Model utility functions for data access - -### Authentication Flow - -1. **SSO Auth** (Production): Shibboleth authentication via Flask-SSO -2. **Header Auth**: External authentication via HTTP headers -3. **Local Auth** (Development): Direct UUID-based authentication -4. **API Auth**: JWT token-based authentication - -## Database Models - -### Core Models - -#### User Management -- **User**: User accounts with UUID -- **Role**: User roles (admin, user, api_only, etc.) -- **Organization**: Network organizations -- **user_role**: Many-to-many relationship table -- **user_organization**: Many-to-many relationship table - -#### Rule Models -- **Flowspec4**: IPv4 FlowSpec rules -- **Flowspec6**: IPv6 FlowSpec rules -- **RTBH**: Remotely Triggered Black Hole rules -- **Rstate**: Rule states (active, expired, etc.) -- **Action**: BGP actions (rate-limit, discard, etc.) - -#### Supporting Models -- **Community**: BGP communities -- **ASPath**: AS path configurations -- **Whitelist**: Automated rule templates -- **RuleWhitelistCache**: Whitelist-generated rule tracking -- **Log**: Audit logging -- **ApiKey**: API authentication tokens -- **MachineApiKey**: Machine-to-machine API keys - -### Important Model Relationships - -```python -# Users belong to multiple organizations -User.organization -> Organization (many-to-many via user_organization) - -# Users have multiple roles -User.roles -> Role (many-to-many via user_role) - -# Rules belong to creators and organizations -Flowspec4.creator_id -> User.id -Flowspec4.organization_id -> Organization.id - -# Rules can originate from whitelists -Flowspec4.whitelist_id -> Whitelist.id (nullable) -``` - -## Configuration - -### Configuration Hierarchy - -1. **Base Config** (`config.py:Config`) - Default settings -2. **Environment Config** (Production/Development/Testing classes) -3. **Instance Config** (`flowapp/instance_config.py`) - Dashboard configuration -4. **Instance Override** (`instance/config_override.py`) - Local overrides - -### Critical Configuration Options - -```python -# Authentication method (choose one) -SSO_AUTH = False # Shibboleth SSO -HEADER_AUTH = False # HTTP header auth -LOCAL_AUTH = False # Local development auth - -# Database connection -SQLALCHEMY_DATABASE_URI = "mysql+pymysql://user:pass@host/db" - -# ExaBGP API method -EXA_API = "RABBIT" # or "HTTP" -EXA_API_RABBIT_HOST = "hostname" -EXA_API_RABBIT_PORT = 5672 - -# Security -JWT_SECRET = "random-secret" -SECRET_KEY = "random-secret" - -# BGP Configuration -USE_RD = True # Route Distinguisher -RD_STRING = "7654:3210" -RD_LABEL = "label" - -# Rule limits -FLOWSPEC4_MAX_RULES = 9000 -FLOWSPEC6_MAX_RULES = 9000 -RTBH_MAX_RULES = 100000 - -# Rule expiration -EXPIRATION_THRESHOLD = 30 # days -``` - -## Development Workflow - -### Initial Setup - -```bash -# Clone repository -git clone https://github.com/CESNET/exafs.git -cd exafs - -# Create virtual environment -python3.9 -m venv venv -source venv/bin/activate - -# Install dependencies -pip install -e .[dev] - -# Copy configuration templates -cp config.example.py config.py -cp run.example.py run.py - -# Edit config.py with database credentials and settings - -# Initialize database (runs flask db upgrade) -python scripts/db-init.py - -# Create the first admin user and organization -python scripts/create-admin.py - -# Run tests -pytest - -# Run development server -python run.py -``` - -### Database Migrations - -Migration files are tracked in `migrations/versions/` and committed to git. - -```bash -# Create a new migration after model changes -flask db migrate -m "Description of changes" - -# Apply migrations -flask db upgrade - -# Rollback migration -flask db downgrade - -# For existing databases adopting migrations for the first time -flask db stamp 001_baseline -``` - -### Running Tests - -```bash -# Run all tests -pytest - -# Run specific test file -pytest flowapp/tests/test_models.py - -# Run with coverage -pytest --cov=flowapp - -# Run specific test -pytest flowapp/tests/test_models.py::test_user_creation -v -``` - -## Code Conventions - -### Python Style -- Follow PEP 8 guidelines -- Maximum line length: 127 characters -- Use flake8 for linting -- Docstrings for complex functions - -### Import Organization -```python -# Standard library imports -import os -from datetime import datetime - -# Third-party imports -from flask import Flask, render_template -from sqlalchemy import Column, Integer - -# Local imports -from flowapp.models import User -from .validators import validate_ipv4 -``` - -### Naming Conventions -- **Classes**: PascalCase (e.g., `Flowspec4`, `RuleService`) -- **Functions/Methods**: snake_case (e.g., `create_rule`, `validate_form`) -- **Constants**: UPPER_SNAKE_CASE (e.g., `MAX_PORT`, `IPV4_PROTOCOL`) -- **Private Methods**: Leading underscore (e.g., `_validate_internal`) - -### Model Conventions -- All models inherit from `db.Model` -- Use `__tablename__` explicitly -- Include `__repr__` for debugging -- Use type hints where helpful - -### Form Conventions -- Forms inherit from `FlaskForm` or custom base classes -- Validators defined in field constructors -- Custom validators in `flowapp/validators.py` -- Form choices generated dynamically in `forms/choices.py` - -### View (Blueprint) Conventions -- One blueprint per functional area -- Use `@auth_required` decorator for protected routes -- Return tuples for status codes: `return render_template(...), 404` -- JSON responses use `jsonify()` - -### Service Layer Conventions -- Business logic goes in services, not views -- Services work with models and return results -- Raise exceptions for error cases -- Use database transactions appropriately - -## Testing Patterns - -### Test Structure -```python -def test_feature_name(client): - """Test description""" - # Arrange - set up test data - - # Act - perform the action - response = client.get('/endpoint') - - # Assert - verify results - assert response.status_code == 200 -``` - -### Common Test Fixtures -- `client`: Unauthenticated Flask test client -- `auth_client`: Authenticated Flask test client -- `app`: Flask application instance -- Database is reset between tests - -### Test File Naming -- Prefix with `test_`: `test_models.py`, `test_api_v3.py` -- Group related tests in classes -- Use descriptive test names - -## Important Modules and Their Purposes - -### `flowapp/flowspec.py` -Translates human-readable rule strings to ExaBGP command format. - -**Key Functions:** -- `translate_sequence()`: Convert port/packet sequences to ExaBGP format -- `to_exabgp_string()`: Translate form strings to FlowSpec values -- `check_limit()`: Validate values are within acceptable ranges - -### `flowapp/validators.py` -Custom validators for forms and data. - -**Key Validators:** -- IP address validation (IPv4/IPv6) -- Network prefix validation -- Port and packet length validation -- BGP-specific syntax validation - -### `flowapp/output.py` -Generates ExaBGP commands from rule models. - -**Key Functions:** -- Format announce/withdraw messages -- Build complete BGP commands -- Handle different rule types - -### `flowapp/auth.py` -Authentication and authorization. - -**Key Decorators:** -- `@auth_required`: Require authentication -- `@role_required(role)`: Require specific role -- Functions to check rule modification permissions - -### `flowapp/services/rule_service.py` -Core business logic for rule management. - -**Key Functions:** -- `create_rule()`: Create new rules with validation -- `update_rule()`: Modify existing rules -- `delete_rule()`: Remove rules -- `reactivate_rule()`: Restore expired rules -- Rule state management - -### `flowapp/models/utils.py` -Helper functions for working with models. - -**Key Functions:** -- `get_user_nets()`: Get networks user can access -- `check_rule_limit()`: Verify rule count limits -- `get_user_rules_ids()`: Get rule IDs for user -- `insert_users()`: Bulk user creation - -## API Structure - -### API Versions -- **API v1** (`/api/v1/*`) - Deprecated, minimal maintenance -- **API v2** (`/api/v2/*`) - Legacy, still supported -- **API v3** (`/api/v3/*`) - Current, recommended version - -### API Authentication -All API endpoints require JWT token authentication via `Authorization` header: -``` -Authorization: Bearer -``` - -### API Documentation -- **Local**: `/apidocs/` (Swagger UI when app is running) -- **Spec**: `flowapp/static/swagger.yml` -- **External**: [Apiary](https://exafs.docs.apiary.io/) - -### Key API Endpoints (v3) -- `GET /api/v3/rules/{type}` - List rules -- `POST /api/v3/rules/{type}` - Create rule -- `PUT /api/v3/rules/{type}/{id}` - Update rule -- `DELETE /api/v3/rules/{type}/{id}` - Delete rule -- `GET /api/v3/whitelist` - List whitelists -- `POST /rules/announce_all` - Re-announce all rules (localhost only) - -## Common Development Tasks - -### Adding a New Rule Field - -1. **Update Model** (`flowapp/models/rules/*.py`) - ```python - new_field = db.Column(db.String(100), nullable=True) - ``` - -2. **Create Migration** - ```bash - flask db migrate -m "Add new_field to Flowspec4" - flask db upgrade - ``` - -3. **Update Form** (`flowapp/forms/rules/*.py`) - ```python - new_field = StringField('New Field', validators=[Optional()]) - ``` - -4. **Update Service Logic** (`flowapp/services/rule_service.py`) - - Add field handling in create/update methods - -5. **Update Output** (`flowapp/output.py`) - - Include field in ExaBGP command generation if needed - -6. **Add Tests** (`flowapp/tests/test_*.py`) - - Test new field validation and functionality - -### Adding a New API Endpoint - -1. **Add Route** (`flowapp/views/api_v3.py`) - ```python - @api_v3.route('/endpoint', methods=['GET']) - @jwt_required - def new_endpoint(): - # Implementation - return jsonify(data), 200 - ``` - -2. **Update Swagger** (`flowapp/static/swagger.yml`) - - Add endpoint documentation - -3. **Add Tests** (`flowapp/tests/test_api_v3.py`) - ```python - def test_new_endpoint(auth_client): - response = auth_client.get('/api/v3/endpoint') - assert response.status_code == 200 - ``` - -### Adding a New Validator - -1. **Create Validator** (`flowapp/validators.py`) - ```python - def validate_something(form, field): - if not is_valid(field.data): - raise ValidationError('Invalid value') - ``` - -2. **Use in Form** (`flowapp/forms/*.py`) - ```python - field_name = StringField('Label', validators=[validate_something]) - ``` - -3. **Add Tests** (`flowapp/tests/test_validators.py`) - -## Common Pitfalls and Solutions - -### Database Session Issues -**Problem**: DetachedInstanceError or stale data -**Solution**: -- Always use `db.session.commit()` after modifications -- Refresh objects after commit: `db.session.refresh(obj)` -- Use `db.session.merge()` for detached objects - -### CSRF Token Issues -**Problem**: Form submissions failing with CSRF errors -**Solution**: -- Ensure `{{ form.csrf_token }}` in templates -- API endpoints should be exempt: `csrf.exempt` decorator -- Check session configuration - -### Authentication Not Working -**Problem**: Users not authenticated -**Solution**: -- Verify auth method in config (SSO_AUTH/HEADER_AUTH/LOCAL_AUTH) -- Check session configuration -- Verify user exists in database with correct UUID - -### Rule Limits Exceeded -**Problem**: Cannot create more rules -**Solution**: -- Check `FLOWSPEC4_MAX_RULES`, `FLOWSPEC6_MAX_RULES`, `RTBH_MAX_RULES` in config -- Use `check_rule_limit()` and `check_global_rule_limit()` before creation -- Clean up expired rules - -### ExaBGP Communication Issues -**Problem**: Rules not being sent to ExaBGP -**Solution**: -- Verify `EXA_API` setting (RABBIT or HTTP) -- Check RabbitMQ connection settings -- Verify ExaBGP process is running -- Check logs for connection errors - -## File Modification Guidelines - -### When Modifying Models -1. ✅ **Always create a migration** after model changes -2. ✅ Update `models/__init__.py` exports if adding new models -3. ✅ Add/update `__repr__` methods for debugging -4. ✅ Update corresponding forms if fields change -5. ✅ Add tests for new fields/relationships - -### When Modifying Forms -1. ✅ Add appropriate validators -2. ✅ Update corresponding templates -3. ✅ Update service layer to handle new fields -4. ✅ Add form validation tests -5. ⚠️ Don't put business logic in forms - use services - -### When Modifying Views -1. ✅ Keep views thin - delegate to services -2. ✅ Use appropriate decorators (`@auth_required`, etc.) -3. ✅ Return consistent response formats -4. ✅ Add route tests -5. ⚠️ Don't query models directly - use service layer - -### When Modifying Services -1. ✅ Use database transactions appropriately -2. ✅ Raise descriptive exceptions -3. ✅ Log important operations -4. ✅ Add comprehensive tests -5. ⚠️ Don't import from views - services should be independent - -### When Modifying Tests -1. ✅ Follow AAA pattern (Arrange, Act, Assert) -2. ✅ Use descriptive test names -3. ✅ Clean up test data -4. ✅ Test both success and failure cases -5. ✅ Run full test suite before committing - -## Security Considerations - -### Authentication -- Never bypass authentication checks -- Use `@auth_required` decorator on all protected routes -- Validate JWT tokens properly in API endpoints -- Store secrets in environment variables or secure config - -### Authorization -- Always check user permissions before rule modifications -- Verify user has access to organization/network -- Use `check_user_can_modify_rule()` helper -- Don't trust client-side validation alone - -### Input Validation -- Validate all user inputs -- Use WTForms validators -- Sanitize data before database operations -- Validate BGP syntax before execution - -### CSRF Protection -- Keep CSRF protection enabled -- Include CSRF tokens in all forms -- Exempt API endpoints appropriately - -### SQL Injection Prevention -- Use SQLAlchemy ORM (parameterized queries) -- Never use raw SQL with user input -- Don't use string formatting for queries - -## Logging and Debugging - -### Logging -- Application uses Loguru for logging -- Logs location: `/var/log/exafs/` (production) -- Configure in `flowapp/utils/configure_logging()` - -### Debug Mode -- Set `DEBUG = True` in config for development -- Shows detailed error pages -- Enables Flask debugger -- **Never enable in production** - -### Common Debug Techniques -```python -# Print to logs -from loguru import logger -logger.info(f"Processing rule: {rule_id}") -logger.error(f"Failed to create rule: {error}") - -# Debug in templates -{{ variable|pprint }} - -# Database query debugging -from flask import current_app -current_app.config['SQLALCHEMY_ECHO'] = True -``` - -## Deployment Considerations - -### Production Checklist -- [ ] Set `DEBUG = False` -- [ ] Use strong `SECRET_KEY` and `JWT_SECRET` -- [ ] Configure production database -- [ ] Set up proper authentication (SSO or Header Auth) -- [ ] Configure HTTPS -- [ ] Set up reverse proxy -- [ ] Configure proper logging -- [ ] Set up database backups -- [ ] Configure ExaBGP connection -- [ ] Set appropriate rule limits -- [ ] Enable session security settings - -### Recommended Stack -- **Web Server**: Apache with mod_proxy_uwsgi -- **WSGI Server**: uWSGI -- **Process Manager**: Supervisord -- **Database**: MariaDB -- **Auth**: Shibboleth SSO -- **Message Queue**: RabbitMQ (for ExaBGP communication) - -### Docker Deployment -- Base image: `jirivrany/exafs-base` (Docker Hub) -- See `docs/DockerImage.md` for details -- Use Docker Compose for multi-container setup -- Reference: [ExaFS Ansible Deploy](https://github.com/CESNET/ExaFS-deploy) - -## CI/CD Pipeline - -### GitHub Actions -- **Workflow**: `.github/workflows/python-app.yml` -- **Triggers**: Push/PR to `master` or `develop` branches -- **Matrix**: Python 3.9, 3.10, 3.11, 3.12, 3.13 -- **Steps**: - 1. Setup Python environment - 2. Set timezone (Europe/Prague) - 3. Install dependencies - 4. Run flake8 linting - 5. Run pytest test suite - -### Running CI Locally -```bash -# Lint -flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics -flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - -# Test -pytest -``` - -## Git Workflow - -### Branch Strategy -- `master` - Production releases -- `develop` - Development branch -- `claude/*` - AI assistant feature branches -- Feature branches for specific work - -### Commit Messages -- Use descriptive commit messages -- Reference issue numbers: `fixes #74` -- Follow conventional commits when possible - -### Version Releases -- Update version in `flowapp/__about__.py` -- Update `CHANGELOG.md` -- Create release tag -- Publish to PyPI - -## Additional Resources - -### Documentation Files -- `docs/INSTALL.md` - Detailed installation guide -- `docs/AUTH.md` - Authentication setup -- `docs/API.md` - API documentation reference -- `docs/DB_MIGRATIONS.md` - Database migration guide -- `docs/DB_BACKUP.md` - Database backup procedures -- `docs/DB_LOCAL.md` - Local database setup -- `docs/DockerImage.md` - Docker deployment -- `docs/guarda-service/` - Rule restoration service - -### External Links -- [GitHub Repository](https://github.com/CESNET/exafs) -- [PyPI Package](https://pypi.org/project/exafs/) -- [Docker Hub](https://hub.docker.com/r/jirivrany/exafs-base) -- [ExaBGP](https://github.com/Exa-Networks/exabgp) - BGP engine -- [ExaBGP Process Package](https://pypi.org/project/exabgp-process/) - ExaBGP connector - -### Related Projects -- **ExaFS Ansible Deploy**: Automated deployment with Ansible -- **ExaBGP Process**: Separate package for ExaBGP communication -- **Guarda Service**: Rule restoration monitor - -## Quick Reference Commands - -```bash -# Development -python run.py # Run development server -pytest # Run tests -pytest -v # Verbose test output -pytest --cov=flowapp # Run with coverage -flask db migrate -m "message" # Create migration -flask db upgrade # Apply migrations -flake8 . # Lint code - -# Database (source install) -python scripts/db-init.py # Initialize database (runs migrations) -python scripts/db-init.py --reset # Drop all tables and recreate (dev only) -python scripts/create-admin.py # Create first admin user interactively - -# Database (PyPI install — run from directory containing config.py) -exafs-db-init # Initialize database (runs migrations) -exafs-db-init --reset # Drop all tables and recreate (dev only) -exafs-create-admin # Create first admin user interactively - -flask db stamp 001_baseline # Mark existing DB as baseline -flask db current # Show current migration -flask db history # Show migration history - -# Production (via supervisord) -supervisorctl start exafs # Start application -supervisorctl stop exafs # Stop application -supervisorctl restart exafs # Restart application -supervisorctl status # Check status -``` - -## Summary for AI Assistants - -When working with this codebase: - -1. **Always run tests** after making changes: `pytest` -2. **Create migrations** for model changes: `flask db migrate` — commit migration files to git -3. **Follow the service layer pattern** - business logic goes in services, not views -4. **Use existing validators** in `flowapp/validators.py` for validation -5. **Check authentication** - most routes need `@auth_required` decorator -6. **Respect rule limits** - use `check_rule_limit()` before creating rules -7. **Update Swagger docs** when adding API endpoints -8. **Follow existing patterns** - look at similar code for examples -9. **Test both web UI and API** when modifying rule functionality -10. **Consider ExaBGP output** - changes to rules may affect BGP command generation - -The codebase is well-structured with clear separation of concerns. Follow the existing patterns and conventions for consistency. +- Make the plan extremely concise. Sacrifice grammar for the sake of concision. +- At the end of each plan, give me a list of unresolved questions to answer, if any. \ No newline at end of file diff --git a/README.md b/README.md index 9b09b57..8805716 100644 --- a/README.md +++ b/README.md @@ -65,4 +65,14 @@ It may also be necessary to monitor ExaBGP and re-announce rules after a restart ### API The REST API is documented using Swagger (OpenAPI). After installing and running the application, the API documentation is available locally at the /apidocs/ endpoint. This interactive documentation provides details about all available endpoints, request and response formats, and supported operations, making it easier to integrate and test the API. +## ExaBGP version + +ExaFS supports both ExaBGP 4.x and 5.x. Set `EXABGP_MAJOR_VERSION` in your `config.py` to select the version (default is `4`): + +```python +EXABGP_MAJOR_VERSION = 5 +``` + +When set to `5`, ExaFS adjusts the generated BGP flow route messages to match ExaBGP 5.x syntax (e.g. TCP flags in lowercase bracket notation). + ## [Change log](./CHANGELOG.md) \ No newline at end of file diff --git a/config.example.py b/config.example.py index 1682dd4..467c346 100644 --- a/config.example.py +++ b/config.example.py @@ -3,6 +3,9 @@ class Config: Default config options """ + # config ExaBGP major version + EXABGP_MAJOR_VERSION = 4 + # Locale for Babel BABEL_DEFAULT_LOCALE = "en_US_POSIX" # Limits diff --git a/docs/dev-guide.md b/docs/dev-guide.md new file mode 100644 index 0000000..2dace51 --- /dev/null +++ b/docs/dev-guide.md @@ -0,0 +1,143 @@ +# ExaFS Developer Guide + +Detailed conventions and step-by-step task guides. Linked from `CLAUDE.md`. + +## Common Development Tasks + +### Adding a New Rule Field + +1. **Update Model** (`flowapp/models/rules/*.py`) + ```python + new_field = db.Column(db.String(100), nullable=True) + ``` +2. **Create and apply migration** + ```bash + flask db migrate -m "Add new_field to Flowspec4" + flask db upgrade + ``` +3. **Update Form** (`flowapp/forms/rules/*.py`) + ```python + new_field = StringField('New Field', validators=[Optional()]) + ``` +4. **Update Service** (`flowapp/services/rule_service.py`) — add field handling in create/update +5. **Update Output** (`flowapp/output.py`) — include field in ExaBGP command if needed +6. **Add Tests** — test validation and full round-trip + +### Adding a New API Endpoint + +1. **Add Route** (`flowapp/views/api_v3.py`) + ```python + @api_v3.route('/endpoint', methods=['GET']) + @jwt_required + def new_endpoint(): + return jsonify(data), 200 + ``` +2. **Update Swagger** (`flowapp/static/swagger.yml`) +3. **Add Tests** (`tests/test_api_v3.py`) + ```python + def test_new_endpoint(auth_client): + response = auth_client.get('/api/v3/endpoint') + assert response.status_code == 200 + ``` + +### Adding a New Validator + +1. **Create in** `flowapp/validators.py` + ```python + def validate_something(form, field): + if not is_valid(field.data): + raise ValidationError('Invalid value') + ``` +2. **Use in form field** validators list +3. **Add tests** in `tests/test_validators.py` + +## Code Conventions + +### Python Style +- PEP 8, max line 127 chars, flake8 for linting +- Docstrings for complex functions + +### Naming +- **Classes**: PascalCase (`Flowspec4`, `RuleService`) +- **Functions/Methods**: snake_case (`create_rule`, `validate_form`) +- **Constants**: UPPER_SNAKE_CASE (`MAX_PORT`, `IPV4_PROTOCOL`) +- **Private**: leading underscore (`_validate_internal`) + +### Imports +```python +# Standard library +import os +from datetime import datetime + +# Third-party +from flask import Flask, render_template + +# Local +from flowapp.models import User +from .validators import validate_ipv4 +``` + +### Models +- Inherit from `db.Model`, set `__tablename__` explicitly +- Include `__repr__` for debugging +- Update `models/__init__.py` exports when adding new models + +### Forms +- Inherit from `FlaskForm` or project base classes +- Business logic stays in services, not forms +- Form choices generated dynamically in `forms/choices.py` + +### Views (Blueprints) +- One blueprint per functional area, keep views thin +- Use `@auth_required` on protected routes +- Status codes via return tuple: `return render_template(...), 404` +- JSON: `jsonify()` +- Don't query models directly — go through services + +### Services +- Own all business logic +- Use DB transactions appropriately +- Raise descriptive exceptions, log important operations +- Don't import from views — services are independent + +## Testing Conventions + +### Structure (AAA pattern) +```python +def test_feature_name(client): + # Arrange + # Act + response = client.get('/endpoint') + # Assert + assert response.status_code == 200 +``` + +### Fixtures +- `client` — unauthenticated Flask test client +- `auth_client` — authenticated test client +- `app` — Flask application instance +- DB is reset between tests + +### File Naming +- Prefix `test_`: `test_models.py`, `test_api_v3.py` +- Test both success and failure cases + +## Security Checklist + +- Never bypass `@auth_required` +- Call `check_user_can_modify_rule()` before any rule modification +- Call `check_rule_limit()` / `check_global_rule_limit()` before creating rules +- Use WTForms validators + server-side BGP syntax validation +- Use SQLAlchemy ORM — never raw SQL with user input +- CSRF tokens in all templates; exempt API endpoints explicitly +- Secrets in config/env, never hardcoded + +## Deployment Notes + +- **Web server**: Apache + mod_proxy_uwsgi +- **WSGI**: uWSGI, **Process manager**: Supervisord +- **Auth (prod)**: Shibboleth SSO +- **ExaBGP comms**: RabbitMQ (preferred) or HTTP +- **Docker base image**: `jirivrany/exafs-base`; see `docs/DockerImage.md` +- **Ansible deploy**: [ExaFS-deploy](https://github.com/CESNET/ExaFS-deploy) +- Production: `DEBUG = False`, strong `SECRET_KEY` and `JWT_SECRET`, HTTPS diff --git a/flowapp/__about__.py b/flowapp/__about__.py index d5b95ac..d814a8a 100755 --- a/flowapp/__about__.py +++ b/flowapp/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.2" +__version__ = "1.3.0" __title__ = "ExaFS" __description__ = "Tool for creation, validation, and execution of ExaBGP messages." __author__ = "CESNET / Jiri Vrany, Petr Adamec, Josef Verich, Jakub Man" diff --git a/flowapp/__init__.py b/flowapp/__init__.py index fb78661..4b1aef6 100644 --- a/flowapp/__init__.py +++ b/flowapp/__init__.py @@ -2,6 +2,7 @@ import os from flask import Flask, redirect, render_template, session, url_for, flash +from sqlalchemy import select from flask_sso import SSO from flask_sqlalchemy import SQLAlchemy @@ -128,7 +129,7 @@ def index(): @auth_required def select_org(org_id=None): uuid = session.get("user_uuid") - user = db.session.query(models.User).filter_by(uuid=uuid).first() + user = db.session.execute(select(models.User).filter_by(uuid=uuid)).scalar_one_or_none() if user is None: return render_template("errors/404.html"), 404 diff --git a/flowapp/auth.py b/flowapp/auth.py index f484386..4278594 100644 --- a/flowapp/auth.py +++ b/flowapp/auth.py @@ -1,6 +1,7 @@ from functools import wraps from typing import List, Optional from flask import current_app, redirect, request, url_for, session, abort +from sqlalchemy import select from flowapp import __version__, db, validators from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist, get_user_nets @@ -153,35 +154,35 @@ def get_user_allowed_rule_ids(rule_type: str, user_id: int, user_role_ids: List[ # Admin users can modify any rules if 3 in user_role_ids: if rule_type == "ipv4": - return [r.id for r in db.session.query(Flowspec4.id).all()] + return list(db.session.scalars(select(Flowspec4.id))) elif rule_type == "ipv6": - return [r.id for r in db.session.query(Flowspec6.id).all()] + return list(db.session.scalars(select(Flowspec6.id))) elif rule_type == "rtbh": - return [r.id for r in db.session.query(RTBH.id).all()] + return list(db.session.scalars(select(RTBH.id))) elif rule_type == "whitelist": - return [r.id for r in db.session.query(Whitelist.id).all()] + return list(db.session.scalars(select(Whitelist.id))) return [] # Regular users - filter by network ranges net_ranges = get_user_nets(user_id) if rule_type == "ipv4": - rules = db.session.query(Flowspec4).all() + rules = db.session.scalars(select(Flowspec4)).all() filtered_rules = validators.filter_rules_in_network(net_ranges, rules) return [r.id for r in filtered_rules] elif rule_type == "ipv6": - rules = db.session.query(Flowspec6).all() + rules = db.session.scalars(select(Flowspec6)).all() filtered_rules = validators.filter_rules_in_network(net_ranges, rules) return [r.id for r in filtered_rules] elif rule_type == "rtbh": - rules = db.session.query(RTBH).all() + rules = db.session.scalars(select(RTBH)).all() filtered_rules = validators.filter_rtbh_rules(net_ranges, rules) return [r.id for r in filtered_rules] elif rule_type == "whitelist": - rules = db.session.query(Whitelist).all() + rules = db.session.scalars(select(Whitelist)).all() filtered_rules = validators.filter_rules_in_network(net_ranges, rules) return [r.id for r in filtered_rules] diff --git a/flowapp/constants.py b/flowapp/constants.py index b37c8a0..30ec9e6 100644 --- a/flowapp/constants.py +++ b/flowapp/constants.py @@ -46,6 +46,14 @@ "last": "last-fragment", } +IPV4_FRAGMENT_V5 = { + "dont": "dont-fragment", + "first": "first-fragment", + "is": "is-fragment", + "last": "last-fragment", + "not": "!is-fragment", +} + COMP_FUNCS = {"active": ge, "expired": lt, "all": None} TCP_FLAGS = [ diff --git a/flowapp/messages.py b/flowapp/messages.py index 45ab4f3..1c7679d 100644 --- a/flowapp/messages.py +++ b/flowapp/messages.py @@ -6,29 +6,56 @@ MAX_PACKET, IPV4_PROTOCOL, IPV6_NEXT_HEADER, + IPV4_FRAGMENT_V5, ) from flowapp.flowspec import translate_sequence as trps from flask import current_app +from sqlalchemy import select from flowapp.models import ASPath from flowapp import db +def format_tcp_flags(flagstring, version=4): + """ + Format tcp-flags string for ExaBGP message. + v4: tcp-flags SYN ACK; + v5: tcp-flags [ syn ack ]; + """ + if version == 5: + flags_lower = flagstring.lower() + return "tcp-flags [ {} ];".format(flags_lower) + return "tcp-flags {};".format(flagstring) + + +def format_fragment(fragment_string, version=4): + """ + Format fragment string for ExaBGP message. + v5 uses IPV4_FRAGMENT_V5 which includes !is-fragment for 'not'. + """ + if version == 5: + parts = fragment_string.split() + translated = " ".join(IPV4_FRAGMENT_V5.get(p, p) for p in parts) + return "fragment [ {} ];".format(translated) + return "fragment [ {} ];".format(fragment_string) + + def create_ipv4(rule, message_type=ANNOUNCE): """ create ExaBpg text message for IPv4 rule @param rule models.Flowspec4 @return string message """ + exabgp_version = current_app.config.get("EXABGP_MAJOR_VERSION", 4) + protocol = "" if rule.protocol and rule.protocol != "all": protocol = "protocol ={};".format(IPV4_PROTOCOL[rule.protocol]) flagstring = rule.flags.replace(";", " ") if rule.flags else "" - - flags = "tcp-flags {};".format(flagstring) if rule.flags and rule.protocol == "tcp" else "" + flags = format_tcp_flags(flagstring, exabgp_version) if rule.flags and rule.protocol == "tcp" else "" fragment_string = rule.fragment.replace(";", " ") if rule.fragment else "" - fragment = "fragment [ {} ];".format(fragment_string) if rule.fragment else "" + fragment = format_fragment(fragment_string, exabgp_version) if rule.fragment else "" spec = { "protocol": protocol, @@ -47,11 +74,13 @@ def create_ipv6(rule, message_type=ANNOUNCE): @return string message :param message_type: """ + exabgp_version = current_app.config.get("EXABGP_MAJOR_VERSION", 4) + protocol = "" if rule.next_header and rule.next_header != "all": protocol = "next-header ={};".format(IPV6_NEXT_HEADER[rule.next_header]) - flagstring = rule.flags.replace(";", " ") - flags = "tcp-flags {};".format(flagstring) if rule.flags and rule.next_header == "tcp" else "" + flagstring = rule.flags.replace(";", " ") if rule.flags else "" + flags = format_tcp_flags(flagstring, exabgp_version) if rule.flags and rule.next_header == "tcp" else "" spec = {"protocol": protocol, "mask": IPV6_DEFMASK, "flags": flags} @@ -110,7 +139,7 @@ def create_rtbh(rule, message_type=ANNOUNCE): as_path_string = "" if rule.community.as_path: - match = db.session.query(ASPath).filter(ASPath.prefix == source).first() + match = db.session.execute(select(ASPath).filter_by(prefix=source)).scalar_one_or_none() as_path_string = f"as-path [ {match.as_path} ]" if match else "" return "{neighbor}{action} route {source} next-hop {nexthop} {as_path} {community} {large_community} {extended_community}{rd_string}".format( diff --git a/flowapp/models/api.py b/flowapp/models/api.py index ed50685..5ae9a79 100644 --- a/flowapp/models/api.py +++ b/flowapp/models/api.py @@ -1,4 +1,6 @@ from datetime import datetime +from typing import Optional +from sqlalchemy import select from .base import db @@ -20,6 +22,14 @@ def is_expired(self): else: return self.expires < datetime.now() + @classmethod + def get_by_key(cls, key: str) -> Optional["ApiKey"]: + return db.session.scalars(select(cls).filter_by(key=key)).first() + + @classmethod + def get_by_user_id(cls, user_id: int) -> list: + return db.session.scalars(select(cls).filter_by(user_id=user_id)).all() + class MachineApiKey(db.Model): id = db.Column(db.Integer, primary_key=True) @@ -38,3 +48,11 @@ def is_expired(self): return False # Non-expiring key else: return self.expires < datetime.now() + + @classmethod + def get_by_key(cls, key: str) -> Optional["MachineApiKey"]: + return db.session.scalars(select(cls).filter_by(key=key)).first() + + @classmethod + def get_all(cls) -> list: + return db.session.scalars(select(cls)).all() diff --git a/flowapp/models/community.py b/flowapp/models/community.py index 5df4310..a0682fd 100644 --- a/flowapp/models/community.py +++ b/flowapp/models/community.py @@ -1,4 +1,5 @@ -from sqlalchemy import event +from typing import List, Optional +from sqlalchemy import event, select from .base import db @@ -24,9 +25,13 @@ def __init__(self, name, comm, larcomm, extcomm, description, as_path, role_id): self.as_path = as_path self.role_id = role_id + @classmethod + def get_all(cls) -> List["Community"]: + return db.session.scalars(select(cls)).all() + @classmethod def get_whitelistable_communities(cls, id_list): - return cls.query.filter(cls.id.in_(id_list)).all() + return db.session.scalars(select(cls).filter(cls.id.in_(id_list))).all() def __repr__(self): return f"" @@ -42,7 +47,13 @@ class ASPath(db.Model): prefix = db.Column(db.String(120), unique=True) as_path = db.Column(db.String(250)) - # Methods and initializer + @classmethod + def get_all(cls) -> List["ASPath"]: + return db.session.scalars(select(cls)).all() + + @classmethod + def get_by_prefix(cls, prefix: str) -> Optional["ASPath"]: + return db.session.scalars(select(cls).filter_by(prefix=prefix)).first() # Note: seed data is also defined in migrations/versions/001_baseline.py - keep in sync diff --git a/flowapp/models/log.py b/flowapp/models/log.py index 16cf5b1..2c0ddea 100644 --- a/flowapp/models/log.py +++ b/flowapp/models/log.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta +from sqlalchemy import select from flowapp.constants import RuleTypes from .base import db @@ -23,6 +24,17 @@ def __init__(self, time, task, user_id, rule_type, rule_id, author): self.user_id = user_id self.author = author + @classmethod + def get_recent_paginated(cls, page: int, per_page: int = 20, weeks: int = 1): + since = datetime.now() - timedelta(weeks=weeks) + return db.paginate( + select(cls).filter(cls.time > since).order_by(cls.time.desc()), + page=page, + per_page=per_page, + max_per_page=None, + error_out=False, + ) + @classmethod def delete_old(cls, days: int = 30): """Delete logs older than :param days from the database""" diff --git a/flowapp/models/organization.py b/flowapp/models/organization.py index 67db8c5..3c41e6b 100644 --- a/flowapp/models/organization.py +++ b/flowapp/models/organization.py @@ -1,4 +1,5 @@ -from sqlalchemy import event +from typing import List, Optional +from sqlalchemy import event, select from .base import db @@ -27,6 +28,14 @@ def get_users(self): # self.user is the backref from the user_organization relationship return self.user + @classmethod + def get_all_ordered(cls) -> List["Organization"]: + return db.session.scalars(select(cls).order_by(cls.name)).all() + + @classmethod + def get_by_name(cls, name: str) -> Optional["Organization"]: + return db.session.scalars(select(cls).filter_by(name=name)).first() + # Event listeners for Organization # Note: seed data is also defined in migrations/versions/001_baseline.py - keep in sync diff --git a/flowapp/models/rules/base.py b/flowapp/models/rules/base.py index cbe889d..7822aad 100644 --- a/flowapp/models/rules/base.py +++ b/flowapp/models/rules/base.py @@ -1,4 +1,5 @@ -from sqlalchemy import event +from typing import List +from sqlalchemy import event, select from ..base import db @@ -30,6 +31,14 @@ def __init__(self, name, command, description, role_id=2): self.description = description self.role_id = role_id + @classmethod + def get_all_ordered(cls) -> List["Action"]: + return db.session.scalars(select(cls).order_by(cls.name)).all() + + @classmethod + def get_all(cls) -> List["Action"]: + return db.session.scalars(select(cls)).all() + # Event listeners for Rstate # Note: seed data is also defined in migrations/versions/001_baseline.py - keep in sync diff --git a/flowapp/models/rules/flowspec.py b/flowapp/models/rules/flowspec.py index 9823aba..6c4dd91 100644 --- a/flowapp/models/rules/flowspec.py +++ b/flowapp/models/rules/flowspec.py @@ -1,6 +1,8 @@ import json from datetime import datetime +from typing import List, Optional from flowapp import utils +from sqlalchemy import func, select from ..base import db @@ -163,6 +165,17 @@ def json(self, prefered_format="yearfirst"): """ return json.dumps(self.to_dict()) + @classmethod + def get_all_ordered(cls) -> List["Flowspec4"]: + return db.session.scalars(select(cls).order_by(cls.expires.desc())).all() + + @classmethod + def count_active(cls, org_id: Optional[int] = None) -> int: + q = select(func.count()).select_from(cls).filter_by(rstate_id=1) + if org_id is not None: + q = select(func.count()).select_from(cls).filter_by(rstate_id=1, org_id=org_id) + return db.session.scalar(q) + class Flowspec6(db.Model): id = db.Column(db.Integer, primary_key=True) @@ -292,3 +305,14 @@ def json(self, prefered_format="yearfirst"): :returns: json """ return json.dumps(self.to_dict()) + + @classmethod + def get_all_ordered(cls) -> List["Flowspec6"]: + return db.session.scalars(select(cls).order_by(cls.expires.desc())).all() + + @classmethod + def count_active(cls, org_id: Optional[int] = None) -> int: + q = select(func.count()).select_from(cls).filter_by(rstate_id=1) + if org_id is not None: + q = select(func.count()).select_from(cls).filter_by(rstate_id=1, org_id=org_id) + return db.session.scalar(q) diff --git a/flowapp/models/rules/rtbh.py b/flowapp/models/rules/rtbh.py index 943fef0..47e526b 100644 --- a/flowapp/models/rules/rtbh.py +++ b/flowapp/models/rules/rtbh.py @@ -1,6 +1,8 @@ import json from datetime import datetime +from typing import List, Optional from flowapp import utils +from sqlalchemy import func, select from ..base import db @@ -151,3 +153,14 @@ def __str__(self): def get_author(self): return f"{self.user.email} / {self.org}" + + @classmethod + def get_all_ordered(cls) -> List["RTBH"]: + return db.session.scalars(select(cls).order_by(cls.expires.desc())).all() + + @classmethod + def count_active(cls, org_id: Optional[int] = None) -> int: + q = select(func.count()).select_from(cls).filter_by(rstate_id=1) + if org_id is not None: + q = select(func.count()).select_from(cls).filter_by(rstate_id=1, org_id=org_id) + return db.session.scalar(q) diff --git a/flowapp/models/rules/whitelist.py b/flowapp/models/rules/whitelist.py index 83c131f..fc29959 100644 --- a/flowapp/models/rules/whitelist.py +++ b/flowapp/models/rules/whitelist.py @@ -1,6 +1,7 @@ from flowapp import utils from ..base import db from datetime import datetime +from sqlalchemy import func, select from flowapp.constants import RuleTypes, RuleOrigin @@ -119,7 +120,7 @@ def get_by_whitelist_id(cls, whitelist_id: int): Returns: list: All RuleWhitelistCache objects with the specified whitelist_id """ - return cls.query.filter_by(whitelist_id=whitelist_id).all() + return db.session.scalars(select(cls).filter_by(whitelist_id=whitelist_id)).all() @classmethod def clean_by_whitelist_id(cls, whitelist_id: int): @@ -132,7 +133,9 @@ def clean_by_whitelist_id(cls, whitelist_id: int): Returns: int: Number of rows deleted """ - deleted = cls.query.filter_by(whitelist_id=whitelist_id).delete() + deleted = db.session.execute( + db.delete(cls).where(cls.whitelist_id == whitelist_id) + ).rowcount db.session.commit() return deleted @@ -147,10 +150,20 @@ def delete_by_rule_id(cls, rule_id: int): Returns: int: Number of rows deleted """ - deleted = cls.query.filter_by(rid=rule_id).delete() + deleted = db.session.execute( + db.delete(cls).where(cls.rid == rule_id) + ).rowcount db.session.commit() return deleted + @classmethod + def get_by_rule_ids(cls, rule_ids: list, rule_type: "RuleTypes") -> list: + if not rule_ids or not rule_type: + return [] + return db.session.scalars( + select(cls).filter(cls.rid.in_(rule_ids), cls.rtype == rule_type.value) + ).all() + @classmethod def count_by_rule(cls, rule_id: int, rule_type: RuleTypes): """ @@ -163,7 +176,9 @@ def count_by_rule(cls, rule_id: int, rule_type: RuleTypes): Returns: int: Number of cache entries """ - return cls.query.filter_by(rid=rule_id, rtype=rule_type.value).count() + return db.session.scalar( + select(func.count()).select_from(cls).filter_by(rid=rule_id, rtype=rule_type.value) + ) def __repr__(self): return f"" diff --git a/flowapp/models/user.py b/flowapp/models/user.py index 78a028d..d7db334 100644 --- a/flowapp/models/user.py +++ b/flowapp/models/user.py @@ -1,4 +1,5 @@ -from sqlalchemy import event +from typing import List, Optional +from sqlalchemy import event, select from .base import db, user_role, user_organization from .organization import Organization @@ -46,17 +47,29 @@ def update(self, form): self.organization.remove(org) for role_id in form.role_ids.data: - my_role = db.session.query(Role).filter_by(id=role_id).first() + my_role = db.session.execute(select(Role).filter_by(id=role_id)).scalar_one() if my_role not in self.role: self.role.append(my_role) for org_id in form.org_ids.data: - my_org = db.session.query(Organization).filter_by(id=org_id).first() + my_org = db.session.execute(select(Organization).filter_by(id=org_id)).scalar_one() if my_org not in self.organization: self.organization.append(my_org) db.session.commit() + @classmethod + def get_all(cls) -> List["User"]: + return db.session.scalars(select(cls)).all() + + @classmethod + def get_all_ordered(cls) -> List["User"]: + return db.session.scalars(select(cls).order_by(cls.name)).all() + + @classmethod + def get_by_uuid(cls, uuid: str) -> Optional["User"]: + return db.session.scalars(select(cls).filter_by(uuid=uuid)).first() + class Role(db.Model): id = db.Column(db.Integer, primary_key=True) @@ -70,6 +83,10 @@ def __init__(self, name, description): def __repr__(self): return self.name + @classmethod + def get_all_ordered(cls) -> List["Role"]: + return db.session.scalars(select(cls).order_by(cls.name)).all() + # Event listeners for Role # Note: seed data is also defined in migrations/versions/001_baseline.py - keep in sync diff --git a/flowapp/models/utils.py b/flowapp/models/utils.py index 8270173..2a19fe2 100644 --- a/flowapp/models/utils.py +++ b/flowapp/models/utils.py @@ -4,6 +4,7 @@ from flowapp import utils from flowapp.constants import RuleTypes from flask import current_app +from sqlalchemy import func, select from flowapp.models.rules.whitelist import Whitelist from .base import db @@ -25,20 +26,20 @@ def check_rule_limit(org_id: int, rule_type: RuleTypes) -> bool: flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + fs4 = db.session.scalar(select(func.count()).select_from(Flowspec4).filter_by(rstate_id=1)) + fs6 = db.session.scalar(select(func.count()).select_from(Flowspec6).filter_by(rstate_id=1)) + rtbh = db.session.scalar(select(func.count()).select_from(RTBH).filter_by(rstate_id=1)) # check the organization limits - org = Organization.query.filter_by(id=org_id).first() + org = db.session.execute(select(Organization).filter_by(id=org_id)).scalar_one() if rule_type == RuleTypes.IPv4 and org.limit_flowspec4 > 0: - count = db.session.query(Flowspec4).filter_by(org_id=org_id, rstate_id=1).count() + count = db.session.scalar(select(func.count()).select_from(Flowspec4).filter_by(org_id=org_id, rstate_id=1)) return count >= org.limit_flowspec4 or fs4 >= flowspec4_limit if rule_type == RuleTypes.IPv6 and org.limit_flowspec6 > 0: - count = db.session.query(Flowspec6).filter_by(org_id=org_id, rstate_id=1).count() + count = db.session.scalar(select(func.count()).select_from(Flowspec6).filter_by(org_id=org_id, rstate_id=1)) return count >= org.limit_flowspec6 or fs6 >= flowspec6_limit if rule_type == RuleTypes.RTBH and org.limit_rtbh > 0: - count = db.session.query(RTBH).filter_by(org_id=org_id, rstate_id=1).count() + count = db.session.scalar(select(func.count()).select_from(RTBH).filter_by(org_id=org_id, rstate_id=1)) return count >= org.limit_rtbh or rtbh >= rtbh_limit return False @@ -48,9 +49,9 @@ def check_global_rule_limit(rule_type: RuleTypes) -> bool: flowspec4_limit = current_app.config.get("FLOWSPEC4_MAX_RULES", 9000) flowspec6_limit = current_app.config.get("FLOWSPEC6_MAX_RULES", 9000) rtbh_limit = current_app.config.get("RTBH_MAX_RULES", 100000) - fs4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - fs6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + fs4 = db.session.scalar(select(func.count()).select_from(Flowspec4).filter_by(rstate_id=1)) + fs6 = db.session.scalar(select(func.count()).select_from(Flowspec6).filter_by(rstate_id=1)) + rtbh = db.session.scalar(select(func.count()).select_from(RTBH).filter_by(rstate_id=1)) # check the global limits if the organization limits are not set @@ -64,20 +65,59 @@ def check_global_rule_limit(rule_type: RuleTypes) -> bool: return False +def get_org_rule_stats() -> dict: + """ + Return per-org and global active rule counts for the admin organizations view. + """ + orgs = db.session.scalars( + select(Organization).options(db.joinedload(Organization.rtbh)) + ).unique().all() + + rtbh_counts = dict( + db.session.execute( + select(RTBH.org_id, func.count(RTBH.id)) + .filter(RTBH.rstate_id == 1) + .group_by(RTBH.org_id) + ).all() + ) + flowspec4_counts = dict( + db.session.execute( + select(Flowspec4.org_id, func.count(Flowspec4.id)) + .filter(Flowspec4.rstate_id == 1) + .group_by(Flowspec4.org_id) + ).all() + ) + flowspec6_counts = dict( + db.session.execute( + select(Flowspec6.org_id, func.count(Flowspec6.id)) + .filter(Flowspec6.rstate_id == 1) + .group_by(Flowspec6.org_id) + ).all() + ) + + return { + "orgs": orgs, + "rtbh_counts": rtbh_counts, + "flowspec4_counts": flowspec4_counts, + "flowspec6_counts": flowspec6_counts, + "rtbh_all_count": RTBH.count_active(), + "flowspec4_all_count": Flowspec4.count_active(), + "flowspec6_all_count": Flowspec6.count_active(), + } + + def get_whitelist_model_if_exists(form_data): """ Check if the record in database exist ip, mask should match expires, rstate_id, user_id, org_id, created, comment can be different """ - record = ( - db.session.query(Whitelist) - .filter( + record = db.session.execute( + select(Whitelist).filter( Whitelist.ip == form_data["ip"], Whitelist.mask == form_data["mask"], ) - .first() - ) + ).scalar_one_or_none() if record: return record @@ -91,9 +131,8 @@ def get_ipv4_model_if_exists(form_data, rstate_id=1): Source and destination addresses, protocol, flags, action and packet_len should match Other fields can be different """ - record = ( - db.session.query(Flowspec4) - .filter( + record = db.session.scalars( + select(Flowspec4).filter( Flowspec4.source == form_data["source"], Flowspec4.source_mask == form_data["source_mask"], Flowspec4.source_port == form_data["source_port"], @@ -106,8 +145,7 @@ def get_ipv4_model_if_exists(form_data, rstate_id=1): Flowspec4.action_id == form_data["action"], Flowspec4.rstate_id == rstate_id, ) - .first() - ) + ).first() if record: return record @@ -119,9 +157,8 @@ def get_ipv6_model_if_exists(form_data, rstate_id=1): """ Check if the record in database exist """ - record = ( - db.session.query(Flowspec6) - .filter( + record = db.session.scalars( + select(Flowspec6).filter( Flowspec6.source == form_data["source"], Flowspec6.source_mask == form_data["source_mask"], Flowspec6.source_port == form_data["source_port"], @@ -134,8 +171,7 @@ def get_ipv6_model_if_exists(form_data, rstate_id=1): Flowspec6.action_id == form_data["action"], Flowspec6.rstate_id == rstate_id, ) - .first() - ) + ).first() if record: return record @@ -150,17 +186,15 @@ def get_rtbh_model_if_exists(form_data): Rule can be in any state and have different expires, user_id, org_id, created, comment """ - record = ( - db.session.query(RTBH) - .filter( + record = db.session.scalars( + select(RTBH).filter( RTBH.ipv4 == form_data["ipv4"], RTBH.ipv4_mask == form_data["ipv4_mask"], RTBH.ipv6 == form_data["ipv6"], RTBH.ipv6_mask == form_data["ipv6_mask"], RTBH.community_id == form_data["community"], ) - .first() - ) + ).first() if record: return record @@ -173,8 +207,8 @@ def insert_users(users): inser list of users {name: string, role_id: integer} to db """ for user in users: - r = Role.query.filter_by(id=user["role_id"]).first() - o = Organization.query.filter_by(id=user["org_id"]).first() + r = db.session.execute(select(Role).filter_by(id=user["role_id"])).scalar_one() + o = db.session.execute(select(Organization).filter_by(id=user["org_id"])).scalar_one() u = User(uuid=user["name"]) u.role.append(r) u.organization.append(o) @@ -204,16 +238,16 @@ def insert_user( :return: None """ u = User(uuid=uuid, name=name, phone=phone, comment=comment, email=email) + db.session.add(u) for role_id in role_ids: - r = Role.query.filter_by(id=role_id).first() + r = db.session.execute(select(Role).filter_by(id=role_id)).scalar_one() u.role.append(r) for org_id in org_ids: - o = Organization.query.filter_by(id=org_id).first() + o = db.session.execute(select(Organization).filter_by(id=org_id)).scalar_one() u.organization.append(o) - db.session.add(u) db.session.commit() @@ -221,7 +255,7 @@ def get_user_nets(user_id): """ Return list of network ranges for all user organization """ - user = db.session.query(User).filter_by(id=user_id).first() + user = db.session.execute(select(User).filter_by(id=user_id)).scalar_one() orgs = user.organization result = [] for org in orgs: @@ -234,7 +268,7 @@ def get_user_orgs_choices(user_id): """ Return list of orgs as choices for form """ - user = db.session.query(User).filter_by(id=user_id).first() + user = db.session.execute(select(User).filter_by(id=user_id)).scalar_one() orgs = user.organization return [(g.id, g.name) for g in orgs] @@ -246,9 +280,9 @@ def get_user_actions(user_roles): """ max_role = max(user_roles) if max_role == 3: - actions = db.session.query(Action).order_by("id").all() + actions = db.session.scalars(select(Action).order_by(Action.id)).all() else: - actions = db.session.query(Action).filter_by(role_id=max_role).order_by("id").all() + actions = db.session.scalars(select(Action).filter_by(role_id=max_role).order_by(Action.id)).all() result = [(g.id, g.name) for g in actions] return result @@ -259,9 +293,9 @@ def get_user_communities(user_roles): """ max_role = max(user_roles) if max_role == 3: - communities = db.session.query(Community).order_by("id") + communities = db.session.scalars(select(Community).order_by(Community.id)) else: - communities = db.session.query(Community).filter_by(role_id=max_role).order_by("id") + communities = db.session.scalars(select(Community).filter_by(role_id=max_role).order_by(Community.id)) return [(g.id, g.name) for g in communities] @@ -274,7 +308,7 @@ def get_existing_action(name=None, command=None): :param command: string action command :return: action id """ - action = Action.query.filter((Action.name == name) | (Action.command == command)).first() + action = db.session.scalars(select(Action).filter((Action.name == name) | (Action.command == command)).limit(1)).first() return action.id if hasattr(action, "id") else None @@ -286,7 +320,7 @@ def get_existing_community(name=None): :param command: string action command :return: action id """ - community = Community.query.filter(Community.name == name).first() + community = db.session.execute(select(Community).filter(Community.name == name)).scalar_one_or_none() return community.id if hasattr(community, "id") else None @@ -299,7 +333,7 @@ def _get_flowspec4_rules(rule_state, sort="expires", order="desc", page=1, per_p sorter = getattr(Flowspec4, sort, Flowspec4.id) sorting = getattr(sorter, order) - query = db.session.query(Flowspec4) + query = select(Flowspec4) if comp_func: query = query.filter(comp_func(Flowspec4.expires, today)) @@ -307,10 +341,10 @@ def _get_flowspec4_rules(rule_state, sort="expires", order="desc", page=1, per_p query = query.order_by(sorting()) if paginate: - pagination = query.paginate(page=page, per_page=per_page, error_out=False, max_per_page=500) + pagination = db.paginate(query, page=page, per_page=per_page, error_out=False, max_per_page=500) return pagination.items, pagination else: - return query.all() + return db.session.scalars(query).all() def _get_flowspec6_rules(rule_state, sort="expires", order="desc", page=1, per_page=50, paginate=False): @@ -322,7 +356,7 @@ def _get_flowspec6_rules(rule_state, sort="expires", order="desc", page=1, per_p sorter = getattr(Flowspec6, sort, Flowspec6.id) sorting = getattr(sorter, order) - query = db.session.query(Flowspec6) + query = select(Flowspec6) if comp_func: query = query.filter(comp_func(Flowspec6.expires, today)) @@ -330,10 +364,10 @@ def _get_flowspec6_rules(rule_state, sort="expires", order="desc", page=1, per_p query = query.order_by(sorting()) if paginate: - pagination = query.paginate(page=page, per_page=per_page, error_out=False, max_per_page=500) + pagination = db.paginate(query, page=page, per_page=per_page, error_out=False, max_per_page=500) return pagination.items, pagination else: - return query.all() + return db.session.scalars(query).all() def _get_rtbh_rules(rule_state, sort="expires", order="desc", page=1, per_page=50, paginate=False): @@ -345,7 +379,7 @@ def _get_rtbh_rules(rule_state, sort="expires", order="desc", page=1, per_page=5 sorter = getattr(RTBH, sort, RTBH.id) sorting = getattr(sorter, order) - query = db.session.query(RTBH) + query = select(RTBH) if comp_func: query = query.filter(comp_func(RTBH.expires, today)) @@ -353,10 +387,10 @@ def _get_rtbh_rules(rule_state, sort="expires", order="desc", page=1, per_page=5 query = query.order_by(sorting()) if paginate: - pagination = query.paginate(page=page, per_page=per_page, error_out=False, max_per_page=500) + pagination = db.paginate(query, page=page, per_page=per_page, error_out=False, max_per_page=500) return pagination.items, pagination else: - return query.all() + return db.session.scalars(query).all() def _get_whitelist_rules(rule_state, sort="expires", order="desc", page=1, per_page=50, paginate=False): @@ -368,7 +402,7 @@ def _get_whitelist_rules(rule_state, sort="expires", order="desc", page=1, per_p sorter = getattr(Whitelist, sort, Whitelist.id) sorting = getattr(sorter, order) - query = db.session.query(Whitelist) + query = select(Whitelist) if comp_func: query = query.filter(comp_func(Whitelist.expires, today)) @@ -376,10 +410,10 @@ def _get_whitelist_rules(rule_state, sort="expires", order="desc", page=1, per_p query = query.order_by(sorting()) if paginate: - pagination = query.paginate(page=page, per_page=per_page, error_out=False, max_per_page=500) + pagination = db.paginate(query, page=page, per_page=per_page, error_out=False, max_per_page=500) return pagination.items, pagination else: - return query.all() + return db.session.scalars(query).all() # Facade function - keeps backward compatibility and config-based routing @@ -430,13 +464,10 @@ def get_user_rules_ids(user_id, rule_type): """ if rule_type == "ipv4": - rules4 = db.session.query(Flowspec4.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules4] + return list(db.session.scalars(select(Flowspec4.id).filter_by(user_id=user_id))) if rule_type == "ipv6": - rules6 = db.session.query(Flowspec6.id).order_by(Flowspec6.expires.desc()).all() - return [int(x[0]) for x in rules6] + return list(db.session.scalars(select(Flowspec6.id).filter_by(user_id=user_id).order_by(Flowspec6.expires.desc()))) if rule_type == "rtbh": - rules_rtbh = db.session.query(RTBH.id).filter_by(user_id=user_id).all() - return [int(x[0]) for x in rules_rtbh] + return list(db.session.scalars(select(RTBH.id).filter_by(user_id=user_id))) diff --git a/flowapp/services/base.py b/flowapp/services/base.py index 77adeb0..7fd4074 100644 --- a/flowapp/services/base.py +++ b/flowapp/services/base.py @@ -1,5 +1,6 @@ from datetime import datetime from operator import ge, lt +from sqlalchemy import select from flowapp import constants, db, messages from flowapp.constants import ANNOUNCE, WITHDRAW from flowapp.models import RTBH, Flowspec4, Flowspec6 @@ -44,27 +45,24 @@ def announce_all_routes(action=constants.ANNOUNCE): today = datetime.now() comp_func = ge if action == constants.ANNOUNCE else lt - rules4 = ( - db.session.query(Flowspec4) + rules4 = db.session.scalars( + select(Flowspec4) .filter(Flowspec4.rstate_id == 1) .filter(comp_func(Flowspec4.expires, today)) .order_by(Flowspec4.expires.desc()) - .all() - ) - rules6 = ( - db.session.query(Flowspec6) + ).all() + rules6 = db.session.scalars( + select(Flowspec6) .filter(Flowspec6.rstate_id == 1) .filter(comp_func(Flowspec6.expires, today)) .order_by(Flowspec6.expires.desc()) - .all() - ) - rules_rtbh = ( - db.session.query(RTBH) + ).all() + rules_rtbh = db.session.scalars( + select(RTBH) .filter(RTBH.rstate_id == 1) .filter(comp_func(RTBH.expires, today)) .order_by(RTBH.expires.desc()) - .all() - ) + ).all() messages_v4 = [messages.create_ipv4(rule, action) for rule in rules4] messages_v6 = [messages.create_ipv6(rule, action) for rule in rules6] diff --git a/flowapp/services/rule_service.py b/flowapp/services/rule_service.py index fa4bfec..0c19af8 100644 --- a/flowapp/services/rule_service.py +++ b/flowapp/services/rule_service.py @@ -10,6 +10,7 @@ from typing import Dict, List, Tuple, Union from flask import current_app +from sqlalchemy import select from flowapp import db, messages from flowapp.constants import WITHDRAW, RuleOrigin, RuleTypes, ANNOUNCE @@ -332,7 +333,7 @@ def check_rtbh_whitelisted(model: RTBH, user_id: int, flashes: List[str], author allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] if model.community_id in allowed_communities: # get all not expired whitelists - whitelists = db.session.query(Whitelist).filter(Whitelist.expires > datetime.now()).all() + whitelists = db.session.scalars(select(Whitelist).filter(Whitelist.expires > datetime.now())).all() wl_cache = map_whitelists_to_strings(whitelists) results = check_rule_against_whitelists(str(model), wl_cache.keys()) # check rule against whitelists @@ -556,14 +557,14 @@ def delete_expired_rules() -> Dict[str, int]: for rule_type, (model_class, rule_enum) in model_map.items(): # Get IDs of rules to delete - expired_rule_ids = [ - r.id - for r in db.session.query(model_class.id) - .filter( - model_class.expires < deletion_date, model_class.rstate_id.in_([2, 3]) # withdrawn or deleted state + expired_rule_ids = list( + db.session.scalars( + select(model_class.id).filter( + model_class.expires < deletion_date, + model_class.rstate_id.in_([2, 3]), # withdrawn or deleted state + ) ) - .all() - ] + ) if not expired_rule_ids: current_app.logger.info(f"No expired {model_class.__name__} rules to delete") @@ -580,9 +581,9 @@ def delete_expired_rules() -> Dict[str, int]: ) # Bulk delete the rules - deleted = ( - db.session.query(model_class).filter(model_class.id.in_(expired_rule_ids)).delete(synchronize_session=False) - ) + deleted = db.session.execute( + db.delete(model_class).where(model_class.id.in_(expired_rule_ids)) + ).rowcount deletion_counts[rule_type] = deleted deletion_counts["total"] += deleted diff --git a/flowapp/services/whitelist_service.py b/flowapp/services/whitelist_service.py index bf281af..d5f5dcf 100644 --- a/flowapp/services/whitelist_service.py +++ b/flowapp/services/whitelist_service.py @@ -9,6 +9,7 @@ from typing import Dict, Tuple, List import sqlalchemy +from sqlalchemy import select from flowapp import db from flowapp.constants import RuleOrigin, RuleTypes @@ -66,7 +67,9 @@ def create_or_update_whitelist( # check RTBH rules against whitelist allowed_communities = current_app.config["ALLOWED_COMMUNITIES"] # filter out RTBH rules that are not active or whitelisted and not in allowed communities - all_rtbh_rules = RTBH.query.filter(RTBH.rstate_id.in_([1, 4]), RTBH.community_id.in_(allowed_communities)).all() + all_rtbh_rules = db.session.scalars( + select(RTBH).filter(RTBH.rstate_id.in_([1, 4]), RTBH.community_id.in_(allowed_communities)) + ).all() rtbh_rules_map = map_rtbh_rules_to_strings(all_rtbh_rules) result = check_whitelist_against_rules(rtbh_rules_map, str(model)) current_app.logger.info(f"Found {len(result)} matching RTBH rules for whitelist {model}") @@ -127,7 +130,7 @@ def delete_expired_whitelists() -> List[str]: Returns: List of messages for the user """ - expired_whitelists = Whitelist.query.filter(Whitelist.expires < db.func.now()).all() + expired_whitelists = db.session.scalars(select(Whitelist).filter(Whitelist.expires < db.func.now())).all() flashes = [] for model in expired_whitelists: flashes.extend(delete_whitelist(model.id)) diff --git a/flowapp/templates/forms/ipv4_rule.html b/flowapp/templates/forms/ipv4_rule.html index 6cc9021..1b85fc0 100644 --- a/flowapp/templates/forms/ipv4_rule.html +++ b/flowapp/templates/forms/ipv4_rule.html @@ -92,4 +92,16 @@

{{ title or 'New'}} IPv4 rule

+ {% endblock %} \ No newline at end of file diff --git a/flowapp/templates/forms/ipv6_rule.html b/flowapp/templates/forms/ipv6_rule.html index 62198b3..f624556 100644 --- a/flowapp/templates/forms/ipv6_rule.html +++ b/flowapp/templates/forms/ipv6_rule.html @@ -85,4 +85,16 @@

{{ title or 'New'}} IPv6 rule

+ {% endblock %} \ No newline at end of file diff --git a/flowapp/views/admin.py b/flowapp/views/admin.py index d00c6d8..296ab70 100644 --- a/flowapp/views/admin.py +++ b/flowapp/views/admin.py @@ -1,9 +1,7 @@ import csv from io import StringIO -from datetime import datetime, timedelta import secrets -from sqlalchemy import func from flask import Blueprint, render_template, redirect, flash, request, session, url_for, current_app import sqlalchemy from sqlalchemy.exc import IntegrityError, OperationalError @@ -21,10 +19,8 @@ Community, get_existing_community, Log, - Flowspec4, - Flowspec6, - RTBH, ) +from ..models.utils import get_org_rule_stats from ..auth import auth_required, admin_required from flowapp import db @@ -37,16 +33,10 @@ @admin_required def log(page): """ - Displays logs for last two days + Displays paginated logs """ per_page = 20 - now = datetime.now() - week_ago = now - timedelta(weeks=1) - logs = ( - Log.query.order_by(Log.time.desc()) - .filter(Log.time > week_ago) - .paginate(page=page, per_page=per_page, max_per_page=None, error_out=False) - ) + logs = Log.get_recent_paginated(page=page, per_page=per_page) return render_template("pages/logs.html", logs=logs) @@ -57,7 +47,7 @@ def machine_keys(): """ Display all machine keys, from all admins """ - keys = db.session.query(MachineApiKey).all() + keys = MachineApiKey.get_all() return render_template("pages/machine_api_key.html", keys=keys) @@ -72,7 +62,7 @@ def add_machine_key(): """ generated = secrets.token_hex(24) form = MachineApiKeyForm(request.form, key=generated) - form.user.choices = [(g.id, f"{g.name} ({g.uuid})") for g in db.session.query(User).order_by("name")] + form.user.choices = [(g.id, f"{g.name} ({g.uuid})") for g in User.get_all_ordered()] if request.method == "POST" and form.validate(): target_user = db.session.get(User, form.user.data) @@ -128,12 +118,12 @@ def delete_machine_key(key_id): @admin_required def user(): form = UserForm(request.form) - form.role_ids.choices = [(g.id, g.name) for g in db.session.query(Role).order_by("name")] - form.org_ids.choices = [(g.id, g.name) for g in db.session.query(Organization).order_by("name")] + form.role_ids.choices = [(g.id, g.name) for g in Role.get_all_ordered()] + form.org_ids.choices = [(g.id, g.name) for g in Organization.get_all_ordered()] if request.method == "POST" and form.validate(): # test if user is unique - exist = db.session.query(User).filter_by(uuid=form.uuid.data).first() + exist = User.get_by_uuid(form.uuid.data) if not exist: insert_user( uuid=form.uuid.data, @@ -164,8 +154,8 @@ def user(): def edit_user(user_id): user = db.session.get(User, user_id) form = UserForm(request.form, obj=user) - form.role_ids.choices = [(g.id, g.name) for g in db.session.query(Role).order_by("name")] - form.org_ids.choices = [(g.id, g.name) for g in db.session.query(Organization).order_by("name")] + form.role_ids.choices = [(g.id, g.name) for g in Role.get_all_ordered()] + form.org_ids.choices = [(g.id, g.name) for g in Organization.get_all_ordered()] if request.method == "POST" and form.validate(): user.update(form) @@ -217,7 +207,7 @@ def delete_user(user_id): @auth_required @admin_required def users(): - users = User.query.all() + users = User.get_all() return render_template("pages/users.html", users=users) @@ -235,9 +225,9 @@ def bulk_import_users(): @admin_required def bulk_import_users_save(): form = BulkUserForm(request.form) - roles = [role.id for role in db.session.query(Role).all()] - orgs = [org.id for org in db.session.query(Organization).all()] - uuids = [user.uuid for user in db.session.query(User).all()] + roles = [role.id for role in Role.get_all_ordered()] + orgs = [org.id for org in Organization.get_all_ordered()] + uuids = [user.uuid for user in User.get_all()] form.roles = roles form.organizations = orgs form.uuids = uuids @@ -288,46 +278,17 @@ def bulk_import_users_save(): @auth_required @admin_required def organizations(): - # Query all organizations and eager load RTBH relationships - orgs = db.session.query(Organization).options(db.joinedload(Organization.rtbh)).all() - - # Get RTBH counts with rstate_id=1 for all organizations in one query - rtbh_counts_query = ( - db.session.query(RTBH.org_id, func.count(RTBH.id)).filter(RTBH.rstate_id == 1).group_by(RTBH.org_id).all() - ) - - flowspec4_count_query = ( - db.session.query(Flowspec4.org_id, func.count(Flowspec4.id)) - .filter(Flowspec4.rstate_id == 1) - .group_by(Flowspec4.org_id) - .all() - ) - - flowspec6_count_query = ( - db.session.query(Flowspec6.org_id, func.count(Flowspec6.id)) - .filter(Flowspec6.rstate_id == 1) - .group_by(Flowspec6.org_id) - .all() - ) - - flowspec4_all_count = db.session.query(Flowspec4).filter(Flowspec4.rstate_id == 1).count() - flowspec6_all_count = db.session.query(Flowspec6).filter(Flowspec6.rstate_id == 1).count() - rtbh_all_count = db.session.query(RTBH).filter(RTBH.rstate_id == 1).count() - - # Convert query result to a dictionary {org_id: count} - rtbh_counts = {org_id: count for org_id, count in rtbh_counts_query} - flowspec4_counts = {org_id: count for org_id, count in flowspec4_count_query} - flowspec6_counts = {org_id: count for org_id, count in flowspec6_count_query} + stats = get_org_rule_stats() return render_template( "pages/orgs.html", - orgs=orgs, - rtbh_counts=rtbh_counts, - flowspec4_counts=flowspec4_counts, - flowspec6_counts=flowspec6_counts, - rtbh_all_count=rtbh_all_count, - flowspec4_all_count=flowspec4_all_count, - flowspec6_all_count=flowspec6_all_count, + orgs=stats["orgs"], + rtbh_counts=stats["rtbh_counts"], + flowspec4_counts=stats["flowspec4_counts"], + flowspec6_counts=stats["flowspec6_counts"], + rtbh_all_count=stats["rtbh_all_count"], + flowspec4_all_count=stats["flowspec4_all_count"], + flowspec6_all_count=stats["flowspec6_all_count"], flowspec4_limit=current_app.config.get("FLOWSPEC4_MAX_RULES", 9000), flowspec6_limit=current_app.config.get("FLOWSPEC6_MAX_RULES", 9000), rtbh_limit=current_app.config.get("RTBH_MAX_RULES", 100000), @@ -342,7 +303,7 @@ def organization(): if request.method == "POST" and form.validate(): # test if user is unique - exist = db.session.query(Organization).filter_by(name=form.name.data).first() + exist = Organization.get_by_name(form.name.data) if not exist: org = Organization( name=form.name.data, @@ -417,7 +378,7 @@ def delete_organization(org_id): @auth_required @admin_required def as_paths(): - mpaths = db.session.query(ASPath).all() + mpaths = ASPath.get_all() return render_template("pages/as_paths.html", paths=mpaths) @@ -429,7 +390,7 @@ def as_path(): if request.method == "POST" and form.validate(): # test if user is unique - exist = db.session.query(ASPath).filter_by(prefix=form.prefix.data).first() + exist = ASPath.get_by_prefix(form.prefix.data) if not exist: pth = ASPath(prefix=form.prefix.data, as_path=form.as_path.data) db.session.add(pth) @@ -492,7 +453,7 @@ def delete_as_path(path_id): @auth_required @admin_required def actions(): - actions = db.session.query(Action).all() + actions = Action.get_all() return render_template("pages/actions.html", actions=actions) @@ -579,7 +540,7 @@ def delete_action(action_id): @auth_required @admin_required def communities(): - communities = db.session.query(Community).all() + communities = Community.get_all() return render_template("pages/communities.html", communities=communities) diff --git a/flowapp/views/api_common.py b/flowapp/views/api_common.py index 2aaef73..35d2075 100644 --- a/flowapp/views/api_common.py +++ b/flowapp/views/api_common.py @@ -65,10 +65,10 @@ def authorize(user_key: str) -> Tuple[Response, int]: """ jwt_key: Optional[str] = current_app.config.get("JWT_SECRET") # try normal user key first - model: Optional[Union[ApiKey, MachineApiKey]] = db.session.query(ApiKey).filter_by(key=user_key).first() + model: Optional[Union[ApiKey, MachineApiKey]] = ApiKey.get_by_key(user_key) # if not found try machine key if not model: - model = db.session.query(MachineApiKey).filter_by(key=user_key).first() + model = MachineApiKey.get_by_key(user_key) # if key is not found return 403 if not model: return jsonify({"message": "auth token is invalid"}), 403 @@ -124,9 +124,9 @@ def index(current_user: Dict[str, Any], key_map: Dict[str, str]) -> Response: prefered_tf: str = request.args.get(TIME_FORMAT_ARG) if request.args.get(TIME_FORMAT_ARG) else "" net_ranges: List[str] = get_user_nets(current_user["id"]) - rules4: List[Flowspec4] = db.session.query(Flowspec4).order_by(Flowspec4.expires.desc()).all() - rules6: List[Flowspec6] = db.session.query(Flowspec6).order_by(Flowspec6.expires.desc()).all() - rules_rtbh: List[RTBH] = db.session.query(RTBH).order_by(RTBH.expires.desc()).all() + rules4: List[Flowspec4] = Flowspec4.get_all_ordered() + rules6: List[Flowspec6] = Flowspec6.get_all_ordered() + rules_rtbh: List[RTBH] = RTBH.get_all_ordered() # admin can see and edit any rules if 3 in current_user["role_ids"]: @@ -138,9 +138,9 @@ def index(current_user: Dict[str, Any], key_map: Dict[str, str]) -> Response: return jsonify(payload) # filter out the rules for normal users else: - rules4 = validators.filter_rules_in_network(net_ranges, rules4) - rules6 = validators.filter_rules_in_network(net_ranges, rules6) - rules_rtbh = validators.filter_rtbh_rules(net_ranges, rules_rtbh) + rules4, rules4_outside_nets = validators.split_rules_for_user(net_ranges, rules4) + rules6, rules6_outside_nets = validators.split_rules_for_user(net_ranges, rules6) + rules_rtbh, rules_rtbh_outside_nets = validators.split_rtbh_rules_for_user(net_ranges, rules_rtbh) user_actions: List[Tuple[int, str]] = get_user_actions(current_user["role_ids"]) user_actions_ids: List[int] = [act[0] for act in user_actions] @@ -156,9 +156,10 @@ def index(current_user: Dict[str, Any], key_map: Dict[str, str]) -> Response: payload = { key_map["ipv4_rules"]: [rule.to_dict(prefered_tf) for rule in rules4_editable], key_map["ipv6_rules"]: [rule.to_dict(prefered_tf) for rule in rules6_editable], - key_map["ipv4_rules_readonly"]: [rule.to_dict(prefered_tf) for rule in rules4_visible], - key_map["ipv6_rules_readonly"]: [rule.to_dict(prefered_tf) for rule in rules6_visible], + key_map["ipv4_rules_readonly"]: [rule.to_dict(prefered_tf) for rule in rules4_visible + rules4_outside_nets], + key_map["ipv6_rules_readonly"]: [rule.to_dict(prefered_tf) for rule in rules6_visible + rules6_outside_nets], key_map["rtbh_rules"]: [rule.to_dict(prefered_tf) for rule in rules_rtbh], + key_map["rtbh_rules_readonly"]: [rule.to_dict(prefered_tf) for rule in rules_rtbh_outside_nets], } return jsonify(payload) @@ -231,11 +232,11 @@ def create_ipv4(current_user: Dict[str, Any]) -> Tuple[Response, int]: :return: json response """ if check_global_rule_limit(RuleTypes.IPv4): - count: int = db.session.query(Flowspec4).filter_by(rstate_id=1).count() + count: int = Flowspec4.count_active() return global_limit_reached(count=count, rule_type=RuleTypes.IPv4) if check_rule_limit(current_user["org_id"], RuleTypes.IPv4): - count = db.session.query(Flowspec4).filter_by(rstate_id=1, org_id=current_user["org_id"]).count() + count = Flowspec4.count_active(org_id=current_user["org_id"]) return limit_reached(count=count, rule_type=RuleTypes.IPv4, org_id=current_user["org_id"]) net_ranges: List[str] = get_user_nets(current_user["id"]) @@ -274,11 +275,11 @@ def create_ipv6(current_user: Dict[str, Any]) -> Tuple[Response, int]: :return: """ if check_global_rule_limit(RuleTypes.IPv6): - count: int = db.session.query(Flowspec6).filter_by(rstate_id=1).count() + count: int = Flowspec6.count_active() return global_limit_reached(count=count, rule_type=RuleTypes.IPv6) if check_rule_limit(current_user["org_id"], RuleTypes.IPv6): - count = db.session.query(Flowspec6).filter_by(rstate_id=1, org_id=current_user["org_id"]).count() + count = Flowspec6.count_active(org_id=current_user["org_id"]) return limit_reached(count=count, rule_type=RuleTypes.IPv6, org_id=current_user["org_id"]) net_ranges: List[str] = get_user_nets(current_user["id"]) @@ -311,14 +312,14 @@ def create_rtbh(current_user: Dict[str, Any]) -> Tuple[Response, int]: Create new RTBH rule """ if check_global_rule_limit(RuleTypes.RTBH): - count: int = db.session.query(RTBH).filter_by(rstate_id=1).count() + count: int = RTBH.count_active() return global_limit_reached(count=count, rule_type=RuleTypes.RTBH) if check_rule_limit(current_user["org_id"], RuleTypes.RTBH): - count = db.session.query(RTBH).filter_by(rstate_id=1, org_id=current_user["org_id"]).count() + count = RTBH.count_active(org_id=current_user["org_id"]) return limit_reached(count=count, rule_type=RuleTypes.RTBH, org_id=current_user["org_id"]) - all_com: List[Community] = db.session.query(Community).all() + all_com: List[Community] = Community.get_all() if not all_com: insert_initial_communities() @@ -378,7 +379,7 @@ def rtbh_rule_get(current_user: Dict[str, Any], rule_id: int) -> Tuple[Response, :param rule_id: :return: """ - model: Optional[RTBH] = db.session.query(RTBH).get(rule_id) + model: Optional[RTBH] = db.session.get(RTBH, rule_id) return get_rule(current_user, model, rule_id) diff --git a/flowapp/views/api_keys.py b/flowapp/views/api_keys.py index 7eae8c3..b20970f 100644 --- a/flowapp/views/api_keys.py +++ b/flowapp/views/api_keys.py @@ -31,7 +31,7 @@ def all(): :return: page with keys """ jwt_key = current_app.config.get("JWT_SECRET") - keys = db.session.query(ApiKey).filter_by(user_id=session["user_id"]).all() + keys = ApiKey.get_by_user_id(session["user_id"]) payload = {"keys": [key.id for key in keys]} encoded = jwt.encode(payload, jwt_key, algorithm="HS256") diff --git a/flowapp/views/dashboard.py b/flowapp/views/dashboard.py index c013f67..6b239ff 100644 --- a/flowapp/views/dashboard.py +++ b/flowapp/views/dashboard.py @@ -705,8 +705,8 @@ def enrich_rules_with_whitelist_info(rules, rule_type): from flowapp.models.rules.whitelist import RuleWhitelistCache from flowapp.constants import RuleTypes, RuleOrigin - # Map rule type string to enum value - rule_type_map = {"ipv4": RuleTypes.IPv4.value, "ipv6": RuleTypes.IPv6.value, "rtbh": RuleTypes.RTBH.value} + # Map rule type string to enum + rule_type_map = {"ipv4": RuleTypes.IPv4, "ipv6": RuleTypes.IPv6, "rtbh": RuleTypes.RTBH} # Get all rule IDs rule_ids = [rule.id for rule in rules] @@ -716,9 +716,7 @@ def enrich_rules_with_whitelist_info(rules, rule_type): return rules, set() # Query the cache for these rule IDs - cache_entries = RuleWhitelistCache.query.filter( - RuleWhitelistCache.rid.in_(rule_ids), RuleWhitelistCache.rtype == rule_type_map.get(rule_type) - ).all() + cache_entries = RuleWhitelistCache.get_by_rule_ids(rule_ids, rule_type_map.get(rule_type)) # Create a set of rule IDs that were created by a whitelist whitelist_rule_ids = {entry.rid for entry in cache_entries if entry.rorigin == RuleOrigin.WHITELIST.value} diff --git a/flowapp/views/rules.py b/flowapp/views/rules.py index 7318159..6ec8d5d 100644 --- a/flowapp/views/rules.py +++ b/flowapp/views/rules.py @@ -81,7 +81,7 @@ def reactivate_rule(rule_type, rule_id): form.net_ranges = get_user_nets(session["user_id"]) if rule_type > 2: - form.action.choices = [(g.id, g.name) for g in db.session.query(Action).order_by("name")] + form.action.choices = [(g.id, g.name) for g in Action.get_all_ordered()] form.action.data = model.action_id if rule_type == RuleTypes.RTBH.value: @@ -337,7 +337,7 @@ def group_delete(): f"{session['user_email']} / {session['user_org']}", ) - db.session.query(model_name).filter(model_name.id.in_(to_delete)).delete(synchronize_session=False) + db.session.execute(db.delete(model_name).where(model_name.id.in_(to_delete))) db.session.commit() flash(f"Rules {to_delete} deleted", "alert-success") @@ -391,7 +391,7 @@ def group_update(): form = form_name(request.form) form.net_ranges = get_user_nets(session["user_id"]) if rule_type_int > 2: - form.action.choices = [(g.id, g.name) for g in db.session.query(Action).order_by("name")] + form.action.choices = [(g.id, g.name) for g in Action.get_all_ordered()] if rule_type_int == 1: form.community.choices = get_user_communities(session["user_role_ids"]) @@ -610,7 +610,7 @@ def rtbh_rule(): if check_rule_limit(session["user_org_id"], RuleTypes.RTBH): return redirect(url_for("rules.limit_reached", rule_type=RuleTypes.RTBH)) - all_com = db.session.query(Community).all() + all_com = Community.get_all() if not all_com: insert_initial_communities() @@ -654,9 +654,9 @@ def rtbh_rule(): @auth_required def limit_reached(rule_type): rule_type = constants.RULE_NAMES_DICT[int(rule_type)] - count_4 = db.session.query(Flowspec4).filter_by(rstate_id=1, org_id=session["user_org_id"]).count() - count_6 = db.session.query(Flowspec6).filter_by(rstate_id=1, org_id=session["user_org_id"]).count() - count_rtbh = db.session.query(RTBH).filter_by(rstate_id=1, org_id=session["user_org_id"]).count() + count_4 = Flowspec4.count_active(org_id=session["user_org_id"]) + count_6 = Flowspec6.count_active(org_id=session["user_org_id"]) + count_rtbh = RTBH.count_active(org_id=session["user_org_id"]) org = db.session.get(Organization, session["user_org_id"]) return render_template( "pages/limit_reached.html", @@ -673,9 +673,9 @@ def limit_reached(rule_type): @auth_required def global_limit_reached(rule_type): rule_type = constants.RULE_NAMES_DICT[int(rule_type)] - count_4 = db.session.query(Flowspec4).filter_by(rstate_id=1).count() - count_6 = db.session.query(Flowspec6).filter_by(rstate_id=1).count() - count_rtbh = db.session.query(RTBH).filter_by(rstate_id=1).count() + count_4 = Flowspec4.count_active() + count_6 = Flowspec6.count_active() + count_rtbh = RTBH.count_active() Limit = namedtuple("Limit", ["limit_flowspec4", "limit_flowspec6", "limit_rtbh"]) limit = Limit( @@ -699,14 +699,14 @@ def global_limit_reached(rule_type): @auth_required @admin_required def export(): - rules4 = db.session.query(Flowspec4).order_by(Flowspec4.expires.desc()).all() - rules6 = db.session.query(Flowspec6).order_by(Flowspec6.expires.desc()).all() + rules4 = Flowspec4.get_all_ordered() + rules6 = Flowspec6.get_all_ordered() rules = {4: rules4, 6: rules6} - actions = db.session.query(Action).all() + actions = Action.get_all() actions = {action.id: action for action in actions} - rules_rtbh = db.session.query(RTBH).order_by(RTBH.expires.desc()).all() + rules_rtbh = RTBH.get_all_ordered() announce_all_routes() diff --git a/tests/conftest.py b/tests/conftest.py index 3a988ef..089fdbb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import os import json import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, select from sqlalchemy.orm import sessionmaker from flowapp import create_app @@ -117,7 +117,7 @@ def db(app, request): print("#: inserting users") flowapp.models.insert_users(users) - org = _db.session.query(Organization).filter_by(id=1).first() + org = _db.session.execute(select(Organization).filter_by(id=1)).scalar_one() # Update the organization address range to include our test networks org.arange = "147.230.0.0/16\n2001:718:1c01::/48\n192.168.0.0/16\n10.0.0.0/8" _db.session.commit() @@ -224,9 +224,30 @@ def reset_org_limits(db, app): yield # Allow test execution with app.app_context(): - org = db.session.query(Organization).filter_by(id=1).first() + org = db.session.execute(select(Organization).filter_by(id=1)).scalar_one_or_none() if org: org.limit_flowspec4 = 0 org.limit_flowspec6 = 0 org.limit_rtbh = 0 db.session.commit() + + +@pytest.fixture(scope="session") +def normal_user_jwt_token(client, app, db, request): + """ + JWT token for a normal (non-admin) user belonging to org 1. + The user's org covers 147.230.0.0/16 and 2001:718:1c01::/48, + so rules on e.g. 200.200.200.0/24 or 2002::/16 are outside their nets. + """ + normal_key = "normal-user-testkey" + with app.app_context(): + flowapp.models.insert_users([{"name": "normal.user@cesnet.cz", "role_id": 2, "org_id": 1}]) + user = db.session.execute(select(flowapp.models.User).filter_by(uuid="normal.user@cesnet.cz")).scalar_one() + model = flowapp.models.ApiKey(machine="127.0.0.1", key=normal_key, user_id=user.id, org_id=1) + db.session.add(model) + db.session.commit() + + url = "/api/v3/auth" + token = client.get(url, headers={"x-api-key": normal_key}) + data = json.loads(token.data) + return data["token"] diff --git a/tests/test_admin_models.py b/tests/test_admin_models.py new file mode 100644 index 0000000..a23e278 --- /dev/null +++ b/tests/test_admin_models.py @@ -0,0 +1,150 @@ +""" +Tests for model classmethods used in admin.py views. +Covers: MachineApiKey, User, Role, Organization, ASPath, Log, and get_org_rule_stats. +""" +from datetime import datetime, timedelta + +import flowapp.models as models +from flowapp.models import ( + MachineApiKey, + User, + Role, + Organization, +) +from flowapp.models.community import ASPath +from flowapp.models.log import Log + + +# --- MachineApiKey --- + +def test_machine_api_key_get_all(db): + key = MachineApiKey(machine="10.0.0.1", key="mkey-getall-1", user_id=1, org_id=1) + db.session.add(key) + db.session.commit() + result = MachineApiKey.get_all() + assert any(k.key == "mkey-getall-1" for k in result) + + +# --- User --- + +def test_user_get_all(db): + result = User.get_all() + assert len(result) >= 1 + + +def test_user_get_all_ordered(db): + result = User.get_all_ordered() + names = [u.name for u in result if u.name] + assert names == sorted(names) + + +def test_user_get_by_uuid_found(db): + user = User(uuid="admin-get-by-uuid@test.cz", name="Test", email="admin@test.cz") + db.session.add(user) + db.session.commit() + result = User.get_by_uuid("admin-get-by-uuid@test.cz") + assert result is not None + assert result.uuid == "admin-get-by-uuid@test.cz" + + +def test_user_get_by_uuid_not_found(db): + result = User.get_by_uuid("does-not-exist@nowhere.cz") + assert result is None + + +# --- Role --- + +def test_role_get_all_ordered(db): + result = Role.get_all_ordered() + assert len(result) >= 1 + names = [r.name for r in result] + assert names == sorted(names) + + +# --- Organization --- + +def test_organization_get_all_ordered(db): + result = Organization.get_all_ordered() + assert len(result) >= 1 + names = [o.name for o in result] + assert names == sorted(names) + + +def test_organization_get_by_name_found(db): + result = Organization.get_by_name("TU Liberec") + assert result is not None + + +def test_organization_get_by_name_not_found(db): + result = Organization.get_by_name("Nonexistent Org XYZ") + assert result is None + + +# --- ASPath --- + +def test_aspath_get_all(db): + path = ASPath(prefix="192.0.2.0/24", as_path="64496 64497") + db.session.add(path) + db.session.commit() + result = ASPath.get_all() + assert any(p.prefix == "192.0.2.0/24" for p in result) + + +def test_aspath_get_by_prefix_found(db): + path = ASPath(prefix="198.51.100.0/24", as_path="64496") + db.session.add(path) + db.session.commit() + result = ASPath.get_by_prefix("198.51.100.0/24") + assert result is not None + assert result.prefix == "198.51.100.0/24" + + +def test_aspath_get_by_prefix_not_found(db): + result = ASPath.get_by_prefix("203.0.113.0/24") + assert result is None + + +# --- Log --- + +def test_log_get_recent_paginated(app, db): + log = Log( + time=datetime.now(), + task="test task", + user_id=1, + rule_type=1, + rule_id=1, + author="test@test.cz", + ) + db.session.add(log) + db.session.commit() + pagination = Log.get_recent_paginated(page=1, per_page=20) + assert pagination.total >= 1 + + +# --- get_org_rule_stats --- + +def test_get_org_rule_stats_returns_expected_keys(app, db): + from flowapp.models.utils import get_org_rule_stats + stats = get_org_rule_stats() + assert "orgs" in stats + assert "rtbh_counts" in stats + assert "flowspec4_counts" in stats + assert "flowspec6_counts" in stats + assert "rtbh_all_count" in stats + assert "flowspec4_all_count" in stats + assert "flowspec6_all_count" in stats + + +def test_get_org_rule_stats_counts_are_integers(app, db): + from flowapp.models.utils import get_org_rule_stats + stats = get_org_rule_stats() + assert isinstance(stats["rtbh_all_count"], int) + assert isinstance(stats["flowspec4_all_count"], int) + assert isinstance(stats["flowspec6_all_count"], int) + + +def test_get_org_rule_stats_dicts_keyed_by_org_id(app, db): + from flowapp.models.utils import get_org_rule_stats + stats = get_org_rule_stats() + for d in [stats["rtbh_counts"], stats["flowspec4_counts"], stats["flowspec6_counts"]]: + assert isinstance(d, dict) diff --git a/tests/test_api_v3.py b/tests/test_api_v3.py index 0548532..6c42969 100644 --- a/tests/test_api_v3.py +++ b/tests/test_api_v3.py @@ -1,6 +1,6 @@ import json - +from sqlalchemy import func, select from flowapp.models import Flowspec4, Organization V_PREFIX = "/api/v3" @@ -458,12 +458,12 @@ def test_create_v4rule_lmit(client, db, app, jwt_token): test that limit checkt for v4 works """ with app.app_context(): - org = db.session.query(Organization).filter_by(id=1).first() + org = db.session.execute(select(Organization).filter_by(id=1)).scalar_one() org.limit_flowspec4 = 2 db.session.commit() # count - count = db.session.query(Flowspec4).count() + count = db.session.scalar(select(func.count()).select_from(Flowspec4)) print("COUNT", count) sources = ["147.230.42.17", "147.230.42.118"] @@ -492,7 +492,7 @@ def test_create_v6rule_lmit(client, db, app, jwt_token): test that limit check for v6 works """ with app.app_context(): - org = db.session.query(Organization).filter_by(id=1).first() + org = db.session.execute(select(Organization).filter_by(id=1)).scalar_one() org.limit_flowspec6 = 3 db.session.commit() @@ -522,7 +522,7 @@ def test_create_rtbh_lmit(client, db, app, jwt_token): test that limit check for v6 works """ with app.app_context(): - org = db.session.query(Organization).filter_by(id=1).first() + org = db.session.execute(select(Organization).filter_by(id=1)).scalar_one() org.limit_rtbh = 1 db.session.commit() @@ -547,10 +547,10 @@ def test_update_existing_v4rule_with_timestamp_limit(client, db, app, jwt_token) """ with app.app_context(): # count - count = db.session.query(Flowspec4).filter_by(org_id=1, rstate_id=1).count() + count = db.session.scalar(select(func.count()).select_from(Flowspec4).filter_by(org_id=1, rstate_id=1)) print("COUNT in update", count) - org = db.session.query(Organization).filter_by(id=1).first() + org = db.session.execute(select(Organization).filter_by(id=1)).scalar_one() org.limit_flowspec4 = count db.session.commit() @@ -582,7 +582,7 @@ def test_overall_limit(client, db, app, jwt_token): with app.app_context(): # count - org = db.session.query(Organization).filter_by(id=1).first() + org = db.session.execute(select(Organization).filter_by(id=1)).scalar_one() org.limit_flowspec4 = 20 db.session.commit() @@ -609,3 +609,107 @@ def test_overall_limit(client, db, app, jwt_token): data = json.loads(req.data) assert data["message"] assert data["message"].startswith("System limit") + + +def test_normal_user_sees_foreign_ipv4_rule_as_readonly(client, app, db, normal_user_jwt_token): + """ + An IPv4 rule on a subnet outside the normal user's org ranges must appear in + flowspec_ipv4_ro (not flowspec_ipv4_rw) when fetched by the normal user. + """ + from datetime import datetime + from flowapp.models import Flowspec4 + + with app.app_context(): + rule = Flowspec4( + source="200.200.200.1", + source_mask=32, + source_port="", + destination=None, + destination_mask=None, + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + expires=datetime(2050, 10, 15, 14, 46), + user_id=1, + org_id=1, + action_id=2, + ) + db.session.add(rule) + db.session.commit() + + req = client.get(f"{V_PREFIX}/rules", headers={"x-access-token": normal_user_jwt_token}) + assert req.status_code == 200 + data = json.loads(req.data) + sources_rw = [r["source"] for r in data["flowspec_ipv4_rw"]] + sources_ro = [r["source"] for r in data["flowspec_ipv4_ro"]] + assert "200.200.200.1" not in sources_rw + assert "200.200.200.1" in sources_ro + + +def test_normal_user_sees_foreign_ipv6_rule_as_readonly(client, app, db, normal_user_jwt_token): + """ + An IPv6 rule on a subnet outside the normal user's org ranges must appear in + flowspec_ipv6_ro (not flowspec_ipv6_rw) when fetched by the normal user. + """ + from datetime import datetime + from flowapp.models import Flowspec6 + + with app.app_context(): + rule = Flowspec6( + source="2002:db8::1", + source_mask=128, + source_port="", + destination=None, + destination_mask=None, + destination_port="", + next_header="tcp", + flags="", + packet_len="", + expires=datetime(2050, 10, 15, 14, 46), + user_id=1, + org_id=1, + action_id=2, + ) + db.session.add(rule) + db.session.commit() + + req = client.get(f"{V_PREFIX}/rules", headers={"x-access-token": normal_user_jwt_token}) + assert req.status_code == 200 + data = json.loads(req.data) + sources_rw = [r["source"] for r in data["flowspec_ipv6_rw"]] + sources_ro = [r["source"] for r in data["flowspec_ipv6_ro"]] + assert "2002:db8::1" not in sources_rw + assert "2002:db8::1" in sources_ro + + +def test_normal_user_sees_foreign_rtbh_rule_as_readonly(client, app, db, normal_user_jwt_token): + """ + An RTBH rule on a subnet outside the normal user's org ranges must appear in + rtbh_any_ro (not rtbh_any_rw) when fetched by the normal user. + """ + from datetime import datetime + from flowapp.models import RTBH + + with app.app_context(): + rule = RTBH( + ipv4="200.200.200.2", + ipv4_mask=32, + ipv6=None, + ipv6_mask=None, + community_id=1, + expires=datetime(2050, 10, 25, 14, 46), + user_id=1, + org_id=1, + ) + db.session.add(rule) + db.session.commit() + + req = client.get(f"{V_PREFIX}/rules", headers={"x-access-token": normal_user_jwt_token}) + assert req.status_code == 200 + data = json.loads(req.data) + rtbh_rw = [r["ipv4"] for r in data["rtbh_any_rw"]] + rtbh_ro = [r["ipv4"] for r in data["rtbh_any_ro"]] + assert "200.200.200.2" not in rtbh_rw + assert "200.200.200.2" in rtbh_ro diff --git a/tests/test_api_whitelist_integration.py b/tests/test_api_whitelist_integration.py index 0391a76..7ba33a6 100644 --- a/tests/test_api_whitelist_integration.py +++ b/tests/test_api_whitelist_integration.py @@ -9,6 +9,7 @@ import pytest from datetime import datetime, timedelta +from sqlalchemy import func, select from flowapp.constants import RuleTypes, RuleOrigin from flowapp.models import RTBH, RuleWhitelistCache, Organization from flowapp.services import whitelist_service @@ -50,7 +51,7 @@ def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_da # Create the whitelist directly using the service with app.app_context(): # Create user and organization if needed for the whitelist - org = db.session.query(Organization).first() + org = db.session.execute(select(Organization)).scalars().first() # Create the whitelist whitelist_model, _ = whitelist_service.create_or_update_whitelist( @@ -81,14 +82,14 @@ def test_create_rtbh_equal_to_whitelist(client, app, db, jwt_token, whitelist_da # Now verify the rule was created but marked as whitelisted with app.app_context(): - rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + rtbh_rule = db.session.execute(select(RTBH).filter_by(id=rule_id)).scalar_one() assert rtbh_rule is not None assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state # Verify a cache entry was created - cache_entry = RuleWhitelistCache.query.filter_by( + cache_entry = db.session.execute(select(RuleWhitelistCache).filter_by( rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id - ).first() + )).scalars().first() assert cache_entry is not None assert cache_entry.rorigin == RuleOrigin.USER.value @@ -110,7 +111,7 @@ def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist # Create the whitelist directly using the service with app.app_context(): - org = db.session.query(Organization).first() + org = db.session.execute(select(Organization)).scalars().first() whitelist_model, _ = whitelist_service.create_or_update_whitelist( form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name ) @@ -135,37 +136,35 @@ def test_create_rtbh_supernet_of_whitelist(client, app, db, jwt_token, whitelist # Now verify the rule was created and marked as whitelisted with app.app_context(): # Check the original rule - rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + rtbh_rule = db.session.execute(select(RTBH).filter_by(id=rule_id)).scalar_one() assert rtbh_rule is not None assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state # Check if a new subnet rule was created for the non-whitelisted part - subnet_rule = ( - db.session.query(RTBH) - .filter( + subnet_rule = db.session.execute( + select(RTBH).filter( RTBH.ipv4 == "192.168.1.0", RTBH.ipv4_mask == 25, # This would be the other half not covered by the whitelist ) - .first() - ) + ).scalars().first() assert subnet_rule is not None assert subnet_rule.rstate_id == 1 # Active status # Verify cache entries # Main rule should be cached as a USER rule - user_cache = RuleWhitelistCache.query.filter_by( + user_cache = db.session.execute(select(RuleWhitelistCache).filter_by( rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.USER.value - ).first() + )).scalars().first() assert user_cache is not None # Subnet rule should be cached as a WHITELIST rule - whitelist_cache = RuleWhitelistCache.query.filter_by( + whitelist_cache = db.session.execute(select(RuleWhitelistCache).filter_by( rid=subnet_rule.id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id, rorigin=RuleOrigin.WHITELIST.value, - ).first() + )).scalars().first() assert whitelist_cache is not None @@ -184,9 +183,9 @@ def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_d # Create the whitelist directly using the service with app.app_context(): - all_rtbh_rules_before = db.session.query(RTBH).count() + all_rtbh_rules_before = db.session.scalar(select(func.count()).select_from(RTBH)) - org = db.session.query(Organization).first() + org = db.session.execute(select(Organization)).scalars().first() whitelist_model, _ = whitelist_service.create_or_update_whitelist( form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name ) @@ -208,20 +207,20 @@ def test_create_rtbh_subnet_of_whitelist(client, app, db, jwt_token, whitelist_d # Now verify the rule was created but marked as whitelisted with app.app_context(): - rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + rtbh_rule = db.session.execute(select(RTBH).filter_by(id=rule_id)).scalar_one() assert rtbh_rule is not None assert rtbh_rule.rstate_id == 4 # 4 = whitelisted state # Verify a cache entry was created - cache_entry = RuleWhitelistCache.query.filter_by( + cache_entry = db.session.execute(select(RuleWhitelistCache).filter_by( rid=rule_id, rtype=RuleTypes.RTBH.value, whitelist_id=whitelist_model.id - ).first() + )).scalars().first() assert cache_entry is not None assert cache_entry.rorigin == RuleOrigin.USER.value # Verify no additional rules were created - all_rtbh_rules = db.session.query(RTBH).count() + all_rtbh_rules = db.session.scalar(select(func.count()).select_from(RTBH)) assert all_rtbh_rules - all_rtbh_rules_before == 1 @@ -240,7 +239,7 @@ def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitel # Create the whitelist directly using the service with app.app_context(): - org = db.session.query(Organization).first() + org = db.session.execute(select(Organization)).scalars().first() whitelist_model, _ = whitelist_service.create_or_update_whitelist( form_data=whitelist_data, user_id=1, org_id=org.id, user_email="test@example.com", org_name=org.name ) @@ -263,11 +262,13 @@ def test_create_rtbh_no_relation_to_whitelist(client, app, db, jwt_token, whitel # Now verify the rule was created with active state with app.app_context(): - rtbh_rule = db.session.query(RTBH).filter_by(id=rule_id).first() + rtbh_rule = db.session.execute(select(RTBH).filter_by(id=rule_id)).scalar_one() assert rtbh_rule is not None assert rtbh_rule.rstate_id == 1 # 1 = active state # Verify no cache entry was created - cache_entry = RuleWhitelistCache.query.filter_by(rid=rule_id, rtype=RuleTypes.RTBH.value).first() + cache_entry = db.session.execute(select(RuleWhitelistCache).filter_by( + rid=rule_id, rtype=RuleTypes.RTBH.value + )).scalars().first() assert cache_entry is None diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..570e40e --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,139 @@ +""" +Tests for flowapp.auth helper functions. +""" +from datetime import datetime, timedelta + +import flowapp.models as models +from flowapp.auth import get_user_allowed_rule_ids + + +def _make_ipv4(db, source="147.230.1.1", user_id=1, org_id=1): + rule = models.Flowspec4( + source=source, + source_mask=32, + source_port="", + destination="", + destination_mask=None, + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + action_id=1, + expires=datetime.now() + timedelta(days=1), + user_id=user_id, + org_id=org_id, + rstate_id=1, + ) + db.session.add(rule) + db.session.commit() + return rule + + +def _make_ipv6(db, source="2001:718:1c01::1", user_id=1, org_id=1): + rule = models.Flowspec6( + source=source, + source_mask=128, + source_port="", + destination="", + destination_mask=None, + destination_port="", + next_header="tcp", + flags="", + packet_len="", + action_id=1, + expires=datetime.now() + timedelta(days=1), + user_id=user_id, + org_id=org_id, + rstate_id=1, + ) + db.session.add(rule) + db.session.commit() + return rule + + +def _make_rtbh(db, ipv4="147.230.1.2", user_id=1, org_id=1): + rule = models.RTBH( + ipv4=ipv4, + ipv4_mask=32, + ipv6="", + ipv6_mask=0, + community_id=1, + expires=datetime.now() + timedelta(days=1), + user_id=user_id, + org_id=org_id, + rstate_id=1, + ) + db.session.add(rule) + db.session.commit() + return rule + + +def _make_whitelist(db, ip="147.230.0.0", user_id=1, org_id=1): + rule = models.Whitelist( + ip=ip, + mask=16, + expires=datetime.now() + timedelta(days=1), + user_id=user_id, + org_id=org_id, + ) + db.session.add(rule) + db.session.commit() + return rule + + +class TestGetUserAllowedRuleIdsAdmin: + """Admin (role_id=3) should get all rule IDs regardless of network.""" + + def test_ipv4_admin(self, app, db): + rule = _make_ipv4(db) + ids = get_user_allowed_rule_ids("ipv4", user_id=1, user_role_ids=[3]) + assert rule.id in ids + + def test_ipv6_admin(self, app, db): + rule = _make_ipv6(db) + ids = get_user_allowed_rule_ids("ipv6", user_id=1, user_role_ids=[3]) + assert rule.id in ids + + def test_rtbh_admin(self, app, db): + rule = _make_rtbh(db) + ids = get_user_allowed_rule_ids("rtbh", user_id=1, user_role_ids=[3]) + assert rule.id in ids + + def test_whitelist_admin(self, app, db): + rule = _make_whitelist(db) + ids = get_user_allowed_rule_ids("whitelist", user_id=1, user_role_ids=[3]) + assert rule.id in ids + + def test_unknown_type_returns_empty(self, app, db): + ids = get_user_allowed_rule_ids("unknown", user_id=1, user_role_ids=[3]) + assert ids == [] + + +class TestGetUserAllowedRuleIdsRegularUser: + """Regular user (role_id=2) should only see rules within their org's network ranges.""" + + def test_ipv4_in_range(self, app, db): + # 147.230.0.0/16 is in org 1's arange (set in conftest) + rule = _make_ipv4(db, source="147.230.1.1") + ids = get_user_allowed_rule_ids("ipv4", user_id=1, user_role_ids=[2]) + assert rule.id in ids + + def test_ipv4_outside_range(self, app, db): + rule = _make_ipv4(db, source="1.2.3.4") + ids = get_user_allowed_rule_ids("ipv4", user_id=1, user_role_ids=[2]) + assert rule.id not in ids + + def test_ipv6_in_range(self, app, db): + rule = _make_ipv6(db, source="2001:718:1c01::1") + ids = get_user_allowed_rule_ids("ipv6", user_id=1, user_role_ids=[2]) + assert rule.id in ids + + def test_rtbh_in_range(self, app, db): + rule = _make_rtbh(db, ipv4="147.230.1.3") + ids = get_user_allowed_rule_ids("rtbh", user_id=1, user_role_ids=[2]) + assert rule.id in ids + + def test_unknown_type_returns_empty(self, app, db): + ids = get_user_allowed_rule_ids("unknown", user_id=1, user_role_ids=[2]) + assert ids == [] diff --git a/tests/test_flowapp.py b/tests/test_flowapp.py index f71e4ae..d962757 100644 --- a/tests/test_flowapp.py +++ b/tests/test_flowapp.py @@ -12,3 +12,58 @@ def test_dashboard(auth_client): # Check that the request is successful and renders the correct template assert response.status_code == 200 # Expecting a 200 OK if the user is authenticated + + +def test_select_org_renders_modal(auth_client): + response = auth_client.get("/select_org") + assert response.status_code == 200 + + +def test_select_org_with_valid_org(auth_client): + # org_id=1 is the org the test user belongs to (set up in conftest) + response = auth_client.get("/select_org/1") + assert response.status_code == 302 + assert response.headers["Location"] == "/" + + +def test_select_org_with_invalid_org(auth_client): + # org_id=999 does not exist — user has no access, should redirect to index + response = auth_client.get("/select_org/999") + assert response.status_code == 302 + assert response.headers["Location"] == "/" + + +def test_enrich_rules_with_whitelist_info_empty(app): + from flowapp.views.dashboard import enrich_rules_with_whitelist_info + rules, ids = enrich_rules_with_whitelist_info([], "ipv4") + assert rules == [] + assert ids == set() + + +def test_enrich_rules_with_whitelist_info_no_cache(app, db): + from datetime import datetime, timedelta + from flowapp.views.dashboard import enrich_rules_with_whitelist_info + import flowapp.models as models + + rule = models.Flowspec4( + source="10.5.0.1", + source_mask=32, + source_port="", + destination="", + destination_mask=None, + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + action_id=1, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + ) + db.session.add(rule) + db.session.commit() + + rules, whitelist_ids = enrich_rules_with_whitelist_info([rule], "ipv4") + assert rule in rules + assert rule.id not in whitelist_ids diff --git a/tests/test_messages.py b/tests/test_messages.py new file mode 100644 index 0000000..a9e7345 --- /dev/null +++ b/tests/test_messages.py @@ -0,0 +1,166 @@ +import flowapp.models as models +from flowapp.constants import ANNOUNCE +from flowapp.messages import format_tcp_flags, format_fragment, create_rtbh + + +# --- format_tcp_flags --- + + +def test_tcp_flags_v4_single(): + assert format_tcp_flags("SYN", 4) == "tcp-flags SYN;" + + +def test_tcp_flags_v4_multiple(): + assert format_tcp_flags("SYN ACK FIN", 4) == "tcp-flags SYN ACK FIN;" + + +def test_tcp_flags_v4_default(): + # default version is 4 + assert format_tcp_flags("SYN ACK") == "tcp-flags SYN ACK;" + + +def test_tcp_flags_v5_single(): + assert format_tcp_flags("SYN", 5) == "tcp-flags [ syn ];" + + +def test_tcp_flags_v5_multiple(): + assert format_tcp_flags("SYN ACK FIN", 5) == "tcp-flags [ syn ack fin ];" + + +def test_tcp_flags_v5_already_lowercase(): + assert format_tcp_flags("syn ack", 5) == "tcp-flags [ syn ack ];" + + +# --- format_fragment --- + + +def test_fragment_v4_single(): + assert format_fragment("is-fragment", 4) == "fragment [ is-fragment ];" + + +def test_fragment_v4_multiple(): + assert format_fragment("is-fragment dont-fragment", 4) == "fragment [ is-fragment dont-fragment ];" + + +def test_fragment_v4_default(): + assert format_fragment("dont-fragment") == "fragment [ dont-fragment ];" + + +def test_fragment_v5_is_fragment(): + assert format_fragment("is-fragment", 5) == "fragment [ is-fragment ];" + + +def test_fragment_v5_dont_fragment(): + assert format_fragment("dont-fragment", 5) == "fragment [ dont-fragment ];" + + +def test_fragment_v5_not_a_fragment(): + # "not" key maps to "!is-fragment" in IPV4_FRAGMENT_V5 + assert format_fragment("not", 5) == "fragment [ !is-fragment ];" + + +def test_fragment_v5_unknown_passthrough(): + # unknown values pass through unchanged + assert format_fragment("first-fragment", 5) == "fragment [ first-fragment ];" + + +class TestCreateRtbh: + def test_as_path_match_found(self, app, db): + """create_rtbh includes as-path string when ASPath record matches the source IP.""" + # Create a community with as_path enabled + community = models.Community( + name="test-aspath-comm", + comm="65535:65283", + larcomm="", + extcomm="", + description="", + as_path=True, + role_id=2, + ) + db.session.add(community) + + # Create an ASPath record matching the RTBH source + aspath = models.ASPath() + aspath.prefix = "147.230.1.99/32" + aspath.as_path = "64512 64513" + db.session.add(aspath) + db.session.commit() + + rule = models.RTBH( + ipv4="147.230.1.99", + ipv4_mask=32, + ipv6="", + ipv6_mask=0, + community_id=community.id, + expires=None, + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(rule) + db.session.commit() + + msg = create_rtbh(rule, ANNOUNCE) + assert "as-path [ 64512 64513 ]" in msg + + def test_as_path_no_match(self, app, db): + """create_rtbh omits as-path string when no ASPath record matches.""" + community = models.Community( + name="test-aspath-comm-nomatch", + comm="65535:65283", + larcomm="", + extcomm="", + description="", + as_path=True, + role_id=2, + ) + db.session.add(community) + db.session.commit() + + rule = models.RTBH( + ipv4="10.0.0.1", + ipv4_mask=32, + ipv6="", + ipv6_mask=0, + community_id=community.id, + expires=None, + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(rule) + db.session.commit() + + msg = create_rtbh(rule, ANNOUNCE) + assert "as-path" not in msg + + def test_no_as_path_community(self, app, db): + """create_rtbh skips DB query entirely when community.as_path is False.""" + community = models.Community( + name="test-no-aspath-comm", + comm="65535:65283", + larcomm="", + extcomm="", + description="", + as_path=False, + role_id=2, + ) + db.session.add(community) + db.session.commit() + + rule = models.RTBH( + ipv4="147.230.2.1", + ipv4_mask=32, + ipv6="", + ipv6_mask=0, + community_id=community.id, + expires=None, + user_id=1, + org_id=1, + rstate_id=1, + ) + db.session.add(rule) + db.session.commit() + + msg = create_rtbh(rule, ANNOUNCE) + assert "as-path" not in msg diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py new file mode 100644 index 0000000..9171b35 --- /dev/null +++ b/tests/test_model_utils.py @@ -0,0 +1,254 @@ +""" +Tests for flowapp.models.utils — DB-touching utility functions. +""" +from datetime import datetime, timedelta + +import pytest +import flowapp.models as models +from flowapp.constants import RuleTypes +from flowapp.models.utils import ( + check_rule_limit, + check_global_rule_limit, + get_ip_rules, + get_user_rules_ids, + get_user_actions, + get_user_communities, + get_existing_action, + get_existing_community, + get_user_nets, + get_user_orgs_choices, + insert_user, +) + + +# ── helpers ────────────────────────────────────────────────────────────────── + + +def _make_ipv4(db, source="147.230.10.1", rstate_id=1, user_id=1, org_id=1): + rule = models.Flowspec4( + source=source, + source_mask=32, + source_port="", + destination="", + destination_mask=None, + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + action_id=1, + expires=datetime.now() + timedelta(days=1), + user_id=user_id, + org_id=org_id, + rstate_id=rstate_id, + ) + db.session.add(rule) + db.session.commit() + return rule + + +def _make_rtbh(db, ipv4="147.230.10.2", rstate_id=1, user_id=1, org_id=1): + rule = models.RTBH( + ipv4=ipv4, + ipv4_mask=32, + ipv6="", + ipv6_mask=0, + community_id=1, + expires=datetime.now() + timedelta(days=1), + user_id=user_id, + org_id=org_id, + rstate_id=rstate_id, + ) + db.session.add(rule) + db.session.commit() + return rule + + +# ── check_rule_limit ───────────────────────────────────────────────────────── + + +def test_check_rule_limit_no_limit_set(app, db): + """Returns False when org has no per-org limit set (limit=0).""" + result = check_rule_limit(1, RuleTypes.IPv4) + assert result is False + + +def test_check_rule_limit_under_limit(app, db): + from sqlalchemy import select + from flowapp.models.organization import Organization + + org = db.session.execute(select(Organization).filter_by(id=1)).scalar_one() + org.limit_flowspec4 = 100 + db.session.commit() + + original_max = app.config.get("FLOWSPEC4_MAX_RULES", 9000) + app.config["FLOWSPEC4_MAX_RULES"] = 9000 + try: + result = check_rule_limit(1, RuleTypes.IPv4) + assert result is False + finally: + org.limit_flowspec4 = 0 + db.session.commit() + app.config["FLOWSPEC4_MAX_RULES"] = original_max + + +def test_check_rule_limit_at_limit(app, db): + from sqlalchemy import select + from flowapp.models.organization import Organization + + rule = _make_ipv4(db) + org = db.session.execute(select(Organization).filter_by(id=1)).scalar_one() + org.limit_flowspec4 = 1 + db.session.commit() + + result = check_rule_limit(1, RuleTypes.IPv4) + assert result is True + + org.limit_flowspec4 = 0 + db.session.commit() + + +# ── check_global_rule_limit ────────────────────────────────────────────────── + + +def test_check_global_rule_limit_not_reached(app, db): + original = app.config.get("FLOWSPEC4_MAX_RULES", 9000) + app.config["FLOWSPEC4_MAX_RULES"] = 9000 + try: + result = check_global_rule_limit(RuleTypes.IPv4) + assert result is False + finally: + app.config["FLOWSPEC4_MAX_RULES"] = original + + +def test_check_global_rule_limit_reached(app, db): + original = app.config.get("FLOWSPEC4_MAX_RULES", 9000) + app.config["FLOWSPEC4_MAX_RULES"] = 0 + try: + result = check_global_rule_limit(RuleTypes.IPv4) + assert result is True + finally: + app.config["FLOWSPEC4_MAX_RULES"] = original + + +# ── get_ip_rules ───────────────────────────────────────────────────────────── + + +def test_get_ip_rules_ipv4_returns_list(app, db): + _make_ipv4(db) + rules = get_ip_rules("ipv4", "active") + assert isinstance(rules, list) + assert all(isinstance(r, models.Flowspec4) for r in rules) + + +def test_get_ip_rules_rtbh_returns_list(app, db): + _make_rtbh(db) + rules = get_ip_rules("rtbh", "active") + assert isinstance(rules, list) + assert all(isinstance(r, models.RTBH) for r in rules) + + +def test_get_ip_rules_unknown_type_returns_empty(app, db): + result = get_ip_rules("unknown", "active") + assert result == [] + + +def test_get_ip_rules_paginate(app, db): + _make_ipv4(db) + result = get_ip_rules("ipv4", "active", paginate=True) + assert isinstance(result, tuple) + items, pagination = result + assert isinstance(items, list) + + +# ── get_user_rules_ids ─────────────────────────────────────────────────────── + + +def test_get_user_rules_ids_ipv4(app, db): + rule = _make_ipv4(db, user_id=1) + ids = get_user_rules_ids(1, "ipv4") + assert rule.id in ids + + +def test_get_user_rules_ids_rtbh(app, db): + rule = _make_rtbh(db, user_id=1) + ids = get_user_rules_ids(1, "rtbh") + assert rule.id in ids + + +# ── get_user_actions / get_user_communities ─────────────────────────────────── + + +def test_get_user_actions_admin(app, db): + choices = get_user_actions([3]) + assert len(choices) > 0 + assert all(isinstance(c, tuple) and len(c) == 2 for c in choices) + + +def test_get_user_actions_regular(app, db): + choices = get_user_actions([2]) + assert isinstance(choices, list) + + +def test_get_user_communities_admin(app, db): + choices = get_user_communities([3]) + assert len(choices) > 0 + assert all(isinstance(c, tuple) and len(c) == 2 for c in choices) + + +def test_get_user_communities_regular(app, db): + choices = get_user_communities([2]) + assert isinstance(choices, list) + + +# ── get_existing_action / get_existing_community ───────────────────────────── + + +def test_get_existing_action_found(app, db): + result = get_existing_action(name="Discard") + assert result is not None + + +def test_get_existing_action_not_found(app, db): + result = get_existing_action(name="nonexistent_action_xyz") + assert result is None + + +def test_get_existing_community_found(app, db): + result = get_existing_community(name="65535:65283") + assert result is not None + + +def test_get_existing_community_not_found(app, db): + result = get_existing_community(name="nonexistent_community_xyz") + assert result is None + + +# ── get_user_nets / get_user_orgs_choices ───────────────────────────────────── + + +def test_get_user_nets(app, db): + nets = get_user_nets(1) + assert isinstance(nets, list) + assert len(nets) > 0 + assert any("147.230" in n for n in nets) + + +def test_get_user_orgs_choices(app, db): + choices = get_user_orgs_choices(1) + assert isinstance(choices, list) + assert len(choices) > 0 + assert all(isinstance(c, tuple) and len(c) == 2 for c in choices) + + +# ── insert_user ─────────────────────────────────────────────────────────────── + + +def test_insert_user(app, db): + insert_user(uuid="new.test.user@test.cz", role_ids=[2], org_ids=[1]) + from sqlalchemy import select + from flowapp.models.user import User + + user = db.session.execute(select(User).filter_by(uuid="new.test.user@test.cz")).scalar_one_or_none() + assert user is not None + assert any(r.id == 2 for r in user.role) diff --git a/tests/test_models.py b/tests/test_models.py index b352a5f..ba732be 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,11 +6,11 @@ ApiKey, MachineApiKey, Rstate, - Community, - Action, Flowspec6, Whitelist, ) +from flowapp.models.community import Community +from flowapp.models.rules.base import Action import flowapp.models as models @@ -496,3 +496,188 @@ def test_whitelist_to_dict(db): assert dict_timestamp["mask"] == 32 assert dict_timestamp["comment"] == "Test whitelist" assert dict_timestamp["user"] == "test-user-whitelist" + + +class TestRuleWhitelistCache: + def _make_whitelist(self, db): + wl = Whitelist( + ip="10.0.0.0", + mask=8, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + ) + db.session.add(wl) + db.session.commit() + return wl + + def _make_cache(self, db, wl_id, rule_id=42): + from flowapp.constants import RuleTypes, RuleOrigin + cache = models.RuleWhitelistCache( + rid=rule_id, + rtype=RuleTypes.RTBH, + whitelist_id=wl_id, + rorigin=RuleOrigin.USER, + ) + db.session.add(cache) + db.session.commit() + return cache + + def test_get_by_whitelist_id(self, db): + wl = self._make_whitelist(db) + cache = self._make_cache(db, wl.id, rule_id=100) + result = models.RuleWhitelistCache.get_by_whitelist_id(wl.id) + assert any(c.id == cache.id for c in result) + + def test_count_by_rule(self, db): + from flowapp.constants import RuleTypes + wl = self._make_whitelist(db) + self._make_cache(db, wl.id, rule_id=200) + count = models.RuleWhitelistCache.count_by_rule(200, RuleTypes.RTBH) + assert count == 1 + + def test_delete_by_rule_id(self, db): + wl = self._make_whitelist(db) + self._make_cache(db, wl.id, rule_id=300) + deleted = models.RuleWhitelistCache.delete_by_rule_id(300) + assert deleted >= 1 + from flowapp.constants import RuleTypes + assert models.RuleWhitelistCache.count_by_rule(300, RuleTypes.RTBH) == 0 + + def test_clean_by_whitelist_id(self, db): + wl = self._make_whitelist(db) + self._make_cache(db, wl.id, rule_id=400) + deleted = models.RuleWhitelistCache.clean_by_whitelist_id(wl.id) + assert deleted >= 1 + result = models.RuleWhitelistCache.get_by_whitelist_id(wl.id) + assert result == [] + + +def test_get_whitelistable_communities(db): + # Communities are seeded at DB creation — id=1 exists (65535:65283) + result = Community.get_whitelistable_communities([1]) + assert len(result) == 1 + assert result[0].id == 1 + + +def test_get_whitelistable_communities_empty_list(db): + result = Community.get_whitelistable_communities([]) + assert result == [] + + +def test_get_whitelistable_communities_nonexistent(db): + result = Community.get_whitelistable_communities([99999]) + assert result == [] + + +def test_action_get_all_ordered_returns_seeded(db): + result = Action.get_all_ordered() + assert len(result) >= 1 + names = [a.name for a in result] + assert names == sorted(names) + + +def test_action_get_all_returns_seeded(db): + result = Action.get_all() + assert len(result) >= 1 + + +def test_apikey_get_by_user_id_returns_keys(db): + key = ApiKey(machine="127.0.0.1", key="testkey-uid-1", user_id=1, org_id=1) + db.session.add(key) + db.session.commit() + result = ApiKey.get_by_user_id(1) + assert any(k.key == "testkey-uid-1" for k in result) + + +def test_apikey_get_by_user_id_empty(db): + result = ApiKey.get_by_user_id(99999) + assert result == [] + + +def test_community_get_all_returns_seeded(db): + result = Community.get_all() + assert len(result) >= 1 + + +class TestRuleWhitelistCacheGetByRuleIds: + def _make_whitelist(self, db): + wl = Whitelist( + ip="10.1.0.0", + mask=16, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + ) + db.session.add(wl) + db.session.commit() + return wl + + def test_get_by_rule_ids_returns_matching(self, db): + from flowapp.constants import RuleTypes, RuleOrigin + wl = self._make_whitelist(db) + cache = models.RuleWhitelistCache( + rid=501, rtype=RuleTypes.IPv4, whitelist_id=wl.id, rorigin=RuleOrigin.WHITELIST + ) + db.session.add(cache) + db.session.commit() + result = models.RuleWhitelistCache.get_by_rule_ids([501], RuleTypes.IPv4) + assert any(c.rid == 501 for c in result) + + def test_get_by_rule_ids_excludes_other_type(self, db): + from flowapp.constants import RuleTypes, RuleOrigin + wl = self._make_whitelist(db) + cache = models.RuleWhitelistCache( + rid=502, rtype=RuleTypes.RTBH, whitelist_id=wl.id, rorigin=RuleOrigin.WHITELIST + ) + db.session.add(cache) + db.session.commit() + result = models.RuleWhitelistCache.get_by_rule_ids([502], RuleTypes.IPv4) + assert not any(c.rid == 502 for c in result) + + def test_get_by_rule_ids_empty_list(self, db): + from flowapp.constants import RuleTypes + result = models.RuleWhitelistCache.get_by_rule_ids([], RuleTypes.IPv4) + assert result == [] + + +class TestUserUpdate: + def _make_form(self, uuid, role_ids, org_ids): + class Field: + def __init__(self, value): + self.data = value + + class Form: + pass + + f = Form() + f.uuid = Field(uuid) + f.name = Field("Test User") + f.email = Field("test@example.com") + f.phone = Field("123") + f.comment = Field("comment") + f.role_ids = Field(role_ids) + f.org_ids = Field(org_ids) + return f + + def test_update_changes_role(self, db): + user = models.User(uuid="update.test.user@test.cz") + db.session.add(user) + db.session.commit() + + form = self._make_form("update.test.user@test.cz", role_ids=[2], org_ids=[1]) + user.update(form) + + role_ids = [r.id for r in user.role] + assert 2 in role_ids + + def test_update_changes_org(self, db): + user = models.User(uuid="update.test.org.user@test.cz") + db.session.add(user) + db.session.commit() + + form = self._make_form("update.test.org.user@test.cz", role_ids=[2], org_ids=[1]) + user.update(form) + + org_ids = [o.id for o in user.organization] + assert 1 in org_ids diff --git a/tests/test_services_base.py b/tests/test_services_base.py new file mode 100644 index 0000000..cf16f38 --- /dev/null +++ b/tests/test_services_base.py @@ -0,0 +1,102 @@ +""" +Tests for flowapp.services.base and flowapp.services.rule_service (delete_expired_rules) +""" +from datetime import datetime, timedelta +from unittest.mock import patch + +import flowapp.models as models +from flowapp.constants import ANNOUNCE, WITHDRAW +from flowapp.services.base import announce_all_routes +from flowapp.services.rule_service import delete_expired_rules + + +def _make_ipv4(db, source="147.230.20.1", expires_delta=1, rstate_id=1): + rule = models.Flowspec4( + source=source, + source_mask=32, + source_port="", + destination="", + destination_mask=None, + destination_port="", + protocol="tcp", + flags="", + packet_len="", + fragment="", + action_id=1, + expires=datetime.now() + timedelta(days=expires_delta), + user_id=1, + org_id=1, + rstate_id=rstate_id, + ) + db.session.add(rule) + db.session.commit() + return rule + + +def _make_rtbh(db, ipv4="147.230.20.2", expires_delta=1, rstate_id=1): + rule = models.RTBH( + ipv4=ipv4, + ipv4_mask=32, + ipv6="", + ipv6_mask=0, + community_id=1, + expires=datetime.now() + timedelta(days=expires_delta), + user_id=1, + org_id=1, + rstate_id=rstate_id, + ) + db.session.add(rule) + db.session.commit() + return rule + + +@patch("flowapp.services.base.announce_route") +def test_announce_all_routes_calls_announce(mock_announce, app, db): + _make_ipv4(db) + _make_rtbh(db) + announce_all_routes(ANNOUNCE) + assert mock_announce.called + + +@patch("flowapp.services.base.announce_route") +@patch("flowapp.services.base.messages") +def test_announce_all_routes_withdraw_sets_rstate(mock_messages, mock_announce, app, db): + mock_messages.create_ipv4.return_value = "mock" + mock_messages.create_ipv6.return_value = "mock" + mock_messages.create_rtbh.return_value = "mock" + rule = _make_ipv4(db, expires_delta=-1) # expired rule + announce_all_routes(WITHDRAW) + db.session.refresh(rule) + assert rule.rstate_id == 2 + + +@patch("flowapp.services.base.announce_route") +def test_announce_all_routes_skips_inactive(mock_announce, app, db): + # Announce with only a withdrawn rule — confirm the rule's source never appears in commands + rule = _make_ipv4(db, source="147.230.20.99", rstate_id=2) + announce_all_routes(ANNOUNCE) + commands = [call.args[0].command for call in mock_announce.call_args_list] + assert not any("147.230.20.99" in cmd for cmd in commands) + + +def test_delete_expired_rules_removes_old_withdrawn(app, db): + # Create a rule that is withdrawn and old enough to be deleted + rule = _make_ipv4(db, source="147.230.20.50", expires_delta=-40, rstate_id=2) + rule_id = rule.id + + counts = delete_expired_rules() + + assert counts["ipv4"] >= 1 + assert counts["total"] >= 1 + # Rule should no longer exist + assert db.session.get(models.Flowspec4, rule_id) is None + + +def test_delete_expired_rules_keeps_active(app, db): + # Active rules (rstate_id=1) should never be deleted regardless of age + rule = _make_ipv4(db, source="147.230.20.51", expires_delta=-40, rstate_id=1) + rule_id = rule.id + + delete_expired_rules() + + assert db.session.get(models.Flowspec4, rule_id) is not None diff --git a/tests/test_whitelist_service.py b/tests/test_whitelist_service.py index 460a172..a11cf10 100644 --- a/tests/test_whitelist_service.py +++ b/tests/test_whitelist_service.py @@ -459,3 +459,41 @@ def test_supernet_relation(self, app): # Verify the correct model was returned assert result == whitelist_model + + +class TestDeleteExpiredWhitelists: + def test_deletes_expired_whitelist(self, app, db): + from flowapp.services.whitelist_service import delete_expired_whitelists + + expired = Whitelist( + ip="172.16.0.0", + mask=12, + expires=datetime.now() - timedelta(days=1), + user_id=1, + org_id=1, + ) + db.session.add(expired) + db.session.commit() + wl_id = expired.id + + delete_expired_whitelists() + + assert db.session.get(Whitelist, wl_id) is None + + def test_keeps_active_whitelist(self, app, db): + from flowapp.services.whitelist_service import delete_expired_whitelists + + active = Whitelist( + ip="172.17.0.0", + mask=16, + expires=datetime.now() + timedelta(days=1), + user_id=1, + org_id=1, + ) + db.session.add(active) + db.session.commit() + wl_id = active.id + + delete_expired_whitelists() + + assert db.session.get(Whitelist, wl_id) is not None diff --git a/tests/test_zzz_api_rtbh_expired_bug.py b/tests/test_zzz_api_rtbh_expired_bug.py index 6fd36a2..8a6be84 100644 --- a/tests/test_zzz_api_rtbh_expired_bug.py +++ b/tests/test_zzz_api_rtbh_expired_bug.py @@ -1,6 +1,7 @@ import json from datetime import datetime, timedelta +from sqlalchemy import delete, func, select from flowapp.models import RTBH from flowapp.models.rules.whitelist import Whitelist @@ -42,7 +43,7 @@ def test_create_rtbh_after_expired_rule_exists(client, app, db, jwt_token): # Verify the first rule is in withdrawn state with app.app_context(): - expired_rule = db.session.query(RTBH).filter_by(id=rule_id_1).first() + expired_rule = db.session.execute(select(RTBH).filter_by(id=rule_id_1)).scalar_one() assert expired_rule is not None assert expired_rule.rstate_id == 2, "Expired rule should be in withdrawn state (rstate_id=2)" assert expired_rule.ipv4 == "192.168.100.50" @@ -74,9 +75,9 @@ def test_create_rtbh_after_expired_rule_exists(client, app, db, jwt_token): # OR if a new rule is created, it has the wrong state # Check if it's the same rule (updated) or a new rule - total_rules = db.session.query(RTBH).filter_by(ipv4="192.168.100.50", ipv4_mask=32).count() + total_rules = db.session.scalar(select(func.count()).select_from(RTBH).filter_by(ipv4="192.168.100.50", ipv4_mask=32)) - new_rule = db.session.query(RTBH).filter_by(id=rule_id_2).first() + new_rule = db.session.execute(select(RTBH).filter_by(id=rule_id_2)).scalar_one() assert new_rule is not None print("\n--- Bug Verification ---") @@ -153,7 +154,7 @@ def test_create_rtbh_after_expired_rule_different_mask(client, app, db, jwt_toke # Verify the new rule is active (this should work because IP+mask don't match) with app.app_context(): - new_rule = db.session.query(RTBH).filter_by(id=data2["rule"]["id"]).first() + new_rule = db.session.execute(select(RTBH).filter_by(id=data2["rule"]["id"])).scalar_one() assert new_rule is not None assert new_rule.rstate_id == 1, "New rule with different mask should be active" print("✓ Different mask creates new active rule correctly") @@ -188,7 +189,7 @@ def test_create_rtbh_after_active_rule_exists(client, app, db, jwt_token): # Verify the first rule is active with app.app_context(): - first_rule = db.session.query(RTBH).filter_by(id=rule_id_1).first() + first_rule = db.session.execute(select(RTBH).filter_by(id=rule_id_1)).scalar_one() assert first_rule.rstate_id == 1, "First rule should be active" # Step 2: Update the same rule with a new expiration @@ -211,7 +212,7 @@ def test_create_rtbh_after_active_rule_exists(client, app, db, jwt_token): # Verify it maintains active state with app.app_context(): - updated_rule = db.session.query(RTBH).filter_by(id=data2["rule"]["id"]).first() + updated_rule = db.session.execute(select(RTBH).filter_by(id=data2["rule"]["id"])).scalar_one() assert updated_rule is not None assert updated_rule.rstate_id == 1, "Updated rule should remain active" print("✓ Updating active rule maintains active state correctly") @@ -224,8 +225,8 @@ def cleanup_before_stack(app, db): Cleanup function to remove all RTBH rules created during tests. """ with app.app_context(): - db.session.query(RTBH).delete() - db.session.query(Whitelist).delete() + db.session.execute(delete(RTBH)) + db.session.execute(delete(Whitelist)) db.session.commit()