In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import glob
import os
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class ConvolutionalVAE1D(nn.Module):
    """
    VAE1D主要函数
    
    """
    
    def __init__(self, input_length=256, latent_dim=32, base_channels=16):
        super().__init__()
        self.latent_dim = latent_dim
        self.base_channels = base_channels
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv1d(1, base_channels, 4, 2, 1), nn.ReLU(),
            nn.Conv1d(base_channels, base_channels*2, 4, 2, 1), nn.ReLU(),
            nn.Conv1d(base_channels*2, base_channels*4, 4, 2, 1), nn.ReLU(),
        )

        # 256 → 128 → 64 → 32
        self.reduced_len = input_length // 8
        flat_dim = base_channels * 4 * self.reduced_len

        # 
        self.fc_mu = nn.Linear(flat_dim, latent_dim)
        self.fc_logvar = nn.Linear(flat_dim, latent_dim)
        self.fc_z = nn.Linear(latent_dim, flat_dim)

        # 解码器
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(base_channels*4, base_channels*2, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose1d(base_channels*2, base_channels, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose1d(base_channels, 1, 4, 2, 1) # 删掉激活函数
        )

    # 重参数化
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

#-----------------新加的-----------------------
    def select_latent_dims(self, z, selected_dims):
        """
        选择特定的潜变量维度
        
        参数:
        - z: 完整的潜变量 [batch_size, latent_dim]
        - selected_dims: 要选择的维度索引列表
        """
        mask = torch.zeros_like(z)
        mask[:, selected_dims] = 1
        return z * mask
    
    def decode_from_selected_dims(self, z_selected):
        """
        从选择的潜变量维度进行解码
        """
        batch_size = z_selected.size(0)
        x = self.fc_z(z_selected)
        reduced_len = x.numel() // (batch_size * self.base_channels * 4)
        x = x.view(batch_size, self.base_channels * 4, reduced_len)
        return self.decoder(x)

    def forward(self, x, selected_dims=None):
        x = self.encoder(x)
        batch_size, _, _ = x.shape
        x = x.view(x.size(0), -1)
        
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        z = self.reparameterize(mu, logvar)
        
        # 使用选择的维度
        if selected_dims is not None:
            z = self.select_latent_dims(z, selected_dims)
        
        x_recon = self.decode_from_selected_dims(z)
        return x_recon, mu, logvar, z

    
    
def analyze_latent_importance(model, dataloader, device, num_batches=10):
    """
    分析每个潜变量维度的重要性
    
    参数:
    - model: 训练好的VAE模型
    - dataloader: 数据加载器
    - device: 设备 (CPU/GPU)
    - num_batches: 用于分析的批次数量
    """
    model.eval()
    importance_scores = torch.zeros(model.latent_dim).to(device)
    batch_count = 0
    
    print("开始分析潜变量维度重要性...")
    
    with torch.no_grad():
        for batch_idx, (x_batch, _, _) in enumerate(dataloader):
            if batch_idx >= num_batches:  # 只使用前几个批次以节省时间
                break
                
            x_batch = x_batch.to(device)  # 正常前向传播一次（得到完整重建，不用于评分）
            recon_full, _, _, z_full = model(x_batch)
            
            # 计算每个维度单独的重建质量
            for dim in tqdm(range(model.latent_dim), desc=f"分析批次 {batch_idx+1}/{num_batches}"):
                # 只使用当前维度
                recon_dim = model(x_batch, selected_dims=[dim])[0]
                # 计算重建损失作为重要性指标（损失越小，维度越重要）
                loss = F.mse_loss(recon_dim, x_batch, reduction='mean')
                importance_scores[dim] += loss.item()
            
            batch_count += 1
    
    # 平均损失
    importance_scores /= batch_count
    
    # 按重要性排序（损失越小越重要）
    sorted_indices = torch.argsort(importance_scores)
    
    print("\n潜变量维度重要性分析完成!")
    print(f"最重要维度: {sorted_indices[:5].cpu().numpy()}")
    print(f"最不重要维度: {sorted_indices[-5:].cpu().numpy()}")
    
    return sorted_indices.cpu().numpy(), importance_scores.cpu().numpy()

def select_and_visualize_dims(importance_scores, num_selected=20):
    """
    选择最重要的维度并可视化
    
    参数:
    - importance_scores: 每个维度的重要性分数
    - num_selected: 要选择的维度数量
    """
    # 按重要性排序（损失越小越重要）
    sorted_indices = np.argsort(importance_scores)
    
    # 选择最重要的维度
    selected_dims = sorted_indices[:num_selected].tolist()
    
    print(f"选择了最重要的 {num_selected} 个维度:")
    print(f"维度索引: {selected_dims}")
    
    # 可视化重要性
    plt.figure(figsize=(12, 6))
    
    # 绘制所有维度的重要性
    plt.subplot(1, 2, 1)
    plt.bar(range(len(importance_scores)), importance_scores)
    plt.xlabel('latent_dim')
    plt.ylabel('reconstruct loss')
    plt.title('Importance of latent_dim')
    
    # 标记选择的维度
    plt.scatter(selected_dims, importance_scores[selected_dims], 
               color='red', s=50, zorder=5, label='selected dim')
    plt.legend()
    
    # 绘制排序后的重要性
    plt.subplot(1, 2, 2)
    sorted_scores = np.sort(importance_scores)
    plt.plot(range(len(sorted_scores)), sorted_scores, 'b-', linewidth=2)
    plt.xlabel('order of dim')
    plt.ylabel('reconstruct loss')
    plt.title('order of importance of latent_dim')
    
    # 标记选择的维度数量
    plt.axvline(x=num_selected, color='red', linestyle='--', 
                label=f'select{num_selected}dim')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    return selected_dims

def compare_reconstructions(model, dataloader, device, selected_dims, num_samples=3):
    """
    比较完整重建和部分重建的效果
    
    参数:
    - model: 训练好的VAE模型
    - dataloader: 数据加载器
    - device: 设备
    - selected_dims: 选择的维度索引
    - num_samples: 要显示的样本数量
    """
    model.eval()
    
    with torch.no_grad():
        for batch_idx, (x_batch, names, scales) in enumerate(dataloader):
            if batch_idx >= 1:  # 只使用第一个批次
                break
                
            x_batch = x_batch.to(device)
            scales = scales.to(device)
            
            # 完整重建（使用全部维度）
            recon_full, _, _, _ = model(x_batch)
            
            # 使用选择的维度重建
            recon_selected, _, _, _ = model(x_batch, selected_dims=selected_dims)
            
            # 反归一化
            x_original = x_batch * scales
            recon_full_denorm = recon_full * scales
            recon_selected_denorm = recon_selected * scales
            
            # 计算损失
            loss_full = F.mse_loss(recon_full, x_batch).item()
            loss_selected = F.mse_loss(recon_selected, x_batch).item()
            
            print(f"完整重建损失: {loss_full:.4f}")
            print(f"使用{len(selected_dims)}个维度的重建损失: {loss_selected:.4f}")
            print(f"重建质量保持: {(1 - loss_selected/loss_full)*100:.2f}%")
            
            # 在同一张图中比较原始谱线和重建谱线
            for i in range(min(num_samples, x_batch.size(0))):
                plt.figure(figsize=(12, 6))
                
                # 获取数据
                original_data = x_original[i].squeeze().cpu().numpy()
                full_recon_data = recon_full_denorm[i].squeeze().cpu().numpy()
                selected_recon_data = recon_selected_denorm[i].squeeze().cpu().numpy()
                
                # 绘制在同一张图上
                plt.plot(original_data, 'k-', linewidth=2, label='original data')
                plt.plot(full_recon_data, 'b-', linewidth=1.5, label='full reconstruct (32)', alpha=0.8)
                plt.plot(selected_recon_data, 'r--', linewidth=1.5, label=f'select {len(selected_dims)} dim', alpha=0.8)
                
                plt.title(f'compare_reconstructions - {names[i]}')
                plt.xlabel('fre')
                plt.ylabel('flux')
                plt.legend()
                plt.grid(True, alpha=0.3)
                
                # 添加文本说明
                textstr = '\n'.join((
                    f'full reconstructions loss: {loss_full:.4f}',
                    f'select{len(selected_dims)}dim to reconstruct: {loss_selected:.4f}',
                    f'Maintaining the quality of reconstruction: {(1 - loss_selected/loss_full)*100:.2f}%'
                ))
                props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
                plt.text(0.02, 0.98, textstr, transform=plt.gca().transAxes, fontsize=10,
                        verticalalignment='top', bbox=props)
                
                plt.tight_layout()
                plt.show()
                
            break
