Skip to content

Commit

Permalink
Use cp_height for checkpoints.
Browse files Browse the repository at this point in the history
This is a rough port of abf12b6 by
rt121212121 from Electron-Cash.
  • Loading branch information
JeremyRand committed Jul 5, 2019
1 parent cc9ad3a commit c3399e0
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 4,067 deletions.
111 changes: 70 additions & 41 deletions electrum/blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,6 @@ def func_wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
return func_wrapper

@property
def checkpoints(self):
return constants.net.CHECKPOINTS

def get_max_child(self) -> Optional[int]:
children = self.get_direct_children()
return max([x.forkpoint for x in children]) if children else None
Expand Down Expand Up @@ -281,24 +277,28 @@ def update_size(self) -> None:
self._size = os.path.getsize(p)//HEADER_SIZE if os.path.exists(p) else 0

@classmethod
def verify_header(cls, header: dict, prev_hash: str, target: int, expected_header_hash: str=None) -> None:
def verify_header(cls, header: dict, prev_hash: str, target: int, expected_header_hash: str=None, proof_was_provided: bool=False) -> None:
_hash = hash_header(header)
if expected_header_hash and expected_header_hash != _hash:
raise Exception("hash mismatches with expected: {} vs {}".format(expected_header_hash, _hash))
if prev_hash != header.get('prev_block_hash'):
raise Exception("prev hash mismatch: %s vs %s" % (prev_hash, header.get('prev_block_hash')))
if constants.net.TESTNET:
return
bits = cls.target_to_bits(target)
if bits != header.get('bits'):
raise Exception("bits mismatch: %s vs %s" % (bits, header.get('bits')))
block_hash_as_num = int.from_bytes(bfh(_hash), byteorder='big')
if block_hash_as_num > target:
raise Exception(f"insufficient proof of work: {block_hash_as_num} vs target {target}")

# We do not need to check the block difficulty if the chain of linked header hashes was proven correct against our checkpoint.
if not proof_was_provided:
bits = cls.target_to_bits(target)
if bits != header.get('bits'):
raise Exception("bits mismatch: %s vs %s" % (bits, header.get('bits')))
block_hash_as_num = int.from_bytes(bfh(_hash), byteorder='big')
if block_hash_as_num > target:
raise Exception(f"insufficient proof of work: {block_hash_as_num} vs target {target}")

def verify_chunk(self, index: int, data: bytes) -> None:
num = len(data) // HEADER_SIZE
start_height = index * 2016
chunk = HeaderChunk(start_height, data)
prev_hash = self.get_hash(start_height - 1)
target = self.get_target(index-1)
for i in range(num):
Expand All @@ -307,8 +307,8 @@ def verify_chunk(self, index: int, data: bytes) -> None:
expected_header_hash = self.get_hash(height)
except MissingHeader:
expected_header_hash = None
raw_header = data[i*HEADER_SIZE : (i+1)*HEADER_SIZE]
header = deserialize_header(raw_header, index*2016 + i)
raw_header = chunk.get_header_at_index(i)
header = deserialize_header(raw_header, start_height + i)
self.verify_header(header, prev_hash, target, expected_header_hash)
prev_hash = hash_header(header)

Expand All @@ -328,7 +328,7 @@ def path(self):
@with_lock
def save_chunk(self, index: int, chunk: bytes):
assert index >= 0, index
chunk_within_checkpoint_region = index < len(self.checkpoints)
chunk_within_checkpoint_region = index * 2016 < constants.net.max_checkpoint()
# chunks in checkpoint region are the responsibility of the 'main chain'
if chunk_within_checkpoint_region and self.parent is not None:
main_chain = get_best_chain()
Expand Down Expand Up @@ -470,19 +470,10 @@ def header_at_tip(self) -> Optional[dict]:
return self.read_header(height)

def get_hash(self, height: int) -> str:
def is_height_checkpoint():
within_cp_range = height <= constants.net.max_checkpoint()
at_chunk_boundary = (height+1) % 2016 == 0
return within_cp_range and at_chunk_boundary

if height == -1:
return '0000000000000000000000000000000000000000000000000000000000000000'
elif height == 0:
return constants.net.GENESIS
elif is_height_checkpoint():
index = height // 2016
h, t = self.checkpoints[index]
return h
else:
header = self.read_header(height)
if header is None:
Expand All @@ -495,9 +486,6 @@ def get_target(self, index: int) -> int:
return 0
if index == -1:
return MAX_TARGET
if index < len(self.checkpoints):
h, t = self.checkpoints[index]
return t
# new target
first = self.read_header(index * 2016)
last = self.read_header(index * 2016 + 2015)
Expand Down Expand Up @@ -569,9 +557,11 @@ def get_chainwork(self, height=None) -> int:
work_in_last_partial_chunk = (height % 2016 + 1) * work_in_single_header
return running_total + work_in_last_partial_chunk

def can_connect(self, header: dict, check_height: bool=True) -> bool:
def can_connect(self, header: dict, check_height: bool=True, proof_was_provided: bool=False) -> bool:
if header is None:
return False
if proof_was_provided:
return True
height = header['block_height']
if check_height and self.height() != height - 1:
return False
Expand All @@ -593,27 +583,18 @@ def can_connect(self, header: dict, check_height: bool=True) -> bool:
return False
return True

def connect_chunk(self, idx: int, hexdata: str) -> bool:
def connect_chunk(self, idx: int, hexdata: str, proof_was_provided: bool=False) -> bool:
assert idx >= 0, idx
try:
data = bfh(hexdata)
self.verify_chunk(idx, data)
if not proof_was_provided:
self.verify_chunk(idx, data)
self.save_chunk(idx, data)
return True
except BaseException as e:
self.logger.info(f'verify_chunk idx {idx} failed: {repr(e)}')
return False

def get_checkpoints(self):
# for each chunk, store the hash of the last block and the target after the chunk
cp = []
n = self.height() // 2016
for index in range(n):
h = self.get_hash((index+1) * 2016 -1)
target = self.get_target(index)
cp.append((h, target))
return cp


def check_header(header: dict) -> Optional[Blockchain]:
if type(header) is not dict:
Expand All @@ -625,9 +606,57 @@ def check_header(header: dict) -> Optional[Blockchain]:
return None


def can_connect(header: dict) -> Optional[Blockchain]:
def can_connect(header: dict, proof_was_provided: bool=False) -> Optional[Blockchain]:
with blockchains_lock: chains = list(blockchains.values())
for b in chains:
if b.can_connect(header):
if b.can_connect(header, proof_was_provided=proof_was_provided):
return b
return None

def verify_proven_chunk(chunk_base_height, chunk_data):
chunk = HeaderChunk(chunk_base_height, chunk_data)

header_count = len(chunk_data) // HEADER_SIZE
prev_header = None
prev_header_hash = None
for i in range(header_count):
raw_header = chunk.get_header_at_index(i)
header = deserialize_header(raw_header, chunk_base_height + i)
# Check the chain of hashes for all headers preceding the proven one.
this_header_hash = hash_header(header)
if i > 0:
if prev_header_hash != header.get('prev_block_hash'):
raise Exception("prev hash mismatch: %s vs %s" % (prev_header_hash, header.get('prev_block_hash')))
prev_header_hash = this_header_hash

# Copied from electrumx
def root_from_proof(hash, branch, index):
hash_func = sha256d
for elt in branch:
if index & 1:
hash = hash_func(elt + hash)
else:
hash = hash_func(hash + elt)
index >>= 1
if index:
raise ValueError('index out of range for branch')
return hash

class HeaderChunk:
def __init__(self, base_height, data):
self.base_height = base_height
self.data = data

def __repr__(self):
return "HeaderChunk(base_height={}, data_count={})".format(self.base_height, len(self.data))

def contains_height(self, height):
header_count = len(self.data) // HEADER_SIZE
return height >= self.base_height and height < self.base_height + header_count

def get_header_at_height(self, height):
return self.get_header_at_index(height - self.base_height)

def get_header_at_index(self, index):
header_offset = index * HEADER_SIZE
return self.data[header_offset:header_offset + HEADER_SIZE]
Loading

0 comments on commit c3399e0

Please sign in to comment.