From baf0da4a46957285dbda378a2e0bc99e6343cf39 Mon Sep 17 00:00:00 2001 From: xunzhang Date: Wed, 21 Sep 2016 20:10:00 +0800 Subject: [PATCH] HAWQ-1024. Rollback if hawq register failed in process. --- tools/bin/hawqregister | 101 +++++++++++++++++++++++++++++++---------- 1 file changed, 76 insertions(+), 25 deletions(-) diff --git a/tools/bin/hawqregister b/tools/bin/hawqregister index fa23a1ac98..bef1460612 100755 --- a/tools/bin/hawqregister +++ b/tools/bin/hawqregister @@ -112,6 +112,42 @@ def register_yaml_dict_check(D): logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % 'AO_FileLocations.%s' % attr) sys.exit(1) + +class FailureHandler(object): + def __init__(self, conn): + self.operations = [] + self.conn = conn + + def commit(self, cmd): + self.operations.append(cmd) + + def assemble_SQL(self, cmd): + return 'DROP TABLE %s' % cmd[cmd.find('table')+6:cmd.find('(')] + + def assemble_hdfscmd(self, cmd): + lst = cmd.strip().split() + return ' '.join(lst[:-2] + [lst[-1], lst[-2]]) + + def rollback(self): + for (typ, cmd) in reversed(self.operations): + if typ == 'SQL': + sql = self.assemble_SQL(cmd) + try: + self.conn.query(sql) + except pg.DatabaseError as e: + logger.error('Rollback failure: %s.' % sql) + print e + sys.exit(1) + if typ == 'HDFSCMD': + hdfscmd = self.assemble_hdfscmd(cmd) + sys.stdout.write('Rollback hdfscmd: "%s"\n' % hdfscmd) + result = local_ssh(hdfscmd, logger) + if result != 0: + logger.error('Fail to rollback: %s.' % hdfscmd) + sys.exit(1) + logger.info('Hawq Register Rollback Succeed.') + + class GpRegisterAccessor(object): def __init__(self, conn): self.conn = conn @@ -135,21 +171,21 @@ class GpRegisterAccessor(object): def do_create_table(self, src_table_name, tablename, schema_info, fmt, distrbution_policy, file_locations, bucket_number, partitionby, partitions_constraint, partitions_name): if self.get_table_existed(tablename): - return False + return False, '' schema = ','.join([k['name'] + ' ' + k['type'] for k in schema_info]) partlist = "" for index in range(len(partitions_constraint)): - if index > 0: - partlist += ", " - partition_refine_name = partitions_name[index] - splitter = src_table_name.split(".")[-1] + '_1_prt_' - partition_refine_name = partition_refine_name.split(splitter)[-1] - #in some case, constraint contains "partition XXX" but in other case, it doesn't contain. we need to treat them separately. - if partitions_constraint[index].strip().startswith("DEFAULT PARTITION") or partitions_constraint[index].strip().startswith("PARTITION") or (len(partition_refine_name) > 0 and partition_refine_name[0].isdigit()): - partlist = partlist + " " + partitions_constraint[index] - else: - partlist = partlist + "PARTITION " + partition_refine_name + " " + partitions_constraint[index] - + if index > 0: + partlist += ", " + partition_refine_name = partitions_name[index] + splitter = src_table_name.split(".")[-1] + '_1_prt_' + partition_refine_name = partition_refine_name.split(splitter)[-1] + #in some case, constraint contains "partition XXX" but in other case, it doesn't contain. we need to treat them separately. + if partitions_constraint[index].strip().startswith("DEFAULT PARTITION") or partitions_constraint[index].strip().startswith("PARTITION") or (len(partition_refine_name) > 0 and partition_refine_name[0].isdigit()): + partlist = partlist + " " + partitions_constraint[index] + else: + partlist = partlist + "PARTITION " + partition_refine_name + " " + partitions_constraint[index] + fmt = 'ROW' if fmt == 'AO' else fmt if fmt == 'ROW': if partitionby is None: @@ -165,9 +201,8 @@ class GpRegisterAccessor(object): else: query = ('create table %s(%s) with (appendonly=true, orientation=%s, compresstype=%s, compresslevel=%s, pagesize=%s, rowgroupsize=%s, bucketnum=%s) %s %s (%s);' % (tablename, schema, fmt, file_locations['CompressionType'], file_locations['CompressionLevel'], file_locations['PageSize'], file_locations['RowGroupSize'], bucket_number, distrbution_policy, partitionby, partlist)) - print query self.conn.query(query) - return True + return True, query def check_hash_type(self, tablename): qry = """select attrnums from gp_distribution_policy, pg_class where pg_class.relname = '%s' and pg_class.oid = gp_distribution_policy.localoid;""" % tablename @@ -241,7 +276,7 @@ class GpRegisterAccessor(object): class HawqRegister(object): - def __init__(self, options, table, utility_conn, conn): + def __init__(self, options, table, utility_conn, conn, failure_handler): self.yml = options.yml_config self.filepath = options.filepath self.database = options.database @@ -249,6 +284,7 @@ class HawqRegister(object): self.filesize = options.filesize self.accessor = GpRegisterAccessor(conn) self.utility_accessor = GpRegisterAccessor(utility_conn) + self.failure_handler = failure_handler self.mode = self._init_mode(options.force, options.repair) self._init() @@ -288,8 +324,14 @@ class HawqRegister(object): sys.exit(1) def create_table(): - return self.accessor.do_create_table(self.src_table_name, self.tablename, self.schema, self.file_format, self.distribution_policy, self.file_locations, self.bucket_number, - self.partitionby, self.partitions_constraint, self.partitions_name) + try: + (ret, query) = self.accessor.do_create_table(self.src_table_name, self.tablename, self.schema, self.file_format, self.distribution_policy, self.file_locations, self.bucket_number, + self.partitionby, self.partitions_constraint, self.partitions_name) + except pg.DatabaseError as e: + print e + sys.exit(1) + self.failure_handler.commit(('SQL', query)) + return ret def get_seg_name(): return self.utility_accessor.get_seg_name(self.tablename, self.database, self.file_format) @@ -325,7 +367,7 @@ class HawqRegister(object): if self.bucket_number != get_bucket_number(): logger.error('Bucket number of %s is not consistent with previous bucket number.' % self.tablename) sys.exit(1) - + def set_yml_dataa(file_format, files, sizes, tablename, schema, distribution_policy, file_locations,\ bucket_number, partitionby, partitions_constraint, partitions_name, partitions_compression_level,\ partitions_compression_type, partitions_checksum, partitions_filepaths, partitions_filesizes, encoding): @@ -339,11 +381,11 @@ class HawqRegister(object): self.bucket_number = bucket_number self.partitionby = partitionby self.partitions_constraint = partitions_constraint - self.partitions_name = partitions_name + self.partitions_name = partitions_name self.partitions_compression_level = partitions_compression_level self.partitions_compression_type = partitions_compression_type self.partitions_checksum = partitions_checksum - self.partitions_filepaths = partitions_filepaths + self.partitions_filepaths = partitions_filepaths self.partitions_filesizes = partitions_filesizes self.encoding = encoding @@ -360,7 +402,7 @@ class HawqRegister(object): partitions_compression_level = [] partitions_compression_type = [] files, sizes = [], [] - + if params['FileFormat'].lower() == 'parquet': partitionby = params.get('Parquet_FileLocations').get('PartitionBy') if params.get('Parquet_FileLocations').get('Partitions') and len(params['Parquet_FileLocations']['Partitions']): @@ -379,7 +421,7 @@ class HawqRegister(object): encoding = params['Encoding'] set_yml_dataa('Parquet', files, sizes, params['TableName'], params['Parquet_Schema'], params['Distribution_Policy'], params['Parquet_FileLocations'], params['Bucketnum'], partitionby,\ partitions_constraint, partitions_name, partitions_compression_level, partitions_compression_type, partitions_checksum, partitions_filepaths, partitions_filesizes, encoding) - + else: #AO format partitionby = params.get('AO_FileLocations').get('PartitionBy') if params.get('AO_FileLocations').get('Partitions') and len(params['AO_FileLocations']['Partitions']): @@ -398,7 +440,7 @@ class HawqRegister(object): encoding = params['Encoding'] set_yml_dataa('AO', files, sizes, params['TableName'], params['AO_Schema'], params['Distribution_Policy'], params['AO_FileLocations'], params['Bucketnum'], partitionby, partitions_constraint,\ partitions_name, partitions_compression_level, partitions_compression_type, partitions_checksum, partitions_filepaths, partitions_filesizes, encoding) - + def check_file_not_folder(): for fn in self.files: hdfscmd = 'hdfs dfs -test -f %s' % fn @@ -576,7 +618,9 @@ class HawqRegister(object): result = local_ssh(hdfscmd, logger) if result != 0: logger.error('Fail to move %s to %s' % (srcfile, dstfile)) + self.failure_handler.rollback() sys.exit(1) + self.failure_handler.commit(('HDFSCMD', hdfscmd)) def _delete_files_in_hdfs(self): for fn in self.files_delete: @@ -634,7 +678,12 @@ class HawqRegister(object): for i, eof in enumerate(self.sizes_update): query += "update pg_aoseg.%s set eof = '%s', tupcount = '%s', varblockcount = '%s', eofuncompressed = '%s' where segno = '%s';" % (self.seg_name, eof, -1, -1, -1, segno_lst[i]) query += "end transaction;" - return self.utility_accessor.update_catalog(query) + try: + self.utility_accessor.update_catalog(query) + except pg.DatabaseError as e: + print e + self.failure_handler.rollback() + sys.exit(1) def _delete_metadata(self): query = "set allow_system_table_mods='dml';" @@ -693,8 +742,10 @@ def main(options, args): except pg.InternalError: logger.error('Fail to connect to database, this script can only be run when database is up.') return 1 + + failure_handler = FailureHandler(conn) # register - ins = HawqRegister(options, args[0], utility_conn, conn) + ins = HawqRegister(options, args[0], utility_conn, conn, failure_handler) ins.register() conn.close()