diff --git a/src/vr/__init__.py b/src/vr/__init__.py index a35d9183..42a86ea1 100644 --- a/src/vr/__init__.py +++ b/src/vr/__init__.py @@ -182,9 +182,8 @@ def base64encode(value): ## Release-based updates ## -cwd = os.getcwd() -createNewTables(DB_URI) -print() +createNewTables(app) + ## Cronjob-like tasks section ## def train_model_every_six_hours(): scheduler = BackgroundScheduler() diff --git a/src/vr/db_models/updates.py b/src/vr/db_models/updates.py index d6028a0f..93106286 100644 --- a/src/vr/db_models/updates.py +++ b/src/vr/db_models/updates.py @@ -1,56 +1,92 @@ -from flask_sqlalchemy import SQLAlchemy -from flask import Flask +import mysql.connector +import sqlite3 +import os -def createNewTables(db_uri): - mock_app = Flask(__name__) - # Example database URI, replace it with your actual database URI - mock_app.config['SQLALCHEMY_DATABASE_URI'] = db_uri - mock_app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False - db = SQLAlchemy(mock_app) +def get_client(app): + if app.config['RUNTIME_ENV'] == 'test': + cur_path = os.getcwd() + if 'www' in cur_path and 'html' in cur_path: + db_uri = '/var/www/html/src/instance/database.db' + else: + db_uri = 'instance/database.db' + db = sqlite3.connect(db_uri) + cur = db.cursor() + return cur, db + else: + db_uri = app.config['SQLALCHEMY_DATABASE_URI'] + main_part = db_uri.split('://')[1] + un = main_part.split(':', 1)[0] + db_name = main_part.rsplit('/', 1)[1] + host_and_port = main_part.rsplit('@', 1)[1].replace(f"/{db_name}", '') + host = host_and_port.split(':')[0] + port = int(host_and_port.split(':')[1]) + pw = main_part.split(':', 1)[1].replace(f"@{host}", '').replace(f"/{db_name}", '').replace(f":{port}", "") + db = mysql.connector.connect(host=host, database=db_name, user=un, password=pw, port=port) + cur = db.cursor() + return cur, db - class AppConfig(db.Model): - __tablename__ = 'AppConfig' - __table_args__ = {'extend_existing': True} - id = db.Column(db.Integer, primary_key=True) - first_access = db.Column(db.Boolean, nullable=False, default=True) - settings_initialized = db.Column(db.Boolean, nullable=False, default=False) - APP_EXT_URL = db.Column(db.String(200)) - AUTH_TYPE = db.Column(db.String(200)) - AZAD_AUTHORITY = db.Column(db.String(200)) - AZAD_CLIENT_ID = db.Column(db.String(200)) - AZAD_CLIENT_SECRET = db.Column(db.String(200)) - AZURE_KEYVAULT_NAME = db.Column(db.String(200)) - ENV = db.Column(db.String(200)) - INSECURE_OAUTH = db.Column(db.String(200)) - JENKINS_HOST = db.Column(db.String(200)) - JENKINS_KEY = db.Column(db.String(200)) - JENKINS_PROJECT = db.Column(db.String(200)) - JENKINS_STAGING_PROJECT = db.Column(db.String(200)) - JENKINS_TOKEN = db.Column(db.String(200)) - JENKINS_USER = db.Column(db.String(200)) - LDAP_BASE_DN = db.Column(db.String(200)) - LDAP_BIND_USER_DN = db.Column(db.String(200)) - LDAP_BIND_USER_PASSWORD = db.Column(db.String(200)) - LDAP_GROUP_DN = db.Column(db.String(200)) - LDAP_HOST = db.Column(db.String(200)) - LDAP_PORT = db.Column(db.String(200)) - LDAP_USER_DN = db.Column(db.String(200)) - LDAP_USER_LOGIN_ATTR = db.Column(db.String(200)) - LDAP_USER_RDN_ATTR = db.Column(db.String(200)) - PROD_DB_URI = db.Column(db.String(200)) - SMTP_ADMIN_EMAIL = db.Column(db.String(200)) - SMTP_HOST = db.Column(db.String(200)) - SMTP_PASSWORD = db.Column(db.String(200)) - SMTP_USER = db.Column(db.String(200)) - SNOW_CLIENT_ID = db.Column(db.String(200)) - SNOW_CLIENT_SECRET = db.Column(db.String(200)) - SNOW_INSTANCE_NAME = db.Column(db.String(200)) - SNOW_PASSWORD = db.Column(db.String(200)) - SNOW_USERNAME = db.Column(db.String(200)) - VERSION = db.Column(db.String(200)) - JENKINS_ENABLED = db.Column(db.String(200)) - SNOW_ENABLED = db.Column(db.String(200)) - with mock_app.app_context(): - db.create_all() +def createNewTables(app): + cur, db = get_client(app) + if app.config['RUNTIME_ENV'] == 'test': + sql = "PRAGMA table_info('AppConfig')" + else: + sql = "SELECT column_name FROM information_schema.columns WHERE table_schema = 'vulnremediator' AND table_name = 'AppConfig'" + cur.execute(sql) + rows = cur.fetchall() + fields = [] + for i in rows: + fields.append(i[1]) + new_fields = [ + {"name": "APP_EXT_URL", "type": "VARCHAR", "char_num": 200}, + {"name": "AUTH_TYPE", "type": "VARCHAR", "char_num": 200}, + {"name": "AZAD_AUTHORITY", "type": "VARCHAR", "char_num": 200}, + {"name": "AZAD_CLIENT_ID", "type": "VARCHAR", "char_num": 200}, + {"name": "AZAD_CLIENT_SECRET", "type": "VARCHAR", "char_num": 200}, + {"name": "AZURE_KEYVAULT_NAME", "type": "VARCHAR", "char_num": 200}, + {"name": "ENV", "type": "VARCHAR", "char_num": 200}, + {"name": "INSECURE_OAUTH", "type": "VARCHAR", "char_num": 200}, + {"name": "JENKINS_HOST", "type": "VARCHAR", "char_num": 200}, + {"name": "JENKINS_KEY", "type": "VARCHAR", "char_num": 200}, + {"name": "JENKINS_PROJECT", "type": "VARCHAR", "char_num": 200}, + {"name": "JENKINS_STAGING_PROJECT", "type": "VARCHAR", "char_num": 200}, + {"name": "JENKINS_TOKEN", "type": "VARCHAR", "char_num": 200}, + {"name": "JENKINS_USER", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_BASE_DN", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_BIND_USER_DN", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_BIND_USER_PASSWORD", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_GROUP_DN", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_HOST", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_PORT", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_USER_DN", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_USER_LOGIN_ATTR", "type": "VARCHAR", "char_num": 200}, + {"name": "LDAP_USER_RDN_ATTR", "type": "VARCHAR", "char_num": 200}, + {"name": "PROD_DB_URI", "type": "VARCHAR", "char_num": 200}, + {"name": "SMTP_ADMIN_EMAIL", "type": "VARCHAR", "char_num": 200}, + {"name": "SMTP_HOST", "type": "VARCHAR", "char_num": 200}, + {"name": "SMTP_PASSWORD", "type": "VARCHAR", "char_num": 200}, + {"name": "SMTP_USER", "type": "VARCHAR", "char_num": 200}, + {"name": "SNOW_CLIENT_ID", "type": "VARCHAR", "char_num": 200}, + {"name": "SNOW_CLIENT_SECRET", "type": "VARCHAR", "char_num": 200}, + {"name": "SNOW_INSTANCE_NAME", "type": "VARCHAR", "char_num": 200}, + {"name": "SNOW_PASSWORD", "type": "VARCHAR", "char_num": 200}, + {"name": "SNOW_USERNAME", "type": "VARCHAR", "char_num": 200}, + {"name": "VERSION", "type": "VARCHAR", "char_num": 200}, + {"name": "JENKINS_ENABLED", "type": "VARCHAR", "char_num": 200}, + {"name": "SNOW_ENABLED", "type": "VARCHAR", "char_num": 200} + ] + + for i in new_fields: + if i['name'] not in fields: + if app.config['RUNTIME_ENV'] == 'test': + if i['type'] == 'VARCHAR': + var_stmt = f"VARCHAR({i['char_num']})" + sql = "ALTER TABLE AppConfig ADD COLUMN" + i['name'] + var_stmt + else: + if i['type'] == 'VARCHAR': + var_stmt = "TEXT" + sql = "ALTER TABLE AppConfig ADD COLUMN" + i['name'] + var_stmt + cur.execute(sql) + db.commit() +