diff --git a/src/infuse_iot/diff.py b/src/infuse_iot/diff.py index 6629dd6..e8082f0 100644 --- a/src/infuse_iot/diff.py +++ b/src/infuse_iot/diff.py @@ -5,11 +5,9 @@ import ctypes import binascii -from collections import defaultdict, OrderedDict +from collections import defaultdict from typing import List, Dict, Tuple -from functools import cmp_to_key - class ValidationError(Exception): """Generic patch validation exception""" @@ -26,11 +24,10 @@ class OpCode(enum.IntEnum): WRITE_LEN_U12 = 5 << 4 WRITE_LEN_U20 = 6 << 4 WRITE_LEN_U32 = 7 << 4 - WRITE_CACHED = 8 << 4 - ADDR_SHIFT_S8 = 9 << 4 - ADDR_SHIFT_S16 = 10 << 4 - ADDR_SET_U32 = 11 << 4 - PATCH = 12 << 4 + ADDR_SHIFT_S8 = 8 << 4 + ADDR_SHIFT_S16 = 9 << 4 + ADDR_SET_U32 = 10 << 4 + PATCH = 11 << 4 OPCODE_MASK = 0xF0 DATA_MASK = 0x0F @@ -62,7 +59,6 @@ def from_bytes( b: bytes, offset: int, original_offset: int, - write_cache: List[bytes], ): """Reconstruct class from bytes""" opcode = OpCode.from_byte(b[offset]) @@ -72,24 +68,22 @@ def from_bytes( or opcode == OpCode.COPY_LEN_U20 or opcode == OpCode.COPY_LEN_U32 ): - return CopyInstr.from_bytes(b, offset, original_offset, write_cache) + return CopyInstr.from_bytes(b, offset, original_offset) if ( opcode == OpCode.WRITE_LEN_U4 or opcode == OpCode.WRITE_LEN_U12 or opcode == OpCode.WRITE_LEN_U20 or opcode == OpCode.WRITE_LEN_U32 ): - return WriteInstr.from_bytes(b, offset, original_offset, write_cache) - if opcode == OpCode.WRITE_CACHED: - return WriteCachedInstr.from_bytes(b, offset, original_offset, write_cache) + return WriteInstr.from_bytes(b, offset, original_offset) if ( opcode == OpCode.ADDR_SHIFT_S8 or opcode == OpCode.ADDR_SHIFT_S16 or opcode == OpCode.ADDR_SET_U32 ): - return SetAddrInstr.from_bytes(b, offset, original_offset, write_cache) + return SetAddrInstr.from_bytes(b, offset, original_offset) if opcode == OpCode.PATCH: - return PatchInstr.from_bytes(b, offset, original_offset, write_cache) + return PatchInstr.from_bytes(b, offset, original_offset) raise NotImplementedError @@ -119,12 +113,16 @@ class SetAddrU32(ctypes.LittleEndianStructure): ] _pack_ = 1 - def __init__(self, old_addr, new_addr): + def __init__(self, old_addr, new_addr, cls_override=None): self.old = old_addr self.new = new_addr self.shift = self.new - self.old + self._cls_override = cls_override def ctypes_class(self): + if self._cls_override is not None: + return self._cls_override + if -128 <= self.shift <= 127: return self.ShiftAddrS8 elif -32768 <= self.shift <= 32767: @@ -133,9 +131,7 @@ def ctypes_class(self): return self.SetAddrU32 @classmethod - def from_bytes( - cls, b: bytes, offset: int, original_offset: int, _write_cache: List[bytes] - ): + def from_bytes(cls, b: bytes, offset: int, original_offset: int): opcode = b[offset] if opcode == OpCode.ADDR_SHIFT_S8: s = cls.ShiftAddrS8.from_buffer_copy(b, offset) @@ -151,17 +147,15 @@ def from_bytes( return c, ctypes.sizeof(s), c.new def __bytes__(self): - if -128 <= self.shift <= 127: - instr_cls = self.ShiftAddrS8 + instr = self.ctypes_class() + if instr == self.ShiftAddrS8: val = self.shift - elif -32768 <= self.shift <= 32767: - instr_cls = self.ShiftAddrS16 + elif instr == self.ShiftAddrS16: val = self.shift else: - instr_cls = self.SetAddrU32 val = self.new - return bytes(instr_cls(instr_cls.op.value, val)) + return bytes(instr(instr.op.value, val)) def __str__(self): if -32768 <= self.shift <= 32767: @@ -216,13 +210,16 @@ class CopyU32(ctypes.LittleEndianStructure): ] _pack_ = 1 - def __init__(self, length: int, original_offset: int = -1): + def __init__(self, length: int, original_offset: int = -1, cls_override=None): assert length != 0 self.length = length # Used in construction to simplify optimisations self.original_offset = original_offset + self._cls_override = cls_override def ctypes_class(self): + if self._cls_override is not None: + return self._cls_override if self.length < 16: return self.CopyU4 elif self.length < 4096: @@ -233,9 +230,7 @@ def ctypes_class(self): return self.CopyU32 @classmethod - def from_bytes( - cls, b: bytes, offset: int, original_offset: int, _write_cache: List[bytes] - ): + def from_bytes(cls, b: bytes, offset: int, original_offset: int): opcode = OpCode.from_byte(b[offset]) if opcode == OpCode.COPY_LEN_U4: s = cls.CopyU4.from_buffer_copy(b, offset) @@ -252,13 +247,13 @@ def from_bytes( def __bytes__(self): instr = self.ctypes_class() - if self.length < 16: + if instr == self.CopyU4: return bytes(instr(instr.op.value | self.length)) - elif self.length < 4096: + elif instr == self.CopyU12: top = self.length >> 8 bottom = self.length & 0xFF return bytes(instr(instr.op.value | top, bottom)) - elif self.length < 1048576: + elif instr == self.CopyU20: top = self.length >> 16 bottom = self.length & 0xFFFF return bytes(instr(instr.op.value | top, bottom)) @@ -314,11 +309,14 @@ class WriteU32(ctypes.LittleEndianStructure): ] _pack_ = 1 - def __init__(self, data): + def __init__(self, data, cls_override=None): assert len(data) != 0 self.data = data + self._cls_override = cls_override def ctypes_class(self): + if self._cls_override is not None: + return self._cls_override if len(self.data) < 16: return self.WriteU4 elif len(self.data) < 4096: @@ -329,9 +327,7 @@ def ctypes_class(self): return self.WriteU32 @classmethod - def from_bytes( - cls, b: bytes, offset: int, original_offset: int, _write_cache: List[bytes] - ): + def from_bytes(cls, b: bytes, offset: int, original_offset: int): opcode = OpCode.from_byte(b[offset]) if opcode == OpCode.WRITE_LEN_U4: s = cls.WriteU4.from_buffer_copy(b, offset) @@ -353,13 +349,13 @@ def from_bytes( def __bytes__(self): instr = self.ctypes_class() - if len(self.data) < 16: + if instr == self.WriteU4: return bytes(instr(instr.op.value | len(self.data))) + self.data - elif len(self.data) < 4096: + elif instr == self.WriteU12: top = len(self.data) >> 8 bottom = len(self.data) & 0xFF return bytes(instr(instr.op.value | top, bottom)) + self.data - elif len(self.data) < 1048576: + elif instr == self.WriteU20: top = len(self.data) >> 16 bottom = len(self.data) & 0xFFFF return bytes(instr(instr.op.value | top, bottom)) + self.data @@ -376,47 +372,6 @@ def __len__(self): return ctypes.sizeof(self.ctypes_class()) + len(self.data) -class WriteCachedInstr(Instr): - class WriteCached(ctypes.LittleEndianStructure): - op = OpCode.WRITE_CACHED - _fields_ = [ - ("opcode", ctypes.c_uint8), - ] - _pack_ = 1 - - def __init__(self, idx, write_len): - self.idx = idx - self.write_len = write_len - - def ctypes_class(self): - return self.WriteCached - - @classmethod - def from_bytes( - cls, b: bytes, offset: int, original_offset: int, write_cache: List[bytes] - ): - """Reconstruct class from bytes""" - instr = cls.WriteCached.from_buffer_copy(b, offset) - idx = OpCode.data(instr.opcode) - write_len = len(write_cache[idx]) - return ( - cls(idx, write_len), - ctypes.sizeof(instr), - original_offset + write_len, - ) - - def __bytes__(self): - instr = self.ctypes_class() - op = instr.op.value | self.idx - return bytes(instr(op)) - - def __str__(self): - return f"WRITE: Cache index {self.idx} ({self.write_len} bytes)" - - def __len__(self): - return ctypes.sizeof(self.ctypes_class()) - - class PatchInstr(Instr): class PatchData(ctypes.LittleEndianStructure): op = OpCode.PATCH @@ -430,9 +385,7 @@ def ctypes_class(self): return self.PatchData @classmethod - def from_bytes( - cls, b: bytes, offset: int, original_offset: int, _write_cache: List[bytes] - ): + def from_bytes(cls, b: bytes, offset: int, original_offset: int): assert b[offset] == OpCode.PATCH operations = [] length = 1 @@ -511,6 +464,9 @@ def __len__(self): class diff: class PatchHeader(ctypes.LittleEndianStructure): + VERSION_MAJOR = 1 + VERSION_MINOR = 0 + class ArrayValidation(ctypes.LittleEndianStructure): _fields_ = [ ("length", ctypes.c_uint32), @@ -522,17 +478,18 @@ class ArrayValidation(ctypes.LittleEndianStructure): cache_size = 128 _fields_ = [ ("magic", ctypes.c_uint32), + ("version_major", ctypes.c_uint8), + ("version_minor", ctypes.c_uint8), ("original_file", ArrayValidation), ("constructed_file", ArrayValidation), ("patch_file", ArrayValidation), - ("write_cache", 128 * ctypes.c_uint8), ("header_crc", ctypes.c_uint32), ] _pack_ = 1 @classmethod def _naive_diff(cls, old: bytes, new: bytes, hash_len: int = 8): - """Construct basic runs Merge runs of COPY and WRITE into PATCH""" + """Construct basic runs of WRITE, COPY, and SET_ADDR instructions""" instr = [] old_offset = 0 new_offset = 0 @@ -614,51 +571,6 @@ def _naive_diff(cls, old: bytes, new: bytes, hash_len: int = 8): return instr - @classmethod - def _common_writes(cls, instructions: List[Instr]) -> OrderedDict: - common = OrderedDict() - - write_chunks = defaultdict(int) - for instr in instructions: - if isinstance(instr, WriteInstr): - if len(instr.data) < 8: - continue - write_chunks[instr.data] += 1 - - for val, cnt in write_chunks.items(): - if cnt > 2: - common[val] = (cnt - 1) * len(val) - - by_savings = sorted(common.items(), key=lambda x: x[1], reverse=True) - allocated = 0 - - # This allocation scheme is not necessarily the most efficient. - # A 100 byte chunk that saves 1000 bytes would be chosen over - # two 50 byte chunks that save 800 bytes each - cached = [] - for write_bytes, _ in by_savings: - if len(cached) > 16: - break - if (1 + len(write_bytes) + allocated) > cls.PatchHeader.cache_size: - continue - cached.append(write_bytes) - allocated += 1 + len(write_bytes) - - # Replace the writes that are of common values - out_instr = [] - for instr in instructions: - if isinstance(instr, WriteInstr): - if instr.data in cached: - out_instr.append( - WriteCachedInstr(cached.index(instr.data), len(instr.data)) - ) - else: - out_instr.append(instr) - else: - out_instr.append(instr) - - return cached, out_instr - @classmethod def _cleanup_jumps(cls, old: bytes, instructions: List[Instr]) -> List[Instr]: """Find locations that jumped backwards just to jump forward to original location""" @@ -750,116 +662,100 @@ def finalise(): return merged @classmethod - def _merge_crack(cls, old: bytes, instructions: List[Instr]) -> List[Instr]: - """Crack a WRITE operation in a PATCH into a [WRITE,COPY,WRITE] if COPY is at least 2 bytes""" - - for instr in instructions: - if not isinstance(instr, PatchInstr): - continue - - old_offset = 0 - updated_ops = [] - while len(instr.operations) > 0: - if len(instr.operations) == 1: - updated_ops.append(instr.operations.pop()) - continue - - copy_op = instr.operations.pop(0) - write_op = instr.operations.pop(0) - assert isinstance(copy_op, CopyInstr) - assert isinstance(write_op, WriteInstr) - assert copy_op.original_offset != -1 - - old_offset = copy_op.original_offset + copy_op.length - updated_ops.append(copy_op) + def _write_crack(cls, old: bytes, instructions: List[Instr]) -> List[Instr]: + """Crack a WRITE operation into a [WRITE,COPY,WRITE] if COPY is at least 2 bytes""" - if len(write_op.data) < 4: - # Too small to crack - updated_ops.append(write_op) - continue + cracked = [] + old_offset = 0 - split = [0] - for idx, b in enumerate(write_op.data): - if old[old_offset + idx] != b: - if len(split) % 2: - # Already on a WRITE - split[-1] += 1 - else: - # On a COPY, swap to a WRITE - split.append(1) - continue + while len(instructions): + instr = instructions.pop(0) + if isinstance(instr, CopyInstr): + old_offset = instr.original_offset + instr.length + cracked.append(instr) + continue + elif isinstance(instr, SetAddrInstr): + old_offset = instr.new + cracked.append(instr) + continue + assert isinstance(instr, WriteInstr) + + split = [0] + for idx, b in enumerate(instr.data): + if old_offset + idx >= len(old): + # Add remainder of write to last split + split[-1] += len(instr.data) - idx + break + if old[old_offset + idx] != b: if len(split) % 2: - # On a WRITE, switch to a COPY - split.append(1) - else: - # Already on a COPY + # Already on a WRITE split[-1] += 1 - - # Total data count should remain the same - assert sum(split) == len(write_op.data) - - if len(split) % 2 == 0: - # Ended on a copy - copy_len = split.pop() - if len(instr.operations) > 0: - # Push the match into the next instruction if possible - assert isinstance(instr.operations[0], CopyInstr) - instr.operations[0].length += copy_len - instr.operations[0].original_offset -= copy_len else: - # Merge the copy back into the previous write - split[-1] += copy_len + # On a COPY, swap to a WRITE + split.append(1) + continue - # Should now have N*[WRITE, COPY] + [WRITE] - assert len(split) % 2 == 1 + if len(split) % 2: + # On a WRITE, switch to a COPY + split.append(1) + else: + # Already on a COPY + split[-1] += 1 + + # Total data count should remain the same + assert sum(split) == len(instr.data) + + if len(split) % 2 == 0: + # Ended on a copy + copy_len = split.pop() + if len(instructions) > 0 and isinstance(instructions[0], CopyInstr): + # Push the match into the next instruction if possible + instructions[0].length += copy_len + instructions[0].original_offset -= copy_len + else: + # Merge the copy back into the previous write + split[-1] += copy_len - # Construct the [WRITE, COPY] pairs - offset = 0 - while len(split) > 1: - write_len = split.pop(0) - copy_len = split.pop(0) + # Should now have N*[WRITE, COPY] + [WRITE] + assert len(split) % 2 == 1 - # If the copy was only 1 byte, roll it back - if copy_len == 1: - split[0] += write_len + copy_len - else: - updated_ops.append( - WriteInstr(write_op.data[offset : offset + write_len]) - ) - offset += write_len - updated_ops.append(CopyInstr(copy_len, old_offset + offset)) - offset += copy_len - - # Append the final WRITE - write_len = split.pop() - updated_ops.append( - WriteInstr(write_op.data[offset : offset + write_len]) - ) + # Construct the [WRITE, COPY] pairs + offset = 0 + while len(split) > 1: + write_len = split.pop(0) + copy_len = split.pop(0) + + # If the copy was only 1 byte, roll it back + if copy_len == 1: + split[0] += write_len + copy_len + else: + cracked.append(WriteInstr(instr.data[offset : offset + write_len])) + offset += write_len + cracked.append(CopyInstr(copy_len, old_offset + offset)) + offset += copy_len - # Update the PATCH operations - instr.operations = updated_ops + # Append the final WRITE + write_len = split.pop() + cracked.append(WriteInstr(instr.data[offset : offset + write_len])) - return instructions + return cracked @classmethod def _gen_patch_instr(cls, bin_orig: bytes, bin_new: bytes) -> List[Instr]: best_patch = None - best_write_cache = None best_patch_len = 2**32 # Find best diff across range for i in range(4, 8): instr = cls._naive_diff(bin_orig, bin_new, i) instr = cls._cleanup_jumps(bin_orig, instr) - write_cache, instr = cls._common_writes(instr) + instr = cls._write_crack(bin_orig, instr) instr = cls._merge_operations(instr) - instr = cls._merge_crack(bin_orig, instr) patch_len = sum([len(i) for i in instr]) if patch_len < best_patch_len: best_patch = instr - best_write_cache = write_cache best_patch_len = patch_len metadata = { @@ -873,21 +769,14 @@ def _gen_patch_instr(cls, bin_orig: bytes, bin_new: bytes) -> List[Instr]: }, } - return metadata, best_write_cache, best_patch + return metadata, best_patch @classmethod - def _gen_patch_header( - cls, patch_metadata: Dict, write_cache: List[bytes], patch_data: bytes - ): - cache_bin = b"" - for entry in write_cache: - cache_bin += len(entry).to_bytes(1, "little") + entry - cache_bin += (cls.PatchHeader.cache_size - len(cache_bin)) * b"\x00" - assert len(cache_bin) == cls.PatchHeader.cache_size - c = (ctypes.c_uint8 * cls.PatchHeader.cache_size).from_buffer_copy(cache_bin) - + def _gen_patch_header(cls, patch_metadata: Dict, patch_data: bytes): hdr = cls.PatchHeader( cls.PatchHeader.magic_value, + cls.PatchHeader.VERSION_MAJOR, + cls.PatchHeader.VERSION_MINOR, cls.PatchHeader.ArrayValidation( patch_metadata["original"]["len"], patch_metadata["original"]["crc"], @@ -900,7 +789,6 @@ def _gen_patch_header( len(patch_data), binascii.crc32(patch_data), ), - c, 0, ) hdr_no_crc = bytes(hdr) @@ -948,24 +836,17 @@ def _patch_load(cls, patch_binary: bytes): f"Patch data CRC does not match patch information ({binascii.crc32(data):08x} != {hdr.patch_file.crc:08x})" ) - cache = [] - cache_bin = bytes(hdr.write_cache) - while len(cache_bin) and cache_bin[0] != 0: - l = cache_bin[0] - cache.append(cache_bin[1 : 1 + l]) - cache_bin = cache_bin[1 + l :] - instructions = [] patch_offset = 0 original_offset = 0 while patch_offset < len(data): instr, length, original_offset = Instr.from_bytes( - data, patch_offset, original_offset, cache + data, patch_offset, original_offset ) patch_offset += length instructions.append(instr) - return metadata, cache, instructions + return metadata, instructions @classmethod def generate( @@ -974,9 +855,9 @@ def generate( bin_new: bytes, verbose: bool, ) -> bytes: - meta, cache, instructions = diff._gen_patch_instr(bin_original, bin_new) + meta, instructions = diff._gen_patch_instr(bin_original, bin_new) patch_data = diff._gen_patch_data(instructions) - patch_header = diff._gen_patch_header(meta, cache, patch_data) + patch_header = diff._gen_patch_header(meta, patch_data) bin_patch = patch_header + patch_data ratio = 100 * len(bin_patch) / meta["new"]["len"] @@ -1003,13 +884,73 @@ def generate( # Return complete file return bin_patch + @classmethod + def validation( + cls, bin_original: bytes, invalid_length: bool, invalid_crc: bool + ) -> bytes: + assert len(bin_original) > 1024 + + # Manually construct an instruction set that runs all instructions + instructions = [] + instructions.append( + WriteInstr(bin_original[:8], cls_override=WriteInstr.WriteU4) + ) + instructions.append( + WriteInstr(bin_original[8:16], cls_override=WriteInstr.WriteU12) + ) + instructions.append(SetAddrInstr(16, 8, cls_override=SetAddrInstr.ShiftAddrS8)) + instructions.append( + WriteInstr(bin_original[16:128], cls_override=WriteInstr.WriteU20) + ) + instructions.append( + SetAddrInstr(120, 200, cls_override=SetAddrInstr.ShiftAddrS16) + ) + instructions.append( + WriteInstr(bin_original[128:256], cls_override=WriteInstr.WriteU32) + ) + instructions.append( + SetAddrInstr(328, 256, cls_override=SetAddrInstr.SetAddrU32) + ) + instructions.append(CopyInstr(8, cls_override=CopyInstr.CopyU4)) + instructions.append(CopyInstr(8, cls_override=CopyInstr.CopyU12)) + instructions.append(CopyInstr(128 - 16, cls_override=CopyInstr.CopyU20)) + instructions.append(CopyInstr(128, cls_override=CopyInstr.CopyU32)) + instructions.append( + PatchInstr( + [ + CopyInstr(15), + WriteInstr(bin_original[512 + 15 : 512 + 16]), + CopyInstr(14), + WriteInstr(bin_original[512 + 30 : 512 + 32]), + ] + ) + ) + instructions.append(CopyInstr(len(bin_original) - 544)) + + meta, _ = diff._gen_patch_instr(bin_original, bin_original) + if invalid_length: + meta["new"]["len"] -= 1 + if invalid_crc: + meta["new"]["crc"] -= 1 + + patch_data = diff._gen_patch_data(instructions) + patch_header = diff._gen_patch_header(meta, patch_data) + bin_patch = patch_header + patch_data + + # Validate that file can be reconstructed + if not invalid_length and not invalid_crc: + patched = cls.patch(bin_original, bin_patch) + assert bin_original == patched + + return bin_patch + @classmethod def patch( cls, bin_original: bytes, bin_patch: bytes, ) -> bytes: - meta, cache, instructions = diff._patch_load(bin_patch) + meta, instructions = diff._patch_load(bin_patch) patched = b"" orig_offset = 0 @@ -1029,9 +970,6 @@ def patch( elif isinstance(instr, WriteInstr): patched += instr.data orig_offset += len(instr.data) - elif isinstance(instr, WriteCachedInstr): - patched += cache[instr.idx] - orig_offset += len(cache[instr.idx]) elif isinstance(instr, SetAddrInstr): orig_offset = instr.new elif isinstance(instr, PatchInstr): @@ -1064,19 +1002,14 @@ def dump( cls, bin_patch: bytes, ): - meta, cache, instructions = diff._patch_load(bin_patch) + meta, instructions = diff._patch_load(bin_patch) total_write_bytes = 0 print(f"Original File: {meta['original']['len']:6d} bytes") print(f" New File: {meta['new']['len']:6d} bytes") print( - f" Patch File: {meta['patch']['len']:6d} bytes ({len(instructions):5d} instructions)" + f" Patch File: {len(bin_patch)} bytes ({len(instructions):5d} instructions)" ) - print("") - print("Write Cache:") - for idx, entry in enumerate(cache): - total_write_bytes += len(entry) - print(f"\t{idx:2d}: {entry.hex()}") class_count = defaultdict(int) for instr in instructions: @@ -1089,9 +1022,9 @@ def dump( total_write_bytes += len(op.data) print("") - print("Patch Data Split") + print("Total WRITE data:") print( - f"\tWRITE data: {total_write_bytes} bytes ({100*total_write_bytes/len(bin_patch):.2f}%)" + f"\t{total_write_bytes} bytes ({100*total_write_bytes/len(bin_patch):.2f}%)" ) print("") @@ -1120,6 +1053,21 @@ def dump( ) generate_args.add_argument("patch", help="Output patch file name") + # Generate validation patch file + validation_args = subparser.add_parser( + "validation", help="Generate a patch file for validating appliers" + ) + validation_args.add_argument( + "--invalid-length", action="store_true", help="Incorrect output file length" + ) + validation_args.add_argument( + "--invalid-crc", action="store_true", help="Incorrect output file CRC" + ) + validation_args.add_argument( + "input_file", help="File to use as base image and desired output" + ) + validation_args.add_argument("patch", help="Output patch file name") + # Apply patch file patch_args = subparser.add_parser("patch", help="Apply a patch file") patch_args.add_argument("original", help="Original file to use as base image") @@ -1146,6 +1094,13 @@ def dump( ) with open(args.patch, "wb") as f_output: f_output.write(patch) + elif args.command == "validation": + with open(args.input_file, "rb") as f_input: + patch = diff.validation( + f_input.read(-1), args.invalid_length, args.invalid_crc + ) + with open(args.patch, "wb") as f_output: + f_output.write(patch) elif args.command == "patch": with open(args.original, "rb") as f_orig: with open(args.patch, "rb") as f_patch: