In [1]:
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import gradio as gr
import functools
import socket


# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")


# 1. 模型定义（严格按照原始pix2pix架构）
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, 
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
            
        if input_nc is None:
            input_nc = outer_nc
            
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)


class UnetGenerator(nn.Module):
    """严格按照原始pix2pix的U-Net结构实现"""
    def __init__(self, input_nc=3, output_nc=3, num_downs=8, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()
        
        # 构建U-Net结构
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, 
                                            submodule=None, norm_layer=norm_layer, innermost=True)
        
        # 中间层
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, 
                                                submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        
        # 逐渐减少通道数
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        
        # 最外层
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, 
                                           submodule=unet_block, outermost=True, norm_layer=norm_layer)
        
        self.model = unet_block

    def forward(self, input):
        return self.model(input)


# 2. 辅助函数
def find_available_port(start_port=7868, end_port=7968):
    for port in range(start_port, end_port + 1):
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('0.0.0.0', port))
                return port
        except OSError:
            continue
    raise OSError(f"无法在范围 {start_port}-{end_port} 内找到可用端口")


def tensor2im(input_image, imtype=np.uint8):
    """增强版：处理各种格式的图像转换"""
    try:
        if isinstance(input_image, torch.Tensor):
            # 处理PyTorch张量
            image_tensor = input_image.data.cpu().float().numpy()
            
            # 调试信息
            print(f"张量形状: {image_tensor.shape}")
            
            # 移除batch维度
            if image_tensor.ndim == 4:
                image_tensor = image_tensor[0]
                
            # 转换通道维度：(C, H, W) -> (H, W, C)
            if image_tensor.shape[0] in [1, 3]:
                image_tensor = np.transpose(image_tensor, (1, 2, 0))
                
            # 处理单通道图像：复制为三通道
            if image_tensor.shape[-1] == 1:
                image_tensor = np.repeat(image_tensor, 3, axis=-1)
                
            # 归一化到[0, 255]
            image_tensor = (image_tensor + 1) / 2.0 * 255.0
            image_tensor = np.clip(image_tensor, 0, 255)
            
            return image_tensor.astype(imtype)
        
        # 处理numpy数组
        if isinstance(input_image, np.ndarray):
            # 调试信息
            print(f"numpy数组形状: {input_image.shape}")
            
            # 处理单通道图像
            if input_image.ndim == 2:
                return np.repeat(input_image[:, :, np.newaxis], 3, axis=2).astype(imtype)
            if input_image.ndim == 3 and input_image.shape[-1] == 1:
                return np.repeat(input_image, 3, axis=2).astype(imtype)
            return input_image.astype(imtype)
            
        # 处理PIL图像
        img = input_image.convert('RGB')
        return np.array(img).astype(imtype)
        
    except Exception as e:
        print(f"张量转换错误: {e}")
        # 打印详细的错误堆栈
        import traceback
        traceback.print_exc()
        return np.zeros((256, 256, 3), dtype=imtype)


def preprocess_image(image, size=(256, 256)):
    """确保所有输入图像都是RGB格式，形状正确"""
    if image is None:
        return None
        
    # 转换为PIL图像
    if isinstance(image, np.ndarray):
        print(f"预处理前numpy数组形状: {image.shape}")
        # 处理单通道numpy数组
        if image.ndim == 2 or (image.ndim == 3 and image.shape[-1] == 1):
            image = Image.fromarray(image.squeeze(), mode='L').convert('RGB')
        else:
            image = Image.fromarray(image)
    
    # 确保是RGB格式
    if image.mode != 'RGB':
        print(f"转换图像模式从 {image.mode} 到 RGB")
        image = image.convert('RGB')
    
    # 调整大小并居中裁剪
    transform = transforms.Compose([
        transforms.Resize(size, Image.BICUBIC),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # 应用转换并返回
    tensor = transform(image)
    print(f"预处理后张量形状: {tensor.shape}")
    return tensor.unsqueeze(0)


# 3. 模型加载（改进权重匹配和调试信息）
def load_model(model_path):
    """严格按照原始pix2pix的U-Net结构加载模型"""
    try:
        # 创建模型（严格匹配原始pix2pix架构）
        model = UnetGenerator(
            input_nc=3,
            output_nc=3,
            num_downs=8,
            ngf=64,  # 关键：与训练时ngf=64匹配
            norm_layer=functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True),
            use_dropout=False
        )
        
        # 打印模型结构用于调试
        print("模型结构:")
        print(model)
        
        # 加载权重
        print(f"从 {model_path} 加载模型权重...")
        checkpoint = torch.load(model_path, map_location=device, weights_only=True)
        
        # 尝试不同的权重键名
        possible_keys = ['generator_state_dict', 'netG', 'state_dict', 'model']
        gen_weights = None
        
        for key in possible_keys:
            if key in checkpoint:
                gen_weights = checkpoint[key]
                print(f"从检查点中找到权重键: {key}")
                break
        
        # 如果没找到特定键，尝试直接使用检查点
        if gen_weights is None:
            gen_weights = checkpoint
            print("使用整个检查点作为权重")
        
        # 清理权重键名（移除可能的前缀）
        cleaned_weights = {}
        prefixes = ['module.', 'netG.', 'model.']
        
        for k, v in gen_weights.items():
            # 尝试移除常见前缀
            cleaned_k = k
            for prefix in prefixes:
                if cleaned_k.startswith(prefix):
                    cleaned_k = cleaned_k[len(prefix):]
                    break
            cleaned_weights[cleaned_k] = v
        
        # 非严格加载（仅匹配可兼容的层）
        model_dict = model.state_dict()
        matched_weights = {}
        unmatched_weights = {}
        
        for k, v in cleaned_weights.items():
            if k in model_dict and v.shape == model_dict[k].shape:
                matched_weights[k] = v
            else:
                unmatched_weights[k] = v
        
        print(f"匹配的权重: {len(matched_weights)}/{len(model_dict)}")
        print(f"未匹配的权重: {len(unmatched_weights)}")
        
        # 打印未匹配的权重用于调试
        if len(unmatched_weights) > 0:
            print("未匹配的权重:")
            for k in unmatched_weights:
                if k in model_dict:
                    print(f"  {k}: 形状不匹配 - 权重形状 {unmatched_weights[k].shape}, 模型期望 {model_dict[k].shape}")
                else:
                    print(f"  {k}: 模型中不存在此键")
        
        # 更新模型权重
        model_dict.update(matched_weights)
        model.load_state_dict(model_dict, strict=False)
        
        model = model.to(device)
        model.eval()
        print(f"成功加载模型（匹配 {len(matched_weights)}/{len(model_dict)} 层）")
        return model
    except Exception as e:
        print(f"加载模型失败: {e}")
        # 打印详细的错误堆栈
        import traceback
        traceback.print_exc()
        return None


# 4. 生成与界面逻辑
def generate_image(input_image, model_path, dataset=None, dataset_index=None):
    try:
        print("\n===== 开始生成图像 =====")
        
        # 加载模型
        model = load_model(model_path)
        if model is None:
            return None, None, None, "模型加载失败（结构与权重不兼容）"
        
        # 处理输入
        if input_image is not None:
            # 打印输入图像信息用于调试
            if isinstance(input_image, np.ndarray):
                print(f"输入图像类型: numpy数组, 形状: {input_image.shape}, 数据类型: {input_image.dtype}")
            elif isinstance(input_image, Image.Image):
                print(f"输入图像类型: PIL图像, 模式: {input_image.mode}, 大小: {input_image.size}")
            
            input_tensor = preprocess_image(input_image)
            print(f"输入张量形状: {input_tensor.shape}")
        else:
            if dataset is None or dataset_index is None:
                return None, None, None, "请上传图像或选择示例"
            data = dataset[int(dataset_index)]
            input_tensor = data['A'].unsqueeze(0)
            print(f"从数据集中获取的输入张量形状: {input_tensor.shape}")
        
        # 生成图像
        with torch.no_grad():
            print("开始模型推理...")
            output_tensor = model(input_tensor.to(device))
            print(f"输出张量形状: {output_tensor.shape}")
        
        # 转换为图像
        input_img = Image.fromarray(tensor2im(input_tensor))
        gen_img = Image.fromarray(tensor2im(output_tensor))
        
        # 获取真实图像
        real_img = Image.new('RGB', (256, 256), color='gray')
        if dataset and dataset_index is not None:
            data = dataset[int(dataset_index)]
            real_img = Image.fromarray(tensor2im(data['B']))
        
        print("===== 图像生成成功 =====")
        return input_img, gen_img, real_img, f"生成成功（匹配 {len(model.state_dict())//2}/{len(model.state_dict())} 层）"
    
    except Exception as e:
        print(f"生成图像时出错: {e}")
        # 打印详细的错误堆栈
        import traceback
        traceback.print_exc()
        return None, None, None, f"错误: {str(e)}"


def create_interface(model_path, dataset=None):
    with gr.Blocks(title="草图转真实图像") as demo:
        gr.Markdown("# 🎨 草图转建筑图像")
        gr.Markdown("左：输入草图 | 中：生成图像 | 右：真实图像")
        
        with gr.Row():
            with gr.Column(scale=1):
                input_img = gr.Image(label="输入草图", type="pil")
                if dataset and len(dataset) > 0:
                    idx_slider = gr.Slider(0, min(50, len(dataset)-1), 0, step=1, label="示例索引")
                else:
                    idx_slider = None
                gen_btn = gr.Button("生成图像", variant="primary")
                status = gr.Textbox(label="状态", value="就绪")
            
            with gr.Column(scale=3):
                with gr.Row():
                    input_out = gr.Image(label="输入")
                    gen_out = gr.Image(label="生成")
                    real_out = gr.Image(label="真实")
        
        # 绑定事件
        if idx_slider is not None:
            gen_btn.click(
                fn=lambda img, idx: generate_image(img, model_path, dataset, idx),
                inputs=[input_img, idx_slider],
                outputs=[input_out, gen_out, real_out, status]
            )
        else:
            gen_btn.click(
                fn=lambda img: generate_image(img, model_path),
                inputs=[input_img],
                outputs=[input_out, gen_out, real_out, status]
            )
    
    return demo


# 5. 主函数
if __name__ == "__main__":
    class AlignedDataset:
        def __init__(self, root, phase='val'):
            self.root = os.path.join(root, phase)
            self.images = [f for f in os.listdir(self.root) if f.endswith(('.png', '.jpg'))]
        
        def __getitem__(self, idx):
            img = Image.open(os.path.join(self.root, self.images[idx])).convert('RGB')
            w, h = img.size
            w2 = w // 2
            return {
                'A': transforms.ToTensor()(img.crop((w2, 0, w, h))),  # 右侧草图
                'B': transforms.ToTensor()(img.crop((0, 0, w2, h)))   # 左侧真实
            }
        
        def __len__(self):
            return len(self.images)
    
    model_path = "checkpoints/pix2pix_epoch_100.pth"  # 替换为你的模型路径
    dataset = AlignedDataset(root="datasets/facades", phase="val")
    
    demo = create_interface(model_path, dataset)
    port = find_available_port()
    demo.launch(server_name="0.0.0.0", server_port=port)

使用设备: cuda


FileNotFoundError: [WinError 3] 系统找不到指定的路径。: 'datasets/facades\\val'