In [1]:
import sys
sys.path.append('/dccstor/hoo-misha-1/wilds/WOODS')
import os

import numpy as np

import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
import matplotlib.pyplot as plt

import pickle

from scripts.utils import *

In [2]:
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 100

In [29]:
dataset = 'breeds'
if dataset == 'iwildcam':
    model_name = 'resnet50'
elif dataset == 'breeds':
    model_name = 'resnet50'
elif dataset == 'camelyon17':
    model_name = 'densenet121'
elif dataset == 'cifar100':
    model_name = 'resnet50'

In [30]:
path_base = f'/dccstor/hoo-misha-1/wilds/wilds/features/{dataset}'
set_path_base(path_base)
set_model_name(model_name)

In [31]:
sampling_algorithms_dict = {'balanced':'balanced_sample_ind', 'full':'full_kmeans_sample_ind', 'class':'class_kmeans_sample_ind', 'iterative':'iterative_kmeans_sample_ind', 'weighted':'weighted_iterative_kmeans_sample_ind', 'typiclust':'typiclust_sampled_ind'}

In [32]:
def get_dict_path(root_path):
    cam_dict_path = f'{root_path}_cam_dict.pkl'
    orig_dict_path = f'{root_path}_orig_dict.pkl'
    
    with open(cam_dict_path,'rb') as file:
        cam_dict = pickle.load(file)

    with open(orig_dict_path,'rb') as file:
        orig_dict = pickle.load(file)
    
    return cam_dict, orig_dict

In [33]:
os.listdir(f'/dccstor/hoo-misha-1/wilds/wilds/pretrained/{dataset}')

['breeds_deepCORAL.pth',
 '.ipynb_checkpoints',
 'breeds_wassersteindeepCORAL.pth',
 'breeds_DANN.pth',
 'breeds_ERM.pth']

## Balanced Sampling


In [59]:
cam_dicts = {}
orig_dicts = {}
ba_cam_dicts = {}
ba_orig_dicts = {}
models = []
sampling_algorithms = set()

for model in os.listdir(f'/dccstor/hoo-misha-1/wilds/wilds/pretrained/{dataset}'):
    if dataset == 'iwildcam':
        model = model[9:-4]
    elif dataset == 'breeds':
        model = model[7:-4]
    elif dataset == 'camelyon17':
        model = model[11:-4]
    elif dataset == 'cifar100':
        model = model[9:-4]
    models.append(model)
    cam_dicts[model] = {}
    orig_dicts[model] = {}
    ba_cam_dicts[model] = {}
    ba_orig_dicts[model] = {}
    #sampling_algorithms_dict = {'balanced':'balanced', 'full':'full', 'class':'class', 'iterative_pc:False_typ:False_w:False_d:False_phi:euclidean_lambda:1_rng:0':'no pc', 'iterative_pc:True_typ:True_w:False_d:False_phi:euclidean_lambda:1_rng:0':'typiclust', 'iterative_pc:True_typ:False_w:False_d:False_phi:euclidean_lambda:1_rng:0':'iterative', 'iterative_pc:True_typ:False_w:True_d:False_phi:euclidean_lambda:1_rng:0':'weighted', 'iterative_pc:True_typ:False_w:True_d:True_phi:euclidean_lambda:1_rng:0':'dense'}
    sampling_algorithms_dict = {'balanced':'balanced','iterative_False_True':'iterative_False_True','typiclust':'typiclust', 'full':'full', 'class':'class','iterative':'iterative', 'weighted':'weighted', 'dense':'dense'}
    for sampling_algorithm in sampling_algorithms_dict.keys():
        try:
            cam_dict, orig_dict = get_dict_path(f'/dccstor/hoo-misha-1/wilds/WOODS/results/{dataset}/{model}/{model}_{sampling_algorithm}')
            cam_dicts[model][sampling_algorithms_dict[sampling_algorithm]] = cam_dict
            orig_dicts[model][sampling_algorithms_dict[sampling_algorithm]] = orig_dict
            ba_cam_dict, ba_orig_dict = get_dict_path(f'/dccstor/hoo-misha-1/wilds/WOODS/results/{dataset}/{model}/{model}_{sampling_algorithm}_ba')
            ba_cam_dicts[model][sampling_algorithms_dict[sampling_algorithm]] = ba_cam_dict
            ba_orig_dicts[model][sampling_algorithms_dict[sampling_algorithm]] = ba_orig_dict
        except:
            print(f'No {model} model for {sampling_algorithm}')
            continue
        sampling_algorithms.add(sampling_algorithms_dict[sampling_algorithm])

No checkpo model for balanced
No checkpo model for iterative_False_True
No checkpo model for typiclust
No checkpo model for full
No checkpo model for class
No checkpo model for iterative
No checkpo model for weighted
No checkpo model for dense
No wassersteindeepCORAL model for iterative_False_True
No DANN model for iterative_False_True
No ERM model for balanced
No ERM model for iterative_False_True
No ERM model for typiclust
No ERM model for full
No ERM model for class
No ERM model for iterative
No ERM model for weighted
No ERM model for dense


In [60]:
# models.remove('IRM')

In [61]:
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)

In [62]:
cam_dicts[models[0]].keys()

dict_keys(['balanced', 'iterative_False_True', 'typiclust', 'full', 'class', 'iterative', 'weighted', 'dense'])

In [63]:
good_domains = cam_dicts[models[0]]['balanced'].keys()
# good_domains = None
# all_domains = set()
# for model, sampling_dict in cam_dicts.items():
#     for sampling_algorithm, domain_dict in sampling_dict.items():
#             if sampling_algorithm =='weighted':
#                 continue
#             good_domain_subset = set()
#             print(model, sampling_algorithm, len(domain_dict.keys()))
#             for domain, predictions in domain_dict.items():
#                 all_domains.add(domain)
#                 #if min(predictions) > 0:
#                     #print(model, min(predictions), domain)
#                 if len(predictions) != 0:
#                     good_domain_subset.add(domain)
#             if good_domains is None:
#                 good_domains = good_domain_subset
#             elif len(good_domain_subset) == 0:
#                 continue
#             else:
#                 good_domains &= good_domain_subset
            
good_domains = sorted(list(good_domains))
#all_domains = sorted(list(all_domains))

In [64]:
print(good_domains)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [65]:

model_widget = widgets.SelectMultiple(
    options=models,
    value=['ERM'],
    # rows=10,
    description='Model:',
    disabled=False
)

domain_widget = widgets.Dropdown(
    options=good_domains,
    value=good_domains[0],
    description='Domains:',
    disabled=False,
)

sampling_algorithms = list(sampling_algorithms)
sampling_algorithms.sort()
sampling_widget = widgets.SelectMultiple(
    options=sampling_algorithms, 
    #rows=10,
    description='Sampling Algorithms',
    disabled=False
)

def f(model_options,domain_option, sampling_option):
    for model_option in model_options:
        metadata = np.load(f'{path_base}/{model_option}/{model_name}_test_metadata.npy')
        for sampling_algorithm in sampling_option:
            predictions = cam_dicts[model_option][sampling_algorithm][domain_option]
            unique_counts = np.unique(metadata[:,0],return_counts=True)
            ind = np.where(unique_counts[0] == domain_option)
            predictions = np.hstack((orig_dicts[model_option][sampling_algorithm][domain_option] , predictions))
            if len(predictions) > 201:
                predictions = predictions[:201]
            plt.title('Accuracy vs Shots')
            plt.ylabel("Accuracy")
            plt.xlabel("Number of Shots")
            plt.title('Original')
            plt.plot(range(0,len(predictions)), predictions, label=f'{model_option} {sampling_algorithm}')
            plt.legend()
    
out = widgets.interactive_output(f, {'model_options':model_widget, 'domain_option':domain_widget, 'sampling_option': sampling_widget})
widgets.HBox([widgets.VBox([model_widget,domain_widget, sampling_widget]), out])

HBox(children=(VBox(children=(SelectMultiple(description='Model:', index=(4,), options=('deepCORAL', 'checkpo'…

iterative: no typicality initial centers, with preclustering, no typicality for future selections
weighted: no typicality initial centers, with preclustering, no typicality for future selections
iterative_False_True: typicality initial centers, no preclustering, no typicality for future selections
dense: no typicality initial centers, with preclustering, typicality for future selections

In [None]:
[24,49, 58, -59, 73, -95, 101, 120, 125]

In [None]:
model_widget_err = widgets.SelectMultiple(
    options=models,
    value=['ERM'],
    # rows=10,
    description='Model:',
    disabled=False
)

domain_widget_err = widgets.Dropdown(
    options=all_domains,
    value=all_domains[0],
    description='Domains:',
    disabled=False,
)

sampling_algorithms = list(sampling_algorithms)
sampling_widget_err = widgets.SelectMultiple(
    options=sampling_algorithms, 
    #rows=10,
    description='Sampling Algorithms',
    disabled=False
)

def g(model_options, sampling_options):
    for model_option in model_options:
        valid_predictions = []
        try:
            for sampling_algorithm in sampling_options:
                #print(sampling_algorithm)
                domain_shot_predictions = [None]*24
                for i in range(24):
                    domain_shot_predictions[i] = []
                for domain_option in cam_dicts[model_option][sampling_algorithm]:
                    #print(domain_option)
                    if len(cam_dicts[model_option][sampling_algorithm][domain_option]) == 24:
                        valid_predictions.append(cam_dicts[model_option][sampling_algorithm][domain_option])
                        for num_shot in range(24):
                            prediction = cam_dicts[model_option][sampling_algorithm][domain_option][num_shot]
                            
                            if prediction != -1:
                                domain_shot_predictions[num_shot].append(prediction)
                means = []
                stds = []
                for num_shot in range(24):   
                    num_shot_predictions = np.vstack(domain_shot_predictions[num_shot])
                    num_shot_means = num_shot_predictions.mean(axis=0)[0]
                    num_shot_stds = num_shot_predictions.std(axis=0, ddof=1)[0]
                    means.append(num_shot_means)
                    stds.append(num_shot_stds)
                means = np.array(means)
                stds = np.array(stds)
                #print(means)
                plt.plot(range(0,len(means)), means, label=f'{model_option} {sampling_algorithm}')
                plt.fill_between(range(0,len(means)), means - stds, means + stds, alpha=0.1)
        except Exception as e:
            print(e)
    if len(model_options) != 0:
        plt.legend()

out_err = widgets.interactive_output(g, {'model_options':model_widget_err, 'sampling_options':sampling_widget_err})
widgets.HBox([widgets.VBox([model_widget_err, sampling_widget_err]), out_err])

In [None]:



def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = cam_dict[cam_ids[cam_ind]]
    print(f'Original {orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((orig_dict[cam_ids[cam_ind]] , predictions))
    plt.title('Accuracy vs Shots')
    plt.ylabel("Accuracy")
    plt.xlabel("Number of Shots")
    plt.plot(range(0,len(predictions)), predictions, label='Balanced')
    plt.legend()
widget = interact(plot, cam_ind=(0,len(good_inds)-1);

In [None]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = ba_cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = ba_cam_dict[cam_ids[cam_ind]]
    print(f'Original {ba_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]] , predictions))
    plt.title('Balanced Accuracy vs Shots')
    plt.xlabel("Balanced Accuracy")
    plt.ylabel("Number of Shots")
    plt.plot(range(0,len(predictions)), predictions, label='Balanced')
    plt.legend()
interact(plot, cam_ind=(0,len(good_inds)-1));

## KMeans vs Uniform

In [None]:
cam_dict, orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/wilds/results/iwildcam/PseudoLabel')
kmeans_cam_dict, kmeans_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/data/kmeans_closest_batch_classes')

In [None]:
path_base = '/dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/PseudoLabel'
cam_ids = list(cam_dict.keys() & kmeans_cam_dict.keys())

In [None]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = cam_dict[cam_ids[cam_ind]]
    kmeans_predictions = kmeans_cam_dict[cam_ids[cam_ind]]
    print(f'Original {orig_dict[cam_ids[cam_ind]]} KMeans Original {kmeans_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)} KMeans Max {max(kmeans_predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((orig_dict[cam_ids[cam_ind]] , predictions))
    kmeans_predictions = np.hstack((kmeans_orig_dict[cam_ids[cam_ind]], kmeans_predictions))
    plt.plot(range(0,len(predictions)), predictions, label='Uniform')
    plt.plot(range(0,len(kmeans_predictions)), kmeans_predictions, label='Kmeans')
    plt.legend()
    f,l,m = cam_flm(cam_id=[cam_ids[cam_ind]])
    label_unique_counts = np.unique(l, return_counts=True)
    print(f'Total of {sum(label_unique_counts[1] > 25)} classes over cutoff')
    print('[',end='')
    for y,c in zip(label_unique_counts[0], label_unique_counts[1]):
        print_green(f'{y}:{c}:{c/sum(label_unique_counts[1]):.2f}, ', c > 25, end='')
    print(']')
interact(plot, cam_ind=(0,len(good_inds)-1));

## Balanced Accuracy

In [None]:
ba_cam_dict, ba_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/wilds/ba')
#ba_kmeans_cam_dict, ba_kmeans_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/data/ba_kmeans_closest_classes')
ba_kmeans_cam_dict, ba_kmeans_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/data/ba_full_kmeans_closest_batch_classes')

In [None]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = ba_kmeans_cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)

def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = ba_cam_dict[cam_ids[cam_ind]]
    kmeans_predictions = ba_kmeans_cam_dict[cam_ids[cam_ind]]
    print(f'Original {ba_orig_dict[cam_ids[cam_ind]]} KMeans Original {ba_kmeans_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)} KMeans Max {max(kmeans_predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]] , predictions))
    kmeans_predictions = np.hstack((ba_kmeans_orig_dict[cam_ids[cam_ind]], kmeans_predictions))
    plt.plot(range(0,len(predictions)), predictions, label='Uniform')
    plt.plot(range(0,len(kmeans_predictions)), kmeans_predictions, label='Kmeans')
    plt.legend()
    f,l,m = cam_flm(cam_id=[cam_ids[cam_ind]])
    label_unique_counts = np.unique(l, return_counts=True)
    print(f'Total of {sum(label_unique_counts[1] > 25)} classes over cutoff')
    print('[',end='')
    for y,c in zip(label_unique_counts[0], label_unique_counts[1]):
        print_green(f'{y}:{c}:{c/sum(label_unique_counts[1]):.2f}, ', c > 25, end='')
    print(']')
interact(plot, cam_ind=(0,len(good_inds)-1));

## Iterative Kmeans Sampling

In [None]:
argmax_cam_dict, argmax_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/kmeans_argmax_n_classes')
ba_argmax_cam_dict, ba_argmax_orig_dict = get_dict_path('/dccstor/hoo-misha-1/wilds/WOODS/notebooks/ba_kmeans_argmax_n_classes')

In [None]:
good_inds = []
for i in range(len(cam_ids)):
    if len(argmax_cam_dict[cam_ids[i]]) == 0:
        continue
    predictions = argmax_cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)

def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = argmax_cam_dict[cam_ids[cam_ind]]
    print(f'Original {argmax_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((argmax_orig_dict[cam_ids[cam_ind]] , predictions))
    plt.plot(range(0,len(predictions)), predictions, label='Iterative')
    plt.legend()
interact(plot, cam_ind=(0,len(good_inds)-1));

In [None]:
good_inds = []
for i in range(len(cam_ids)):
    if len(ba_argmax_cam_dict[cam_ids[i]]) == 0:
        continue
    predictions = ba_argmax_cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)

def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)
def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = ba_argmax_cam_dict[cam_ids[cam_ind]]
    print(f'Original {ba_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]] , predictions))
    # predictions = np.hstack((ba_argmax_orig_dict[cam_ids[cam_ind]] , predictions))
    plt.plot(range(0,len(predictions)), predictions, label='Uniform')
    plt.legend()
interact(plot, cam_ind=(0,len(good_inds)-1));

## Overall Comparison

In [None]:
cam_ids = list(cam_dict.keys() & kmeans_cam_dict.keys() & argmax_cam_dict.keys())

In [None]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = argmax_cam_dict[cam_ids[i]]
    if len(predictions) == 0:
        continue
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)

def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = cam_dict[cam_ids[cam_ind]]
    kmeans_predictions = kmeans_cam_dict[cam_ids[cam_ind]]
    argmax_predictions = argmax_cam_dict[cam_ids[cam_ind]]
    print(f'Original {orig_dict[cam_ids[cam_ind]]} KMeans Original {kmeans_orig_dict[cam_ids[cam_ind]]} Iterative Original {argmax_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)} KMeans Max {max(kmeans_predictions)} Iterative Max {max(argmax_predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((orig_dict[cam_ids[cam_ind]] , predictions))
    kmeans_predictions = np.hstack((kmeans_orig_dict[cam_ids[cam_ind]], kmeans_predictions))
    argmax_predictions = np.hstack((argmax_orig_dict[cam_ids[cam_ind]], argmax_predictions))
    plt.title('Accuracy vs Shots')
    plt.ylabel("Accuracy")
    plt.xlabel("Number of Shots")
    plt.plot(range(0,len(predictions)), predictions, label='Balanced')
    plt.plot(range(0,len(kmeans_predictions)), kmeans_predictions, label='Kmeans')
    plt.plot(range(0,len(argmax_predictions)), argmax_predictions, label='Iterative')
    plt.legend()
    plt.savefig(f'images/accuracy_{cam_ids[cam_ind]}.png')
    f,l,m = cam_flm(cam_id=[cam_ids[cam_ind]])
    label_unique_counts = np.unique(l, return_counts=True)
    print(f'Total of {sum(label_unique_counts[1] > 25)} classes over cutoff')
    print('[',end='')
    for y,c in zip(label_unique_counts[0], label_unique_counts[1]):
        print_green(f'{y}:{c}:{c/sum(label_unique_counts[1]):.2f}, ', c > 25, end='')
    print(']')
interact(plot, cam_ind=(0,len(good_inds)-1));

In [None]:
good_inds = []
for i in range(len(cam_ids)):
    predictions = ba_argmax_cam_dict[cam_ids[i]]
    if len(predictions) == 0:
        continue
    if predictions[-1] > 0:
        good_inds.append(i)
def print_green(text, green=True, end='\n'):
    print(f'\x1b[{32 if green else 31}m{text}\x1b[0m', end=end)

def plot(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Selecting camera with id {cam_ids[cam_ind]}')
    predictions = ba_cam_dict[cam_ids[cam_ind]]
    kmeans_predictions = ba_kmeans_cam_dict[cam_ids[cam_ind]]
    argmax_predictions = ba_argmax_cam_dict[cam_ids[cam_ind]]
    print(f'Original {ba_orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)} KMeans Max {max(kmeans_predictions)} Iterative Max {max(argmax_predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]] , predictions))
    kmeans_predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]], kmeans_predictions))
    argmax_predictions = np.hstack((ba_orig_dict[cam_ids[cam_ind]], argmax_predictions))
    plt.title('Balanced Accuracy vs Shots')
    plt.ylabel("Balanced Accuracy")
    plt.xlabel("Number of Shots")
    plt.plot(range(0,len(predictions)), predictions, label='Balanced')
    plt.plot(range(0,len(kmeans_predictions)), kmeans_predictions, label='Kmeans')
    plt.plot(range(0,len(argmax_predictions)), argmax_predictions, label='Iterative')
    plt.legend()
    plt.savefig(f'images/balanced_accuracy_{cam_ids[cam_ind]}.png')
    f,l,m = cam_flm(cam_id=[cam_ids[cam_ind]])
    label_unique_counts = np.unique(l, return_counts=True)
    print(f'Total of {sum(label_unique_counts[1] > 25)} classes over cutoff')
    print('[',end='')
    for y,c in zip(label_unique_counts[0], label_unique_counts[1]):
        print_green(f'{y}:{c}:{c/sum(label_unique_counts[1]):.2f}, ', c > 25, end='')
    print(']')
interact(plot, cam_ind=(0,len(good_inds)-1));

## Comparing Different Models

In [None]:
models = ['ERM', 'PseudoLabel', 'deepCORAL', 'DANN']

In [None]:
model_dicts = {}
ba_model_dicts = {}
for model in models:
    cam_dict, orig_dict = get_dict_path(f'/dccstor/hoo-misha-1/wilds/WOODS/results/iwildcam/{model}/{model}_iterative')
    ba_cam_dict, ba_orig_dict = get_dict_path(f'/dccstor/hoo-misha-1/wilds/WOODS/results/iwildcam/{model}/{model}_balanced')
    model_dicts[model] = (cam_dict, orig_dict)
    ba_model_dicts[model] = (ba_cam_dict, ba_orig_dict)


good_inds = []
for i in range(len(cam_ids)):
    predictions = model_dicts['ERM'][0][cam_ids[i]]
    if len(predictions) == 0:
        continue
    if predictions[-1] > 0:
        good_inds.append(i)    
    
def plot_4(cam_ind):
    cam_ind = good_inds[cam_ind]
    for model in model_dicts:
        cam_dict, orig_dict = model_dicts[model]
        ba_cam_dict, ba_orig_dict = ba_model_dicts[model]
        predictions = cam_dict[cam_ids[cam_ind]]
        ba_predictions = ba_cam_dict[cam_ids[cam_ind]]
        #print(f'Original {orig_dict[cam_ids[cam_ind]]} Balanced Original {ba_orig_dict[cam_ids[cam_ind]]}')
        #print(f'Max {max(predictions)} Balanced Max {max(ba_predictions)}')
        metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
        unique_counts = np.unique(metadata[:,0],return_counts=True)
        ind = np.where(unique_counts[0] == cam_ids[cam_ind])
        #print(f'With {unique_counts[1][ind]} data points pre-pruning')
        predictions = np.hstack((orig_dict[cam_ids[cam_ind]] , predictions))
        ba_predictions = np.hstack((0, ba_predictions))
        plt.subplot(2,1,1)
        plt.title('Original')
        plt.plot(range(0,len(predictions)), predictions, label=model)
        plt.legend()
        plt.subplot(2,1,2)
        plt.title('Balanced')
        plt.plot(range(0,len(ba_predictions)), ba_predictions, label=model)
        plt.legend()
interact(plot_4, cam_ind=(0,len(good_inds)));