In [None]:
import collections  # 导入collections模块，用于统计和操作容器数据，如Counter
import torch  # 导入PyTorch库，用于深度学习任务
import torch.nn as nn  # 从torch中导入神经网络模块，简化模型构建
from torch.utils.data import TensorDataset  # 导入TensorDataset，用于将Tensor数据打包成数据集
import numpy as np  # 导入NumPy库，用于高效的数值计算和数组操作
from sklearn.neighbors import NearestNeighbors  # 导入最近邻算法，用于在SMOTE中寻找相邻样本
import time  # 导入time模块，用于计时
import os  # 导入os模块，用于文件和目录操作
import torch.nn.functional as F
# Attention Block
class AttentionBlock(nn.Module):
    def __init__(self, filters):
        """
        :param filters: 输入特征图的通道数(同时也是卷积输出的通道数)
        """
        super(AttentionBlock, self).__init__()
        # 与 Keras 中的 Conv2D(filters, kernel_size=1, padding='same') 对应
        # PyTorch 中 padding=0 就相当于 'same'（仅当 kernel_size=1 时）
        self.query_conv = nn.Conv2d(filters, filters, kernel_size=1, padding=0)
        self.key_conv   = nn.Conv2d(filters, filters, kernel_size=1, padding=0)
        self.value_conv = nn.Conv2d(filters, filters, kernel_size=1, padding=0)

    def forward(self, x):
        """
        x 的形状一般是 (batch_size, filters, H, W)
        """
        # 1. 分别得到 query, key, value
        query = F.relu(self.query_conv(x))
        key   = F.relu(self.key_conv(x))
        value = F.relu(self.value_conv(x))

        # 2. 计算注意力图: 先元素乘，再对通道维度 (dim=1) 求和
        attention_map = query * key                  # 形状 (N, filters, H, W)
        attention_map = torch.sum(attention_map, dim=1, keepdim=True)  
        # 现在 attention_map 的形状是 (N, 1, H, W)

        # 3. 对空间维度 (H, W) 做 softmax
        # 先展平再 softmax，再 reshape 回去
        N, _, H, W = attention_map.shape
        attention_map = attention_map.view(N, 1, -1)         # (N, 1, H*W)
        attention_map = F.softmax(attention_map, dim=-1)     # 在 H*W 上做 softmax
        attention_map = attention_map.view(N, 1, H, W)       # (N, 1, H, W)

        # 4. 注意力加权 value，并与原输入相加
        attended_value = attention_map * value
        output = x + attended_value

        return output
    
# 编码器
class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']
        
        # 卷积层：输入28x28 → 输出1x1
        self.conv = nn.Sequential(
            # 输入: (1, 28, 28)
            nn.Conv2d(self.n_channel, self.dim_h, 4, 2, 1),  # 输出: (dim_h, 14, 14)
            nn.LeakyReLU(0.2, inplace=True),
            AttentionBlock(filters=self.dim_h),  # 保持尺寸不变
            nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1),  # 输出: (dim_h*2, 7, 7)
            nn.BatchNorm2d(self.dim_h * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1),  # 输出: (dim_h*4, 2, 2)
            nn.BatchNorm2d(self.dim_h * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1),  # 输出: (dim_h*8, 1, 1)
            nn.BatchNorm2d(self.dim_h * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # 全连接层：输入dim_h*8 → 输出n_z
        self.fc = nn.Linear(self.dim_h * 8, self.n_z)

    def forward(self, x):
        x = self.conv(x)          # 输出形状: (batch, dim_h*8, 1, 1)
        x = x.view(x.size(0), -1)  # 展平: (batch, dim_h*8)
        x = self.fc(x)             # 输出形状: (batch, n_z)
        return x
#解码器
class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()
        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']

        # 全连接层：将潜在变量映射到1x1特征图
        self.fc = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 8),  # 输出: (batch, dim_h*8)
            nn.ReLU()
        )

        # 反卷积层：1x1 → 28x28
        self.deconv = nn.Sequential(
            # 输入: (dim_h*8, 1, 1)
            nn.ConvTranspose2d(self.dim_h * 8, self.dim_h * 4, kernel_size=3, stride=2),  # 输出: (dim_h*4, 3, 3)
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 4, self.dim_h * 2, kernel_size=3, stride=2),  # 输出: (dim_h*2, 7, 7)
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 2, self.n_channel, kernel_size=4, stride=4),  # 输出: (1, 28, 28)
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)                          # 输出形状: (batch, dim_h*8)
        x = x.view(-1, self.dim_h * 8, 1, 1)    # 重塑: (batch, dim_h*8, 1, 1)
        x = self.deconv(x)                      # 输出形状: (batch, 1, 28, 28)
        return x