In [1]:
import torch
from torch import nn
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np
import os
import os.path as osp

# 假设你的模型定义在 model.py 文件中
from model import BiSeNet

# 如果你有自定义的日志配置
from logger import setup_logger
from tqdm import tqdm

### png 转 jpg
准备训练图像 jpg

In [2]:
from PIL import Image
import os
from tqdm import tqdm

def convert_png_to_jpg_with_progress(source_folder, target_folder):
    """
    Convert all PNG images in a folder to JPG format with a progress bar.

    Parameters:
    - source_folder: Folder containing PNG images.
    - target_folder: Folder where JPG images will be saved.
    """
    # 确保目标文件夹存在
    if not os.path.exists(target_folder):
        os.makedirs(target_folder)
    
    # 获取所有PNG文件
    png_files = [f for f in os.listdir(source_folder) if f.endswith(".png")]
    
    # 使用tqdm显示进度条
    for filename in tqdm(png_files, desc="Converting PNG to JPG"):
        basename, extension = os.path.splitext(filename)
        source_path = os.path.join(source_folder, filename)
        target_path = os.path.join(target_folder, basename + ".jpg")
        
        # 打开并转换图像
        with Image.open(source_path) as img:
            rgb_im = img.convert('RGB')  # 转换为RGB模式以便正确保存为JPG
            rgb_im.save(target_path, "JPEG")
        
        # print(f"Converted {filename} to JPG format.")



In [8]:
# 使用示例
idname = 'justin'
source_folder = f'/root/autodl-tmp/FlashAvatar-code/metrical-tracker/output/{idname}/input'  # PNG图像的源文件夹路径
target_folder = f'/root/autodl-tmp/FlashAvatar-code/dataset/{idname}/imgs'  # JPG图像的目标文件夹路径

convert_png_to_jpg_with_progress(source_folder, target_folder)

Converting PNG to JPG: 100%|██████████| 256/256 [00:02<00:00, 106.66it/s]


### 图像分割
准备mask图像，segmentaion

In [4]:
from logger import setup_logger
from model import BiSeNet
import torch
import os
import os.path as osp
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import cv2
from tqdm import tqdm

def evaluate(dspth='./data', respth_alpha='./alpha', respth_parsing='./parsing', cp='model_final_diss.pth'):
    if not os.path.exists(respth_alpha):
        os.makedirs(respth_alpha)
    if not os.path.exists(respth_parsing):
        os.makedirs(respth_parsing)

    n_classes = 19  # 总共19个类别
    net = BiSeNet(n_classes=n_classes)
    net.cuda()
    save_pth = osp.join('res/cp', cp)
    net.load_state_dict(torch.load(save_pth))
    net.eval()

    to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    with torch.no_grad():
        for image_path in tqdm([f for f in os.listdir(dspth) if f.endswith('.jpg')], desc='Processing'):
            img = Image.open(osp.join(dspth, image_path))
            image = img.resize((512, 512), Image.BILINEAR)
            img = to_tensor(image)
            img = torch.unsqueeze(img, 0)
            img = img.cuda()
            out = net(img)[0]
            parsing = out.squeeze(0).cpu().numpy().argmax(0)

            # 生成人物和背景的mask（类别1-17）
            mask_segment = np.isin(parsing, np.arange(1, 18)).astype(np.uint8)

            # 保存人物和背景的mask到alpha目录
            cv2.imwrite(osp.join(respth_alpha, os.path.splitext(image_path)[0] + '_segment.jpg'), mask_segment * 255)

            # 根据类别生成脖子+头部mask，排除衣服(16)
            mask_neck_head = mask_segment.copy()
            mask_neck_head[parsing == 16] = 0  # 排除衣服
            # 生成嘴部mask
            mask_mouth = np.isin(parsing, [11, 12, 13]).astype(np.uint8)

            # 保存脖子+头部的mask和嘴部的mask到parsing目录
            cv2.imwrite(osp.join(respth_parsing, os.path.splitext(image_path)[0] + '_neckhead.png'), mask_neck_head * 255)
            cv2.imwrite(osp.join(respth_parsing, os.path.splitext(image_path)[0] + '_mouth.png'), mask_mouth * 255)



In [9]:
idname = 'justin'

input_dir = f'/root/autodl-tmp/FlashAvatar-code/dataset/{idname}/imgs'  # 输入目录
output_dir_alpha = f'/root/autodl-tmp/FlashAvatar-code/dataset/{idname}/alpha' # 人物和背景mask的输出目录
output_dir_parsing = f'/root/autodl-tmp/FlashAvatar-code/dataset/{idname}/parsing' # 其他mask的输出目录

# 运行分割
evaluate(dspth=input_dir, respth_alpha=output_dir_alpha, respth_parsing=output_dir_parsing, cp='79999_iter.pth')

Processing: 100%|██████████| 256/256 [00:09<00:00, 27.40it/s]
