In [1]:
import glob
import matplotlib.pyplot as plt
import numpy as np
import pickle

from COCOWrapper import COCOWrapper


In [2]:

root='/home/gregory/Datasets/COCO'
year='2017'

main_classes = ['snowboard', 'couch', 'tie', 'handbag', 'skis', 'remote', 'toothbrush', 'sports ball', 'knife', 'cell phone', 'fork', 'wine glass', 'skateboard', 'spoon', 'backpack', 'bench', 'frisbee']
spurious_class = 'person'

comparisons =  ['random-tune', 'initial-transfer']
baseline = 'initial-tune'

datafiles = []
datafiles.append('{}/train{}-with-{}-info.p'.format(root, year, spurious_class))
datafiles.append('{}/train{}-without-{}-info.p'.format(root, year, spurious_class))
datafiles.append('{}/val{}-{}-info.p'.format(root, year, spurious_class))

coco = COCOWrapper(root = root, mode = 'val', year = year)


loading annotations into memory...
Done (t=0.31s)
creating index...
index created!


In [3]:
def collect(model, datafile):

    p = []
    r = []

    for main_class in main_classes:        
        index = coco.get_cat_ids(main_class)[0]
        
        p_tmp = []
        r_tmp = []
        
        for file in glob.glob('./Models/{}/*.p'.format(model)):
            with open(file, 'rb') as f:
                data = pickle.load(f)
                
            data = data[datafile]
            p_tmp.append(data[0][index])
            r_tmp.append(data[1][index])
            
        p.append(np.mean(p_tmp))
        r.append(np.mean(r_tmp))

    return np.array(p), np.array(r)

        
def compare(comparisons, baseline, datafile, title):  
    
    plt.axhline(0, color='black', linestyle = '--')
    plt.axvline(0, color='black', linestyle = '--')
    
    plt.ylabel('Change in Recall')
    plt.xlabel('Change in Precision')
    
    plt.title('Comparison to {}: {}'.format(baseline, title))
        
    p_base, r_base = collect(baseline, datafile)
    
    for comparison in comparisons:
        p_comp, r_comp = collect(comparison, datafile)

        p_diff = p_comp - p_base
        r_diff = r_comp - r_base

        plt.scatter(p_diff, r_diff, label = comparison)

    plt.legend()
    plt.show()
    plt.close()

    


In [4]:
compare(comparisons, baseline, datafiles[0], 'Images with {}'.format(spurious_class)) 
compare(comparisons, baseline, datafiles[1], 'Images without {}'.format(spurious_class))    
compare(comparisons, baseline, datafiles[2], 'Counterfactual Images')    


TypeError: compare() takes 2 positional arguments but 4 were given

In [None]:
compare(['random-transfer'], 'initial-transfer', datafiles[0], 'Images with {}'.format(spurious_class)) 
compare(['random-transfer'], 'initial-transfer', datafiles[1], 'Images without {}'.format(spurious_class))    
compare(['random-transfer'], 'initial-transfer', datafiles[2], 'Counterfactual Images')    
