In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
import copy
from model.Component_Attention_Module import ComponentAttentionModule

from model.VectorQuantizer import  VectorQuantizer
class Generator(nn.Module):
    """
    Generator
    """

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(Generator,self).__init__()

        # codebook使用VQ-VAE进行编码
        num_embeddings = 8192 # 嵌入向量数量，过多容易过拟合，过少容易欠拟合
        embedding_dim = 512*8*8 # 512*1*1
        commitment_cost = 0.25
        self.vq = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)

        # MHA多头注意力机制，输入input和vq生成的Query
        self.MHA = ComponentAttentionModule()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True,vq=self.vq,HMA=self.MHA)  # add the innermost layer
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,nodown=True)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,nodown=True)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,nodown=True)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf, ngf, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,nodown=True)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

    def forward(self, input_s):
        """Standard forward"""
        return self.model(input_s)

class UnetSkipConnectionBlock(nn.Module):
    """Defines the Unet submodule with skip connection.
        X -------------------identity----------------------
        |-- downsampling -- |submodule| -- upsampling --|
    """

    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,HMA = None,vq = None,nodown = None):
        """Construct a Unet submodule with skip connections.

        Parameters:
            outer_nc (int) -- the number of filters in the outer conv layer
            inner_nc (int) -- the number of filters in the inner conv layer
            input_nc (int) -- the number of channels in input images/features
            submodule (UnetSkipConnectionBlock) -- previously defined submodules
            outermost (bool)    -- if this module is the outermost module
            innermost (bool)    -- if this module is the innermost module
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
        """
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.innermost = innermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]

            self.down_s = nn.Sequential(*[copy.deepcopy(layer) for layer in down])
            self.up = nn.Sequential(*up)
            self.submodule  = submodule
        elif innermost:
             # 原有的下采样层
            down = [downrelu, downconv]
            self.down_s = nn.Sequential(*[copy.deepcopy(layer) for layer in down])       
            # 定义上采样层
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            up = [uprelu, upconv, upnorm]
            self.up = nn.Sequential(*up)
            #生成codebook中最相近的vq
            self.vq = vq
            # MHA多头注意力机制，输入input和vq生成的Query
            self.MHA = HMA

        elif nodown:
            nodownconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3,
                     stride=1, padding=1, bias=use_bias)
            noupconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3,
                     stride=1, padding=1, bias=use_bias)
            
            down = [downrelu, nodownconv, downnorm]
            up = [uprelu, noupconv, upnorm]
            self.down_s = nn.Sequential(*[copy.deepcopy(layer) for layer in down])
            self.up = nn.Sequential(*up)
            self.submodule  = submodule

        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            self.down_s = nn.Sequential(*[copy.deepcopy(layer) for layer in down])
            self.up = nn.Sequential(*up)
            self.submodule  = submodule


    def forward(self, style):
        if self.outermost:
            down_s = self.down_s(style)
            submoduled_s,vq_loss_s= self.submodule(down_s)
            up_s = self.up(submoduled_s)
            return up_s,vq_loss_s
        else:   # add skip connections
            if self.innermost:
                # 在最内层先进行下采样
                down_s = self.down_s(style)            
                # 然后并行地执行MHA和VQ操作               
                query_s ,vq_loss_s  = self.vq(down_s)
                MHA_s = self.MHA(down_s,query_s)
                # 执行上采样
                up_s = self.up(MHA_s)
                return torch.cat([style, up_s], 1),vq_loss_s
            else:
                down_s = self.down_s(style)
                submoduled_s,vq_loss_s = self.submodule(down_s)
                up_s = self.up(submoduled_s)
                # 对于非最内层，添加skip连接和子模块            
                return torch.cat([style, up_s], 1) ,vq_loss_s

In [2]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class CustomDataset(Dataset):
    def __init__(self, root, transform=None):
        self.image_paths = root
        self.imgs = self.read_file(self.image_paths)
        self.transform = transform

    def read_file(self, path):
        """从文件夹中读取数据"""
        files_list = os.listdir(path)
        file_path_list = [os.path.join(path, img) for img in files_list]
        file_path_list.sort()  # 如果你需要特定的顺序则保留这一行
        return file_path_list

    def __len__(self):
        return len(self.imgs)  # 返回图片列表的长度

    def __getitem__(self, index):
        image_path = self.imgs[index]  # 使用图片列表中的路径
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image


In [3]:
# 定义一个转换操作，比如转换成张量并且归一化
transform = transforms.Compose([transforms.Resize((256, 256)),
                                          transforms.ToTensor()])

train_data_root = "./data/CUHK/trainB"

# 创建数据集
dataset = CustomDataset(root=train_data_root, transform=transform)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
dataiter = iter(dataloader)
images = next(dataiter)

# 打印图像和标签的形状
print('Images shape:', images.shape)

Images shape: torch.Size([32, 3, 256, 256])


In [4]:
from torch import optim
model = Generator(input_nc=3, output_nc=3,num_downs=8)


learning_rate = 2e-4

# 提取 model.vq 的参数
vq_params = set(model.vq.parameters())
# 提取除 model.vq 之外的所有参数
other_params = filter(lambda p: p not in vq_params, model.parameters())

optimizer_vq = optim.Adam(vq_params, lr=learning_rate, amsgrad=False)
optimizer = optim.Adam(other_params, lr=learning_rate, amsgrad=False)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model,device_ids=[0, 1])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)


DataParallel(
  (module): Generator(
    (vq): VectorQuantizer(
      (_embedding): Embedding(8192, 32768)
    )
    (MHA): ComponentAttentionModule(
      (layers): ModuleList(
        (0-7): 8 x MultiHeadAttention(
          (linears_key): Linear(in_features=512, out_features=512, bias=False)
          (linears_value): Linear(in_features=512, out_features=512, bias=False)
          (linears_query): Linear(in_features=512, out_features=512, bias=False)
          (multihead_concat_fc): Linear(in_features=512, out_features=512, bias=False)
          (layer_norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (model): UnetSkipConnectionBlock(
      (down_s): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      )
      (up): Sequential(
        (0): ReLU(inplace=True)
        (1): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (2): Tanh()
      )
      (submodule)

In [5]:
# 在训练循环中使用数据加载器
for i in range(50000):
    data = next(iter(dataloader)).cuda()
    # 在这里进行你的训练...
    # print(data.shape)
    output_image,loss_vq = model(data)

    if loss_vq.dim() > 0:  # 检查loss_vq_G是否为标量
            loss_vq = loss_vq.mean()
    optimizer_vq.zero_grad()
    loss_vq.backward(retain_graph=True)
    optimizer_vq.step()

    loss_rec = F.mse_loss(data,output_image)
    if loss_rec.dim() > 0:  # 检查loss_rec_G是否为标量
            loss_rec = loss_rec.mean()
    optimizer.zero_grad()
    loss_rec.backward()
    optimizer.step()
    
    if(i%100==0):
        print(f" VQ_loss : {loss_vq}; rec_loss : {loss_rec}")
    if(i%1000==0):
        # 假设 'model' 是你的神经网络模型实例，'i' 是当前的训练轮数
        torch.save(model.state_dict(), f'./model_weight/2/model_weights_epoch_{i}.pth')



 VQ_loss : 0.19587081670761108; rec_loss : 1.1529916524887085
 VQ_loss : 5.637734413146973; rec_loss : 0.015834370627999306
 VQ_loss : 5.242614269256592; rec_loss : 0.009924331679940224
 VQ_loss : 13.98647689819336; rec_loss : 0.0019414968555793166
 VQ_loss : 24.64476776123047; rec_loss : 0.0014034243067726493
 VQ_loss : 22.556791305541992; rec_loss : 0.0009998355526477098
 VQ_loss : 22.458839416503906; rec_loss : 0.0007405895739793777
 VQ_loss : 22.465723037719727; rec_loss : 0.0006783397402614355
 VQ_loss : 22.56629180908203; rec_loss : 0.00060084875440225
 VQ_loss : 21.150936126708984; rec_loss : 0.0007073736051097512
 VQ_loss : 22.227703094482422; rec_loss : 0.0004895029123872519
 VQ_loss : 20.388389587402344; rec_loss : 0.0004844893701374531
 VQ_loss : 20.69454574584961; rec_loss : 0.00042007374577224255
 VQ_loss : 20.317289352416992; rec_loss : 0.00038547543226741254
 VQ_loss : 20.094646453857422; rec_loss : 0.00038067481364123523
 VQ_loss : 20.127744674682617; rec_loss : 0.00033

: 

In [None]:
layer_outputs = {}  # 用于存储每层的输出

def hook_fn(module, input, output):
    layer_name = str(module)
    layer_outputs[layer_name] = output

for name, layer in model.named_children():
    layer.register_forward_hook(hook_fn)

# 假设 input_image 是您的输入图像
model(input_image)

import matplotlib.pyplot as plt
import numpy as np

def save_layer_output(output, layer_name, save_dir):
    # 将输出转换为可视化的形式
    # 注意：这里的转换方式可能需要根据您的具体情况调整
    output = output.detach().cpu().numpy()
    if output.ndim == 4:  # 对于卷积层的输出
        # 取第一个样本的第一个特征映射
        output_img = output[0, 0]
    else:  # 对于全连接层等的输出
        # 将输出转换为一个方形图像
        side_length = int(np.ceil(np.sqrt(output.size)))
        output_img = np.reshape(output, (side_length, side_length))

    # 保存输出图像
    plt.imsave(f'{save_dir}/{layer_name}.png', output_img, cmap='gray')

# 遍历并保存每层的输出
for layer_name, output in layer_outputs.items():
    save_layer_output(output, layer_name, 'output_images')

