In [67]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn import functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from tqdm import tqdm

In [68]:
# 超参数定义
num_epochs = 30
batch_size = 64
learning_rate = 8e-3
thresholds_percent = 0.8
KLD_weight = 0.6
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [69]:
# 数据增强
transform = transforms.Compose([
    transforms.RandomRotation(25),
    transforms.ToTensor(),
])
transform2 = transforms.Compose([
    transforms.ToTensor(),
])

# 加载全部数据
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_data2 = datasets.MNIST(root='./data', train=False, download=True, transform=transform2)
# 合并数据集
mnist_data = torch.utils.data.ConcatDataset([mnist_data, mnist_data2])
# 获取数据集的大小
total_size = len(mnist_data)
# 计算训练数据集的大小（80%）
train_size = int(total_size * 0.8)
# 计算测试数据集的大小（20%）
test_size = total_size - train_size
# 分割数据集
train_data, test_data = torch.utils.data.random_split(mnist_data, [train_size, test_size])

# 对于每个数字，创建一个数据加载器
digit_train_loaders = [DataLoader(Subset(train_data, 
                       [idx for idx, (_, target) in enumerate(train_data) if target == i]),
                       batch_size=batch_size, shuffle=True) for i in range(10)]

# 对于每个数字，创建一个测试数据加载器
digit_test_loaders = [DataLoader(Subset(test_data, 
                       [idx for idx, (_, target) in enumerate(test_data) if target == i]),
                       batch_size=batch_size, shuffle=True) for i in range(10)]

In [70]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 500)  # 增大第一层的输出大小
        self.fc21 = nn.Linear(500, 30)  # 增大第二层的输出大小
        self.fc22 = nn.Linear(500, 30)  # 增大第二层的输出大小
        self.fc3 = nn.Linear(30, 500)   # 增大第三层的输入大小
        self.fc4 = nn.Linear(500, 784)  # 增大第四层的输入大小

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1) 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std) 
        return mu + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [71]:
# 用于存储每个数字对应的模型
vaes = [VAE().to(device) for _ in range(10)]
optimizers = [optim.Adamax(vae.parameters(), lr=learning_rate) for vae in vaes]


In [72]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (1-KLD_weight)*BCE + KLD_weight*KLD
# 训练模型
def train():
    for epoch in range(num_epochs):
        for i in range(10):
            for data in digit_train_loaders[i]:
                img, _ = data
                img = img.to(device)
                img = img.view(img.size(0), -1)
                recon_batch, mu, logvar = vaes[i](img)
                loss = loss_function(recon_batch, img, mu, logvar)
                print('number:{}:epoch [{}/{}], loss:{:.4f}'.format(i,epoch + 1, num_epochs, loss.item()))
                optimizers[i].zero_grad()
                loss.backward()
                optimizers[i].step()

# 计算重构误差
def compute_reconstruction_error(data, model):
    with torch.no_grad():
        recon_data, _, _ = model(data)
        recon_error = ((recon_data - data.view(-1, 784))**2).sum(dim=1)
        return recon_error.cpu().numpy()


In [73]:
def test():
    # 加载FashionMNIST测试数据作为OOD样本
    fashion_mnist_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
    fashion_mnist_loader = DataLoader(fashion_mnist_data, batch_size=batch_size, shuffle=False)

    # 对于每个测试样本，计算其对于每个模型的重构误差
    test_errors = []
    # 计算FashionMNIST样本在每个模型上的重构误差
    fashion_errors = []
    with torch.no_grad():
        for data, _ in fashion_mnist_loader:
            img = data.to(device)
            img = img.view(img.size(0), -1)
            errors = np.stack([compute_reconstruction_error(img, vae) for vae in vaes], axis=-1)
            fashion_errors.extend(errors)
    fashion_errors = np.array(fashion_errors)
    
    # 计算阈值，设置为训练集上的thresholds_percent重构误差百分位数
    thresholds = [np.percentile([compute_reconstruction_error(data, vaes[i]).max() 
                 for data, _ in digit_train_loaders[i]], thresholds_percent) for i in range(10)]
    # 对FashionMNIST样本进行分类
    fashion_pred = np.argmin(fashion_errors, axis=1)
    # 如果重构误差大于阈值，则判断为OOD
    fashion_pred = [pred if error[pred] <= thresholds[pred] else 10 for pred, error in zip(fashion_pred, fashion_errors)]
    # 计算OOD检测的准确率，精确度，召回率和F1分数
    fashion_true = [10] * len(fashion_pred) # 10表示OOD
    accuracy = accuracy_score(fashion_true, fashion_pred)
    precision = precision_score(fashion_true, fashion_pred, average='weighted')
    recall = recall_score(fashion_true, fashion_pred, average='weighted')
    f1 = f1_score(fashion_true, fashion_pred, average='weighted')
    print(f"OOD Detection: Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1:.2f}")
    
    y_true = []
    test_errors = []
    with torch.no_grad():
        for i in range(10):
            for data, targets in digit_test_loaders[i]:
                img = data.to(device)
                img = img.view(img.size(0), -1)
                errors = np.stack([compute_reconstruction_error(img, vae) for vae in vaes], axis=-1)
                test_errors.extend(errors)
                y_true.extend(targets.numpy())
    test_errors = np.array(test_errors)

    # 对于每个测试样本，将其分类为重构误差最小的类别
    y_pred = np.argmin(test_errors, axis=1)
    
    # 计算accuracy，precision，recall和F1-score
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')

    print(f"In MNIST:Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1:.2f}")


In [74]:
train()

number:0:epoch [1/30], loss:14029.3740
number:0:epoch [1/30], loss:10665.4492
number:0:epoch [1/30], loss:7428.0913
number:0:epoch [1/30], loss:6539.5601
number:0:epoch [1/30], loss:5979.0430
number:0:epoch [1/30], loss:5992.2544
number:0:epoch [1/30], loss:5974.5449
number:0:epoch [1/30], loss:5732.2817
number:0:epoch [1/30], loss:5729.4268
number:0:epoch [1/30], loss:5792.9897
number:0:epoch [1/30], loss:5537.8232
number:0:epoch [1/30], loss:5396.4307
number:0:epoch [1/30], loss:5596.8931
number:0:epoch [1/30], loss:5554.7681
number:0:epoch [1/30], loss:5477.0215
number:0:epoch [1/30], loss:5749.0054
number:0:epoch [1/30], loss:5723.1636
number:0:epoch [1/30], loss:5485.5854
number:0:epoch [1/30], loss:5606.6401
number:0:epoch [1/30], loss:5550.0078
number:0:epoch [1/30], loss:5366.5488
number:0:epoch [1/30], loss:5547.0986
number:0:epoch [1/30], loss:5378.3643
number:0:epoch [1/30], loss:5625.9302
number:0:epoch [1/30], loss:5378.3657
number:0:epoch [1/30], loss:5274.2681
number:0:e

In [75]:
test()

  _warn_prf(average, modifier, msg_start, len(result))


OOD Detection: Accuracy: 0.88, Precision: 1.00, Recall: 0.88, F1-score: 0.93
In MNIST:Accuracy: 0.95, Precision: 0.95, Recall: 0.95, F1-score: 0.95


In [76]:
model_dict = {f'VAE_digit_{i}': vae.state_dict() for i, vae in enumerate(vaes)} 
torch.save(model_dict, 'VAEs.pth')