In [None]:
import cv2
import numpy as np
import torch
from accelerate import PartialState
from IPython.display import display
from PIL import Image

from attn_map_utils import register_cross_attention_hook
from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler
from diffusers.utils import load_image
from hi_sam.text_segmentation import make_text_segmentation_args
from text_diffuser.generate_mask_only import gen_mask_only
from text_diffuser.pipeline_text_diffuser_sd15 import StableDiffusionPipeline
from text_diffuser.t_diffusers.unet_2d_condition import UNet2DConditionModel

# input_image = Image.open("text_diffuser/assets/test01.jpeg").convert("RGB").resize((512,512))
hf_dataset_base_url = "https://huggingface.co/datasets/GoGiants1/TMDBEval500/resolve/main/TMDBEval500/images/"
input_image = load_image(hf_dataset_base_url + '3.jpg')
guidance_scale = 7

""" Change the text in the original image by coordinates and  """

sample_text="MLVU Project"
prompt = "a tiger and a lion, talk together"



# for original_input.jpeg. 110, 500에서 가장 가까운 mask의 글자를 바꾼다.
coordinates=[[256, 256]] 
arg_textseg = make_text_segmentation_args(
    model_type='vit_l',
    checkpoint_path='sam_tss_l_hiertext.pth',
    input_size=input_image.size,
    hier_det=False,
)

arg_maskgen = make_text_segmentation_args(
    model_type='vit_h',
    checkpoint_path='word_detection_totaltext.pth',
    input_size=input_image.size,
    hier_det=True,
)

out = gen_mask_only(input_image, sample_text=sample_text, coordinates=coordinates, arg_textseg=arg_textseg, arg_maskgen=arg_maskgen)


img = out
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

_, binary_tss = cv2.threshold(
    gray, 50, 255, cv2.THRESH_BINARY
)

_, binary_bbox = cv2.threshold(
    gray, 200, 255, cv2.THRESH_BINARY
)

binary_tss_pil = Image.fromarray(binary_tss, 'L')
binary_bbox_pil = Image.fromarray(binary_bbox, 'L')



In [None]:
td_ckpt = "GoGiants1/td-unet15"


unet = UNet2DConditionModel.from_pretrained(
    td_ckpt,
    subfolder="unet",
)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    unet=unet,
    vae=vae,
    safety_checker=None,
    torch_dtype=torch.float32,
)
# distributed_state = PartialState()
# pipe.to(distributed_state.device)

pipe.enable_model_cpu_offload()

# pipe.to("cuda")

pipe.load_ip_adapter(
    "h94/IP-Adapter",
    subfolder=[
        "models",
        "models",
    ],
    weight_name=[
        # "ip-adapter_sd15.bin",
        # "ip-adapter_sd15.bin",
        "ip-adapter-plus_sd15.safetensors",
        "ip-adapter-plus_sd15.safetensors",
        
    ],
)
pipe.set_ip_adapter_scale(0.25)

# pipe.unet = register_cross_attention_hook(pipe.unet)


""" Change the text in the original image by coordinates and  """

text_mask_image = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
generator = torch.Generator(device="cpu").manual_seed(42)

output = pipe(
    prompt=prompt,
    input_image=input_image,
    text_mask_image=text_mask_image,
    ip_adapter_image=[input_image, input_image],
    width=512,
    height=512,
    guidance_scale=5,
    generator=generator,
).images[0]

output.save(f"experiments/td-15-style-transfer/{prompt}/output_cfg_{guidance_scale}.png", "PNG")
input_image.save(f"experiments/td-15-style-transfer/{prompt}/input.png", "PNG")
text_mask_image = binary_tss_pil.save(f"experiments/td-15-style-transfer/{prompt}/text_mask.png", "PNG")
# img = make_image_grid([input_image, output], rows=1, cols=2)
