In [1]:
from torch.utils.data import DataLoader
from utils.dataset_TinyImage import load_data
from utils.OpenGAN_arc import Discriminator
from utils import train_openganII
from utils import visualize_results
from utils import evaluate_opengan
from utils.extr_fea import *
import warnings

warnings.filterwarnings('ignore')

In [2]:
config = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "batch_size": 128,
    "num_closed_classes": 200,  # TinyImageNet类别数
    "image_size": 64,  # 统一图像尺寸
    "feature_dim": 512,  # ResNet18特征维度

    # 分类器训练参数
    "classifier_lr": 0.001,
    "classifier_epochs": 1,
    "classifier_momentum": 0.9,
    "classifier_weight_decay": 5e-4,

    # OpenGAN参数
    "opengan_lr": 1e-4,
    "opengan_epochs": 1,
    "lambda_g": 0.5,
    "lambda_gp": 10.0,  # 梯度惩罚系数
    "d_train_ratio": 5,  # 判别器训练次数/生成器训练次数

    # 数据集路径
    "tinyimagenet_path": "./data/tiny-imagenet-200",
    "output_dir": "./results_setup2",
    "save_dir": "./results_setup2",

    # 开集数据集
    "open_datasets": ["CIFAR10", "SVHN", "MNIST"],
}

os.makedirs(config["output_dir"], exist_ok=True)

In [3]:

if __name__ == '__main__':
    # 加载TinyImageNet数据集
    data_dict = load_data(config)
    closed_train_loader = DataLoader(
        data_dict["closed"]["train"],
        batch_size=config["batch_size"],
        shuffle=True
    )
    closed_val_loader = DataLoader(
        data_dict["closed"]["val"],
        batch_size=config["batch_size"],
        shuffle=False
    )
    # 训练闭集分类器
    classifier = train_classifierII(closed_train_loader, closed_val_loader,config)
    # 加载开集数据集
    open_train_data = data_dict["open"]["CIFAR10"]["train"]
    open_train_loader = DataLoader(
        open_train_data,
        batch_size=config["batch_size"],
        shuffle=True
    )
    val_data = {
    "closed": closed_val_loader,
    "open": DataLoader(
        data_dict["open"]["CIFAR10"]["test"],  # 或其它你想用的 open 数据集
        batch_size=config["batch_size"],
        shuffle=False
    )
}
    # 训练OpenGAN
    discriminator = train_openganII(
        classifier, open_train_loader, val_data, config
    )

    best_discriminator = Discriminator().to(config["device"])
    best_discriminator.load_state_dict(
        torch.load(os.path.join(config["output_dir"], "best_discriminator.pth"))
    )
    visualize_results(
        classifier, best_discriminator, data_dict["closed"]["test"], data_dict["open"]["CIFAR10"]["test"], config
    )
    # 最终测试评估
    test_auroc = evaluate_opengan(
        best_discriminator, classifier, data_dict["closed"]["test"], data_dict["open"]["CIFAR10"]["test"], config
    )
    print(f"\n最终测试集AUROC: {test_auroc:.4f}")

加载TinyImageNet数据集...
Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat

训练闭集分类器...


分类器 Epoch 1/1:  16%|█▌        | 99/625 [02:12<11:42,  1.34s/it] 


KeyboardInterrupt: 