In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, Dataset
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from datetime import datetime

from scipy.ndimage import median_filter
from sklearn.metrics import roc_auc_score

LEARNING_RATE = 5e-4
BATCH_SIZE = 256
CLEAN_NUM = 16000
ANOMALY_NUM = 1600
DIRTYTYPE = "plastic"

# 定义lambda值列表
LAMBDA_LIST = [0] #[0.01]# [0.01, 0.02, 0.05]

num_epochs = 100
record_epochs = list(range(5, num_epochs + 1, 5))  # 每5个epoch记录一次

# 加载数据
data_path = '16000clean7dirty.pkl'


with open(data_path, 'rb') as f:
    data = pickle.load(f)

normal_images = data['clean'][:CLEAN_NUM] 
anomaly_images = data[DIRTYTYPE][:ANOMALY_NUM]#异常图片

# 归一化
def normalize_image(img):
    x_min = img.min()
    x_max = img.max()
    img_norm = (img - x_min) / (x_max - x_min + 1e-8) 
    return img_norm.astype(np.float32)

normal_images_norm = np.array([normalize_image(img) for img in normal_images])
anomaly_images_norm = np.array([normalize_image(img) for img in anomaly_images])

# 合并数据
images = np.concatenate((normal_images_norm, anomaly_images_norm), axis=0)
labels = np.array([0]*CLEAN_NUM + [1]*ANOMALY_NUM)

# 调整维度顺序，从(H, W, C)变为(C, H, W)
images = images.transpose((0, 3, 1, 2))
print('数据形状：', images.shape)

images_tensor = torch.from_numpy(images)
labels_tensor = torch.from_numpy(labels)

# 创建训练和评估的数据集和数据加载器
train_dataset = TensorDataset(images_tensor, labels_tensor)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# 评估数据集和数据加载器（不打乱顺序，并返回索引）
class IndexedTensorDataset(Dataset):
    def __init__(self, images_tensor, labels_tensor):
        self.images_tensor = images_tensor
        self.labels_tensor = labels_tensor

    def __len__(self):
        return len(self.labels_tensor)

    def __getitem__(self, idx):
        image = self.images_tensor[idx]
        label = self.labels_tensor[idx]
        return image, label, idx  

eval_dataset = IndexedTensorDataset(images_tensor, labels_tensor)
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)

# CAE模型
class CAE(nn.Module):
    def __init__(self):
        super(CAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),

            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.encoder_fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 28 * 28, 4096),
            nn.ReLU(True)
        )
        self.encoder_fc2 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(4096, 256),
            nn.ReLU(True)
        )

        self.decoder_fc1 = nn.Sequential(
            nn.Linear(256, 4096),
            nn.ReLU(True)
        )
        self.decoder_fc2 = nn.Sequential(
            nn.Linear(4096, 64 * 28 * 28),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x_enc = self.encoder(x)
        batch_size = x.size(0)
        x_flat = x_enc.contiguous().view(batch_size, -1)
        latent1 = self.encoder_fc1(x_flat)  
        latent2 = self.encoder_fc2(latent1)  

        x_dec_fc1 = self.decoder_fc1(latent2)  
        x_dec_fc2 = self.decoder_fc2(x_dec_fc1)
        x_dec = x_dec_fc2.view(batch_size, 64, 28, 28)
        x_recon = self.decoder(x_dec)
        return x_recon, latent2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 遍历不同的lambda值
for LAMBDA in LAMBDA_LIST:
    print(f"\n***** 开始训练，lambda = {LAMBDA} *****")
    # 创建结果目录
    results_dir = f"results_mse16000_{DIRTYTYPE}/{LAMBDA}_{BATCH_SIZE}_0.0005"
    os.makedirs(results_dir, exist_ok=True)

    model = CAE().to(device)
    criterion = nn.MSELoss(reduction='none')  
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # 记录指标的字典
    loss_records = []    
    mse_loss_records = {}     
    rc_records = {}      
    smoothed_rc_records = {}  
    auc_records = {}  

    all_indices = np.arange(len(images_tensor))
    reference_indices = np.random.choice(all_indices, size=800, replace=False)
    reference_images = images_tensor[reference_indices].to(device)
    '''
    model.eval()
    with torch.no_grad():
        _, reference_latent = model(reference_images)
    reference_latent = reference_latent.detach()  # shape: (500, 256)
    '''
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        for data in train_loader:
            img, labels = data
            img = img.to(device)
            
            output, latent = model(img)  
            loss_recon = criterion(output, img)  

            mse_loss = loss_recon.mean(dim=[1,2,3])  

            _, reference_latent = model(reference_images)
            distances = torch.cdist(latent, reference_latent)  

            min_indices = torch.argmin(distances, dim=1)  
            nearest_latent = reference_latent[min_indices]  

            l2_loss = (latent - nearest_latent).pow(2).mean(dim=1)  

            total_loss_batch = mse_loss + LAMBDA * l2_loss

            loss_mean = total_loss_batch.mean()

            optimizer.zero_grad()
            loss_mean.backward()
            optimizer.step()

            total_loss += loss_mean.item()
        
        avg_loss = total_loss / len(train_loader)
        loss_records.append(avg_loss)
        print(f"Lambda {LAMBDA}, Epoch [{epoch}/{num_epochs}], Loss: {avg_loss:.6f}")
        
        # 记录MSE、RC、平滑RC和AUC
        if epoch in record_epochs:
            model.eval()
            mse_loss_values = []
            rc_values = []
            indices_list = []  
            labels_list = []  
            latent_values = []
            anomaly_indices = []
            with torch.no_grad():
                for data in eval_loader:
                    img, labels, indices = data
                    img = img.to(device)
                    labels = labels.to(device)
                    indices = indices.to(device)
                    
                    # 前向传播
                    output, latent = model(img)
                    latent_values.append(latent.cpu().numpy())
                    anomaly_indices.extend(indices[labels == 1].cpu().numpy())
                    loss = criterion(output, img)  
                    
                    # 计算MSE
                    mse_loss  = loss.mean(dim=[1,2,3]) 
                    mse_loss_values.extend(mse_loss.cpu().numpy())
                    
                    # 计算RC
                    # 展平
                    img_flat = img.view(img.size(0), -1) 
                    output_flat = output.view(output.size(0), -1)
                    
                    img_mean = img_flat.mean(dim=1, keepdim=True)
                    output_mean = output_flat.mean(dim=1, keepdim=True)
                    
                    img_centered = img_flat - img_mean 
                    output_centered = output_flat - output_mean
                    
                    numerator = torch.sum(img_centered * output_centered, dim=1) 
                    
                    img_norm = torch.norm(img_centered, p=2, dim=1) 
                    output_norm = torch.norm(output_centered, p=2, dim=1)
                    
                    denominator = img_norm * output_norm + 1e-8
                    rc = numerator / denominator 
                    
                    rc_values.extend(rc.cpu().numpy())
                    indices_list.extend(indices.cpu().numpy())
                    labels_list.extend(labels.cpu().numpy())
            
            latent_values = np.concatenate(latent_values, axis=0)
            anomaly_indices = np.array(anomaly_indices)
            
            # np.save(f"{epoch}.npy", latent_values)
            # np.save(f"index{epoch}.npy", anomaly_indices)
            
            # 对RC值进行平滑处理
            rc_values_array = np.array(rc_values)
            smoothed_rc_values = median_filter(rc_values_array, size=1001, mode='reflect')
            
            # 计算AUC值
            auc_value = roc_auc_score(labels_list, smoothed_rc_values)
            print(f"Lambda {LAMBDA}, Epoch [{epoch}] AUC: {auc_value:.6f}")
            
            # 保存记录
            mse_loss_records[epoch] = (mse_loss_values, indices_list)
            rc_records[epoch] = (rc_values, indices_list)
            smoothed_rc_records[epoch] = (smoothed_rc_values, indices_list)
            auc_records[epoch] = auc_value
    
    # 保存所有记录到文件
    np.save(os.path.join(results_dir, 'loss_records.npy'), np.array(loss_records))
    np.save(os.path.join(results_dir, 'mse_loss_records.npy'), mse_loss_records)
    np.save(os.path.join(results_dir, 'rc_records.npy'), rc_records)
    np.save(os.path.join(results_dir, 'smoothed_rc_records.npy'), smoothed_rc_records)
    np.save(os.path.join(results_dir, 'auc_records.npy'), auc_records)
    print(f"Lambda {LAMBDA} 的所有结果已保存到 {results_dir}")

    # 保存AUC值到文件
    auc_file_path = os.path.join(results_dir, 'auc_values.txt')
    with open(auc_file_path, 'w') as f:
        for epoch in sorted(auc_records.keys()):
            f.write(f"Epoch {epoch}: AUC = {auc_records[epoch]:.6f}\n")
    print(f"Lambda {LAMBDA} 的AUC值已保存到 {auc_file_path}")
