In [None]:
import argparse
import os
import random
from collections import deque
from glob import glob

import cv2
import gradio as gr
import numpy as np
import pandas as pd
import torch
import torch as th
import torch.distributed as dist
from einops import rearrange
from inference import create_argparser
from PIL import Image
from pretrained_diffusion import dist_util
from pretrained_diffusion.glide_util import sample
from pretrained_diffusion.image_datasets_sketch import get_tensor
from pretrained_diffusion.script_util import (add_dict_to_argparser,
                                              args_to_dict,
                                              create_model_and_diffusion,
                                              model_and_diffusion_defaults)
from torchvision.utils import make_grid
from tqdm import tqdm

In [None]:
def fix_seed(seed=23333):

    if dist.is_initialized():
        seed = seed + dist.get_rank()

    np.random.seed(seed)
    torch.manual_seed(seed)  # CPU随机种子确定
    torch.cuda.manual_seed(seed)  # GPU随机种子确定
    torch.cuda.manual_seed_all(seed)  # 所有的GPU设置种子

    torch.backends.cudnn.benchmark = False  # 模型卷积层预先优化关闭
    torch.backends.cudnn.deterministic = True  # 确定为默认卷积算法

    random.seed(seed)
    np.random.seed(seed)

    os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return np.array(image)

def mark_paths(image, label):
    rows, cols, _ = image.shape
    visited = np.zeros((rows, cols), dtype=bool)
    
    def is_white(pixel):
        return all(pixel == [255, 255, 255])

    def bfs(x, y):
        queue = deque([(x, y)])
        while queue:
            x, y = queue.popleft()
            # print(f'{x = }, {y = }')
            if visited[x, y] or is_white(image[x, y]):
                continue
            visited[x, y] = True
            image[x, y] = [255, 0, 0]  # Mark as red  
            if x > 0: queue.append((x - 1, y))
            if x < rows - 1: queue.append((x + 1, y))
            if y > 0: queue.append((x, y - 1))
            if y < cols - 1: queue.append((x, y + 1))

    # Iterate over the boundary
    for x in range(rows):
        for px, py in ((x, 0), (x, cols - 1)):
            is_border = all(
                [
                    any([is_white(image[xx, py]) for xx in range(0, px)]), 
                    any([is_white(image[xx, py]) for xx in range(px+1, rows)]),
                ]
            )

            # print('px, py', px, py)
            if not is_border:
                bfs(px, py)
    for y in range(cols):
        for px, py in ((0, y), (rows - 1, y)):
            is_border = all(
                [
                    any([is_white(image[px, yy]) for yy in range(0, py)]), 
                    any([is_white(image[px, yy]) for yy in range(py+1, cols)]),
                ]
            )
            if not is_border:
                bfs(px, py)

    # Mark all non-visited (and non-white) pixels as red
    if label == 'RI':
        mark_color = [0, 0, 255]# Mark as blue
    elif label == 'RO':
        mark_color = [0, 255, 0]# Mark as green
    else:
        raise ValueError(f'No label: {label}')
        
    for i in range(rows):
        for j in range(cols):
            if not visited[i, j]:
                image[i, j] = mark_color

    return image

def process_image(image_path, output_path, label):
    image = load_image(image_path)
    result_image = mark_paths(image, label)
    result = Image.fromarray(result_image)
    result.save(output_path)  # Save the result as a PNG file

In [None]:
sample_c = 1.4
num_samples = 1
sample_step = 1000
mode = 'mask'

parser, parser_up = create_argparser()
args = parser.parse_known_args()[0]
args_up = parser_up.parse_known_args()[0]
dist_util.setup_dist()

if mode == 'sketch':
    args.mode = 'coco-edge'
    args_up.mode = 'coco-edge'
    args.model_path = './ckpt/base_edge.pt'
    args.sr_model_path = './ckpt/upsample_edge.pt'

elif mode == 'mask':
    args.mode = 'coco'
    args_up.mode = 'coco'
    # args.model_path = './logs/coco-mask/coco-64-stage1/checkpoints/ema_0.9999_020000.pt'
    # args.model_path = './logs/coco-mask/coco-64-stage1-cont/checkpoints/ema_0.9999_015000.pt'
    args.model_path = './logs/coco-mask/coco-64-stage2-decoder/checkpoints/ema_0.9999_015000.pt'
    args.sr_model_path = './ckpt/upsample_mask.pt'


# args.val_data_dir = image
args.sample_c = sample_c
args.num_samples = num_samples


options=args_to_dict(args, model_and_diffusion_defaults(0.).keys())
model, diffusion = create_model_and_diffusion(**options)

options_up=args_to_dict(args_up, model_and_diffusion_defaults(True).keys())
model_up, diffusion_up = create_model_and_diffusion(**options_up)


if  args.model_path:
    print('loading model', args.model_path)
    model_ckpt = dist_util.load_state_dict(args.model_path, map_location="cpu")

    model.load_state_dict(
        model_ckpt, strict=True )

if  args.sr_model_path:
    model_ckpt2 = dist_util.load_state_dict(args.sr_model_path, map_location="cpu")

    model_up.load_state_dict(
        model_ckpt2, strict=True ) 


model.to(dist_util.dev())
model_up.to(dist_util.dev())
model.eval()
model_up.eval()


In [None]:
output_name = f'output-{sample_c}SampleStep{sample_step}'
test_data_dir = '../../../影像資料生成競賽/test_dataset'

In [None]:
label_dir = os.path.join(test_data_dir, 'label_img')
output_dir = os.path.join(test_data_dir, output_name)
output_resize_dir = os.path.join(test_data_dir, f'{output_name}-resized')
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_resize_dir, exist_ok=True)

label_img_paths = sorted(glob(os.path.join(label_dir, '*')))
len(label_img_paths)

In [None]:
for label_img_path in tqdm(label_img_paths, total=len(label_img_paths)):
    img_name = os.path.basename(label_img_path)
    img_class = img_name.split('_')[1]
    
    output_path = os.path.join(output_dir, img_name).replace('.png', '.jpg')
    output_resize_path = os.path.join(output_resize_dir, img_name).replace('.png', '.jpg')

    if not os.path.exists(output_resize_path):
        name, ext = os.path.splitext(label_img_path)
        preprocessed_img_path = f'{name}_mask{ext}'

        process_image(label_img_path, preprocessed_img_path, img_class)

        image = Image.open(preprocessed_img_path)

        ########### dataset
        # logger.log("creating data loader...")

        if args.mode == 'coco':
            pil_image = image  
            label_pil = pil_image.convert("RGB").resize((256, 256), Image.NEAREST)
            # print('label_pil', type(label_pil), label_pil)
            label_tensor =  get_tensor()(label_pil)
            # print('label_tensor', label_tensor.shape)

            data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)}
            # print('data_dict.ref.shape', data_dict['ref'].shape)

        elif args.mode == 'coco-edge':
            # pil_image = Image.open(image)
            pil_image = image  
            label_pil = pil_image.convert("L").resize((256, 256), Image.NEAREST)

            im_dist = cv2.distanceTransform(255-np.array(label_pil), cv2.DIST_L1, 3)
            im_dist = np.clip((im_dist) , 0, 255).astype(np.uint8)
            im_dist = Image.fromarray(im_dist).convert("RGB")

            label_tensor =  get_tensor()(im_dist)[:1]

            data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)}

        # print("sampling...")

        sampled_imgs = []
        grid_imgs = []
        img_id = 0
        while (True):
            if img_id >= args.num_samples:
                break

            model_kwargs = data_dict
            with th.no_grad():
                samples_lr =sample(
                    glide_model= model,
                    glide_options= options,
                    side_x= 64,
                    side_y= 64,
                    prompt=model_kwargs,
                    batch_size= args.num_samples,
                    guidance_scale=args.sample_c,
                    device=dist_util.dev(),
                    prediction_respacing= str(sample_step),
                    upsample_enabled= False,
                    upsample_temp=0.997,
                    mode = args.mode,
                )

                samples_lr = samples_lr.clamp(-1, 1)

                tmp = (127.5*(samples_lr + 1.0)).int() 
                model_kwargs['low_res'] = tmp/127.5 - 1.

                samples_hr =sample(
                    glide_model= model_up,
                    glide_options= options_up,
                    side_x=256,
                    side_y=256,
                    prompt=model_kwargs,
                    batch_size=args.num_samples,
                    guidance_scale=1,
                    device=dist_util.dev(),
                    prediction_respacing= "fast27",
                    upsample_enabled=True,
                    upsample_temp=0.997,
                    mode = args.mode,
                )


                samples_hr = samples_hr 


                for hr in samples_hr:

                    hr = 255. * rearrange((hr.cpu().numpy()+1.0)*0.5, 'c h w -> h w c')
                    sample_img = Image.fromarray(hr.astype(np.uint8))
                    sampled_imgs.append(sample_img)
                    img_id += 1   

                grid_imgs.append(samples_hr)

        grid = torch.stack(grid_imgs, 0)
        grid = rearrange(grid, 'n b c h w -> (n b) c h w')
        grid = make_grid(grid, nrow=2)
        # to image
        grid = 255. * rearrange((grid+1.0)*0.5, 'c h w -> h w c').cpu().numpy()

        output_img = Image.fromarray(grid.astype(np.uint8)) 
        output_img.save(output_path)
        output_img = output_img.resize(image.size, Image.NEAREST)
        output_img.save(output_resize_path)