In [184]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn import functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from tqdm import tqdm

In [185]:
# 超参数定义
num_epochs = 10
batch_size = 64
learning_rate = 8e-3
thresholds_percent = 0.8
KLD_weight = 0.4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [186]:
# 数据增强
transform = transforms.Compose([
    transforms.RandomRotation(25),
    transforms.ToTensor(),
])
transform2 = transforms.Compose([
    transforms.ToTensor(),
])

# 加载全部数据
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_data2 = datasets.MNIST(root='./data', train=False, download=True, transform=transform2)
# 合并数据集
mnist_data = torch.utils.data.ConcatDataset([mnist_data, mnist_data2])
# 获取数据集的大小
total_size = len(mnist_data)
# 计算训练数据集的大小（80%）
train_size = int(total_size * 0.8)
# 计算测试数据集的大小（20%）
test_size = total_size - train_size
# 分割数据集
train_data, test_data = torch.utils.data.random_split(mnist_data, [train_size, test_size])

# 对于每个数字，创建一个数据加载器
digit_train_loaders = [DataLoader(Subset(train_data, 
                       [idx for idx, (_, target) in enumerate(train_data) if target == i]),
                       batch_size=batch_size, shuffle=True) for i in range(10)]

# 对于每个数字，创建一个测试数据加载器
digit_test_loaders = [DataLoader(Subset(test_data, 
                       [idx for idx, (_, target) in enumerate(test_data) if target == i]),
                       batch_size=batch_size, shuffle=True) for i in range(10)]

In [187]:
class VAE_CNN(nn.Module):
    def __init__(self):
        super(VAE_CNN, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.fc_mu = nn.Linear(64 * 7 * 7, 20) 
        self.fc_logvar = nn.Linear(64 * 7 * 7, 20) 

        # Decoder
        self.fc_decode = nn.Linear(20, 64 * 7 * 7)
        self.deconv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=1)

    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.shape[0], -1)
        return self.fc_mu(x), self.fc_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        z = F.relu(self.fc_decode(z))
        z = z.view(z.shape[0], 64, 7, 7)
        z = F.relu(self.deconv1(z))
        z = F.relu(self.deconv2(z))
        return torch.sigmoid(self.deconv3(z))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [188]:
# 用于存储每个数字对应的模型
vaes = [VAE_CNN().to(device) for _ in range(10)]
optimizers = [optim.Adamax(vae.parameters(), lr=learning_rate) for vae in vaes]


In [189]:
def loss_function(recon_x, x, mu, logvar): # 重构损失 + KL散度
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 1, 28, 28), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (1-KLD_weight)*BCE + KLD_weight*KLD


In [190]:
def train():
    for epoch in range(num_epochs):
        for i in range(10):
            for data in digit_train_loaders[i]:
                img, _ = data
                img = img.to(device)
                recon_batch, mu, logvar = vaes[i](img)
                loss = loss_function(recon_batch, img, mu, logvar)
                print('number:{}:epoch [{}/{}], loss:{:.4f}'.format(i,epoch + 1, num_epochs, loss.item()))
                optimizers[i].zero_grad()
                loss.backward()
                optimizers[i].step()

def compute_reconstruction_error(data, model): # 计算重构误差
    with torch.no_grad():
        recon_data, _, _ = model(data)
        recon_error = ((recon_data - data)**2).sum(dim=(1,2,3))
        return recon_error.cpu().numpy()


In [191]:
def test():
    # 加载FashionMNIST测试数据作为OOD样本
    fashion_mnist_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
    fashion_mnist_loader = DataLoader(fashion_mnist_data, batch_size=batch_size, shuffle=False)

    # 对于每个测试样本，计算其对于每个模型的重构误差
    test_errors = []
    # 计算FashionMNIST样本在每个模型上的重构误差
    fashion_errors = []
    with torch.no_grad():
        for data, _ in fashion_mnist_loader:
            img = data.to(device)
            errors = np.stack([compute_reconstruction_error(img, vae) for vae in vaes], axis=-1)
            fashion_errors.extend(errors)
    fashion_errors = np.array(fashion_errors)

    # 计算阈值，设置为训练集上的thresholds_percent重构误差百分位数
    thresholds = [np.percentile([compute_reconstruction_error(data, vaes[i]).max()
                 for data, _ in digit_train_loaders[i]], thresholds_percent) for i in range(10)]
    # 对FashionMNIST样本进行分类
    fashion_pred = np.argmin(fashion_errors, axis=1)
    # 如果重构误差大于阈值，则判断为OOD
    fashion_pred = [pred if error[pred] <= thresholds[pred] else 10 for pred, error in zip(fashion_pred, fashion_errors)]
    # 计算OOD检测的准确率，精确度，召回率和F1分数
    fashion_true = [10] * len(fashion_pred)  # 10表示OOD
    accuracy = accuracy_score(fashion_true, fashion_pred)
    precision = precision_score(fashion_true, fashion_pred, average='weighted')
    recall = recall_score(fashion_true, fashion_pred, average='weighted')
    f1 = f1_score(fashion_true, fashion_pred, average='weighted')
    print(f"OOD Detection: Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1:.2f}")

    y_true = []
    test_errors = []
    with torch.no_grad():
        for i in range(10):
            for data, targets in digit_test_loaders[i]:
                img = data.to(device)
                errors = np.stack([compute_reconstruction_error(img, vae) for vae in vaes], axis=-1)
                test_errors.extend(errors)
                y_true.extend(targets.numpy())
    test_errors = np.array(test_errors)

    # 对于每个测试样本，将其分类为重构误差最小的类别
    y_pred = np.argmin(test_errors, axis=1)

    # 计算accuracy，precision，recall和F1-score
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')

    print(f"In MNIST:Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1-score: {f1:.2f}")

In [192]:
train()

number:0:epoch [1/10], loss:21616.8809
number:0:epoch [1/10], loss:17094.8398
number:0:epoch [1/10], loss:45504.1523
number:0:epoch [1/10], loss:16052.3398
number:0:epoch [1/10], loss:16168.5703
number:0:epoch [1/10], loss:17261.8477
number:0:epoch [1/10], loss:17868.9355
number:0:epoch [1/10], loss:17990.3438
number:0:epoch [1/10], loss:18083.0215
number:0:epoch [1/10], loss:17643.5664
number:0:epoch [1/10], loss:17596.5020
number:0:epoch [1/10], loss:17082.6191
number:0:epoch [1/10], loss:16619.9590
number:0:epoch [1/10], loss:16057.3506
number:0:epoch [1/10], loss:15616.2285
number:0:epoch [1/10], loss:14888.9131
number:0:epoch [1/10], loss:14842.8906
number:0:epoch [1/10], loss:14300.3477
number:0:epoch [1/10], loss:13802.8613
number:0:epoch [1/10], loss:13725.4512
number:0:epoch [1/10], loss:13196.3369
number:0:epoch [1/10], loss:14290.7588
number:0:epoch [1/10], loss:13371.7227
number:0:epoch [1/10], loss:13375.4648
number:0:epoch [1/10], loss:13874.8691
number:0:epoch [1/10], lo

In [193]:
test()

  _warn_prf(average, modifier, msg_start, len(result))


OOD Detection: Accuracy: 0.95, Precision: 1.00, Recall: 0.95, F1-score: 0.97
In MNIST:Accuracy: 0.90, Precision: 0.91, Recall: 0.90, F1-score: 0.90


In [194]:
model_dict = {f'VAE_digit_{i}': vae.state_dict() for i, vae in enumerate(vaes)}
torch.save(model_dict, 'VAEs.pth')