Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Remove basicsr's need
Browse files Browse the repository at this point in the history
  • Loading branch information
Fannovel16 committed Apr 20, 2023
1 parent 31ecac7 commit 6b07310
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 11 deletions.
2 changes: 1 addition & 1 deletion v11/hed_v11/__init__.py
Expand Up @@ -59,7 +59,7 @@ def __init__(self):
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
from comfy_controlnet_preprocessors.util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
self.netNetwork = ControlNetHED_Apache2().float().to(model_management.get_torch_device()).eval()
self.netNetwork.load_state_dict(torch.load(modelpath))
Expand Down
2 changes: 1 addition & 1 deletion v11/normalbae/__init__.py
Expand Up @@ -12,14 +12,14 @@
from comfy_controlnet_preprocessors.util import annotator_ckpts_path
import torchvision.transforms as transforms
import model_management
from comfy_controlnet_preprocessors.util import load_file_from_url


class NormalBaeDetector:
def __init__(self):
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/scannet.pt"
modelpath = os.path.join(annotator_ckpts_path, "scannet.pt")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
args = types.SimpleNamespace()
args.mode = 'client'
Expand Down
4 changes: 1 addition & 3 deletions v11/openpose_v11/__init__.py
Expand Up @@ -14,6 +14,7 @@
from .hand import Hand
from .face import Face
from comfy_controlnet_preprocessors import annotator_ckpts_path
from comfy_controlnet_preprocessors.util import load_file_from_url


body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth"
Expand Down Expand Up @@ -48,15 +49,12 @@ def __init__(self):
face_modelpath = os.path.join(annotator_ckpts_path, "facenet.pth")

if not os.path.exists(body_modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(body_model_path, model_dir=annotator_ckpts_path)

if not os.path.exists(hand_modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(hand_model_path, model_dir=annotator_ckpts_path)

if not os.path.exists(face_modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(face_model_path, model_dir=annotator_ckpts_path)

self.body_estimation = Body(body_modelpath)
Expand Down
3 changes: 1 addition & 2 deletions v11/pidinet_v11/__init__.py
Expand Up @@ -6,15 +6,14 @@
import numpy as np
from einops import rearrange
from .model import pidinet
from comfy_controlnet_preprocessors.util import annotator_ckpts_path, safe_step
from comfy_controlnet_preprocessors.util import annotator_ckpts_path, safe_step, load_file_from_url


class PidiNetDetector:
def __init__(self):
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth"
modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
self.netNetwork = pidinet()
self.netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()})
Expand Down
30 changes: 29 additions & 1 deletion v11/pidinet_v11/model.py
Expand Up @@ -10,7 +10,35 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils import img2tensor

def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""

def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img

if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)


nets = {
'baseline': {
Expand Down
4 changes: 1 addition & 3 deletions v11/zoe/__init__.py
Expand Up @@ -9,15 +9,13 @@
from einops import rearrange
from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth
from .zoedepth.utils.config import get_config
from comfy_controlnet_preprocessors.util import annotator_ckpts_path

from comfy_controlnet_preprocessors.util import annotator_ckpts_path, load_file_from_url

class ZoeDetector:
def __init__(self):
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
modelpath = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
conf = get_config("zoedepth", "infer")
model = ZoeDepth.build_from_config(conf)
Expand Down

0 comments on commit 6b07310

Please sign in to comment.