Skip to content

Commit

Permalink
Merge branch 'main' of github.com:7eu7d7/HCP-Diffusion into main
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Jul 20, 2023
2 parents 326999b + 76a7583 commit 751aa09
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 47 deletions.
4 changes: 3 additions & 1 deletion cfgs/train/examples/DreamArtist++.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ data:
data_source1:
img_root: 'imgs/v15'
prompt_template: 'prompt_tuning_template/caption.txt'
caption_file: 'imgs/v15/image_captions.json'
caption_file:
_targe_: hcpdiff.data.JsonCaptionLoader
path: 'imgs/v15/image_captions.json'
att_mask: null

word_names:
Expand Down
15 changes: 15 additions & 0 deletions cfgs/train/examples/add_logger_tensorboard_wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_: [cfgs/train/train_base.yaml]

logger:
-
_target_: hcpdiff.loggers.CLILogger
_partial_: True
out_path: 'train.log'
log_step: 20
- _target_: hcpdiff.loggers.TBLogger
_partial_: True
out_path: 'tblog/'
log_step: 5
- _target_: hcpdiff.loggers.WanDBLogger
_partial_: True
log_step: 5
2 changes: 1 addition & 1 deletion hcpdiff/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .cond_pair_dataset import TextImageCondPairDataset
from .bucket import BaseBucket, FixedBucket, RatioBucket
from .utils import CycleData

from .caption_loader import JsonCaptionLoader, TXTCaptionLoader

class DataGroup:
def __init__(self, loader_list, loss_weights):
Expand Down
8 changes: 4 additions & 4 deletions hcpdiff/data/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def crop_resize(self, image, size, mask_interp=cv2.INTER_CUBIC):
return image

class FixedBucket(BaseBucket):
def __init__(self, target_size: Union[Tuple[int, int], int] = 512):
def __init__(self, target_size: Union[Tuple[int, int], int] = 512, **kwargs):
self.target_size = (target_size, target_size) if isinstance(target_size, int) else target_size

def build(self, bs: int, img_root_list: List[str]):
Expand Down Expand Up @@ -218,14 +218,14 @@ def __len__(self):

@classmethod
def from_ratios(cls, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, ratio_max: float = 4,
pre_build_bucket: str = None):
pre_build_bucket: str = None, **kwargs):
arb = cls(target_area, step_size, num_bucket, pre_build_bucket=pre_build_bucket)
arb.ratio_max = ratio_max
arb._build = arb.build_buckets_from_ratios
return arb

@classmethod
def from_files(cls, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None):
def from_files(cls, target_area: int = 640*640, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None, **kwargs):
arb = cls(target_area, step_size, num_bucket, pre_build_bucket=pre_build_bucket)
arb._build = arb.build_buckets_from_images
return arb
Expand Down Expand Up @@ -265,7 +265,7 @@ def crop_resize(self, image, size):
return pad_crop_fix(image, size)

@classmethod
def from_files(cls, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None):
def from_files(cls, step_size: int = 8, num_bucket: int = 10, pre_build_bucket: str = None, **kwargs):
arb = cls(step_size, num_bucket, pre_build_bucket=pre_build_bucket)
arb._build = arb.build_buckets_from_images
return arb
40 changes: 40 additions & 0 deletions hcpdiff/data/caption_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json
import os
import glob
import yaml

class BaseCaptionLoader:
def __init__(self, path):
self.path = path

def load(self):
raise NotImplementedError()

class JsonCaptionLoader(BaseCaptionLoader):
def load(self):
with open(self.path, 'r', encoding='utf-8') as f:
return json.loads(f.read())

class YamlCaptionLoader(BaseCaptionLoader):
def load(self):
with open(self.path, 'r', encoding='utf-8') as f:
return yaml.load(f.read(), Loader=yaml.FullLoader)

class TXTCaptionLoader(BaseCaptionLoader):
def load(self):
txt_files = os.listdir(self.path)
captions = {}
for file in txt_files:
with open(os.path.join(self.path, file), 'r', encoding='utf-8') as f:
captions[file] = f.read().strip()
return captions

def auto_caption_loader(path):
if len(glob.glob('*.json'))>0:
return JsonCaptionLoader(path)
elif len(glob.glob('*.yaml'))>0:
return YamlCaptionLoader(path)
elif len(glob.glob('*.txt'))>0:
return TXTCaptionLoader(path)
else:
raise FileNotFoundError('Caption file not found')
42 changes: 18 additions & 24 deletions hcpdiff/data/pair_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
:Licence: Apache-2.0
"""

import json
import os.path
from argparse import Namespace

import cv2
import torch
import yaml
from PIL import Image
from torch.utils.data import Dataset
from tqdm.auto import tqdm
Expand All @@ -23,7 +21,7 @@
from hcpdiff.utils.img_size_tool import types_support
from hcpdiff.utils.utils import get_file_name, get_file_ext
from .bucket import BaseBucket

from .caption_loader import BaseCaptionLoader, auto_caption_loader

class TextImagePairDataset(Dataset):
"""
Expand Down Expand Up @@ -68,17 +66,13 @@ def load_image(self, path):
image = canvas
return image.convert("RGB")

def load_captions(self, caption_file):
def load_captions(self, caption_file: Union[str, BaseCaptionLoader]):
if caption_file is None:
return dict()
elif caption_file.endswith('.json'):
with open(caption_file, 'r', encoding='utf-8') as f:
return json.loads(f.read())
elif caption_file.endswith('.yaml'):
with open(caption_file, 'r', encoding='utf-8') as f:
return yaml.load(f.read(), Loader=yaml.FullLoader)
return {}
elif isinstance(caption_file, str):
return auto_caption_loader(caption_file).load()
else:
return dict()
return caption_file.load()

def load_template(self, template_file):
with open(template_file, 'r', encoding='utf-8') as f:
Expand All @@ -95,31 +89,31 @@ def cache_latents(self, vae, weight_dtype, device, show_prog=True):
data = self.load_data(path, size)
image = data['img'].unsqueeze(0).to(device, dtype=weight_dtype)
latents = vae.encode(image).latent_dist.sample().squeeze(0)
data['img'] = (latents * 0.18215).cpu()
data['img'] = (latents*0.18215).cpu()
self.latents[img_name] = data

def get_att_map(self, img_root, img_name):
if img_name not in self.source_dict[img_root].att_mask_path:
return None
att_mask = Image.open(self.source_dict[img_root].att_mask_path[img_name]).convert("L")
np_mask = np.array(att_mask).astype(float)
np_mask[np_mask <= 127 + 0.1] = (np_mask[np_mask <= 127 + 0.1] / 127.)
np_mask[np_mask > 127] = ((np_mask[np_mask > 127] - 127) / 128.) * 4 + 1
np_mask[np_mask<=127+0.1] = (np_mask[np_mask<=127+0.1]/127.)
np_mask[np_mask>127] = ((np_mask[np_mask>127]-127)/128.)*4+1
return np_mask

def load_data(self, path, size):
img_root, img_name = os.path.split(path)
image = self.load_image(path)
att_mask = self.get_att_map(img_root, get_file_name(img_name))
if att_mask is None:
data = self.bucket.crop_resize({"img": image}, size)
data = self.bucket.crop_resize({"img":image}, size)
image = self.source_dict[img_root].image_transforms(data['img']) # resize to bucket size
att_mask = torch.ones((size[1] // 8, size[0] // 8))
att_mask = torch.ones((size[1]//8, size[0]//8))
else:
data = self.bucket.crop_resize({"img": image, "mask": att_mask}, size)
data = self.bucket.crop_resize({"img":image, "mask":att_mask}, size)
image = self.source_dict[img_root].image_transforms(data['img'])
att_mask = torch.tensor(cv2.resize(att_mask, (size[0] // 8, size[1] // 8), interpolation=cv2.INTER_LINEAR))
return {'img': image, 'mask': att_mask}
att_mask = torch.tensor(cv2.resize(att_mask, (size[0]//8, size[1]//8), interpolation=cv2.INTER_LINEAR))
return {'img':image, 'mask':att_mask}

def __len__(self):
return len(self.bucket)
Expand All @@ -134,14 +128,14 @@ def __getitem__(self, index):
data = self.latents[img_name]

caption_ist = self.source_dict[img_root].caption_dict[img_name] if img_name in \
self.source_dict[img_root].caption_dict else None
self.source_dict[img_root].caption_dict else None
prompt_ist = self.source_dict[img_root].tag_transforms(
{'prompt': random.choice(self.source_dict[img_root].prompt_template), 'caption': caption_ist})['prompt']
{'prompt':random.choice(self.source_dict[img_root].prompt_template), 'caption':caption_ist})['prompt']

# tokenize Sp or (Sn, Sp)
prompt_ids = self.tokenizer(
prompt_ist, truncation=True, padding="max_length", return_tensors="pt",
max_length=self.tokenizer.model_max_length * self.tokenizer_repeats).input_ids.squeeze()
max_length=self.tokenizer.model_max_length*self.tokenizer_repeats).input_ids.squeeze()

data['prompt'] = prompt_ids

Expand Down Expand Up @@ -180,4 +174,4 @@ def collate_fn(batch):
datas['cond'] = torch.stack(datas['cond'])
datas['prompt'] = torch.stack(sn_list)

return datas
return datas
12 changes: 8 additions & 4 deletions hcpdiff/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Dict, Any

import os
import wandb
from PIL import Image

from .base_logger import BaseLogger


class WanDBLogger(BaseLogger):
def __init__(self, exp_dir, out_path, enable_log_image=True, project='hcp-diffusion'):
super().__init__(exp_dir, out_path, enable_log_image)
def __init__(self, exp_dir, out_path=None, enable_log_image=True, project='hcp-diffusion', log_step=10):
super().__init__(exp_dir, out_path, enable_log_image, log_step)
if exp_dir is not None: # exp_dir is only available in local main process
wandb.init(project=project)
wandb.init(project=project, name=os.path.basename(exp_dir))
wandb.save(os.path.join(exp_dir, 'cfg.yaml'), base_path=exp_dir)
else:
self.writer = None
self.disable()
Expand All @@ -19,7 +21,9 @@ def _info(self, info):
pass

def _log(self, datas: Dict[str, Any], step: int = 0):
wandb.log({k: v['data'] for k, v in datas.items()}, step=step)
for k, v in datas.items():
if len(v['data']) == 1:
wandb.log({k: v['data'][0]}, step=step)

def _log_image(self, imgs: Dict[str, Image.Image], step: int = 0):
wandb.log({next(iter(imgs.keys())): list(imgs.values())}, step=step)
21 changes: 11 additions & 10 deletions hcpdiff/noise/pyramid_noise.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .noise_base import NoiseBase
import random

import torch
from torch.nn import functional as F
import random

from .noise_base import NoiseBase

class PyramidNoiseScheduler(NoiseBase):
def __init__(self, base_scheduler, level: int = 10, discount: float = 0.9, step: float = 2., mode: str = 'bilinear'):
Expand All @@ -13,20 +14,20 @@ def __init__(self, base_scheduler, level: int = 10, discount: float = 0.9, step:
self.discount = discount

def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
with torch.no_grad():
b, c, h, w = noise.shape
for i in range(1, self.level):
r = random.random() * 2 + self.step
wn, hn = max(1, int(w / (r ** i))), max(1, int(h / (r ** i)))
noise += F.interpolate(torch.randn(b, c, hn, wn).to(noise), (h, w), None, self.mode) * (self.discount ** i)
r = random.random()*2+self.step
wn, hn = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
noise += F.interpolate(torch.randn(b, c, hn, wn).to(noise), (h, w), None, self.mode)*(self.discount**i)
if wn == 1 or hn == 1:
break
noise = noise / noise.std()
noise = noise/noise.std()
return self.base_scheduler.add_noise(original_samples, noise, timesteps)

# if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion hcpdiff/train_ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def __init__(self, cfgs_raw):
if self.is_local_main_process:
self.exp_dir = os.path.join(self.cfgs.exp_dir, f'{time.strftime("%Y-%m-%d-%H-%M-%S")}')
os.makedirs(os.path.join(self.exp_dir, 'ckpts/'), exist_ok=True)
self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=self.exp_dir) for builder in self.cfgs.logger])
with open(os.path.join(self.exp_dir, 'cfg.yaml'), 'w', encoding='utf-8') as f:
f.write(OmegaConf.to_yaml(cfgs_raw))
self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=self.exp_dir) for builder in self.cfgs.logger])
else:
self.loggers: LoggerGroup = LoggerGroup([builder(exp_dir=None) for builder in self.cfgs.logger])

Expand Down
3 changes: 2 additions & 1 deletion hcpdiff/utils/net_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def save_emb(path, emb: torch.Tensor, replace=False):
if os.path.exists(path) and not replace:
raise FileExistsError(f'embedding "{name}" already exist.')
name = name[:name.rfind('.')]
torch.save({'emb_params':emb, 'name':name}, path)
#torch.save({'emb_params':emb, 'name':name}, path)
torch.save({'string_to_param':{'*':emb}, 'name':name}, path)

def hook_compile(model):
named_modules = {k:v for k, v in model.named_modules()}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_data_files(data_dir, prefix=''):
setuptools.setup(
name="hcpdiff",
py_modules=["hcpdiff"],
version="0.5.3",
version="0.5.4",
author="Ziyi Dong",
author_email="dzy7eu7d7@gmail.com",
description="A universal Stable-Diffusion toolbox",
Expand Down

0 comments on commit 751aa09

Please sign in to comment.