Skip to content

Commit

Permalink
fix race in config reloader
Browse files Browse the repository at this point in the history
nothing dangerous, just confusing log messages if an
admin hammers the reload button 100+ times per second,
or another linux process rapidly sends SIGUSR1
  • Loading branch information
9001 committed Feb 28, 2024
1 parent 8413ed6 commit 096de50
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 24 deletions.
17 changes: 10 additions & 7 deletions copyparty/svchub.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
self.stopping = False
self.stopped = False
self.reload_req = False
self.reloading = False
self.reloading = 0
self.stop_cond = threading.Condition()
self.nsigs = 3
self.retcode = 0
Expand Down Expand Up @@ -662,21 +662,24 @@ def start_zeroconf(self) -> None:
self.log("root", "ssdp startup failed;\n" + min_ex(), 3)

def reload(self) -> str:
if self.reloading:
return "cannot reload; already in progress"
with self.up2k.mutex:
if self.reloading:
return "cannot reload; already in progress"
self.reloading = 1

self.reloading = True
Daemon(self._reload, "reloading")
return "reload initiated"

def _reload(self) -> None:
self.log("root", "reload scheduled")
with self.up2k.mutex:
if self.reloading != 1:
return
self.reloading = 2
self.log("root", "reloading config")
self.asrv.reload()
self.up2k.reload()
self.broker.reload()

self.reloading = False
self.reloading = 0

def stop_thr(self) -> None:
while not self.stop_req:
Expand Down
53 changes: 36 additions & 17 deletions copyparty/up2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ def init_vols(self) -> None:
Daemon(self.deferred_init, "up2k-deferred-init")

def reload(self) -> None:
self.gid += 1
self.log("reload #{} initiated".format(self.gid))
"""mutex me"""
self.log("reload #{} scheduled".format(self.gid + 1))
all_vols = self.asrv.vfs.all_vols
self.rescan(all_vols, list(all_vols.keys()), True, False)
self._rescan(all_vols, list(all_vols.keys()), True, False)

def deferred_init(self) -> None:
all_vols = self.asrv.vfs.all_vols
Expand Down Expand Up @@ -232,7 +232,7 @@ def deferred_init(self) -> None:
for n in range(max(1, self.args.mtag_mt)):
Daemon(self._tagger, "tagger-{}".format(n))

Daemon(self._run_all_mtp, "up2k-mtp-init")
Daemon(self._run_all_mtp, "up2k-mtp-init", (self.gid,))

def log(self, msg: str, c: Union[int, str] = 0) -> None:
if self.pp:
Expand Down Expand Up @@ -337,14 +337,21 @@ def _get_volsize(self, ptop: str) -> tuple[int, int]:
def rescan(
self, all_vols: dict[str, VFS], scan_vols: list[str], wait: bool, fscan: bool
) -> str:
with self.mutex:
return self._rescan(all_vols, scan_vols, wait, fscan)

def _rescan(
self, all_vols: dict[str, VFS], scan_vols: list[str], wait: bool, fscan: bool
) -> str:
"""mutex me"""
if not wait and self.pp:
return "cannot initiate; scan is already in progress"

args = (all_vols, scan_vols, fscan)
self.gid += 1
Daemon(
self.init_indexes,
"up2k-rescan-{}".format(scan_vols[0] if scan_vols else "all"),
args,
(all_vols, scan_vols, fscan, self.gid),
)
return ""

Expand Down Expand Up @@ -575,19 +582,32 @@ def _expr_idx_filter(self, flags: dict[str, Any]) -> tuple[bool, dict[str, Any]]
return True, ret

def init_indexes(
self, all_vols: dict[str, VFS], scan_vols: list[str], fscan: bool
self, all_vols: dict[str, VFS], scan_vols: list[str], fscan: bool, gid: int = 0
) -> bool:
gid = self.gid
while self.pp and gid == self.gid:
time.sleep(0.1)
if not gid:
with self.mutex:
gid = self.gid

if gid != self.gid:
return False
nspin = 0
while True:
nspin += 1
if nspin > 1:
time.sleep(0.1)

with self.mutex:
if gid != self.gid:
return False

if self.pp:
continue

self.pp = ProgressPrinter(self.log, self.args)

break

if gid:
self.log("reload #{} running".format(self.gid))
self.log("reload #%d running" % (gid,))

self.pp = ProgressPrinter(self.log, self.args)
vols = list(all_vols.values())
t0 = time.time()
have_e2d = False
Expand Down Expand Up @@ -775,7 +795,7 @@ def init_indexes(
if self.mtag:
t = "online (running mtp)"
if scan_vols:
thr = Daemon(self._run_all_mtp, "up2k-mtp-scan", r=False)
thr = Daemon(self._run_all_mtp, "up2k-mtp-scan", (gid,), r=False)
else:
self.pp = None
t = "online, idle"
Expand Down Expand Up @@ -1809,8 +1829,7 @@ def _flush_mpool(self, wcur: "sqlite3.Cursor") -> list[str]:
self.pending_tags = []
return ret

def _run_all_mtp(self) -> None:
gid = self.gid
def _run_all_mtp(self, gid: int) -> None:
t0 = time.time()
for ptop, flags in self.flags.items():
if "mtp" in flags:
Expand Down

0 comments on commit 096de50

Please sign in to comment.