In [1]:
import datetime
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn

from datasets.osr_dataloader import MNIST_OSR
from models import gan
from models.models import classifier32, classifier32ABN
from utils import Logger, save_networks, load_networks
from core import train, train_cs, test

In [2]:
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device="gpu:0"
data_root = "../data/mnist"

In [3]:
known = [2, 4, 5, 9, 8, 3]
unknown = list(set(list(range(0, 10))) - set(known))
batch_size=128
img_size= 32
Data = MNIST_OSR(known,data_root)
trainloader,testloader,outloader = Data.train_loader,Data.test_loader,Data.out_loader

Selected Labels:  [2, 4, 5, 9, 8, 3]
All Train Data: 60000
All Test Data: 10000
Train:  35152 Test:  5899 Out:  4101
All Test:  10000


In [4]:
# net = classifier32(num_classes=Data.num_classes)
net = classifier32ABN(num_classes=Data.num_classes)
net=nn.DataParallel(net).cuda()

In [5]:
# GAN
nz,ns=100,1
netG = gan.Generator32(1, nz, 64, 3)
netD = gan.Discriminator32(1, 3, 64)
fixed_noise = torch.FloatTensor(64, nz, 1, 1).normal_(0, 1)

netG = nn.DataParallel(netG).cuda()
netD = nn.DataParallel(netD).cuda()
fixed_noise=fixed_noise.cuda()

In [6]:
import loss.ARPLoss as Loss
criterion = Loss.ARPLoss(num_classes=Data.num_classes)
criterion=criterion.cuda()

criterionD = nn.BCELoss()

In [7]:
eval = False
model_path="E:\osr"
file_name= "parm"

In [8]:
if eval:
    net, criterion = load_networks(net, model_path, file_name, criterion=criterion)
    results = test(net, criterion, testloader, outloader, epoch=0)
    print("Acc (%): {:.3f}\t AUROC (%): {:.3f}\t OSCR (%): {:.3f}\t"
      .format(results['ACC'], results['AUROC'], results['OSCR']))

In [9]:
lr=0.1
gan_lr=0.0002
params_list = [{'params': net.parameters()},
                {'params': criterion.parameters()}]
optimizer = torch.optim.SGD(params_list, lr, momentum=0.9, weight_decay=1e-4)
# GAN
optimizerD = torch.optim.Adam(netD.parameters(), lr=gan_lr, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=gan_lr, betas=(0.5, 0.999))

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30,60,90,120])

In [20]:
start_time = time.time()
max_epoch=20
eval_freq=1
print_freq=100
for epoch in range(max_epoch):
    print("==> Epoch {}/{}".format(epoch+1, max_epoch))

    # if options['cs']:
    train_cs(net, netD, netG, criterion, criterionD,
                optimizer, optimizerD, optimizerG,
                trainloader,nz,ns,print_freq,beta=0.1, epoch=epoch)
    #     train_cs(net, netD, netG, criterion, criterionD,
    #         optimizer, optimizerD, optimizerG,
    #         trainloader, epoch=epoch, **options)

    train(net, criterion, optimizer, trainloader, epoch=epoch,print_freq=print_freq)

    if eval_freq > 0 and (epoch+1) % eval_freq == 0 or (epoch+1) == max_epoch:
        print("==> Test", "ARPLoss")
        results = test(net, criterion, testloader, outloader, epoch=epoch)
        print("Acc (%): {:.3f}\t AUROC (%): {:.3f}\t OSCR (%): {:.3f}\t".format(results['ACC'], results['AUROC'], results['OSCR']))

        save_networks(net, model_path, file_name, criterion=criterion)

    if 30 > 0: scheduler.step()

elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))

==> Epoch 1/20
train with confusing samples
Batch 100/550	 Net 0.209 (0.647) G 5.394 (6.025) D 0.042 (0.169)
Batch 200/550	 Net 0.212 (0.417) G 1.761 (5.487) D 0.756 (0.283)
Batch 300/550	 Net 0.130 (0.319) G 2.691 (4.772) D 0.421 (0.401)
Batch 400/550	 Net 0.042 (0.266) G 1.287 (4.329) D 0.924 (0.458)
Batch 500/550	 Net 0.112 (0.234) G 2.048 (4.004) D 0.479 (0.498)
Batch 100/550	 Loss 0.056042 (0.070950)
Batch 200/550	 Loss 0.009324 (0.066400)
Batch 300/550	 Loss 0.161371 (0.068910)
Batch 400/550	 Loss 0.011236 (0.066773)
Batch 500/550	 Loss 0.021558 (0.063934)
==> Test ARPLoss
Acc: 98.93202
       TNR    AUROC  DTACC  AUIN   AUOUT 
Bas    81.078 96.246 91.265 97.637 92.096
Acc (%): 98.932	 AUROC (%): 96.246	 OSCR (%): 95.608	
==> Epoch 2/20
train with confusing samples
Batch 100/550	 Net 0.066 (0.071) G 3.882 (2.346) D 1.100 (0.761)
Batch 200/550	 Net 0.076 (0.071) G 0.724 (2.302) D 1.447 (0.769)
Batch 300/550	 Net 0.079 (0.070) G 2.158 (2.290) D 0.694 (0.776)
Batch 400/550	 Net 0.02