In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from matplotlib import colors as mcolors
import torch
import random
import time
from sklearn.manifold import TSNE
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torchvision import datasets, models, transforms
from sklearn.decomposition import PCA, FastICA
import itertools

sys.path.append('../src')
import tforms
import feature_extraction.feature_extraction_utils as futils
from feature_extraction.Network_Latents_Wrapper import NetworkLatents
import classifier as clf

import novelty_ODD.novelty_detector as novel
import novelty_ODD.novelty_eval as novelval 

np.random.seed(42)



# Choose between 1 ID and 1 OD

In [2]:


def get_dataset(task_subset, dset_name, dir_data, dir_save, subsample_num, train=True):

    if dset_name=='cifar10':
        import cifar10_dataset as dset  
        num_labels = 10
        tform_apply = tforms.cifar_train()
        # tform_apply = tforms.tf_simple()
        task_filepaths = dset.cifar10Experiments([task_subset], dir_data +'cifar10/', dir_save, 'debug', train=train, scenario='nc', shuffle=False)
        dataset = dset.cifar10Task(dir_data +'cifar10/', tasklist=task_filepaths[0], transform=tform_apply, train=train)
        run_extractor=True
        if subsample_num>0:
            dataset.select_random_subset(subsample_num)
        target_ind = -1
        homog_ind = -2

    elif dset_name=='svhn':
        import svhn_dataset as dset  
        num_labels = 10
        tform_apply = tforms.svhn_train()
        # tform_apply = tforms.tf_simple()
        task_filepaths = dset.svhnExperiments([task_subset], dir_data +'svhn/', dir_save, 'debug', train=train, scenario='nc', shuffle=False)
        dataset = dset.svhnTask(dir_data +'svhn/', tasklist=task_filepaths[0], transform=tform_apply, train=train)
        run_extractor=True
        if subsample_num>0:
            dataset.select_random_subset(subsample_num)
        target_ind = -1
        homog_ind = -2

    return dataset, target_ind, homog_ind, num_labels



def get_features(dataset, network_inner, num_labels, target_ind, extractor_name, device=0):

    loader = torch.utils.data.DataLoader(dataset, batch_size=50,
                                            shuffle=True, num_workers=4)
    start = time.time()
    print('feat extraction begin')
    current_features = futils.extract_features(network_inner, loader, \
            target_ind=target_ind, homog_ind=-2, device=device, use_raw_images=False, raw_image_transform=None)

    print('feat extraction done', time.time()-start)
    feat_name = 'base.8'

    X = current_features[0][feat_name]
    Y = current_features[-2]

    return X, Y



# Get ID and OD features

In [3]:
ID_dset = 'cifar10'
OD_dset = 'svhn'
# ID_subset = [0,1]
ID_subset = [i for i in range(10)]
OD_subset = [i for i in range(10)]
# OD_subset=[3]
# OD_subset = [9]
subsample_num_ID = -1
subsample_num_OD=-1
extractor_name = 'resnet50_contrastive'
dir_data = '/home/amandari/CodeDev/ProjIntel/data/'
dir_save = '/home/amandari/CodeDev/ProjIntel/sandbox/random_files/'
device = 0

surrogate_num_labels=1

network = clf.Resnet(surrogate_num_labels, resnet_arch=extractor_name, FC_layers=[],  
            resnet_base=-1, multihead_type='single', base_freeze=True, pretrained_weights=None)
network = network.to(device)
# run extractor 
network_inner = NetworkLatents(network, ['base.8'], pool_factors={'base.8':-1})


print('Get ID/OD datasets')
dset_ID, target_ind_ID, _, num_labels_ID = get_dataset(ID_subset, ID_dset, dir_data, dir_save, subsample_num_ID, train=True)
dset_ID_test, _, _, _ = get_dataset(ID_subset, ID_dset, dir_data, dir_save, subsample_num_ID, train=False)
dset_OD, target_ind_OD, _, num_labels_OD = get_dataset(OD_subset, OD_dset, dir_data, dir_save, subsample_num_OD, train=True)


print('get features ID/OD')
Feats_ID = get_features(dset_ID, network_inner, num_labels_ID, target_ind_ID, extractor_name, device=device)

Feats_OD = get_features(dset_OD, network_inner, num_labels_OD, target_ind_OD, extractor_name, device=device)

load contrastive backbone
Will fetch activations from:
base.8, average pooled by -1
Get ID/OD datasets
get features ID/OD
feat extraction begin
feat extraction done 38.71662473678589
feat extraction begin
feat extraction done 36.22305989265442


In [4]:
ID_subset

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# Fit ID data on novelty detector 

In [5]:
detector_params = {'pca_level':0.995, 'score_type':'pca', 'n_components': None, 'n_percent_comp': 0.2}

detector_params['target_ind']=target_ind_ID
detector_params['device']=device

novelty_detector = novel.NoveltyDetector().create_detector(type='dfm', params=detector_params)
 
novelty_detector.fit_total(Feats_ID[0].T, Feats_ID[1])


n_comp var 0.995
end fit 10.899869680404663


# Get score on OD data

In [6]:
novelty_detector.pca_mats

{0: PCA(n_components=0.995),
 1: PCA(n_components=0.995),
 2: PCA(n_components=0.995),
 3: PCA(n_components=0.995),
 4: PCA(n_components=0.995),
 5: PCA(n_components=0.995),
 6: PCA(n_components=0.995),
 7: PCA(n_components=0.995),
 8: PCA(n_components=0.995),
 9: PCA(n_components=0.995)}

In [7]:
# Wrap data in loader format 
OD_loader = torch.utils.data.DataLoader(dset_OD, batch_size=100, shuffle=True, num_workers=4)
ID_loaders_test = [torch.utils.data.DataLoader(dset_ID_test, batch_size=100, shuffle=True, num_workers=4)]

print('Get scores for novelty detector')
noveltyResults = novelval.save_novelty_results(1, 'debug', dir_save)
# train_loaders_new_only evaluates only current new data (unseen by classifier/main model). 
params_score = {'layer':'base.8', 'feature_extractor':network_inner, 'base_apply_score':True, 'target_ind':target_ind_OD}
# print('new', next(iter(train_loaders_new_only[t])))
novelval.evaluate_dfm(1, novelty_detector, noveltyResults, params_score, OD_loader, ID_loaders_test)

Get scores for novelty detector
accuracy new 0.11327275314590067
accuracy old 0.6386
auroc 0.9763680256550995
New task 1, AUROC = 0.976, AUPR = 0.995, AUPR_NORM = 0.995
new_scores [51.12514114 45.98828888 20.08731461 ... 31.18333817 10.36089897
 12.85968399]
old_scores [4.18988514 3.1307559  7.6462245  ... 4.37493992 1.59661889 2.68742704]
DFM Results -  Auroc 0.976, Aupr 0.995, Aupr_norm 0.995
Average Accuracy per class old 0.6386


In [8]:
noveltyResults.auroc

[0.9763680256550995]