In [1]:
import warnings
from utils.dataset_MNIST import *
from utils.OpenGAN_arc import *
from utils.visualization import *
from utils.extr_fea import *



plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
warnings.filterwarnings('ignore', category=RuntimeWarning)

In [2]:
config = {
    "dataset": "MNIST",  # 仅使用MNIST数据集
    "closed_classes": [0, 1, 2, 3, 4],  # 闭集类别 (K=5)
    "open_classes": [5, 6, 7, 8, 9],  # 开集类别
    "feature_dim": 512,  # ResNet18特征维度
    "batch_size": 128,
    "lr_classifier": 0.01,  # 分类器学习率
    "lr_gan": 1e-4,  # GAN学习率
    "classifier_epochs": 10,  # 分类器训练轮数
    "gan_epochs": 50,  # GAN训练轮数
    "lambda_g": 0.2,  # 生成样本权重
    "val_ratio": 0.1,  # 验证集比例
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "save_dir": "./mnist_results"  # 结果保存目录
}
# 创建保存目录
os.makedirs(config["save_dir"], exist_ok=True)

In [3]:
if __name__ == "__main__":
    # 加载数据
    train_loader, val_loader, test_loader = load_MNIST(config)

    # 训练闭集分类器
    classifier = train_classifierI(train_loader, val_loader, config)

    # 训练OpenGAN
    discriminator = train_opengan(classifier, train_loader, val_loader, config)

    # 加载最佳判别器
    best_discriminator = Discriminator().to(config["device"])
    best_discriminator.load_state_dict(
        torch.load(os.path.join(config["save_dir"], "best_discriminator.pth")))

    # 可视化结果
    visualize_results(classifier, best_discriminator, test_loader["closed"], test_loader["open"], config)

    # 最终测试评估
    test_auroc = evaluate_opengan(best_discriminator, classifier, test_loader["closed"], test_loader["open"], config)
    print(f"\n最终测试集AUROC: {test_auroc:.4f}")

加载MNIST数据集...
训练集: 27537闭集样本
验证集: 3059闭集 + 3059开集样本
测试集: 5139闭集 + 4861开集样本

训练闭集分类器...


分类器 Epoch 1: 100%|██████████| 216/216 [00:43<00:00,  4.94it/s]


Epoch 1: 训练损失=0.2222, 验证准确率=0.9582


分类器 Epoch 2: 100%|██████████| 216/216 [00:38<00:00,  5.63it/s]


Epoch 2: 训练损失=0.0383, 验证准确率=0.9928


分类器 Epoch 3: 100%|██████████| 216/216 [00:38<00:00,  5.62it/s]


Epoch 3: 训练损失=0.0236, 验证准确率=0.9954


分类器 Epoch 4: 100%|██████████| 216/216 [00:38<00:00,  5.61it/s]


Epoch 4: 训练损失=0.0152, 验证准确率=0.9980


分类器 Epoch 5: 100%|██████████| 216/216 [00:38<00:00,  5.61it/s]


Epoch 5: 训练损失=0.0167, 验证准确率=0.9958


分类器 Epoch 6: 100%|██████████| 216/216 [00:38<00:00,  5.59it/s]


Epoch 6: 训练损失=0.0130, 验证准确率=0.9951


分类器 Epoch 7: 100%|██████████| 216/216 [00:38<00:00,  5.59it/s]


Epoch 7: 训练损失=0.0154, 验证准确率=0.9951


分类器 Epoch 8: 100%|██████████| 216/216 [00:38<00:00,  5.60it/s]


Epoch 8: 训练损失=0.0080, 验证准确率=0.9958


分类器 Epoch 9: 100%|██████████| 216/216 [00:38<00:00,  5.60it/s]


Epoch 9: 训练损失=0.0091, 验证准确率=0.9974


分类器 Epoch 10: 100%|██████████| 216/216 [00:38<00:00,  5.59it/s]


Epoch 10: 训练损失=0.0104, 验证准确率=0.9935

训练OpenGAN...


OpenGAN Epoch 1: 100%|██████████| 216/216 [00:18<00:00, 11.75it/s]


Epoch 1: D_loss=0.2586, G_loss=0.8797, Val AUROC=0.2524


OpenGAN Epoch 2: 100%|██████████| 216/216 [00:18<00:00, 11.76it/s]


Epoch 2: D_loss=-0.0655, G_loss=1.6318, Val AUROC=0.6221


OpenGAN Epoch 3: 100%|██████████| 216/216 [00:18<00:00, 11.68it/s]


Epoch 3: D_loss=-0.3231, G_loss=2.3968, Val AUROC=0.7589


OpenGAN Epoch 4: 100%|██████████| 216/216 [00:18<00:00, 11.75it/s]


Epoch 4: D_loss=-0.5392, G_loss=3.1724, Val AUROC=0.8412


OpenGAN Epoch 5: 100%|██████████| 216/216 [00:18<00:00, 11.71it/s]


Epoch 5: D_loss=-0.7481, G_loss=4.0120, Val AUROC=0.8181


OpenGAN Epoch 6: 100%|██████████| 216/216 [00:18<00:00, 11.69it/s]


Epoch 6: D_loss=-0.9405, G_loss=4.8579, Val AUROC=0.9118


OpenGAN Epoch 7: 100%|██████████| 216/216 [00:18<00:00, 11.72it/s]


Epoch 7: D_loss=-1.1358, G_loss=5.7778, Val AUROC=0.8344


OpenGAN Epoch 8: 100%|██████████| 216/216 [00:18<00:00, 11.72it/s]


Epoch 8: D_loss=-1.3379, G_loss=6.7622, Val AUROC=0.8526


OpenGAN Epoch 9: 100%|██████████| 216/216 [00:18<00:00, 11.78it/s]


Epoch 9: D_loss=-1.5463, G_loss=7.7924, Val AUROC=0.7367


OpenGAN Epoch 10: 100%|██████████| 216/216 [00:18<00:00, 11.78it/s]


Epoch 10: D_loss=-1.7539, G_loss=8.8348, Val AUROC=0.8524


OpenGAN Epoch 11: 100%|██████████| 216/216 [00:18<00:00, 11.76it/s]


Epoch 11: D_loss=-1.9325, G_loss=9.7352, Val AUROC=0.7824


OpenGAN Epoch 12: 100%|██████████| 216/216 [00:18<00:00, 11.74it/s]


Epoch 12: D_loss=-2.0984, G_loss=10.5532, Val AUROC=0.7913


OpenGAN Epoch 13: 100%|██████████| 216/216 [00:18<00:00, 11.76it/s]


Epoch 13: D_loss=-2.2432, G_loss=11.2748, Val AUROC=0.7677


OpenGAN Epoch 14: 100%|██████████| 216/216 [00:18<00:00, 11.70it/s]


Epoch 14: D_loss=-2.3549, G_loss=11.8412, Val AUROC=0.6409


OpenGAN Epoch 15: 100%|██████████| 216/216 [00:18<00:00, 11.70it/s]


Epoch 15: D_loss=-2.3660, G_loss=11.9375, Val AUROC=0.7507


OpenGAN Epoch 16: 100%|██████████| 216/216 [00:18<00:00, 11.61it/s]


Epoch 16: D_loss=-2.4740, G_loss=12.4521, Val AUROC=0.9290


OpenGAN Epoch 17: 100%|██████████| 216/216 [00:18<00:00, 11.69it/s]


Epoch 17: D_loss=-2.5324, G_loss=12.7371, Val AUROC=0.8082


OpenGAN Epoch 18: 100%|██████████| 216/216 [00:18<00:00, 11.70it/s]


Epoch 18: D_loss=-2.5733, G_loss=12.9643, Val AUROC=0.8301


OpenGAN Epoch 19: 100%|██████████| 216/216 [00:18<00:00, 11.67it/s]


Epoch 19: D_loss=-2.5347, G_loss=12.7625, Val AUROC=0.8417


OpenGAN Epoch 20: 100%|██████████| 216/216 [00:18<00:00, 11.76it/s]


Epoch 20: D_loss=-2.5683, G_loss=12.9606, Val AUROC=0.7798


OpenGAN Epoch 21: 100%|██████████| 216/216 [00:18<00:00, 11.74it/s]


Epoch 21: D_loss=-2.6392, G_loss=13.3254, Val AUROC=0.8737


OpenGAN Epoch 22: 100%|██████████| 216/216 [00:18<00:00, 11.72it/s]


Epoch 22: D_loss=-2.7180, G_loss=13.7205, Val AUROC=0.8586


OpenGAN Epoch 23: 100%|██████████| 216/216 [00:18<00:00, 11.72it/s]


Epoch 23: D_loss=-2.7963, G_loss=14.1223, Val AUROC=0.7956


OpenGAN Epoch 24: 100%|██████████| 216/216 [00:18<00:00, 11.73it/s]


Epoch 24: D_loss=-2.8541, G_loss=14.4450, Val AUROC=0.8209


OpenGAN Epoch 25: 100%|██████████| 216/216 [00:18<00:00, 11.72it/s]


Epoch 25: D_loss=-2.9164, G_loss=14.7537, Val AUROC=0.8434


OpenGAN Epoch 26: 100%|██████████| 216/216 [00:18<00:00, 11.72it/s]


Epoch 26: D_loss=-3.0139, G_loss=15.2027, Val AUROC=0.8458


OpenGAN Epoch 27: 100%|██████████| 216/216 [00:18<00:00, 11.68it/s]


Epoch 27: D_loss=-3.0570, G_loss=15.3882, Val AUROC=0.7608


OpenGAN Epoch 28: 100%|██████████| 216/216 [00:18<00:00, 11.73it/s]


Epoch 28: D_loss=-3.1758, G_loss=15.9516, Val AUROC=0.6754


OpenGAN Epoch 29: 100%|██████████| 216/216 [00:18<00:00, 11.68it/s]


Epoch 29: D_loss=-3.2211, G_loss=16.1127, Val AUROC=0.6973


OpenGAN Epoch 30: 100%|██████████| 216/216 [00:18<00:00, 11.69it/s]


Epoch 30: D_loss=-3.2235, G_loss=16.1180, Val AUROC=0.6702


OpenGAN Epoch 31: 100%|██████████| 216/216 [00:18<00:00, 11.71it/s]


Epoch 31: D_loss=-3.2235, G_loss=16.1180, Val AUROC=0.6944


OpenGAN Epoch 32: 100%|██████████| 216/216 [00:18<00:00, 11.69it/s]


Epoch 32: D_loss=-3.2236, G_loss=16.1180, Val AUROC=0.7587


OpenGAN Epoch 33: 100%|██████████| 216/216 [00:18<00:00, 11.47it/s]


Epoch 33: D_loss=-3.2235, G_loss=16.1178, Val AUROC=0.7562


OpenGAN Epoch 34: 100%|██████████| 216/216 [00:18<00:00, 11.55it/s]


Epoch 34: D_loss=-3.2235, G_loss=16.1180, Val AUROC=0.7562


OpenGAN Epoch 35: 100%|██████████| 216/216 [00:18<00:00, 11.56it/s]


Epoch 35: D_loss=-3.2236, G_loss=16.1173, Val AUROC=0.7972


OpenGAN Epoch 36: 100%|██████████| 216/216 [00:18<00:00, 11.46it/s]


Epoch 36: D_loss=-3.2229, G_loss=16.1172, Val AUROC=0.7547


OpenGAN Epoch 37: 100%|██████████| 216/216 [00:18<00:00, 11.50it/s]


Epoch 37: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.7501


OpenGAN Epoch 38: 100%|██████████| 216/216 [00:18<00:00, 11.56it/s]


Epoch 38: D_loss=-3.2235, G_loss=16.1180, Val AUROC=0.7051


OpenGAN Epoch 39: 100%|██████████| 216/216 [00:18<00:00, 11.50it/s]


Epoch 39: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6883


OpenGAN Epoch 40: 100%|██████████| 216/216 [00:18<00:00, 11.63it/s]


Epoch 40: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6780


OpenGAN Epoch 41: 100%|██████████| 216/216 [00:18<00:00, 11.42it/s]


Epoch 41: D_loss=-3.2235, G_loss=16.1181, Val AUROC=0.6779


OpenGAN Epoch 42: 100%|██████████| 216/216 [00:18<00:00, 11.53it/s]


Epoch 42: D_loss=-3.2236, G_loss=16.1180, Val AUROC=0.6566


OpenGAN Epoch 43: 100%|██████████| 216/216 [00:18<00:00, 11.66it/s]


Epoch 43: D_loss=-3.2236, G_loss=16.1180, Val AUROC=0.6409


OpenGAN Epoch 44: 100%|██████████| 216/216 [00:18<00:00, 11.49it/s]


Epoch 44: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6355


OpenGAN Epoch 45: 100%|██████████| 216/216 [00:18<00:00, 11.56it/s]


Epoch 45: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6388


OpenGAN Epoch 46: 100%|██████████| 216/216 [00:18<00:00, 11.39it/s]


Epoch 46: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6347


OpenGAN Epoch 47: 100%|██████████| 216/216 [00:20<00:00, 10.48it/s]


Epoch 47: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6211


OpenGAN Epoch 48: 100%|██████████| 216/216 [00:18<00:00, 11.54it/s]


Epoch 48: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6634


OpenGAN Epoch 49: 100%|██████████| 216/216 [00:18<00:00, 11.39it/s]


Epoch 49: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6353


OpenGAN Epoch 50: 100%|██████████| 216/216 [00:19<00:00, 10.94it/s]


Epoch 50: D_loss=-3.2236, G_loss=16.1181, Val AUROC=0.6378
最佳验证AUROC: 0.9290

可视化结果...
测试集AUROC: 0.9375
可视化特征空间...
可视化决策边界...


  plt.scatter(features_2d[labels == 1, 0], features_2d[labels == 1, 1],



最终测试集AUROC: 0.9375
