In [1]:
a={'a':10,'b':20,'c':30}
print(type(a[0]))

KeyError: 0

In [2]:
import torch
import torch.nn as nn

def Get_feature_all(feature_hidden_1, feature_skip_list_1):
    """
    将 bottleneck 特征和 skip connection 特征全部重排(PixelShuffle逆操作)到统一的空间尺寸并拼接。
    
    Args:
        feature_hidden_1: [B, 512, 8, 8]
        feature_skip_list_1: List of tensors, 长度12
                             从 Layer 0 (128,64,64) 到 Layer 11 (512,8,8)
    
    Returns:
        feature: [B, Total_Channels, 64, 64]
    """
    
    # 1. 获取基础维度
    B, C_hidden, H_s, W_s = feature_hidden_1.shape 
    # H_s=8, W_s=8
    
    # 2. 设定目标尺寸 (根据你的要求: H*8, W*8 -> 64x64)
    # 注意：这里的 8 是基于 feature_hidden_1 的尺寸放大的倍数
    H_target = H_s * 8 
    W_target = W_s * 8
    
    # 用于收集所有处理后的特征图
    processed_features = []

    # 3. 处理 Hidden Feature (Bottleneck)
    # [B, 512, 8, 8] -> [B, ?, 64, 64]
    # 通道计算: 512 * (8*8) / (64*64) = 512 / 64 = 8
    C_hidden_new = int(C_hidden * H_s * W_s / (H_target * W_target))
    feature_hidden_reshaped = feature_hidden_1.view(B, C_hidden_new, H_target, W_target)
    
    # 将 Hidden 层先加入列表 (或者你想把它放在最后也可以，这里默认放最前)
    processed_features.append(feature_hidden_reshaped)

    # 4. 循环处理 Skip List (倒序: Layer 11 -> Layer 0)
    # 使用 while 循环配合 pop，直到列表为空
    while feature_skip_list_1:
        # 弹出最后一个元素 (例如 Layer 11)
        current_feat = feature_skip_list_1.pop()
        
        # 获取当前特征的形状
        # e.g., Layer 11: [B, 512, 8, 8]
        # e.g., Layer 9:  [B, 384, 8, 8]
        # e.g., Layer 0:  [B, 128, 64, 64]
        b, c, h, w = current_feat.shape
        
        # 自动计算重排后的新通道数
        # 原理：体积守恒 (C * H * W) = (C_new * H_target * W_target)
        c_new = (c * h * w) / (H_target * W_target)
        
        # 检查是否能整除 (为了安全性)
        if c_new % 1 != 0:
            raise ValueError(f"Feature shape ({c},{h},{w}) cannot be reshaped to ({H_target},{W_target})")
        
        c_new = int(c_new)
        
        # 执行 View 操作 (Depth-to-Space)
        feat_reshaped = current_feat.view(B, c_new, H_target, W_target)
        
        # 加入列表
        processed_features.append(feat_reshaped)

    # 5. 最终拼接
    # 列表里的 tensor 现在都是 [B, C_n, 64, 64]，在通道维 dim=1 拼接
    feature = torch.cat(processed_features, dim=1)
    
    return feature

# --- 测试代码 ---
if __name__ == '__main__':
    # 模拟输入数据
    B = 2
    feature_hidden = torch.randn(B, 512, 8, 8)
    
    # 模拟 skip_list (对应你给出的 shape)
    skip_list = []
    # Layer 0-2: [128, 64, 64]
    for _ in range(3): skip_list.append(torch.randn(B, 128, 64, 64))
    # Layer 3: [128, 32, 32]
    skip_list.append(torch.randn(B, 128, 32, 32))
    # Layer 4-5: [256, 32, 32]
    for _ in range(2): skip_list.append(torch.randn(B, 256, 32, 32))
    # Layer 6: [256, 16, 16]
    skip_list.append(torch.randn(B, 256, 16, 16))
    # Layer 7-8: [384, 16, 16]
    for _ in range(2): skip_list.append(torch.randn(B, 384, 16, 16))
    # Layer 9: [384, 8, 8]
    skip_list.append(torch.randn(B, 384, 8, 8))
    # Layer 10-11: [512, 8, 8]
    for _ in range(2): skip_list.append(torch.randn(B, 512, 8, 8))
    
    # 运行函数
    out = Get_feature_all(feature_hidden, skip_list)
    print(f"Final Output Shape: {out.shape}")
    # 预期输出通道数计算:
    # Hidden(8) + L11(8) + L10(8) + L9(6) + L8(24) + L7(24) + L6(16) + 
    # L5(64) + L4(64) + L3(32) + L2(128) + L1(128) + L0(128)
    # 总和应该是一个整数，例如 646 左右

Final Output Shape: torch.Size([2, 638, 64, 64])


In [3]:
import torch

def Split_feature_all(feature_hidden_fusion):
    """
    将融合后的特征图 [B, 638, 64, 64] 拆解还原回 Hidden 和 Skip List。
    
    Args:
        feature_hidden_fusion: [B, 638, 64, 64]
    
    Returns:
        feature_hidden: [B, 512, 8, 8]
        feature_skip_list: List, 包含 Layer 0 到 Layer 11 的特征图 (正序)
    """
    
    B, Total_C, H, W = feature_hidden_fusion.shape
    # H=64, W=64
    
    # 指针：记录当前切分到了哪个通道，从 0 开始
    current_channel_ptr = 0
    
    # ==========================================
    # 1. 还原 Feature Hidden
    # ==========================================
    # Hidden 原始形状: [512, 8, 8]
    # 压缩后通道数: 512 * (8*8) / (64*64) = 8
    c_hidden_compressed = 8
    
    # 切片
    hidden_slice = feature_hidden_fusion[:, current_channel_ptr : current_channel_ptr + c_hidden_compressed, :, :]
    # 还原形状 [B, 512, 8, 8]
    feature_hidden = hidden_slice.view(B, 512, 8, 8)
    
    # 更新指针
    current_channel_ptr += c_hidden_compressed

    # ==========================================
    # 2. 还原 Skip List
    # ==========================================
    # 注意：在 Get 阶段，我们是 pop() 出来的，所以拼接顺序是 Layer 11 -> Layer 10 -> ... -> Layer 0
    # 所以这里定义的配置表必须也是这个顺序
    
    # 定义每一层的原始形状 (Channel, H, W)
    # 顺序：Layer 11 -> Layer 0
    layers_config_reverse = [
        (512, 8, 8),   # Layer 11
        (512, 8, 8),   # Layer 10
        (384, 8, 8),   # Layer 9  (现在可以完美还原，无特殊操作)
        (384, 16, 16), # Layer 8
        (384, 16, 16), # Layer 7
        (256, 16, 16), # Layer 6
        (256, 32, 32), # Layer 5
        (256, 32, 32), # Layer 4
        (128, 32, 32), # Layer 3
        (128, 64, 64), # Layer 2
        (128, 64, 64), # Layer 1
        (128, 64, 64)  # Layer 0
    ]
    
    temp_skip_list = []
    
    for (org_c, org_h, org_w) in layers_config_reverse:
        # 1. 计算这一层在 64x64 画布上占用了多少通道
        # 公式: c_new = (C * H * W) / (64 * 64)
        c_compressed = int((org_c * org_h * org_w) / (H * W))
        
        # 2. 切片提取
        feat_slice = feature_hidden_fusion[:, current_channel_ptr : current_channel_ptr + c_compressed, :, :]
        
        # 3. 还原形状 (Space-to-Depth)
        feat_restored = feat_slice.view(B, org_c, org_h, org_w)
        
        # 4. 加入临时列表
        temp_skip_list.append(feat_restored)
        
        # 5. 更新指针
        current_channel_ptr += c_compressed

    # ==========================================
    # 3. 整理返回结果
    # ==========================================
    # temp_skip_list 目前是 [Layer 11, Layer 10, ..., Layer 0]
    # 通常网络层级列表都是正序的 [Layer 0, ..., Layer 11]，所以这里反转一下
    feature_skip_list = temp_skip_list[::-1]
    
    return feature_hidden, feature_skip_list

# --- 测试代码 ---
if __name__ == '__main__':
    # 模拟输入 [B, 638, 64, 64]
    B = 2
    feature_fusion = torch.randn(B, 638, 64, 64)
    
    h, skips = Split_feature_all(feature_fusion)
    
    print(f"Hidden shape: {h.shape}") # 应为 [2, 512, 8, 8]
    print(f"Skip list length: {len(skips)}") # 应为 12
    print(f"Layer 0 shape: {skips[0].shape}") # 应为 [2, 128, 64, 64]
    print(f"Layer 9 shape: {skips[9].shape}") # 应为 [2, 384, 8, 8]
    print(f"Layer 11 shape: {skips[11].shape}") # 应为 [2, 512, 8, 8]

Hidden shape: torch.Size([2, 512, 8, 8])
Skip list length: 12
Layer 0 shape: torch.Size([2, 128, 64, 64])
Layer 9 shape: torch.Size([2, 384, 8, 8])
Layer 11 shape: torch.Size([2, 512, 8, 8])


In [1]:
print('hello world')

hello world


In [2]:
import yaml

data = {"batch_size": 8, "lr": 0.0001, "use_ddp": True,
        'name':{
            'first':'zhangsan',
            'second':'lisi'
        }}

with open("config.yaml", "w") as f:
    yaml.dump(data, f)

In [2]:
from models.modules.UNet_arch import UNet
import torch 
model=UNet(in_ch=3,out_ch=3,ch=8,ch_mult=[4,8,8,16],embed_dim=8)
x=torch.rand(1,3,512,512)
pretrained_path='pretrained/AutoEncoder.pth'
state_dict=torch.load(pretrained_path)
model.load_state_dict(state_dict,strict=True)
model.eval()
with torch.no_grad():
    h,hs=model.encode(x)
    print(f'h.shape: {h.shape}')
    print(f'h.range: {h.min().item()} ~ {h.max().item()}')
    for i,hi in enumerate(hs):
        print(f'hs[{i}].shape: {hi.shape}')
        print(f'hs[{i}].range: {hi.min().item()} ~ {hi.max().item()}')
        
    
    out=model(x)
    print(f'out.shape: {out.shape}')
    print(f'out.range: {out.min().item()} ~ {out.max().item()}')
    


h.shape: torch.Size([1, 8, 64, 64])
h.range: -0.13313032686710358 ~ 0.8203734755516052
hs[0].shape: torch.Size([1, 8, 512, 512])
hs[0].range: -0.8120931386947632 ~ 1.6001704931259155
hs[1].shape: torch.Size([1, 8, 512, 512])
hs[1].range: -0.7877449989318848 ~ 1.5025765895843506
hs[2].shape: torch.Size([1, 8, 512, 512])
hs[2].range: -0.8144247531890869 ~ 1.4652830362319946
hs[3].shape: torch.Size([1, 32, 256, 256])
hs[3].range: -0.9149216413497925 ~ 1.4275686740875244
hs[4].shape: torch.Size([1, 32, 256, 256])
hs[4].range: -1.0509828329086304 ~ 1.6208029985427856
hs[5].shape: torch.Size([1, 64, 128, 128])
hs[5].range: -1.751322865486145 ~ 1.5689940452575684
hs[6].shape: torch.Size([1, 64, 128, 128])
hs[6].range: -1.8444328308105469 ~ 1.565863847732544
hs[7].shape: torch.Size([1, 64, 64, 64])
hs[7].range: -2.1619269847869873 ~ 2.4993879795074463
hs[8].shape: torch.Size([1, 64, 64, 64])
hs[8].range: -1.8443394899368286 ~ 1.8846741914749146
out.shape: torch.Size([1, 3, 512, 512])
out.range

In [None]:

# 这里的这样的设计肯定是存在一定的问题的； 一定是存在问题的； 