In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

import json
import torchvision
import torchvision.transforms as tf

import argparse

import datasets.utils as dataHelper

from networks import openSetClassifier

from utils import progress_bar

import os
import numpy as np
dataset = 'MNIST'
trial = 0
alpha = 10
lbda = 0.1
name = "myTest"



device = 'cuda' if torch.cuda.is_available() else 'cpu'

#parameters useful when resuming and finetuning
best_acc = 0
best_cac = 10000
best_anchor = 10000
start_epoch = 0

#Create dataloaders for training
print('==> Preparing data..')
with open('datasets/config.json') as config_file:
	cfg = json.load(config_file)[dataset]

trainloader, valloader, _, mapping = dataHelper.get_train_loaders(dataset, trial, cfg)

print('==> Building network..')
net = openSetClassifier.openSetClassifier(cfg['num_known_classes'], cfg['im_channels'], cfg['im_size'], dropout = cfg['dropout'])

# initialising with anchors
anchors = torch.diag(torch.Tensor([alpha for i in range(cfg['num_known_classes'])]))	
net.set_anchors(anchors)

net = net.to(device)

==> Preparing data..
==> Building network..


In [6]:

net.train()
optimizer = optim.SGD(net.parameters(), lr = cfg['openset_training']['learning_rate'][0], 
							momentum = 0.9, weight_decay = cfg['openset_training']['weight_decay'])

def CACLoss(distances, gt):
	'''Returns CAC loss, as well as the Anchor and Tuplet loss components separately for visualisation.'''
	true = torch.gather(distances, 1, gt.view(-1, 1)).view(-1)
	non_gt = torch.Tensor([[i for i in range(cfg['num_known_classes']) if gt[x] != i] for x in range(len(distances))]).long().cuda()
	others = torch.gather(distances, 1, non_gt)
	
	anchor = torch.mean(true)

	tuplet = torch.exp(-others+true.unsqueeze(1))
	tuplet = torch.mean(torch.log(1+torch.sum(tuplet, dim = 1)))

	total = lbda*anchor + tuplet

	return total, anchor, tuplet


# Training
def train(epoch):
	print('\nEpoch: %d' % epoch)
	net.train()
	train_loss = 0
	correctDist = 0
	total = 0

	for batch_idx, (inputs, targets) in enumerate(trainloader):
		inputs, targets = inputs.to(device), targets.to(device)
		#convert from original dataset label to known class label
		targets = torch.Tensor([mapping[x] for x in targets]).long().to(device)

		optimizer.zero_grad()

		outputs = net(inputs)
		cacLoss, anchorLoss, tupletLoss = CACLoss(outputs[1], targets)


		cacLoss.backward()

		optimizer.step()

		train_loss += cacLoss.item()

		_, predicted = outputs[1].min(1)

		total += targets.size(0)
		correctDist += predicted.eq(targets).sum().item()

		progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
			% (train_loss/(batch_idx+1), 100.*correctDist/total, correctDist, total))

def val(epoch):
	global best_acc
	global best_anchor
	global best_cac
	net.eval()
	anchor_loss = 0
	cac_loss = 0
	correct = 0
	total = 0
	with torch.no_grad():
		for batch_idx, (inputs, targets) in enumerate(valloader):
			inputs = inputs.to(device)
			targets = torch.Tensor([mapping[x] for x in targets]).long().to(device)

			outputs = net(inputs)

			cacLoss, anchorLoss, tupletLoss = CACLoss(outputs[1], targets)

			anchor_loss += anchorLoss
			cac_loss += cacLoss

			_, predicted = outputs[1].min(1)
			
			total += targets.size(0)

			correct += predicted.eq(targets).sum().item()

			progress_bar(batch_idx, len(valloader), 'Acc: %.3f%% (%d/%d)'
				% (100.*correct/total, correct, total))
   
	anchor_loss /= len(valloader)
	cac_loss /= len(valloader)
	acc = 100.*correct/total

	# Save checkpoint.
	state = {
		'net': net.state_dict(),
		'acc': acc,
		'epoch': epoch,
	}
	if not os.path.isdir('networks/weights/{}'.format(dataset)):
		os.mkdir('networks/weights/{}'.format(dataset))
	if dataset == 'CIFAR+10':
		if not os.path.isdir('networks/weights/CIFAR+50'):
			os.mkdir('networks/weights/CIFAR+50')
	save_name = '{}_{}_{}CACclassifier'.format(dataset, trial, name)
	if anchor_loss <= best_anchor:
		print('Saving..')
		torch.save(state, 'networks/weights/{}/'.format(dataset)+save_name+'AnchorLoss.pth')
		best_anchor = anchor_loss

		if dataset == 'CIFAR+10':
			save_name = save_name.replace('+10', '+50')
			torch.save(state, 'networks/weights/CIFAR+50/'+save_name+'AnchorLoss.pth')


	if cac_loss <= best_cac:
		print('Saving..')
		torch.save(state, 'networks/weights/{}/'.format(dataset)+save_name+'CACLoss.pth')
		best_cac = cac_loss
		if dataset == 'CIFAR+10':
			save_name = save_name.replace('+10', '+50')
			torch.save(state, 'networks/weights/CIFAR+50/'+save_name+'CACLoss.pth')


	if acc >= best_acc:
		print('Saving..')
		torch.save(state, 'networks/weights/{}/'.format(dataset)+save_name+'Accuracy.pth')
		best_acc = acc

		if dataset == 'CIFAR+10':
			save_name = save_name.replace('+10', '+50')
			torch.save(state, 'networks/weights/CIFAR+50/'+save_name+'Accuracy.pth')
	


In [7]:
max_epoch = cfg['openset_training']['max_epoch'][0]+start_epoch
for epoch in range(start_epoch, max_epoch):
	train(epoch)
	val(epoch)


Epoch: 0
Saving..
Saving..
Saving..

Epoch: 1
Saving..
Saving..
Saving..

Epoch: 2
Saving..
Saving..
Saving..

Epoch: 3

Epoch: 4
Saving..
Saving..

Epoch: 5
Saving..
Saving..

Epoch: 6
Saving..

Epoch: 7
Saving..
Saving..
Saving..

Epoch: 8
Saving..

Epoch: 9
Saving..
Saving..
Saving..

Epoch: 10
Saving..
Saving..

Epoch: 11
Saving..
Saving..
Saving..

Epoch: 12
Saving..
Saving..

Epoch: 13
Saving..
Saving..

Epoch: 14
Saving..

Epoch: 15

Epoch: 16
Saving..
Saving..
Saving..

Epoch: 17

Epoch: 18

Epoch: 19
Saving..
Saving..
Saving..

Epoch: 20

KeyboardInterrupt: 

In [9]:
"""
	Evaluate average performance for our proposed CAC open-set classifier on a given dataset.

	Dimity Miller, 2020
"""


import argparse
import json

import torchvision
import torchvision.transforms as tf
import torch
import torch.nn as nn

from networks import openSetClassifier
import datasets.utils as dataHelper
from utils import find_anchor_means, gather_outputs

import metrics
import scipy.stats as st
import numpy as np

dataset = 'MNIST'
num_trials = 1
start_trial = 0
name = 'myTest'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

all_accuracy = []
all_auroc = []

for trial_num in range(start_trial, start_trial+num_trials):
	print('==> Preparing data for trial {}..'.format(trial_num))
	with open('datasets/config.json') as config_file:
		cfg = json.load(config_file)[dataset]

	#Create dataloaders for evaluation
	knownloader, unknownloader, mapping = dataHelper.get_eval_loaders(dataset, trial_num, cfg)

	print('==> Building open set network for trial {}..'.format(trial_num))
	net = openSetClassifier.openSetClassifier(cfg['num_known_classes'], cfg['im_channels'], cfg['im_size'], dropout = cfg['dropout'])
	checkpoint = torch.load('networks/weights/{}/{}_{}_{}CACclassifierAnchorLoss.pth'.format(dataset, dataset, trial_num, name))

	net = net.to(device)
	net_dict = net.state_dict()
	pretrained_dict = {k: v for k, v in checkpoint['net'].items() if k in net_dict}
	if 'anchors' not in pretrained_dict.keys():
		pretrained_dict['anchors'] = checkpoint['net']['means']
	net.load_state_dict(pretrained_dict)
	net.eval()

	#find mean anchors for each class
	anchor_means = find_anchor_means(net, mapping, dataset, trial_num, cfg, only_correct = True)
	net.set_anchors(torch.Tensor(anchor_means))

	
	print('==> Evaluating open set network accuracy for trial {}..'.format(trial_num))
	x, y = gather_outputs(net, mapping, knownloader, data_idx = 1, calculate_scores = True)
	accuracy = metrics.accuracy(x, y)
	all_accuracy += [accuracy]

	print('==> Evaluating open set network AUROC for trial {}..'.format(trial_num))
	xK, yK = gather_outputs(net, mapping, knownloader, data_idx = 1, calculate_scores = True)
	xU, yU = gather_outputs(net, mapping, unknownloader, data_idx = 1, calculate_scores = True, unknown = True)

	auroc = metrics.auroc(xK, xU)
	all_auroc += [auroc]

mean_auroc = np.mean(all_auroc)
mean_acc = np.mean(all_accuracy)

print('Raw Top-1 Accuracy: {}'.format(all_accuracy))
print('Raw AUROC: {}'.format(all_auroc))
print('Average Top-1 Accuracy: {}'.format(mean_acc))
print('Average AUROC: {}'.format(mean_auroc))

==> Preparing data for trial 0..
==> Building open set network for trial 0..
==> Evaluating open set network accuracy for trial 0..
==> Evaluating open set network AUROC for trial 0..
Raw Top-1 Accuracy: [0.9985074626865672]
Raw AUROC: [0.9926922482465924]
Average Top-1 Accuracy: 0.9985074626865672
Average AUROC: 0.9926922482465924
