In [1]:
import warnings
from utils.dataset_MINST import *
from utils.OpenGAN_arc import *
from utils.visualization import *
from utils.extr_fea import *



plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
warnings.filterwarnings('ignore', category=RuntimeWarning)

In [2]:
config = {
    "dataset": "MNIST",  # 仅使用MNIST数据集
    "closed_classes": [0, 1, 2, 3, 4],  # 闭集类别 (K=5)
    "open_classes": [5, 6, 7, 8, 9],  # 开集类别
    "feature_dim": 512,  # ResNet18特征维度
    "batch_size": 128,
    "lr_classifier": 0.01,  # 分类器学习率
    "lr_gan": 1e-4,  # GAN学习率
    "classifier_epochs": 10,  # 分类器训练轮数
    "gan_epochs": 50,  # GAN训练轮数
    "lambda_g": 0.2,  # 生成样本权重
    "val_ratio": 0.1,  # 验证集比例
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "save_dir": "./mnist_results"  # 结果保存目录
}
# 创建保存目录
os.makedirs(config["save_dir"], exist_ok=True)

In [3]:
if __name__ == "__main__":
    # 加载数据
    train_loader, val_loader, test_loader = load_MNIST(config)

    # 训练闭集分类器
    classifier = train_classifierI(train_loader, val_loader, config)

    # 训练OpenGAN
    discriminator = train_opengan(classifier, train_loader, val_loader, config)

    # 加载最佳判别器
    best_discriminator = Discriminator().to(config["device"])
    best_discriminator.load_state_dict(
        torch.load(os.path.join(config["save_dir"], "best_discriminator.pth")))

    # 可视化结果
    visualize_results(classifier, best_discriminator, test_loader["closed"], test_loader["open"], config)

    # 最终测试评估
    test_auroc = evaluate_opengan(best_discriminator, classifier, test_loader["closed"], test_loader["open"], config)
    print(f"\n最终测试集AUROC: {test_auroc:.4f}")

加载MNIST数据集...
训练集: 27537闭集样本
验证集: 3059闭集 + 3059开集样本
测试集: 5139闭集 + 4861开集样本

训练闭集分类器...


分类器 Epoch 1: 100%|██████████| 216/216 [00:09<00:00, 23.70it/s]


Epoch 1: 训练损失=0.2293, 验证准确率=0.7457


分类器 Epoch 2: 100%|██████████| 216/216 [00:08<00:00, 24.21it/s]


Epoch 2: 训练损失=0.0378, 验证准确率=0.7797


分类器 Epoch 3: 100%|██████████| 216/216 [00:08<00:00, 24.02it/s]


Epoch 3: 训练损失=0.0258, 验证准确率=0.9984


分类器 Epoch 4: 100%|██████████| 216/216 [00:08<00:00, 24.24it/s]


Epoch 4: 训练损失=0.0158, 验证准确率=0.9882


分类器 Epoch 5: 100%|██████████| 216/216 [00:08<00:00, 24.23it/s]


Epoch 5: 训练损失=0.0172, 验证准确率=0.9964


分类器 Epoch 6: 100%|██████████| 216/216 [00:08<00:00, 24.55it/s]


Epoch 6: 训练损失=0.0131, 验证准确率=0.9944


分类器 Epoch 7: 100%|██████████| 216/216 [00:08<00:00, 24.01it/s]


Epoch 7: 训练损失=0.0123, 验证准确率=0.9908


分类器 Epoch 8: 100%|██████████| 216/216 [00:08<00:00, 24.05it/s]


Epoch 8: 训练损失=0.0319, 验证准确率=0.9827


分类器 Epoch 9: 100%|██████████| 216/216 [00:08<00:00, 24.27it/s]


Epoch 9: 训练损失=0.0103, 验证准确率=0.9958


分类器 Epoch 10: 100%|██████████| 216/216 [00:08<00:00, 24.32it/s]
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families w

Epoch 10: 训练损失=0.0075, 验证准确率=0.9964


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
  plt.savefig(os.path.join(config["save_dir"], "classifier_training.png"))
  plt.savefig(os.path.join(config["save_dir"], "classifier_training.png"))
  


训练OpenGAN...


OpenGAN Epoch 1: 100%|██████████| 216/216 [00:06<00:00, 33.77it/s]


Epoch 1: D_loss=0.2655, G_loss=0.8875, Val AUROC=0.2562


OpenGAN Epoch 2: 100%|██████████| 216/216 [00:06<00:00, 34.29it/s]


Epoch 2: D_loss=-0.0783, G_loss=1.7131, Val AUROC=0.4776


OpenGAN Epoch 3: 100%|██████████| 216/216 [00:06<00:00, 35.00it/s]


Epoch 3: D_loss=-0.3289, G_loss=2.4285, Val AUROC=0.5314


OpenGAN Epoch 4: 100%|██████████| 216/216 [00:06<00:00, 35.20it/s]


Epoch 4: D_loss=-0.5455, G_loss=3.1995, Val AUROC=0.7005


OpenGAN Epoch 5: 100%|██████████| 216/216 [00:06<00:00, 35.09it/s]


Epoch 5: D_loss=-0.7506, G_loss=4.0185, Val AUROC=0.6251


OpenGAN Epoch 6: 100%|██████████| 216/216 [00:06<00:00, 34.10it/s]


Epoch 6: D_loss=-0.9371, G_loss=4.8435, Val AUROC=0.7667


OpenGAN Epoch 7: 100%|██████████| 216/216 [00:06<00:00, 34.13it/s]


Epoch 7: D_loss=-1.1255, G_loss=5.7449, Val AUROC=0.8771


OpenGAN Epoch 8: 100%|██████████| 216/216 [00:06<00:00, 34.38it/s]


Epoch 8: D_loss=-1.3249, G_loss=6.7025, Val AUROC=0.9150


OpenGAN Epoch 9: 100%|██████████| 216/216 [00:06<00:00, 34.27it/s]


Epoch 9: D_loss=-1.5341, G_loss=7.7279, Val AUROC=0.8501


OpenGAN Epoch 10: 100%|██████████| 216/216 [00:06<00:00, 34.23it/s]


Epoch 10: D_loss=-1.7434, G_loss=8.7746, Val AUROC=0.8621


OpenGAN Epoch 11: 100%|██████████| 216/216 [00:06<00:00, 34.12it/s]


Epoch 11: D_loss=-1.9251, G_loss=9.7008, Val AUROC=0.8032


OpenGAN Epoch 12: 100%|██████████| 216/216 [00:06<00:00, 34.44it/s]


Epoch 12: D_loss=-2.0824, G_loss=10.4893, Val AUROC=0.7839


OpenGAN Epoch 13: 100%|██████████| 216/216 [00:06<00:00, 34.18it/s]


Epoch 13: D_loss=-2.2314, G_loss=11.2343, Val AUROC=0.7677


OpenGAN Epoch 14: 100%|██████████| 216/216 [00:06<00:00, 34.25it/s]


Epoch 14: D_loss=-2.3327, G_loss=11.7471, Val AUROC=0.7612


OpenGAN Epoch 15: 100%|██████████| 216/216 [00:06<00:00, 34.24it/s]


Epoch 15: D_loss=-2.2570, G_loss=11.4037, Val AUROC=0.7964


OpenGAN Epoch 16: 100%|██████████| 216/216 [00:06<00:00, 34.25it/s]


Epoch 16: D_loss=-2.2894, G_loss=11.5614, Val AUROC=0.8964


OpenGAN Epoch 17: 100%|██████████| 216/216 [00:06<00:00, 34.21it/s]


Epoch 17: D_loss=-2.3503, G_loss=11.8742, Val AUROC=0.9041


OpenGAN Epoch 18: 100%|██████████| 216/216 [00:06<00:00, 34.40it/s]


Epoch 18: D_loss=-2.4060, G_loss=12.1601, Val AUROC=0.8925


OpenGAN Epoch 19: 100%|██████████| 216/216 [00:06<00:00, 33.97it/s]


Epoch 19: D_loss=-2.4642, G_loss=12.4523, Val AUROC=0.8586


OpenGAN Epoch 20: 100%|██████████| 216/216 [00:06<00:00, 34.90it/s]


Epoch 20: D_loss=-2.5109, G_loss=12.7170, Val AUROC=0.8286


OpenGAN Epoch 21: 100%|██████████| 216/216 [00:06<00:00, 35.14it/s]


Epoch 21: D_loss=-2.5327, G_loss=12.8266, Val AUROC=0.8328


OpenGAN Epoch 22: 100%|██████████| 216/216 [00:06<00:00, 34.98it/s]


Epoch 22: D_loss=-2.5706, G_loss=13.0447, Val AUROC=0.8859


OpenGAN Epoch 23: 100%|██████████| 216/216 [00:06<00:00, 33.82it/s]


Epoch 23: D_loss=-2.6153, G_loss=13.2367, Val AUROC=0.8918


OpenGAN Epoch 24: 100%|██████████| 216/216 [00:06<00:00, 33.24it/s]


Epoch 24: D_loss=-2.6036, G_loss=13.2230, Val AUROC=0.8706


OpenGAN Epoch 25: 100%|██████████| 216/216 [00:06<00:00, 33.76it/s]


Epoch 25: D_loss=-2.7195, G_loss=13.7608, Val AUROC=0.8296


OpenGAN Epoch 26: 100%|██████████| 216/216 [00:06<00:00, 34.89it/s]


Epoch 26: D_loss=-2.8556, G_loss=14.3899, Val AUROC=0.9100


OpenGAN Epoch 27: 100%|██████████| 216/216 [00:06<00:00, 34.92it/s]


Epoch 27: D_loss=-2.9233, G_loss=14.7634, Val AUROC=0.7440


OpenGAN Epoch 28: 100%|██████████| 216/216 [00:06<00:00, 34.65it/s]


Epoch 28: D_loss=-2.9977, G_loss=15.1565, Val AUROC=0.8128


OpenGAN Epoch 29: 100%|██████████| 216/216 [00:06<00:00, 32.99it/s]


Epoch 29: D_loss=-3.0618, G_loss=15.4098, Val AUROC=0.7283


OpenGAN Epoch 30: 100%|██████████| 216/216 [00:06<00:00, 33.02it/s]


Epoch 30: D_loss=-3.1155, G_loss=15.6716, Val AUROC=0.6732


OpenGAN Epoch 31: 100%|██████████| 216/216 [00:06<00:00, 34.92it/s]


Epoch 31: D_loss=-3.1756, G_loss=15.9325, Val AUROC=0.5597


OpenGAN Epoch 32: 100%|██████████| 216/216 [00:06<00:00, 34.86it/s]


Epoch 32: D_loss=-3.2214, G_loss=16.1142, Val AUROC=0.5203


OpenGAN Epoch 33: 100%|██████████| 216/216 [00:06<00:00, 34.84it/s]


Epoch 33: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5605


OpenGAN Epoch 34: 100%|██████████| 216/216 [00:06<00:00, 33.89it/s]


Epoch 34: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5773


OpenGAN Epoch 35: 100%|██████████| 216/216 [00:06<00:00, 33.29it/s]


Epoch 35: D_loss=-3.2230, G_loss=16.1169, Val AUROC=0.3711


OpenGAN Epoch 36: 100%|██████████| 216/216 [00:06<00:00, 33.75it/s]


Epoch 36: D_loss=-3.2232, G_loss=16.1174, Val AUROC=0.5641


OpenGAN Epoch 37: 100%|██████████| 216/216 [00:06<00:00, 34.99it/s]


Epoch 37: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5675


OpenGAN Epoch 38: 100%|██████████| 216/216 [00:06<00:00, 34.98it/s]


Epoch 38: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5708


OpenGAN Epoch 39: 100%|██████████| 216/216 [00:06<00:00, 34.91it/s]


Epoch 39: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5572


OpenGAN Epoch 40: 100%|██████████| 216/216 [00:06<00:00, 33.41it/s]


Epoch 40: D_loss=-3.2230, G_loss=16.1170, Val AUROC=0.5361


OpenGAN Epoch 41: 100%|██████████| 216/216 [00:06<00:00, 33.50it/s]


Epoch 41: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.4940


OpenGAN Epoch 42: 100%|██████████| 216/216 [00:06<00:00, 34.44it/s]


Epoch 42: D_loss=-3.2235, G_loss=16.1179, Val AUROC=0.5143


OpenGAN Epoch 43: 100%|██████████| 216/216 [00:06<00:00, 34.60it/s]


Epoch 43: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5116


OpenGAN Epoch 44: 100%|██████████| 216/216 [00:06<00:00, 34.65it/s]


Epoch 44: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5186


OpenGAN Epoch 45: 100%|██████████| 216/216 [00:06<00:00, 34.63it/s]


Epoch 45: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5217


OpenGAN Epoch 46: 100%|██████████| 216/216 [00:06<00:00, 33.38it/s]


Epoch 46: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5179


OpenGAN Epoch 47: 100%|██████████| 216/216 [00:06<00:00, 33.12it/s]


Epoch 47: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5178


OpenGAN Epoch 48: 100%|██████████| 216/216 [00:06<00:00, 34.90it/s]


Epoch 48: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.5213


OpenGAN Epoch 49: 100%|██████████| 216/216 [00:06<00:00, 34.90it/s]


Epoch 49: D_loss=-3.2236, G_loss=16.1180, Val AUROC=0.5069


OpenGAN Epoch 50: 100%|██████████| 216/216 [00:06<00:00, 34.91it/s]
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following famili

Epoch 50: D_loss=-3.2233, G_loss=16.1166, Val AUROC=0.5000


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun

最佳验证AUROC: 0.9150

可视化结果...
测试集AUROC: 0.9184
可视化特征空间...


findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not foun

可视化决策边界...


  plt.scatter(features_2d[labels == 1, 0], features_2d[labels == 1, 1],
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following families were found: SimHei
findfont: Generic family 'sans-serif' not found because none of the following fa


最终测试集AUROC: 0.9184
