Skip to content

Commit

Permalink
Merge pull request #337 from alibaba/DiffSynth
Browse files Browse the repository at this point in the history
DiffSynth Update
  • Loading branch information
chywang authored Sep 7, 2023
2 parents bc5fcbe + 492efa1 commit cff7463
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 203 deletions.
315 changes: 131 additions & 184 deletions diffusion/DiffSynth/DiffSynth/smoother/PySynthSmoother.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch, os, cv2
import cv2
from PIL import Image, ImageEnhance
import numpy as np
from einops import rearrange, repeat
from tqdm import tqdm
import cupy as cp
from DiffSynth.utils import save_video, save_images


remapping_kernel = cp.RawKernel(r'''
Expand Down Expand Up @@ -269,213 +267,161 @@ def estimate_nnf(self, source_guide, target_guide, source_style, nnf=None):
return nnf, target_style


class LeftVideoGraph:
def __init__(self, n):
self.n = n
self.edges = {}
for i in range(n):
father = self.father(i)
if father<self.n:
self.edges[(i, father)] = None
for j in self.cousin_leaves(i):
if j<self.n:
self.edges[(i, j)] = None

def father(self, x):
y = 1
while x&y:
y <<= 1
return x|y

def cousin(self, x):
y = 1
while (y<<1)<x:
y <<= 1
if (y>>1)>(x^y):
return None
return x^y

def cousin_leaves(self, x):
y = 1
while x&y:
y <<= 1
x -= x & (y - 1)
return range(x+y, x+(y<<1))

def query_middle_node(self, x, y):
for z in range(x+1, y):
if (x, z) in self.edges and (z, y) in self.edges:
return z
return None

def query(self, x):
z_list = []
z = -1
for i in range(10):
y = 1
while z + (y<<1)<=x:
y <<= 1
z += y
z_list.append(z)
if z==x:
break
return z_list

def query_edge(self, level):
edge_list = []
step = 1<<level
for x in range(step-1, self.n, step*2):
y = x + step
if y<self.n:
edge_list.append((x, y))
return edge_list


class NNFCache:
def __init__(self):
pass

def get_nnf_dict(self, graph, frames_guide, frames_style):
nnf_dict = {}
for u, v in tqdm(graph.edges, desc="Estimating NNF"):
nnf, _ = self.patch_matcher.estimate_nnf(
source_guide=frames_guide[u],
target_guide=frames_guide[v],
source_style=frames_style[u]
)
nnf_dict[(u, v)] = nnf.cpu()
return nnf_dict
class VideoWithOperator:
def __init__(self, frames_guide, frames_style, patch_size=21, threads_per_block=8, num_iter=6, gpu_id=0, guide_weight=100.0):
self.frames_guide = frames_guide
self.frames_style = frames_style
image_height, image_width, _ = frames_style[0].shape
self.patch_match_engine = PyramidPatchMatcher(
image_height, image_width, channel=3, patch_size=patch_size,
threads_per_block=threads_per_block, num_iter=num_iter,
gpu_id=gpu_id, guide_weight=guide_weight
)

def remap(self, x, i, j):
source_style, num_blend = x
nnf, target_style = self.patch_match_engine.estimate_nnf(
source_guide=self.frames_guide[i],
target_guide=self.frames_guide[j],
source_style=source_style
)
target_style = target_style.get()
return target_style, num_blend

def blend(self, x):
sum_num_blend = sum([num_blend for style, num_blend in x])
weighted_frames = [style * (num_blend / sum_num_blend) for style, num_blend in x]
mean_frame = np.stack(weighted_frames).sum(axis=0)
return mean_frame, sum_num_blend

def __call__(self, i):
return self.frames_style[i], 1

def __len__(self):
return len(self.frames_style)


class FastBlendingAlgorithm:
def __init__(self, data):
self.data = data
n = len(self.data)
self.remapping_table = [[self.data(i)] for i in range(n)]
self.blending_table = [[self.data(i)] for i in range(n)]
level = 1
while (1<<level)<=n:
for i in tqdm(range((1<<level)-1, n, 1<<level), desc=f"Building remapping table (level-{level})"):
source, target = i - (1<<level-1), i
remapping_result = self.data.remap(self.blending_table[source][-1], source, target)
self.remapping_table[target].append(remapping_result)
blending_result = self.data.blend(self.remapping_table[target])
self.blending_table[target].append(blending_result)
level += 1

def tree_query(self, leftbound, rightbound):
node_list = []
node_index = rightbound
while node_index>=leftbound:
node_level = 0
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
node_level += 1
node_list.append((node_index, node_level))
node_index -= 1<<node_level
return node_list

def query(self, leftbound, rightbound):
node_list = self.tree_query(leftbound, rightbound)
result = []
for node_index, node_level in node_list:
node_value = self.blending_table[node_index][node_level]
if node_index!=rightbound:
node_value = self.data.remap(node_value, node_index, rightbound)
result.append(node_value)
result = self.data.blend(result)
return result


class ImagePostProcessor:
def __init__(self, postprocessing):
self.postprocessing = postprocessing

def postprocessing_contrast(self, style, rate):
style = [ImageEnhance.Contrast(i).enhance(rate) for i in style]
return style

def postprocessing_sharpness(self, style, rate):
style = [ImageEnhance.Sharpness(i).enhance(rate) for i in style]
return style

def __call__(self, images):
for name in self.postprocessing:
rate = self.postprocessing[name]
if name == "contrast":
images = self.postprocessing_contrast(images, rate)
elif name == "sharpness":
images = self.postprocessing_sharpness(images, rate)
return images


class PySynthSmoother:
def __init__(self, patch_size=21, threads_per_block=8, num_iter=6, gpu_id=0, guide_weight=100.0, speed="slowest", window_size=10, postprocessing={}):
self.patch_size = patch_size
self.threads_per_block = threads_per_block
self.num_iter = num_iter
self.gpu_id = gpu_id
self.guide_weight = guide_weight
def __init__(self, speed="slowest", window_size=3, postprocessing={}, ebsynth_config={}):
self.speed = speed
self.window_size = window_size
self.postprocessing = postprocessing
self.postprocessor = ImagePostProcessor(postprocessing)
self.ebsynth_config = ebsynth_config
self.operating_space = "pixel"

def reset(self, image_height, image_width):
self.patch_match_engine = PyramidPatchMatcher(
image_height, image_width, channel=3, patch_size=self.patch_size,
threads_per_block=self.threads_per_block, num_iter=self.num_iter,
gpu_id=self.gpu_id, guide_weight=self.guide_weight
)

def prepare(self, images):
self.frames_guide = images
image_width, image_height = images[0].size
self.reset(image_height, image_width)

def PIL_to_numpy(self, frames):
return [np.array(frame).astype(np.float32)/255 for frame in frames]

def numpy_to_PIL(self, frames):
return [Image.fromarray(np.clip((frame * 255), 0, 255).astype("uint8")) for frame in frames]

def remapping_operator(self, nnf, image):
with cp.cuda.Device(self.gpu_id):
nnf = cp.array(nnf, dtype=cp.int32)
image = cp.array(image, dtype=cp.float32)
image = self.patch_match_engine.apply_nnf_to_image(nnf, image)
image = image.get()
return image

def blending_operator(self, frames):
frame = np.stack(frames).sum(axis=0)
return frame

def smooth_slowest(self, frames_guide, frames_style):
data = VideoWithOperator(frames_guide, frames_style, **self.ebsynth_config)
frames_output = []
for frame_id in tqdm(range(len(frames_style))):
remapped_frames = [frames_style[frame_id]]
for frame_id in tqdm(range(len(data))):
remapped_frames = [data(frame_id)]
for i in range(frame_id - self.window_size, frame_id + self.window_size + 1):
if i<0 or i>=len(frames_style) or i==frame_id:
if i<0 or i>=len(data) or i==frame_id:
continue
_, remapped_frame = self.patch_match_engine.estimate_nnf(frames_guide[i], frames_guide[frame_id], frames_style[i])
remapped_frames.append(remapped_frame.get())
blended_frame = self.blending_operator(remapped_frames) / len(remapped_frames)
remapped_frame = data.remap(data(i), i, frame_id)
remapped_frames.append(remapped_frame)
blended_frame, _ = data.blend(remapped_frames)
frames_output.append(blended_frame)
return frames_output

def remap_and_blend_left(self, frames_guide, frames_style):
n = len(frames_guide)
graph = LeftVideoGraph(n)
# Estimate NNF
nnf_dict = {}
for u, v in tqdm(graph.edges, desc="Estimating NNF"):
nnf, _ = self.patch_match_engine.estimate_nnf(
source_guide=frames_guide[u],
target_guide=frames_guide[v],
source_style=frames_style[u]
)
nnf_dict[(u, v)] = nnf.get()
# remap_table and blend_table
remap_table = [[frames_style[i]] for i in range(n)]
blend_table = [[frames_style[i]] for i in range(n)]
level = 0
while True:
edges = graph.query_edge(level)
level += 1
if len(edges)==0:
break
for u, v in edges:
nnf = nnf_dict[(u, v)]
remaping_result = self.remapping_operator(nnf, blend_table[u][-1])
remap_table[v].append(remaping_result)
blending_result = self.blending_operator(remap_table[v])
blend_table[v].append(blending_result)
# calculate remapping prefix sum
blending_inputs = []
for i in tqdm(range(n), desc="Remapping frames"):
blending_input = []
# sum of 0...i-1
nodes = graph.query(i)
for u in nodes:
if u==i:
if len(remap_table[u])==1:
continue
else:
remaping_result = self.blending_operator(remap_table[u][1:])
else:
nnf = nnf_dict[(u, i)]
remaping_result = self.remapping_operator(nnf, blend_table[u][-1])
blending_input.append(remaping_result)
blending_inputs.append(blending_input)
return blending_inputs

def smooth_fastest(self, frames_guide, frames_style):
n = len(frames_guide)
prefix_sum = self.remap_and_blend_left(frames_guide, frames_style)
suffix_sum = self.remap_and_blend_left(frames_guide[::-1], frames_style[::-1])[::-1]
# left
data = VideoWithOperator(frames_guide, frames_style, **self.ebsynth_config)
algo = FastBlendingAlgorithm(data)
remapped_frames_l = []
for frame_id in tqdm(range(len(data)), desc="Remapping and blending (left part)"):
bound = max(frame_id - self.window_size, 0)
remapped_frames_l.append(algo.query(bound, frame_id))
# right
data = VideoWithOperator(frames_guide[::-1], frames_style[::-1], **self.ebsynth_config)
algo = FastBlendingAlgorithm(data)
remapped_frames_r = []
for frame_id in tqdm(range(len(data)), desc="Remapping and blending (right part)"):
bound = max(frame_id - self.window_size, 0)
remapped_frames_r.append(algo.query(bound, frame_id))
remapped_frames_r = remapped_frames_r[::-1]
# merge
frames_output = []
for i, l, m, r in zip(range(n), prefix_sum, frames_style, suffix_sum):
window_size = min(i + self.window_size, n - 1) - max(i - self.window_size, 0) + 1
frame = self.blending_operator(l + [m] + r) / n
frames_output.append(frame)
data = VideoWithOperator(frames_guide, frames_style, **self.ebsynth_config)
for frame_id in range(len(data)):
frame, _ = data(frame_id)
frame_output, _ = data.blend([
remapped_frames_l[frame_id],
(frame, -1),
remapped_frames_r[frame_id]
])
frames_output.append(frame_output)
return frames_output

def postprocessing_contrast(self, style, rate):
style = [ImageEnhance.Contrast(i).enhance(rate) for i in style]
return style

def postprocessing_sharpness(self, style, rate):
style = [ImageEnhance.Sharpness(i).enhance(rate) for i in style]
return style

def image_postprocessing(self, images):
for name in self.postprocessing:
rate = self.postprocessing[name]
if name == "contrast":
images = self.postprocessing_contrast(images, rate)
elif name == "sharpness":
images = self.postprocessing_sharpness(images, rate)
return images

def smooth(self, frames_style):
frames_guide = self.PIL_to_numpy(self.frames_guide)
frames_style = self.PIL_to_numpy(frames_style)
Expand All @@ -486,5 +432,6 @@ def smooth(self, frames_style):
else:
raise NotImplementedError()
frames_output = self.numpy_to_PIL(frames_output)
frames_output = self.image_postprocessing(frames_output)
frames_output = self.postprocessor(frames_output)
return frames_output

Loading

0 comments on commit cff7463

Please sign in to comment.