Skip to content

Commit

Permalink
Fix known issues
Browse files Browse the repository at this point in the history
1. Will always rescale to 99.9% when clip mode is set to "rescale"
2. Window will stop to function when opening log file on non-Windows
3. Optimize ETA algorithm
  • Loading branch information
CarlGao4 committed Nov 25, 2023
1 parent 5af7a39 commit 6391d36
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 60 deletions.
70 changes: 40 additions & 30 deletions GUI/GuiMain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
You should have received a copy of the GNU General Public License \
along with this program. If not, see <https://www.gnu.org/licenses/>."""

__version__ = "1.0"
__version__ = "1.0.1"

import shared

Expand Down Expand Up @@ -96,14 +96,18 @@
import platform
import psutil
import random
import shlex
import sys
import threading
import time
import traceback
import webbrowser

import separator
from PySide6_modified import Action, ModifiedQLabel, ProgressDelegate

file_queue_lock = threading.Lock()


class StartingWindow(QMainWindow):
finish_sgn = Signal(float)
Expand Down Expand Up @@ -276,10 +280,10 @@ def open_log(self):
if sys.platform == "win32":
os.startfile(str(shared.logfile))
elif sys.platform == "darwin":
os.system("open " + str(shared.logfile))
os.system(shlex.join(["open", str(shared.logfile), "&"]))
else:
try:
os.system("xdg-open " + str(shared.logfile))
os.system(shlex.join(["xdg-open", str(shared.logfile), "&"]))
except:
if (
self.m.question(
Expand Down Expand Up @@ -352,7 +356,7 @@ def refreshModels(self):
self.select_combobox.addItems(self.models)
self.setEnabled(True)

@shared.thread_wrapper
@shared.thread_wrapper(daemon=True)
def loadModel(self):
global main_window

Expand Down Expand Up @@ -466,7 +470,7 @@ def __init__(self):
global main_window

super().__init__()
self.setTitle("Separating parameters")
self.setTitle("Separation parameters")

self.device_label = QLabel()
self.device_label.setText("Device:")
Expand Down Expand Up @@ -643,7 +647,7 @@ def browseLocation(self):
if p:
self.loc_input.setText(p)

@shared.thread_wrapper
@shared.thread_wrapper(daemon=True)
def save(self, file, tensor, save_func, item, finishCallback):
self.saving += 1
finishCallback(shared.FileStatus.Writing, item)
Expand All @@ -660,7 +664,10 @@ def save(self, file, tensor, save_func, item, finishCallback):
else:
file_path = pathlib.Path(file_path_str)
if self.clip_mode.currentText() == "rescale":
data = stem_data / stem_data.abs().max() * 0.999
if (peak := stem_data.abs().max()) > 0.999:
data = stem_data / peak * 0.999
else:
data = stem_data
elif self.clip_mode.currentText() == "clamp":
data = stem_data.clamp(-0.999, 0.999)
else:
Expand Down Expand Up @@ -793,20 +800,21 @@ def addFiles(self, files):
dirpath_path = pathlib.Path(dirpath)
self.addFiles([str(dirpath_path / filename) for filename in filenames])
else:
row = self.table.rowCount()
self.table.insertRow(row)
if self.show_full_path:
self.table.setItem(row, 0, QTableWidgetItem(str(file)))
else:
self.table.setItem(row, 0, QTableWidgetItem(file.name))
delegate = ProgressDelegate()
self.table.setItemDelegateForColumn(1, delegate)
self.table.setItem(row, 1, QTableWidgetItem())
self.table.item(row, 0).setToolTip(str(file))
self.table.item(row, 0).setData(Qt.ItemDataRole.UserRole, file)
self.table.item(row, 1).setData(Qt.ItemDataRole.UserRole, [shared.FileStatus.Queued])
self.table.item(row, 1).setData(ProgressDelegate.ProgressRole, 0)
self.table.item(row, 1).setData(ProgressDelegate.TextRole, "Queued")
with file_queue_lock:
row = self.table.rowCount()
self.table.insertRow(row)
if self.show_full_path:
self.table.setItem(row, 0, QTableWidgetItem(str(file)))
else:
self.table.setItem(row, 0, QTableWidgetItem(file.name))
delegate = ProgressDelegate()
self.table.setItemDelegateForColumn(1, delegate)
self.table.setItem(row, 1, QTableWidgetItem())
self.table.item(row, 0).setToolTip(str(file))
self.table.item(row, 0).setData(Qt.ItemDataRole.UserRole, file)
self.table.item(row, 1).setData(Qt.ItemDataRole.UserRole, [shared.FileStatus.Queued])
self.table.item(row, 1).setData(ProgressDelegate.ProgressRole, 0)
self.table.item(row, 1).setData(ProgressDelegate.TextRole, "Queued")

def tableHeaderClicked(self, index):
if index == 0:
Expand Down Expand Up @@ -842,6 +850,8 @@ def removeFiles(self):
shared.FileStatus.Paused,
shared.FileStatus.Queued,
shared.FileStatus.Finished,
shared.FileStatus.Cancelled,
shared.FileStatus.Failed,
]:
continue
self.table.removeRow(i)
Expand Down Expand Up @@ -877,13 +887,14 @@ def moveTop(self):
self.table.removeRow(index + 1)

def getFirstQueued(self):
self.setEnabled(False)
for i in range(self.table.rowCount()):
if self.table.item(i, 1).data(Qt.ItemDataRole.UserRole)[0] == shared.FileStatus.Queued:
self.setEnabled(True)
return i
self.setEnabled(True)
return None
with file_queue_lock:
self.setEnabled(False)
for i in range(self.table.rowCount()):
if self.table.item(i, 1).data(Qt.ItemDataRole.UserRole)[0] == shared.FileStatus.Queued:
self.setEnabled(True)
return i
self.setEnabled(True)
return None


class SeparationControl(QGroupBox):
Expand Down Expand Up @@ -993,7 +1004,6 @@ def startSeparation(self):
if "{stem}" not in main_window.save_options.loc_input.text():
main_window.showWarning.emit("Warning", '"{stem}" not included in save location. May cause overwrite.')
self.start_button.setEnabled(False)
index = main_window.file_queue.getFirstQueued()
if (index := main_window.file_queue.getFirstQueued()) is None:
self.start_button.setEnabled(True)
main_window.setStatusText.emit("No more file to separate")
Expand Down Expand Up @@ -1031,7 +1041,7 @@ def stopCurrent(self):
if shared.debug:
log = sys.stderr
else:
log = open(str(shared.logfile / log_filename), mode="at")
log = open(str(shared.logfile / log_filename), mode="at", encoding="utf-8")
sys.stderr = log
handler = logging.StreamHandler(log)
try:
Expand Down
18 changes: 10 additions & 8 deletions GUI/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
used_cuda = False


@shared.thread_wrapper
@shared.thread_wrapper(daemon=True)
def starter(update_status: tp.Callable[[str], None], finish: tp.Callable[[float], None]):
global torch, demucs, audio
import torch
Expand Down Expand Up @@ -229,25 +229,27 @@ def updateProgress(self, progress_dict):
current_time = time.time()
self.time_hists.append((current_time, progress))
if current_time - self.last_update_eta > 1:
self.last_update_eta = current_time
while len(self.time_hists) >= 10 and current_time - self.time_hists[0][0] > 15:
while len(self.time_hists) >= 20 and current_time - self.time_hists[0][0] > 15:
self.time_hists.pop(0)
if len(self.time_hists) >= 2:
if len(self.time_hists) >= 2 and progress != self.time_hists[0][1]:
eta = int((1 - progress) / (progress - self.time_hists[0][1]) * (current_time - self.time_hists[0][0]))
else:
eta = 1000000000
if eta >= 86400:
if eta >= 99 * 86400:
eta_str = "--:--:--:--"
elif eta >= 86400:
eta_str = "%d:" % (eta // 86400)
eta %= 86400
eta_str += time.strftime("%H:%M:%S", time.gmtime(eta))
else:
eta_str = ""
eta_str += time.strftime("%H:%M:%S", time.gmtime(eta))
eta_str = time.strftime("%H:%M:%S", time.gmtime(eta))
self.updateStatus("Separating audio: %s | ETA %s" % (self.file.name, eta_str))
self.last_update_eta = current_time

def save_callback(self, *args):
audio.save_audio(*args, self.separator.samplerate, self.updateStatus)

@shared.thread_wrapper
@shared.thread_wrapper(daemon=True)
def separate(
self,
file,
Expand Down
83 changes: 61 additions & 22 deletions GUI/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import __main__
import functools
import json
import logging
import os
Expand All @@ -23,6 +24,7 @@
import sys
import threading
import traceback
import urllib.request


homeDir = pathlib.Path(__main__.__file__).resolve().parent
Expand All @@ -43,6 +45,8 @@
Please remember that absolute path must start from the root dir (like "C:\\xxx" on Windows or "/xxx" on macOS and \
Linux) in case something unexpected would happen."""

update_url = "https://api.github.com/repos/CarlGao4/Demucs-GUI/releases/latest"


def HSize(size):
s = size
Expand Down Expand Up @@ -74,10 +78,14 @@ def InitializeFolder():
if settingsFile.exists():
try:
with open(str(settingsFile), mode="rt", encoding="utf8") as f:
settings = json.loads(f.read())
settings_data = f.read()
settings = json.loads(settings_data)
if type(settings) != dict:
raise TypeError
except:
print("Settings file is corrupted, reset to default", file=sys.stderr)
print("Error message:\n%s" % traceback.format_exc(), file=sys.stderr)
print("Settings file content:\n%s" % settings_data, file=sys.stderr)
settings = {}
else:
settings = {}
Expand All @@ -86,9 +94,15 @@ def InitializeFolder():
def SetSetting(attr, value):
global settings, settingsFile
logging.debug('(%s) Set setting "%s" to %s' % (traceback.extract_stack()[-2].name, attr, str(value)))
if attr in settings and settings[attr] == value:
logging.debug("Setting not changed, ignored")
return
settings[attr] = value
with open(str(settingsFile), mode="wt", encoding="utf8") as f:
f.write(json.dumps(settings, separators=(",", ":")))
try:
with open(str(settingsFile), mode="wt", encoding="utf8") as f:
f.write(json.dumps(settings, separators=(",", ":")))
except:
logging.warning("Failed to save settings:\n%s" % traceback.format_exc())


def GetSetting(attr, default=None, autoset=True):
Expand Down Expand Up @@ -118,29 +132,54 @@ def Popen(*args, **kwargs):
kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW
kwargs["stdout"] = subprocess.PIPE
kwargs["stderr"] = subprocess.PIPE
kwargs["stdin"] = subprocess.PIPE
return subprocess.Popen(*args, **kwargs)


def thread_wrapper(func):
if not hasattr(thread_wrapper, "index"):
thread_wrapper.index = 0
def thread_wrapper(*args_thread, **kw_thread):
if "target" in kw_thread:
kw_thread.pop("target")
if "args" in kw_thread:
kw_thread.pop("args")
if "kwargs" in kw_thread:
kw_thread.pop("kwargs")

def thread_func_wrapper(func):
if not hasattr(thread_wrapper, "index"):
thread_wrapper.index = 0

def wrapper(*args, **kwargs):
thread_wrapper.index += 1
@functools.wraps(func)
def wrapper(*args, **kwargs):
thread_wrapper.index += 1

def run_and_log(idx=thread_wrapper.index):
logging.info(
"[%d] Thread %s (%s) starts" % (idx, func.__name__, pathlib.Path(func.__code__.co_filename).name)
)
try:
func(*args, **kwargs)
finally:
def run_and_log(idx=thread_wrapper.index):
logging.info(
"[%d] Thread %s (%s) ends" % (idx, func.__name__, pathlib.Path(func.__code__.co_filename).name)
"[%d] Thread %s (%s) starts" % (idx, func.__name__, pathlib.Path(func.__code__.co_filename).name)
)

t = threading.Thread(target=run_and_log, daemon=True)
t.start()
return t

return wrapper
try:
func(*args, **kwargs)
finally:
logging.info(
"[%d] Thread %s (%s) ends" % (idx, func.__name__, pathlib.Path(func.__code__.co_filename).name)
)

t = threading.Thread(target=run_and_log, *args_thread, **kw_thread)
t.start()
return t

return wrapper

return thread_func_wrapper


@thread_wrapper(daemon=True)
def checkUpdate(callback):
try:
logging.info("Checking for updates...")
with urllib.request.urlopen(update_url) as f:
data = json.loads(f.read())
logging.info("Latest version: %s" % data["tag_name"])
callback(data["tag_name"])
except:
logging.warning("Failed to check for updates:\n%s" % traceback.format_exc())
callback(None)

0 comments on commit 6391d36

Please sign in to comment.