Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
246 lines (198 sloc) 7.42 KB
import network
import socket
import machine
import ujson
import uhashlib
import ubinascii
import gc
import pycom
import os
import machine
# Try to get version number
try:
from OTA_VERSION import VERSION
except ImportError:
VERSION = '1.0.0'
class OTA():
# The following two methods need to be implemented in a subclass for the
# specific transport mechanism e.g. WiFi
def connect(self):
raise NotImplementedError()
def get_data(self, req, dest_path=None, hash=False):
raise NotImplementedError()
# OTA methods
def get_current_version(self):
return VERSION
def get_update_manifest(self):
req = "manifest.json?current_ver={}".format(self.get_current_version())
manifest_data = self.get_data(req).decode()
manifest = ujson.loads(manifest_data)
gc.collect()
return manifest
def update(self):
manifest = self.get_update_manifest()
if manifest is None:
print("Already on the latest version")
return
# Download new files and verify hashes
for f in manifest['new'] + manifest['update']:
# Upto 5 retries
for _ in range(5):
try:
self.get_file(f)
break
except Exception as e:
print(e)
print("Error downloading `{}` retrying...".format(f['URL']))
else:
raise Exception("Failed to download `{}`".format(f['URL']))
# Backup old files
# only once all files have been successfully downloaded
for f in manifest['update']:
self.backup_file(f)
# Rename new files to proper name
for f in manifest['new'] + manifest['update']:
new_path = "{}.new".format(f['dst_path'])
dest_path = "{}".format(f['dst_path'])
os.rename(new_path, dest_path)
# `Delete` files no longer required
# This actually makes a backup of the files incase we need to roll back
for f in manifest['delete']:
self.delete_file(f)
# Flash firmware
if "firmware" in manifest:
self.write_firmware(manifest['firmware'])
# Save version number
try:
self.backup_file({"dst_path": "/flash/OTA_VERSION.py"})
except OSError:
pass # There isnt a previous file to backup
with open("/flash/OTA_VERSION.py", 'w') as fp:
fp.write("VERSION = '{}'".format(manifest['version']))
from OTA_VERSION import VERSION
# Reboot the device to run the new decode
machine.reset()
def get_file(self, f):
new_path = "{}.new".format(f['dst_path'])
# If a .new file exists from a previously failed update delete it
try:
os.remove(new_path)
except OSError:
pass # The file didnt exist
# Download new file with a .new extension to not overwrite the existing
# file until the hash is verified.
hash = self.get_data(f['URL'].split("/", 3)[-1],
dest_path=new_path,
hash=True)
# Hash mismatch
if hash != f['hash']:
print(hash, f['hash'])
msg = "Downloaded file's hash does not match expected hash"
raise Exception(msg)
def backup_file(self, f):
bak_path = "{}.bak".format(f['dst_path'])
dest_path = "{}".format(f['dst_path'])
# Delete previous backup if it exists
try:
os.remove(bak_path)
except OSError:
pass # There isnt a previous backup
# Backup current file
os.rename(dest_path, bak_path)
def delete_file(self, f):
bak_path = "/{}.bak_del".format(f)
dest_path = "/{}".format(f)
# Delete previous delete backup if it exists
try:
os.remove(bak_path)
except OSError:
pass # There isnt a previous delete backup
# Backup current file
os.rename(dest_path, bak_path)
def write_firmware(self, f):
hash = self.get_data(f['URL'].split("/", 3)[-1],
hash=True,
firmware=True)
# TODO: Add verification when released in future firmware
class WiFiOTA(OTA):
def __init__(self, ssid, password, ip, port):
self.SSID = ssid
self.password = password
self.ip = ip
self.port = int(port)
def connect(self):
self.wlan = network.WLAN(mode=network.WLAN.STA)
if not self.wlan.isconnected() or self.wlan.ssid() != self.SSID:
for net in self.wlan.scan():
if net.ssid == self.SSID:
self.wlan.connect(self.SSID, auth=(network.WLAN.WPA2,
self.password))
while not self.wlan.isconnected():
machine.idle() # save power while waiting
break
else:
raise Exception("Cannot find network '{}'".format(self.SSID))
else:
# Already connected to the correct WiFi
pass
def _http_get(self, path, host):
req_fmt = 'GET /{} HTTP/1.0\r\nHost: {}\r\n\r\n'
req = bytes(req_fmt.format(path, host), 'utf8')
return req
def get_data(self, req, dest_path=None, hash=False, firmware=False):
# Connect to server
print("Requesting: {}".format(req))
s = socket.socket(socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP)
s.connect((self.ip, self.port))
# Request File
s.sendall(self._http_get(req, "{}:{}".format(self.ip, self.port)))
try:
content = bytearray()
fp = None
if dest_path is not None:
if firmware:
raise Exception("Cannot write firmware to a file")
fp = open(dest_path, 'wb')
if firmware:
pycom.ota_start()
h = uhashlib.sha1()
# Get data from server
result = s.recv(100)
start_writing = False
while (len(result) > 0):
# Ignore the HTTP headers
if not start_writing:
if "\r\n\r\n" in result:
start_writing = True
result = result.decode().split("\r\n\r\n")[1].encode()
if start_writing:
if firmware:
pycom.ota_write(result)
elif fp is None:
content.extend(result)
else:
fp.write(result)
if hash:
h.update(result)
result = s.recv(100)
s.close()
if fp is not None:
fp.close()
if firmware:
pycom.ota_finish()
except Exception as e:
# Since only one hash operation is allowed at Once
# ensure we close it if there is an error
if h is not None:
h.digest()
raise e
hash_val = ubinascii.hexlify(h.digest()).decode()
if dest_path is None:
if hash:
return (bytes(content), hash_val)
else:
return bytes(content)
elif hash:
return hash_val
You can’t perform that action at this time.