Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to run on CPU and MPS #36

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions kandinsky2/kandinsky2_1_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
):
self.config = config
self.device = device
if device != "cuda":
self.config["model_config"]["use_fp16"] = False
self.use_fp16 = self.config["model_config"]["use_fp16"]
self.task_type = task_type
self.clip_image_size = config["clip_image_size"]
Expand All @@ -54,7 +56,7 @@ def __init__(
clip_mean,
clip_std,
)
self.prior.load_state_dict(torch.load(prior_path), strict=False)
self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False)
if self.use_fp16:
self.prior = self.prior.half()
self.text_encoder = TextEncoder(**self.config["text_enc_params"])
Expand Down Expand Up @@ -88,7 +90,7 @@ def __init__(

self.config["model_config"]["cache_text_emb"] = True
self.model = create_model(**self.config["model_config"])
self.model.load_state_dict(torch.load(model_path))
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
if self.use_fp16:
self.model.convert_to_fp16()
self.image_encoder = self.image_encoder.half()
Expand Down Expand Up @@ -261,12 +263,14 @@ def denoised_fun(x):
model=model_fn,
old_diffusion=diffusion,
schedule="linear",
device=self.device,
)
elif sampler == "plms_sampler":
sampler = PLMSSampler(
model=model_fn,
old_diffusion=diffusion,
schedule="linear",
device=self.device,
)
else:
raise ValueError("Only ddim_sampler and plms_sampler is available")
Expand Down
2 changes: 1 addition & 1 deletion kandinsky2/model/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
res = th.from_numpy(arr).to(dtype=th.float32).to(device=timesteps.device)[timesteps]
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
30 changes: 15 additions & 15 deletions kandinsky2/model/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,18 @@ def extract_into_tensor(a, t, x_shape):


class DDIMSampler(object):
def __init__(self, model, old_diffusion, schedule="linear", **kwargs):
def __init__(self, model, old_diffusion, schedule="linear", device="cuda", **kwargs):
super().__init__()
self.model = model
self.old_diffusion = old_diffusion
self.ddpm_num_timesteps = 1000
self.schedule = schedule
self.device = device

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.device):
attr = attr.to(dtype=torch.float32).to(torch.device(self.device))
setattr(self, name, attr)

def make_schedule(
Expand All @@ -98,7 +99,7 @@ def make_schedule(
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to("cuda")
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)

self.register_buffer(
"betas", to_torch(torch.from_numpy(self.old_diffusion.betas))
Expand Down Expand Up @@ -223,10 +224,9 @@ def ddim_sampling(
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = "cuda"
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
img = torch.randn(shape, device=self.device)
else:
img = x_T

Expand Down Expand Up @@ -258,7 +258,7 @@ def ddim_sampling(

for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts = torch.full((b,), step, device=self.device, dtype=torch.long)

outs = self.p_sample_ddim(
img,
Expand Down Expand Up @@ -332,17 +332,18 @@ def p_sample_ddim(


class PLMSSampler(object):
def __init__(self, model, old_diffusion, schedule="linear", **kwargs):
def __init__(self, model, old_diffusion, schedule="linear", device="cuda", **kwargs):
super().__init__()
self.model = model
self.old_diffusion = old_diffusion
self.ddpm_num_timesteps = 1000
self.schedule = schedule
self.device = device

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.device):
attr = attr.to(dtype=torch.float32).to(torch.device(self.device))
setattr(self, name, attr)

def make_schedule(
Expand All @@ -366,7 +367,7 @@ def make_schedule(
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to("cuda")
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)

self.register_buffer(
"betas", to_torch(torch.from_numpy(self.old_diffusion.betas))
Expand Down Expand Up @@ -492,10 +493,9 @@ def plms_sampling(
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = "cuda"
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
img = torch.randn(shape, device=self.device)
else:
img = x_T

Expand Down Expand Up @@ -529,11 +529,11 @@ def plms_sampling(

for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
device=self.device,
dtype=torch.long,
)

Expand Down
2 changes: 1 addition & 1 deletion kandinsky2/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
res = torch.from_numpy(arr).to(dtype=torch.float32).to(device=timesteps.device)[timesteps]
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
Expand Down