Skip to content

Commit

Permalink
feat(pyd): file integrity check support in upydevice backend.
Browse files Browse the repository at this point in the history
Signed-off-by: Braden Mars <bradenmars@bradenmars.me>
  • Loading branch information
BradenM committed Apr 17, 2023
1 parent e58c1f2 commit a3cbf31
Showing 1 changed file with 106 additions and 26 deletions.
132 changes: 106 additions & 26 deletions micropy/pyd/backend_upydevice.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

import binascii
import hashlib
import io
import random
import stat
import string
import time
from functools import wraps
from pathlib import Path, PurePosixPath
from typing import AnyStr, Callable, Generator, TypeVar, Union
from typing import AnyStr, Callable, Generator, Optional, TypeVar, Union

import upydevice
from boltons import iterutils
from micropy.exceptions import PyDeviceConnectionError, PyDeviceError
from micropy.exceptions import PyDeviceConnectionError, PyDeviceError, PyDeviceFileIntegrityError
from rich import print
from typing_extensions import ParamSpec, TypeAlias
from upydevice.phantom import UOS as UPY_UOS

Expand Down Expand Up @@ -41,7 +43,16 @@ def _wrapper(self_: UPyDeviceBackend, *args: P.args, **kwargs: P.kwargs) -> T |

while retry_count < 4:
try:
if (integrity := kwargs.pop("verify_integrity", None)) is not None:
# skip integrity check on last retry as last ditch.
kwargs["verify_integrity"] = integrity and retry_count < 3
if integrity and not kwargs["verify_integrity"]:
print("Attempting again without file integrity check...")
_result = fn(self_, *args, **kwargs) # type: ignore
except PyDeviceFileIntegrityError as e:
retry_count += 1
print(e)
self_.reset()
except Exception as e:
retry_count += 1
self_.BUFFER_SIZE = BUFFER_SIZE // pow(2, retry_count + 1)
Expand Down Expand Up @@ -125,24 +136,38 @@ def iter_files(self, path: DevicePath) -> Generator[DevicePath, None, None]:
results = self._pydevice.cmd(f"list(uos.ilistdir('{path}'))", silent=True, rtn_resp=True)
if not results:
return
for name, type_, _, _ in results:
for file_result in results:
name, type_, _, _ = file_result
abs_path = PurePosixPath(path) / name
if type_ == stat.S_IFDIR:
yield from self.iter_files(abs_path)
else:
yield abs_path

def copy_dir(self, source_path: DevicePath, target_path: HostPath, **kwargs):
def copy_dir(
self,
source_path: DevicePath,
target_path: HostPath,
exclude_integrity: Optional[set[str]] = None,
**kwargs,
):
target_path = Path(str(target_path)) # type: ignore
source_path = self.resolve_path(source_path)
exclude_integrity = exclude_integrity or set()
for file_path in self.iter_files(source_path):
rel_path = PurePosixPath(file_path).relative_to(
list(PurePosixPath(file_path).parents)[-1]
)
# handles os-path conversion
file_dest = Path(target_path / rel_path)
file_dest.parent.mkdir(parents=True, exist_ok=True)
self.pull_file(file_path, HostPath(str(file_dest)), **kwargs)
integ_exclude = (
file_path in exclude_integrity or Path(file_path).name in exclude_integrity
)
integrity = kwargs.pop("verify_integrity", True) and not integ_exclude
self.pull_file(
file_path, HostPath(str(file_dest)), verify_integrity=integrity, **kwargs
)

def push_file(
self, source_path: HostPath, target_path: DevicePath, binary: bool = False, **kwargs
Expand Down Expand Up @@ -201,33 +226,88 @@ def write_file(
self._pydevice.cmd("f.close()")
self._pydevice.cmd("import gc; gc.collect()")

def _compute_chunk_size(self) -> int:
mem_free = int(
self._pydevice.cmd("import gc;_=gc.collect();gc.mem_free()", rtn_resp=True, silent=True)
)
return min(mem_free // 4, 4096)

def _compute_device_file_digest(
self,
device_path: DevicePath,
*,
chunk_size: int = 256,
content_size: Optional[int] = None,
pos: int = 0,
) -> str:
checksum_cmd = ";".join(
[
"import ubinascii, uhashlib, gc",
"f=open('{path}', 'rb')",
"sha = uhashlib.sha256()",
"__=[sha.update(f.read({chunk_size})) and gc.collect() for _ in range({pos}, {file_size}, {chunk_size})]",
"ubinascii.hexlify(sha.digest()).decode()",
]
)
if content_size is None:
content_size = self.uos.stat(str(device_path))[6]
sum_cmd = checksum_cmd.format(
path=str(device_path), chunk_size=chunk_size, file_size=content_size, pos=pos
)
return self._pydevice.cmd(sum_cmd, silent=True, rtn_resp=True)

@retry
def read_file(
self, target_path: DevicePath, *, consumer: PyDeviceConsumer | None = None
self,
target_path: DevicePath,
*,
consumer: PyDeviceConsumer = NoOpConsumer,
verify_integrity: bool = True,
) -> str:
target_path = self.resolve_path(target_path)
self._pydevice.cmd("import ubinascii", silent=True)
self._pydevice.cmd(f"f = open('{str(target_path)}', 'rb')", silent=True)
content_size = self._pydevice.cmd("f.seek(0,2)", rtn_resp=True, silent=True)
self._pydevice.cmd("f.seek(0)", silent=True)

read_chunk_cmd = (
"f=open('{path}', 'rb');_=f.seek({pos});ch=f.read({chunk_size});f.close();ch"
)

content_size = self.uos.stat(str(target_path))[6]
buffer = io.BytesIO()
if consumer:
consumer.on_start(name=f"Reading {target_path}", size=content_size // 2)
last_pos = 0
while True:
pos = self._pydevice.cmd("f.tell()", rtn_resp=True, silent=True)
if consumer:
consumer.on_update(size=(pos - last_pos) // 2)
last_pos = pos
if pos == content_size:
if consumer:
consumer.on_end()
break
next_chunk = self._pydevice.cmd(
f"f.read({self.BUFFER_SIZE})", rtn_resp=True, silent=True
)
pos = 0
chunk_size = self._compute_chunk_size()
consumer.on_start(
name=f"Reading {Path(target_path).name} (xsize: {chunk_size})", size=int(content_size)
)
hasher = hashlib.sha256(usedforsecurity=False)
while pos < content_size:
try:
cmd = read_chunk_cmd.format(path=str(target_path), pos=pos, chunk_size=chunk_size)
next_chunk = self._pydevice.cmd(cmd, rtn_resp=True, silent=True)
except Exception as e:
consumer.on_message(f"Failed to read chunk; retrying ({e})")
self.reset()
chunk_size = self._compute_chunk_size()
continue
if len(next_chunk) == 0:
consumer.on_message("Failed to read chunk (no data); retrying.")
self.reset()
continue
hasher.update(next_chunk)
buffer.write(next_chunk)
self._pydevice.cmd("f.close()")
pos += chunk_size
consumer.on_update(size=len(next_chunk))
consumer.on_end()

if verify_integrity:
device_sum = self._compute_device_file_digest(
target_path, chunk_size=chunk_size, content_size=content_size, pos=0
)
digest = hasher.hexdigest()
if device_sum != digest:
raise PyDeviceFileIntegrityError(
device_path=Path(target_path).name, device_sum=device_sum, digest=digest
)
consumer.on_message(f"Verified integrity: {Path(target_path).name}")

value = buffer.getvalue().decode()
return value

Expand Down

0 comments on commit a3cbf31

Please sign in to comment.