-
Notifications
You must be signed in to change notification settings - Fork 219
/
cli.py
366 lines (308 loc) · 19 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import os
import string
import re
import argparse
import json
import time
import hashlib
import snowflake.connector
import warnings
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives import serialization
# Set a few global variables here
_snowchange_version = '2.8.0'
_metadata_database_name = 'METADATA'
_metadata_schema_name = 'SNOWCHANGE'
_metadata_table_name = 'CHANGE_HISTORY'
# Define the Jinja expression template class
# snowchange uses Jinja style variable references of the form "{{ variablename }}"
# See https://jinja.palletsprojects.com/en/2.11.x/templates/
# Variable names follow Python variable naming conventions
class JinjaExpressionTemplate(string.Template):
delimiter = '{{ '
pattern = r'''
\{\{[ ](?:
(?P<escaped>\{\{)|
(?P<named>[_A-Za-z][_A-Za-z0-9]*)[ ]\}\}|
(?P<braced>[_A-Za-z][_A-Za-z0-9]*)[ ]\}\}|
(?P<invalid>)
)
'''
def snowchange(root_folder, snowflake_account, snowflake_user, snowflake_role, snowflake_warehouse, snowflake_database, change_history_table_override, vars, create_change_history_table, autocommit, verbose, dry_run):
if dry_run:
print("Running in dry-run mode")
# Password authentication will take priority
if "SNOWFLAKE_PASSWORD" not in os.environ and "SNOWSQL_PWD" not in os.environ: # We will accept SNOWSQL_PWD for now, but it is deprecated
if "SNOWFLAKE_PRIVATE_KEY_PATH" not in os.environ or "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE" not in os.environ:
raise ValueError("Missing environment variable(s). SNOWFLAKE_PASSWORD must be defined for password authentication. SNOWFLAKE_PRIVATE_KEY_PATH and SNOWFLAKE_PRIVATE_KEY_PASSPHRASE must be defined for private key authentication.")
root_folder = os.path.abspath(root_folder)
if not os.path.isdir(root_folder):
raise ValueError("Invalid root folder: %s" % root_folder)
print("snowchange version: %s" % _snowchange_version)
print("Using root folder %s" % root_folder)
print("Using variables %s" % vars)
print("Using Snowflake account %s" % snowflake_account)
print("Using default role %s" % snowflake_role)
print("Using default warehouse %s" % snowflake_warehouse)
print("Using default database %s" % snowflake_database)
# Set default Snowflake session parameters
snowflake_session_parameters = {
"QUERY_TAG": "snowchange %s" % _snowchange_version
}
# TODO: Is there a better way to do this without setting environment variables?
os.environ["SNOWFLAKE_ACCOUNT"] = snowflake_account
os.environ["SNOWFLAKE_USER"] = snowflake_user
os.environ["SNOWFLAKE_ROLE"] = snowflake_role
os.environ["SNOWFLAKE_WAREHOUSE"] = snowflake_warehouse
os.environ["SNOWFLAKE_AUTHENTICATOR"] = 'snowflake'
scripts_skipped = 0
scripts_applied = 0
# Deal with the change history table (create if specified)
change_history_table = get_change_history_table_details(change_history_table_override)
change_history_metadata = fetch_change_history_metadata(change_history_table, snowflake_session_parameters, autocommit, verbose)
if change_history_metadata:
print("Using change history table %s.%s.%s (last altered %s)" % (change_history_table['database_name'], change_history_table['schema_name'], change_history_table['table_name'], change_history_metadata['last_altered']))
elif create_change_history_table:
# Create the change history table (and containing objects) if it don't exist.
if not dry_run:
create_change_history_table_if_missing(change_history_table, snowflake_session_parameters, autocommit, verbose)
print("Created change history table %s.%s.%s" % (change_history_table['database_name'], change_history_table['schema_name'], change_history_table['table_name']))
else:
raise ValueError("Unable to find change history table %s.%s.%s" % (change_history_table['database_name'], change_history_table['schema_name'], change_history_table['table_name']))
# Find the max published version
max_published_version = ''
change_history = None
if (dry_run and change_history_metadata) or not dry_run:
change_history = fetch_change_history(change_history_table, snowflake_session_parameters, autocommit, verbose)
if change_history:
max_published_version = change_history[0]
max_published_version_display = max_published_version
if max_published_version_display == '':
max_published_version_display = 'None'
print("Max applied change script version: %s" % max_published_version_display)
# Find all scripts in the root folder (recursively) and sort them correctly
all_scripts = get_all_scripts_recursively(root_folder, verbose)
all_script_names = list(all_scripts.keys())
# Sort scripts such that versioned scripts get applied first and then the repeatable ones.
all_script_names_sorted = sorted_alphanumeric([script for script in all_script_names if script[0] == 'V']) \
+ sorted_alphanumeric([script for script in all_script_names if script[0] == 'R'])
# Loop through each script in order and apply any required changes
for script_name in all_script_names_sorted:
script = all_scripts[script_name]
# Apply a versioned-change script only if the version is newer than the most recent change in the database
# Apply any other scripts, i.e. repeatable scripts, irrespective of the most recent change in the database
if script_name[0] == 'V' and get_alphanum_key(script['script_version']) <= get_alphanum_key(max_published_version):
if verbose:
print("Skipping change script %s because it's older than the most recently applied change (%s)" % (script['script_name'], max_published_version))
scripts_skipped += 1
continue
print("Applying change script %s" % script['script_name'])
if not dry_run:
apply_change_script(script, vars, snowflake_database, change_history_table, snowflake_session_parameters, autocommit, verbose)
scripts_applied += 1
print("Successfully applied %d change scripts (skipping %d)" % (scripts_applied, scripts_skipped))
print("Completed successfully")
# This function will return a list containing the parts of the key (split by number parts)
# Each number is converted to and integer and string parts are left as strings
# This will enable correct sorting in python when the lists are compared
# e.g. get_alphanum_key('1.2.2') results in ['', 1, '.', 2, '.', 2, '']
def get_alphanum_key(key):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = [ convert(c) for c in re.split('([0-9]+)', key) ]
return alphanum_key
def sorted_alphanumeric(data):
return sorted(data, key=get_alphanum_key)
def get_all_scripts_recursively(root_directory, verbose):
all_files = dict()
all_versions = list()
# Walk the entire directory structure recursively
for (directory_path, directory_names, file_names) in os.walk(root_directory):
for file_name in file_names:
file_full_path = os.path.join(directory_path, file_name)
script_name_parts = re.search(r'^([V])(.+)__(.+)\.(?:sql|SQL)$', file_name.strip())
repeatable_script_name_parts = re.search(r'^([R])__(.+)\.(?:sql|SQL)$', file_name.strip())
# Set script type depending on whether it matches the versioned file naming format
if script_name_parts is not None:
script_type = 'V'
if verbose:
print("Versioned file " + file_full_path)
elif repeatable_script_name_parts is not None:
script_type = 'R'
if verbose:
print("Repeatable file " + file_full_path)
else:
if verbose:
print("Ignoring non-change file " + file_full_path)
continue
# Add this script to our dictionary (as nested dictionary)
script = dict()
script['script_name'] = file_name
script['script_full_path'] = file_full_path
script['script_type'] = script_type
script['script_version'] = None if script_type == 'R' else script_name_parts.group(2)
script['script_description'] = (repeatable_script_name_parts.group(2) if script_type == 'R' else script_name_parts.group(3)).replace('_', ' ').capitalize()
all_files[file_name] = script
# Throw an error if the same version exists more than once
if script_type == 'V':
if script['script_version'] in all_versions:
raise ValueError("The script version %s exists more than once (second instance %s)" % (script['script_version'], script['script_full_path']))
all_versions.append(script['script_version'])
return all_files
def execute_snowflake_query(snowflake_database, query, snowflake_session_parameters, autocommit, verbose):
# Password authentication is the default
snowflake_password = None
if os.getenv("SNOWFLAKE_PASSWORD") is not None and os.getenv("SNOWFLAKE_PASSWORD"):
snowflake_password = os.getenv("SNOWFLAKE_PASSWORD")
elif os.getenv("SNOWSQL_PWD") is not None and os.getenv("SNOWSQL_PWD"): # Check legacy/deprecated env variable
snowflake_password = os.getenv("SNOWSQL_PWD")
warnings.warn("The SNOWSQL_PWD environment variable is deprecated and will be removed in a later version of snowchange. Please use SNOWFLAKE_PASSWORD instead.", DeprecationWarning)
if snowflake_password is not None:
if verbose:
print("Proceeding with password authentication")
con = snowflake.connector.connect(
user = os.environ["SNOWFLAKE_USER"],
account = os.environ["SNOWFLAKE_ACCOUNT"],
role = os.environ["SNOWFLAKE_ROLE"],
warehouse = os.environ["SNOWFLAKE_WAREHOUSE"],
database = snowflake_database,
authenticator = os.environ["SNOWFLAKE_AUTHENTICATOR"],
password = snowflake_password,
session_parameters = snowflake_session_parameters
)
# If no password, try private key authentication
elif os.getenv("SNOWFLAKE_PRIVATE_KEY_PATH") is not None and os.getenv("SNOWFLAKE_PRIVATE_KEY_PATH") and os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE") is not None and os.getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE"):
if verbose:
print("Proceeding with private key authentication")
with open(os.environ["SNOWFLAKE_PRIVATE_KEY_PATH"], "rb") as key:
p_key= serialization.load_pem_private_key(
key.read(),
password = os.environ['SNOWFLAKE_PRIVATE_KEY_PASSPHRASE'].encode(),
backend = default_backend()
)
pkb = p_key.private_bytes(
encoding = serialization.Encoding.DER,
format = serialization.PrivateFormat.PKCS8,
encryption_algorithm = serialization.NoEncryption())
con = snowflake.connector.connect(
user = os.environ["SNOWFLAKE_USER"],
account = os.environ["SNOWFLAKE_ACCOUNT"],
role = os.environ["SNOWFLAKE_ROLE"],
warehouse = os.environ["SNOWFLAKE_WAREHOUSE"],
database = snowflake_database,
authenticator = os.environ["SNOWFLAKE_AUTHENTICATOR"],
private_key = pkb,
session_parameters = snowflake_session_parameters
)
else:
raise ValueError("Unable to find connection credentials for private key or password authentication")
if not autocommit:
con.autocommit(False)
if verbose:
print("SQL query: %s" % query)
try:
res = con.execute_string(query)
if not autocommit:
con.commit()
return res
except Exception as e:
if not autocommit:
con.rollback()
raise e
finally:
con.close()
def get_change_history_table_details(change_history_table_override):
# Start with the global defaults
details = dict()
details['database_name'] = _metadata_database_name.upper()
details['schema_name'] = _metadata_schema_name.upper()
details['table_name'] = _metadata_table_name.upper()
# Then override the defaults if requested. The name could be in one, two or three part notation.
if change_history_table_override is not None:
table_name_parts = change_history_table_override.strip().split('.')
if len(table_name_parts) == 1:
details['table_name'] = table_name_parts[0].upper()
elif len(table_name_parts) == 2:
details['table_name'] = table_name_parts[1].upper()
details['schema_name'] = table_name_parts[0].upper()
elif len(table_name_parts) == 3:
details['table_name'] = table_name_parts[2].upper()
details['schema_name'] = table_name_parts[1].upper()
details['database_name'] = table_name_parts[0].upper()
else:
raise ValueError("Invalid change history table name: %s" % change_history_table_override)
return details
def fetch_change_history_metadata(change_history_table, snowflake_session_parameters, autocommit, verbose):
# This should only ever return 0 or 1 rows
query = "SELECT CREATED, LAST_ALTERED FROM {0}.INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA ILIKE '{1}' AND TABLE_NAME ILIKE '{2}'".format(change_history_table['database_name'], change_history_table['schema_name'], change_history_table['table_name'])
results = execute_snowflake_query(change_history_table['database_name'], query, snowflake_session_parameters, autocommit, verbose)
# Collect all the results into a list
change_history_metadata = dict()
for cursor in results:
for row in cursor:
change_history_metadata['created'] = row[0]
change_history_metadata['last_altered'] = row[1]
return change_history_metadata
def create_change_history_table_if_missing(change_history_table, snowflake_session_parameters, autocommit, verbose):
# Create the schema if it doesn't exist
query = "CREATE SCHEMA IF NOT EXISTS {0}".format(change_history_table['schema_name'])
execute_snowflake_query(change_history_table['database_name'], query, snowflake_session_parameters, autocommit, verbose)
# Finally, create the change history table if it doesn't exist
query = "CREATE TABLE IF NOT EXISTS {0}.{1} (VERSION VARCHAR, DESCRIPTION VARCHAR, SCRIPT VARCHAR, SCRIPT_TYPE VARCHAR, CHECKSUM VARCHAR, EXECUTION_TIME NUMBER, STATUS VARCHAR, INSTALLED_BY VARCHAR, INSTALLED_ON TIMESTAMP_LTZ)".format(change_history_table['schema_name'], change_history_table['table_name'])
execute_snowflake_query(change_history_table['database_name'], query, snowflake_session_parameters, autocommit, verbose)
def fetch_change_history(change_history_table, snowflake_session_parameters, autocommit, verbose):
query = "SELECT VERSION FROM {0}.{1} WHERE SCRIPT_TYPE = 'V' ORDER BY INSTALLED_ON DESC LIMIT 1".format(change_history_table['schema_name'], change_history_table['table_name'])
results = execute_snowflake_query(change_history_table['database_name'], query, snowflake_session_parameters, autocommit, verbose)
# Collect all the results into a list
change_history = list()
for cursor in results:
for row in cursor:
change_history.append(row[0])
return change_history
def apply_change_script(script, vars, default_database, change_history_table, snowflake_session_parameters, autocommit, verbose):
# First read the contents of the script
with open(script['script_full_path'],'r') as content_file:
content = content_file.read().strip()
content = content[:-1] if content.endswith(';') else content
# Define a few other change related variables
checksum = hashlib.sha224(content.encode('utf-8')).hexdigest()
execution_time = 0
status = 'Success'
# Replace any variables used in the script content
content = replace_variables_references(content, vars, verbose)
# Execute the contents of the script
if len(content) > 0:
start = time.time()
session_parameters = snowflake_session_parameters.copy()
session_parameters["QUERY_TAG"] += ";%s" % script['script_name']
execute_snowflake_query(default_database, content, session_parameters, autocommit, verbose)
end = time.time()
execution_time = round(end - start)
# Finally record this change in the change history table
query = "INSERT INTO {0}.{1} (VERSION, DESCRIPTION, SCRIPT, SCRIPT_TYPE, CHECKSUM, EXECUTION_TIME, STATUS, INSTALLED_BY, INSTALLED_ON) values ('{2}','{3}','{4}','{5}','{6}',{7},'{8}','{9}',CURRENT_TIMESTAMP);".format(change_history_table['schema_name'], change_history_table['table_name'], script['script_version'], script['script_description'], script['script_name'], script['script_type'], checksum, execution_time, status, os.environ["SNOWFLAKE_USER"])
execute_snowflake_query(change_history_table['database_name'], query, snowflake_session_parameters, autocommit, verbose)
# This method will throw an error if there are any leftover variables in the change script
# Since a leftover variable in the script isn't valid SQL, and will fail when run it's
# better to throw an error here and have the user fix the problem ahead of time.
def replace_variables_references(content, vars, verbose):
t = JinjaExpressionTemplate(content)
return t.substitute(vars)
def main():
parser = argparse.ArgumentParser(prog = 'snowchange', description = 'Apply schema changes to a Snowflake account. Full readme at https://github.com/Snowflake-Labs/snowchange', formatter_class = argparse.RawTextHelpFormatter)
parser.add_argument('-f','--root-folder', type = str, default = ".", help = 'The root folder for the database change scripts', required = False)
parser.add_argument('-a', '--snowflake-account', type = str, help = 'The name of the snowflake account (e.g. xy12345.east-us-2.azure)', required = True)
parser.add_argument('-u', '--snowflake-user', type = str, help = 'The name of the snowflake user', required = True)
parser.add_argument('-r', '--snowflake-role', type = str, help = 'The name of the default role to use', required = True)
parser.add_argument('-w', '--snowflake-warehouse', type = str, help = 'The name of the default warehouse to use. Can be overridden in the change scripts.', required = True)
parser.add_argument('-d', '--snowflake-database', type = str, help = 'The name of the default database to use. Can be overridden in the change scripts.', required = False)
parser.add_argument('-c', '--change-history-table', type = str, help = 'Used to override the default name of the change history table (the default is METADATA.SNOWCHANGE.CHANGE_HISTORY)', required = False)
parser.add_argument('--vars', type = json.loads, help = 'Define values for the variables to replaced in change scripts, given in JSON format (e.g. {"variable1": "value1", "variable2": "value2"})', required = False)
parser.add_argument('--create-change-history-table', action='store_true', help = 'Create the change history schema and table, if they do not exist (the default is False)', required = False)
parser.add_argument('-ac', '--autocommit', action='store_true', help = 'Enable autocommit feature for DML commands (the default is False)', required = False)
parser.add_argument('-v','--verbose', action='store_true', help = 'Display verbose debugging details during execution (the default is False)', required = False)
parser.add_argument('--dry-run', action='store_true', help = 'Run snowchange in dry run mode (the default is False)', required = False)
args = parser.parse_args()
snowchange(args.root_folder, args.snowflake_account, args.snowflake_user, args.snowflake_role, args.snowflake_warehouse, args.snowflake_database, args.change_history_table, args.vars, args.create_change_history_table, args.autocommit, args.verbose, args.dry_run)
if __name__ == "__main__":
main()