In [51]:
import sys
sys.path.append('/dccstor/hoo-misha-1/wilds/wilds/examples')
import os

import numpy as np

from ipywidgets import interact, interactive, fixed, interact_manual
import matplotlib.pyplot as plt

import pickle

In [52]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

In [53]:
def load_flm():
    test_features = np.load(f'{path_base}/resnet50_test_features.npy')
    test_labels = np.load(f'{path_base}/resnet50_test_labels.npy')
    test_metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    return test_features, test_labels, test_metadata

In [54]:
def get_cam_ind(metadata, num_cams=1, cam_id = None):
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    if cam_id is None:
        top_id = unique_counts[0][np.argpartition(unique_counts[1], -num_cams)[-num_cams:]]
    else:
        top_id = cam_id
    print(f'Selecting cameras with ids {top_id}')
    ind = np.zeros(metadata.shape[0]) == 1
    for c_id in top_id:
        ind = np.logical_or(ind,metadata[:,0] == c_id)
    return ind

In [55]:
def cam_flm(num_cams=1, cam_id = None):
    features, labels, metadata = load_flm()
    cam_ind = get_cam_ind(metadata, num_cams, cam_id)
    return features[cam_ind], labels[cam_ind], metadata[cam_ind]

In [56]:
def get_dict_path(root_path):
    cam_dict_path = f'{root_path}_cam_dict.pkl'
    orig_dict_path = f'{root_path}_orig_dict.pkl'
    
    with open(cam_dict_path,'rb') as file:
        cam_dict = pickle.load(file)

    with open(orig_dict_path,'rb') as file:
        orig_dict = pickle.load(file)
    
    return cam_dict, orig_dict

## Uniform Sampling


In [79]:
cam_dict, orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/wilds/results/iwildcam/PseudoLabel')
ba_cam_dict, ba_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/wilds/ba')

In [80]:
path_base = '/dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/ERM'
cam_ids = list(cam_dict.keys())

In [81]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = cam_dict[cam_ids[cam_ind]]
    print(f'Original {orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((orig_dict[cam_ids[cam_ind]] , predictions))
    plt.title('Accuracy vs Shots')
    plt.ylabel("Accuracy")
    plt.xlabel("Number of Shots")
    plt.plot(range(0,len(predictions)), predictions, label='Balanced')
    plt.legend()
interact(plot, cam_ind=(0,len(good_inds)-1));

interactive(children=(IntSlider(value=10, description='cam_ind', max=21), Output()), _dom_classes=('widget-int…

In [60]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = ba_cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = ba_cam_dict[cam_ids[cam_ind]]
    print(f'Original {ba_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]] , predictions))
    plt.title('Balanced Accuracy vs Shots')
    plt.xlabel("Balanced Accuracy")
    plt.ylabel("Number of Shots")
    plt.plot(range(0,len(predictions)), predictions, label='Balanced')
    plt.legend()
interact(plot, cam_ind=(0,len(good_inds)-1));

interactive(children=(IntSlider(value=10, description='cam_ind', max=21), Output()), _dom_classes=('widget-int…

## KMeans vs Uniform

In [61]:
cam_dict, orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/wilds/results/iwildcam/PseudoLabel')
kmeans_cam_dict, kmeans_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/data/kmeans_closest_batch_classes')

In [62]:
path_base = '/dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/PseudoLabel'
cam_ids = list(cam_dict.keys() & kmeans_cam_dict.keys())

In [63]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = cam_dict[cam_ids[cam_ind]]
    kmeans_predictions = kmeans_cam_dict[cam_ids[cam_ind]]
    print(f'Original {orig_dict[cam_ids[cam_ind]]} KMeans Original {kmeans_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)} KMeans Max {max(kmeans_predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((orig_dict[cam_ids[cam_ind]] , predictions))
    kmeans_predictions = np.hstack((kmeans_orig_dict[cam_ids[cam_ind]], kmeans_predictions))
    plt.plot(range(0,len(predictions)), predictions, label='Uniform')
    plt.plot(range(0,len(kmeans_predictions)), kmeans_predictions, label='Kmeans')
    plt.legend()
    f,l,m = cam_flm(cam_id=[cam_ids[cam_ind]])
    label_unique_counts = np.unique(l, return_counts=True)
    print(f'Total of {sum(label_unique_counts[1] > 25)} classes over cutoff')
    print('[',end='')
    for y,c in zip(label_unique_counts[0], label_unique_counts[1]):
        print_green(f'{y}:{c}:{c/sum(label_unique_counts[1]):.2f}, ', c > 25, end='')
    print(']')
interact(plot, cam_ind=(0,len(good_inds)-1));

interactive(children=(IntSlider(value=10, description='cam_ind', max=21), Output()), _dom_classes=('widget-int…

## Balanced Accuracy

In [64]:
ba_cam_dict, ba_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/wilds/ba')
#ba_kmeans_cam_dict, ba_kmeans_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/data/ba_kmeans_closest_classes')
ba_kmeans_cam_dict, ba_kmeans_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/data/ba_full_kmeans_closest_batch_classes')

In [65]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = ba_kmeans_cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)

def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = ba_cam_dict[cam_ids[cam_ind]]
    kmeans_predictions = ba_kmeans_cam_dict[cam_ids[cam_ind]]
    print(f'Original {ba_orig_dict[cam_ids[cam_ind]]} KMeans Original {ba_kmeans_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)} KMeans Max {max(kmeans_predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]] , predictions))
    kmeans_predictions = np.hstack((ba_kmeans_orig_dict[cam_ids[cam_ind]], kmeans_predictions))
    plt.plot(range(0,len(predictions)), predictions, label='Uniform')
    plt.plot(range(0,len(kmeans_predictions)), kmeans_predictions, label='Kmeans')
    plt.legend()
    f,l,m = cam_flm(cam_id=[cam_ids[cam_ind]])
    label_unique_counts = np.unique(l, return_counts=True)
    print(f'Total of {sum(label_unique_counts[1] > 25)} classes over cutoff')
    print('[',end='')
    for y,c in zip(label_unique_counts[0], label_unique_counts[1]):
        print_green(f'{y}:{c}:{c/sum(label_unique_counts[1]):.2f}, ', c > 25, end='')
    print(']')
interact(plot, cam_ind=(0,len(good_inds)-1));

interactive(children=(IntSlider(value=10, description='cam_ind', max=21), Output()), _dom_classes=('widget-int…

## Iterative Kmeans Sampling

In [66]:
argmax_cam_dict, argmax_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/kmeans_argmax_n_classes')
ba_argmax_cam_dict, ba_argmax_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/ba_kmeans_argmax_n_classes')

In [67]:
good_inds = []
for i in range(len(cam_ids)):
    if len(argmax_cam_dict[cam_ids[i]]) == 0:
        continue
    predictions = argmax_cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)

def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = argmax_cam_dict[cam_ids[cam_ind]]
    print(f'Original {argmax_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((argmax_orig_dict[cam_ids[cam_ind]] , predictions))
    plt.plot(range(0,len(predictions)), predictions, label='Iterative')
    plt.legend()
interact(plot, cam_ind=(0,len(good_inds)-1));

interactive(children=(IntSlider(value=5, description='cam_ind', max=10), Output()), _dom_classes=('widget-inte…

In [68]:
good_inds = []
for i in range(len(cam_ids)):
    if len(ba_argmax_cam_dict[cam_ids[i]]) == 0:
        continue
    predictions = ba_argmax_cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)

def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = ba_argmax_cam_dict[cam_ids[cam_ind]]
    print(f'Original {ba_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]] , predictions))
    # predictions = np.hstack((ba_argmax_orig_dict[cam_ids[cam_ind]] , predictions))
    plt.plot(range(0,len(predictions)), predictions, label='Uniform')
    plt.legend()
interact(plot, cam_ind=(0,len(good_inds)-1));

interactive(children=(IntSlider(value=5, description='cam_ind', max=10), Output()), _dom_classes=('widget-inte…

## Overall Comparison

In [69]:
cam_ids = list(cam_dict.keys() & kmeans_cam_dict.keys() & argmax_cam_dict.keys())

In [70]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = argmax_cam_dict[cam_ids[i]]
    if len(predictions) == 0:
        continue
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)

def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = cam_dict[cam_ids[cam_ind]]
    kmeans_predictions = kmeans_cam_dict[cam_ids[cam_ind]]
    argmax_predictions = argmax_cam_dict[cam_ids[cam_ind]]
    print(f'Original {orig_dict[cam_ids[cam_ind]]} KMeans Original {kmeans_orig_dict[cam_ids[cam_ind]]} Iterative Original {argmax_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)} KMeans Max {max(kmeans_predictions)} Iterative Max {max(argmax_predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((orig_dict[cam_ids[cam_ind]] , predictions))
    kmeans_predictions = np.hstack((kmeans_orig_dict[cam_ids[cam_ind]], kmeans_predictions))
    argmax_predictions = np.hstack((argmax_orig_dict[cam_ids[cam_ind]], argmax_predictions))
    plt.title('Accuracy vs Shots')
    plt.ylabel("Accuracy")
    plt.xlabel("Number of Shots")
    plt.plot(range(0,len(predictions)), predictions, label='Balanced')
    plt.plot(range(0,len(kmeans_predictions)), kmeans_predictions, label='Kmeans')
    plt.plot(range(0,len(argmax_predictions)), argmax_predictions, label='Iterative')
    plt.legend()
    plt.savefig(f'images/accuracy_{cam_ids[cam_ind]}.png')
    f,l,m = cam_flm(cam_id=[cam_ids[cam_ind]])
    label_unique_counts = np.unique(l, return_counts=True)
    print(f'Total of {sum(label_unique_counts[1] > 25)} classes over cutoff')
    print('[',end='')
    for y,c in zip(label_unique_counts[0], label_unique_counts[1]):
        print_green(f'{y}:{c}:{c/sum(label_unique_counts[1]):.2f}, ', c > 25, end='')
    print(']')
interact(plot, cam_ind=(0,len(good_inds)-1));

interactive(children=(IntSlider(value=5, description='cam_ind', max=10), Output()), _dom_classes=('widget-inte…

In [72]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = ba_argmax_cam_dict[cam_ids[i]]
    if len(predictions) == 0:
        continue
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)

def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = ba_cam_dict[cam_ids[cam_ind]]
    kmeans_predictions = ba_kmeans_cam_dict[cam_ids[cam_ind]]
    argmax_predictions = ba_argmax_cam_dict[cam_ids[cam_ind]]
    print(f'Original {ba_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)} KMeans Max {max(kmeans_predictions)} Iterative Max {max(argmax_predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]] , predictions))
    kmeans_predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]], kmeans_predictions))
    argmax_predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]], argmax_predictions))
    plt.title('Balanced Accuracy vs Shots')
    plt.ylabel("Balanced Accuracy")
    plt.xlabel("Number of Shots")
    plt.plot(range(0,len(predictions)), predictions, label='Balanced')
    plt.plot(range(0,len(kmeans_predictions)), kmeans_predictions, label='Kmeans')
    plt.plot(range(0,len(argmax_predictions)), argmax_predictions, label='Iterative')
    plt.legend()
    plt.savefig(f'images/balanced_accuracy_{cam_ids[cam_ind]}.png')
    f,l,m = cam_flm(cam_id=[cam_ids[cam_ind]])
    label_unique_counts = np.unique(l, return_counts=True)
    print(f'Total of {sum(label_unique_counts[1] > 25)} classes over cutoff')
    print('[',end='')
    for y,c in zip(label_unique_counts[0], label_unique_counts[1]):
        print_green(f'{y}:{c}:{c/sum(label_unique_counts[1]):.2f}, ', c > 25, end='')
    print(']')
interact(plot, cam_ind=(0,len(good_inds)-1));

interactive(children=(IntSlider(value=5, description='cam_ind', max=10), Output()), _dom_classes=('widget-inte…

## Comparing Different Models

In [81]:
models = ['ERM', 'PseudoLabel', 'deepCORAL', 'DANN']

In [82]:
model_dicts = {}
ba_model_dicts = {}
for model in models:
    cam_dict, orig_dict = get_dict_path(f'/dccstor/hoo-misha-1/wilds/WOODS/results/{model}')
    ba_cam_dict, ba_orig_dict = get_dict_path(f'/dccstor/hoo-misha-1/wilds/WOODS/results/{model}_ba')
    model_dicts[model] = (cam_dict, orig_dict)
    ba_model_dicts[model] = (ba_cam_dict, ba_orig_dict)


good_inds = []
for i in range(len(cam_ids)):
    predictions = model_dicts['ERM'][0][cam_ids[i]]
    if len(predictions) == 0:
        continue
    if predictions[-1] > 0:
        good_inds.append(i)    
    
def plot_4(cam_ind):
    cam_ind = good_inds[cam_ind]
    for model in model_dicts:
        cam_dict, orig_dict = model_dicts[model]
        ba_cam_dict, ba_orig_dict = ba_model_dicts[model]
        predictions = cam_dict[cam_ids[cam_ind]]
        ba_predictions = ba_cam_dict[cam_ids[cam_ind]]
        #print(f'Original {orig_dict[cam_ids[cam_ind]]} Balanced Original {ba_orig_dict[cam_ids[cam_ind]]}')
        #print(f'Max {max(predictions)} Balanced Max {max(ba_predictions)}')
        metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
        unique_counts = np.unique(metadata[:,0],return_counts=True)
        ind = np.where(unique_counts[0] == cam_ids[cam_ind])
        #print(f'With {unique_counts[1][ind]} data points pre-pruning')
        predictions = np.hstack((orig_dict[cam_ids[cam_ind]] , predictions))
        ba_predictions = np.hstack((0, ba_predictions))
        plt.subplot(2,1,1)
        plt.title('Original')
        plt.plot(range(0,len(predictions)), predictions, label=model)
        plt.legend()
        plt.subplot(2,1,2)
        plt.title('Balanced')
        plt.plot(range(0,len(ba_predictions)), ba_predictions, label=model)
        plt.legend()
interact(plot_4, cam_ind=(0,len(good_inds)));

interactive(children=(IntSlider(value=5, description='cam_ind', max=11), Output()), _dom_classes=('widget-inte…