Skip to content

Commit

Permalink
change cfg loading to fix cpu mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Zarxrax committed Mar 4, 2024
1 parent ec55980 commit cd53635
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
4 changes: 4 additions & 0 deletions cutie/config/config.py
@@ -0,0 +1,4 @@
from hydra import compose, initialize

initialize(version_base='1.3.2', config_path="", job_name="gui")
global_config = compose(config_name="gui_config")
7 changes: 3 additions & 4 deletions cutie_roto.py
Expand Up @@ -13,11 +13,11 @@
sys.exit("Please execute \"install_pytorch.bat\" to install Pytorch, then try again.")

from omegaconf import open_dict
from hydra import compose, initialize
import logging
from PySide6.QtWidgets import QApplication, QDialog
import qdarktheme

from cutie.config.config import global_config
from gui.main_controller import MainController
from gui.launcher_gui import Launcher_Dialog

Expand Down Expand Up @@ -54,9 +54,8 @@ def get_arguments():
if __name__ in "__main__":
log = logging.getLogger()

# getting hydra's config without using its decorator
initialize(version_base='1.3.2', config_path="cutie/config", job_name="gui")
cfg = compose(config_name="gui_config")
# get the config
cfg = global_config

# input arguments
args = get_arguments()
Expand Down
26 changes: 12 additions & 14 deletions gui/interactive_utils.py
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn.functional as F
from cutie.utils.palette import davis_palette
from cutie.config.config import global_config


def image_to_torch(frame: np.ndarray, device: str = 'cuda'):
Expand All @@ -26,22 +27,19 @@ def index_numpy_to_one_hot_torch(mask: np.ndarray, num_classes: int):
return F.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float()



# Some constants for visualization
"""
try:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
except:
# set torch device for interactice segmentation
cfg = global_config
if cfg.force_cpu:
device = torch.device("cpu")
"""
# get existing device instead of detecting again
device = torch.cuda.current_device()
elif torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
#print(f'Using click device: {device}')

# Some constants for visualization
color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
# scales for better visualization
color_map_np = (color_map_np.astype(np.float32) * 1.5).clip(0, 255).astype(np.uint8)
Expand Down

0 comments on commit cd53635

Please sign in to comment.