Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
auto download/select modules
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Jul 2, 2023
1 parent 96fe570 commit b8f4392
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 81 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
tmp/
ait_tmp/
modules/
*.xz
*.png
test*.py
Expand Down
200 changes: 126 additions & 74 deletions AITemplate/AITemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,19 @@ def cleanup_temp_library(prefix="ait", extension=".so"):
current_loaded_model = None
vram_state = None

AITemplate = AIT()
modules_path = str(modules_dir).replace("\\", "/")
AITemplate = AIT(modules_path)
AIT_OS = "windows" if os.name == "nt" else "linux"
cuda = torch.cuda.get_device_capability()
if cuda[0] == 7 and cuda[1] == 5:
AIT_CUDA = "sm75"
elif cuda[0] == 7 and cuda[1] == 0:
AIT_CUDA = "sm70"
elif cuda[0] >= 8:
AIT_CUDA = "sm80"
else:
raise ValueError(f"Unsupported CUDA version {cuda[0]}.{cuda[1]}")


def get_full_path(folder_name, filename):
global folder_names_and_paths
Expand Down Expand Up @@ -129,7 +141,7 @@ def get_filename_list(folder_name):
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
use_aitemplate = isinstance(model, tuple)
if use_aitemplate:
model, keep_loaded, aitemplate_path = model
model, keep_loaded = model
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]

Expand All @@ -150,7 +162,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
previewer = latent_preview.get_previewer(device, model.model.latent_format)

if use_aitemplate:
model = model, keep_loaded, aitemplate_path
model = model, keep_loaded

pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
Expand Down Expand Up @@ -181,6 +193,14 @@ def maximum_batch_area():

comfy.model_management.maximum_batch_area = maximum_batch_area

def load_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning"""
control_nets = comfy.sample.get_models_from_cond(positive, "control") + comfy.sample.get_models_from_cond(negative, "control")
gligen = comfy.sample.get_models_from_cond(positive, "gligen") + comfy.sample.get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_nets + gligen
return models


def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
global current_loaded_model
Expand All @@ -189,8 +209,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
global vram_st
use_aitemplate = isinstance(model, tuple)
if use_aitemplate:
model, keep_loaded, aitemplate_path = model
device = comfy.model_management.get_torch_device()
model, keep_loaded = model
device = torch.device("cpu")
else:
device = comfy.model_management.get_torch_device()

has_loaded = False
if use_aitemplate:
Expand All @@ -210,42 +232,44 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
if "control" in x:
control = True
break
if context_dim == 1024 and "v2" not in aitemplate_path:
raise Exception("You are trying to use a SD2.x model with a SD1.x AITemplate. Please use a SD2.x AITemplate instead.")
elif context_dim == 768 and "v1" not in aitemplate_path:
raise Exception("You are trying to use a SD1.x model with a SD2.x AITemplate. Please use a SD1.x AITemplate instead.")
if context_dim != 1024 and context_dim != 768:
raise Exception(f"Unsupported context dimension: {context_dim}. Currently only SD1.x and SD2.x are supported.")
sd = "v1"
if context_dim == 1024:
sd = "v2"
batch_size = noise.shape[0]
resolution = max(noise.shape[2], noise.shape[3]) * 8
model_type = "unet"
if control:
if "control_unet" not in aitemplate_path:
raise Exception("You are trying to use ControlNet with a regular UNet module. Please use a control_unet instead.")
else:
if "control_unet" in aitemplate_path:
raise Exception("You are trying to use a regular UNet module with a control_unet. Please use a regular UNet module instead.")
if "unet" not in AITemplate.modules or keep_loaded == "disable":
AITemplate.modules["unet"] = AITemplate.loader.load(aitemplate_path)
model_type = "control_unet"
module = AITemplate.loader.filter_modules(AIT_OS, sd, AIT_CUDA, batch_size, resolution, model_type)[0]
if keep_loaded == "disable":
if len(AITemplate.unet.keys()) > 0:
to_delete = list(AITemplate.unet.keys())
for x in to_delete:
del AITemplate.unet[x]
if module['sha256'] not in AITemplate.unet:
AITemplate.unet[module['sha256']] = AITemplate.loader.load_module(module['sha256'], module['url'])
has_loaded = True

if noise_mask is not None:
noise_mask = comfy.sample.prepare_mask(noise_mask, noise.shape, device)

if use_aitemplate:
apply_aitemplate_weights = has_loaded or current_loaded_model != model or ("unet" in AITemplate.modules and vram_st != comfy.model_management.VRAMState.DISABLED)
if vram_state is None:
vram_state = vram_st
vram_st = comfy.model_management.VRAMState.DISABLED
apply_aitemplate_weights = has_loaded or current_loaded_model != model or keep_loaded == "disable"
try:
model.patch_model()
except Exception as e:
model.unpatch_model()
raise e
else:
if vram_state is not None:
vram_st = vram_state
comfy.model_management.load_model_gpu(model)
comfy.model_management.load_model_gpu(model)
real_model = model.model

if use_aitemplate:
current_loaded_model = model
real_model.alphas_cumprod = real_model.alphas_cumprod.float()
if apply_aitemplate_weights:
AITemplate.modules["unet"] = AITemplate.loader.apply_unet(
aitemplate_module=AITemplate.modules["unet"],
AITemplate.unet[module['sha256']] = AITemplate.loader.apply_unet(
aitemplate_module=AITemplate.unet[module['sha256']],
unet=AITemplate.loader.compvis_unet(real_model.state_dict()),
in_channels=real_model.diffusion_model.in_channels,
conv_in_key="conv_in_weight",
Expand All @@ -257,11 +281,11 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
positive_copy = comfy.sample.broadcast_cond(positive, noise.shape[0], device)
negative_copy = comfy.sample.broadcast_cond(negative, noise.shape[0], device)

models = comfy.sample.load_additional_models(positive, negative)
models = load_additional_models(positive, negative)

sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
if use_aitemplate:
model_wrapper = AITemplateModelWrapper(AITemplate.modules["unet"], real_model.alphas_cumprod)
model_wrapper = AITemplateModelWrapper(AITemplate.unet[module['sha256']], real_model.alphas_cumprod)
sampler.model_denoise = comfy.samplers.CFGNoisePredictor(model_wrapper)
if real_model.parameterization == "v":
sampler.model_wrap = comfy.samplers.CompVisVDenoiser(sampler.model_denoise, quantize=True)
Expand All @@ -276,8 +300,12 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
comfy.sample.cleanup_additional_models(models)

if use_aitemplate and keep_loaded == "disable":
AITemplate.modules.pop("unet")
del AITemplate.unet[module['sha256']]
del sampler
controlnet_keys = list(AITemplate.controlnet.keys())
for x in controlnet_keys:
del AITemplate.controlnet[x]
AITemplate.control_net = None
torch.cuda.empty_cache()
current_loaded_model = None

Expand All @@ -290,7 +318,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
class ControlNet:
def __init__(self, control_model, global_average_pooling=False, device=None):
global AITemplate
if "controlnet" in AITemplate.modules:
if AITemplate.control_net is not None:
self.aitemplate = True
else:
self.aitemplate = None
Expand All @@ -308,10 +336,23 @@ def aitemplate_controlnet(
self, latent_model_input, timesteps, encoder_hidden_states, controlnet_cond
):
global AITemplate
batch = latent_model_input.shape[0] / 2
resolution = max(latent_model_input.shape[2], latent_model_input.shape[3]) * 8
control_net_module = None
if len(AITemplate.controlnet.keys()) == 0:
module = AITemplate.loader.filter_modules(AIT_OS, "v1", AIT_CUDA, batch, resolution, "controlnet")[0]
AITemplate.controlnet[module['sha256']] = AITemplate.loader.load_module(module['sha256'], module['url'])
AITemplate.controlnet[module['sha256']] = AITemplate.loader.apply_controlnet(
aitemplate_module=AITemplate.controlnet[module['sha256']],
controlnet=AITemplate.loader.compvis_controlnet(self.control_model.state_dict())
)
control_net_module = module['sha256']
else:
control_net_module = list(AITemplate.controlnet.keys())[0]
if self.aitemplate is None:
raise RuntimeError("No aitemplate loaded")
return controlnet_inference(
exe_module=AITemplate.modules["controlnet"],
exe_module=AITemplate.controlnet[control_net_module],
latent_model_input=latent_model_input,
timesteps=timesteps,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -407,19 +448,15 @@ class AITemplateLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"aitemplate_module": (filter_files_contains(get_filename_list("aitemplate"), set(["unet"])), ),
"keep_loaded": (["enable", "disable"], ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_aitemplate"

CATEGORY = "loaders"

def load_aitemplate(self, model, aitemplate_module, keep_loaded):
global AITemplate
aitemplate_path = get_full_path("aitemplate", aitemplate_module)
AITemplate.modules["unet"] = AITemplate.loader.load(aitemplate_path)
return ((model,keep_loaded,aitemplate_path),)
def load_aitemplate(self, model, keep_loaded):
return ((model,keep_loaded),)



Expand All @@ -429,7 +466,6 @@ def INPUT_TYPES(s):
return {"required": {
"pixels": ("IMAGE", ),
"vae": ("VAE", ),
"aitemplate_module": (filter_files_contains(get_filename_list("aitemplate"), set(["vae_encode"])), ),
"keep_loaded": (["enable", "disable"], ),
}}
RETURN_TYPES = ("LATENT",)
Expand All @@ -447,24 +483,31 @@ def vae_encode_crop_pixels(pixels):
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels

def encode(self, vae, pixels, aitemplate_module, keep_loaded):
def encode(self, vae, pixels, keep_loaded):
global AITemplate
if "vae_encode" not in AITemplate.modules:
aitemplate_path = get_full_path("aitemplate", aitemplate_module)
AITemplate.modules["vae_encode"] = AITemplate.loader.load(aitemplate_path)
AITemplate.modules["vae_encode"] = AITemplate.loader.apply_vae(
aitemplate_module=AITemplate.modules["vae_encode"],
resolution = max(pixels.shape[1], pixels.shape[2])
model_type = "vae_encode"
if keep_loaded == "disable":
if len(AITemplate.vae.keys()) > 0:
to_delete = list(AITemplate.vae.keys())
for key in to_delete:
del AITemplate.vae[key]
module = AITemplate.loader.filter_modules(AIT_OS, "v1", AIT_CUDA, 1, resolution, model_type)[0]
if module["sha256"] not in AITemplate.vae:
AITemplate.vae[module["sha256"]] = AITemplate.loader.load_module(module["sha256"], module["url"])
AITemplate.vae[module["sha256"]] = AITemplate.loader.apply_vae(
aitemplate_module=AITemplate.vae[module["sha256"]],
vae=AITemplate.loader.compvis_vae(vae.first_stage_model.state_dict()),
encoder=True,
)
pixels = self.vae_encode_crop_pixels(pixels)
pixels = pixels[:,:,:,:3]
pixels = pixels.movedim(-1, 1)
pixels = 2. * pixels - 1.
samples = vae_inference(AITemplate.modules["vae_encode"], pixels, encoder=True)
samples = vae_inference(AITemplate.vae[module["sha256"]], pixels, encoder=True)
samples = samples.cpu()
if keep_loaded == "disable":
AITemplate.modules.pop("vae_encode")
del AITemplate.vae[module["sha256"]]
torch.cuda.empty_cache()
return ({"samples":samples}, )

Expand All @@ -478,21 +521,27 @@ def INPUT_TYPES(s):
"vae": ("VAE", ),
"mask": ("MASK", ),
"grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),
"aitemplate_module": (filter_files_contains(get_filename_list("aitemplate"), set(["vae_encode"])), ),
"keep_loaded": (["enable", "disable"], ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"

CATEGORY = "latent/inpaint"

def encode(self, vae, pixels, mask, aitemplate_module, keep_loaded, grow_mask_by=6):
def encode(self, vae, pixels, mask, keep_loaded, grow_mask_by=6):
global AITemplate
if "vae_encode" not in AITemplate.modules:
aitemplate_path = get_full_path("aitemplate", aitemplate_module)
AITemplate.modules["vae_encode"] = AITemplate.loader.load(aitemplate_path)
AITemplate.modules["vae_encode"] = AITemplate.loader.apply_vae(
aitemplate_module=AITemplate.modules["vae_encode"],
resolution = max(pixels.shape[1], pixels.shape[2])
model_type = "vae_encode"
if keep_loaded == "disable":
if len(AITemplate.vae.keys()) > 0:
to_delete = list(AITemplate.vae.keys())
for key in to_delete:
del AITemplate.vae[key]
module = AITemplate.loader.filter_modules(AIT_OS, "v1", AIT_CUDA, 1, resolution, model_type)[0]
if module["sha256"] not in AITemplate.vae:
AITemplate.vae[module["sha256"]] = AITemplate.loader.load_module(module["sha256"], module["url"])
AITemplate.vae[module["sha256"]] = AITemplate.loader.apply_vae(
aitemplate_module=AITemplate.vae[module["sha256"]],
vae=AITemplate.loader.compvis_vae(vae.first_stage_model.state_dict()),
encoder=True,
)
Expand Down Expand Up @@ -524,10 +573,10 @@ def encode(self, vae, pixels, mask, aitemplate_module, keep_loaded, grow_mask_by
pixels = pixels[:,:,:,:3]
pixels = pixels.movedim(-1, 1)
pixels = 2. * pixels - 1.
samples = vae_inference(AITemplate.modules["vae_encode"], pixels, encoder=True)
samples = vae_inference(AITemplate.vae[module["sha256"]], pixels, encoder=True)
samples = samples.cpu()
if keep_loaded == "disable":
AITemplate.modules.pop("vae_encode")
del AITemplate.vae[module["sha256"]]
torch.cuda.empty_cache()
return ({"samples":samples, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )

Expand All @@ -538,7 +587,6 @@ def INPUT_TYPES(s):
return {"required":
{
"vae": ("VAE",),
"aitemplate_module": (filter_files_contains(get_filename_list("aitemplate"), set(["vae_64"])), ),
"keep_loaded": (["enable", "disable"], ),
"samples": ("LATENT", ), "vae": ("VAE", )
}
Expand All @@ -548,18 +596,25 @@ def INPUT_TYPES(s):

CATEGORY = "latent"

def decode(self, vae, aitemplate_module, keep_loaded, samples):
def decode(self, vae, keep_loaded, samples):
global AITemplate
if "vae" not in AITemplate.modules:
aitemplate_path = get_full_path("aitemplate", aitemplate_module)
AITemplate.modules["vae"] = AITemplate.loader.load(aitemplate_path)
AITemplate.modules["vae"] = AITemplate.loader.apply_vae(
aitemplate_module=AITemplate.modules["vae"],
resolution = max(samples["samples"].shape[2], samples["samples"].shape[3]) * 8
model_type = "vae"
module = AITemplate.loader.filter_modules(AIT_OS, "v1", AIT_CUDA, 1, resolution, model_type)[0]
if keep_loaded == "disable":
if len(AITemplate.vae.keys()) > 0:
to_delete = list(AITemplate.vae.keys())
for key in to_delete:
del AITemplate.vae[key]
if module["sha256"] not in AITemplate.vae:
AITemplate.vae[module["sha256"]] = AITemplate.loader.load_module(module["sha256"], module["url"])
AITemplate.vae[module["sha256"]] = AITemplate.loader.apply_vae(
aitemplate_module=AITemplate.vae[module["sha256"]],
vae=AITemplate.loader.compvis_vae(vae.first_stage_model.state_dict()),
)
output = (torch.clamp((vae_inference(AITemplate.modules["vae"], samples["samples"]) + 1.0) / 2.0, min=0.0, max=1.0).cpu().movedim(1,-1), )
output = (torch.clamp((vae_inference(AITemplate.vae[module["sha256"]], samples["samples"]) + 1.0) / 2.0, min=0.0, max=1.0).cpu().movedim(1,-1), )
if keep_loaded == "disable":
AITemplate.modules.pop("vae")
del AITemplate.vae[module["sha256"]]
torch.cuda.empty_cache()
return output

Expand All @@ -568,24 +623,21 @@ class AITemplateControlNetLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net": ("CONTROL_NET",),
"aitemplate_module": (filter_files_contains(get_filename_list("aitemplate"), set(["controlnet"])), ),
"keep_loaded": (["enable", "disable"], )
}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_aitemplate_controlnet"

CATEGORY = "loaders"

def load_aitemplate_controlnet(self, control_net, aitemplate_module):
def load_aitemplate_controlnet(self, control_net, keep_loaded):
global AITemplate
aitemplate_path = get_full_path("aitemplate", aitemplate_module)
AITemplate.modules["controlnet"] = AITemplate.loader.load(aitemplate_path)
AITemplate.modules["controlnet"] = AITemplate.loader.apply_controlnet(
aitemplate_module=AITemplate.modules["controlnet"],
controlnet=AITemplate.loader.compvis_controlnet(control_net.control_model.state_dict())
)
AITemplate.control_net = keep_loaded
control_net.control_model = control_net.control_model.to("cpu")
control_net.device = torch.device("cpu")
torch.cuda.empty_cache()
return (control_net,)


class AITemplateEmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
Expand Down
Loading

0 comments on commit b8f4392

Please sign in to comment.