In [17]:
import os
import numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score, f1_score
from tqdm import tqdm
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, fixed, interact_manual

In [2]:
rng = np.random.default_rng(2022)

In [4]:
path_base = '/dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/PseudoLabel'
os.system('ls /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/PseudoLabel')

resnet50_id_test_features.npy
resnet50_id_test_labels.npy
resnet50_id_test_metadata.npy
resnet50_id_val_features.npy
resnet50_id_val_labels.npy
resnet50_id_val_metadata.npy
resnet50_test_features.npy
resnet50_test_labels.npy
resnet50_test_metadata.npy
resnet50_train_features.npy
resnet50_train_labels.npy
resnet50_train_metadata.npy
resnet50_val_features.npy
resnet50_val_labels.npy
resnet50_val_metadata.npy


0

In [5]:
def load_flm():
    test_features = np.load(f'{path_base}/resnet50_test_features.npy')
    test_labels = np.load(f'{path_base}/resnet50_test_labels.npy')
    test_metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    return test_features, test_labels, test_metadata

In [6]:
def prune_cam_id(cutoff=50):
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    return unique_counts[0][unique_counts[1] > cutoff]

In [7]:
def get_cam_ind(metadata, num_cams=1, cam_id = None):
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    if cam_id is None:
        top_id = unique_counts[0][np.argpartition(unique_counts[1], -num_cams)[-num_cams:]]
    else:
        top_id = cam_id
    print(f'Selecting cameras with ids {top_id}')
    ind = np.zeros(metadata.shape[0]) == 1
    for c_id in top_id:
        ind = np.logical_or(ind,metadata[:,0] == c_id)
    return ind

In [8]:
def cam_flm(num_cams=1, cam_id = None):
    features, labels, metadata = load_flm()
    cam_ind = get_cam_ind(metadata, num_cams, cam_id)
    return features[cam_ind], labels[cam_ind], metadata[cam_ind]

In [9]:
def prune_flm(features, labels, metadata, cutoff=25):
    unique_counts = np.unique(labels,return_counts=True)
    print(f'|   | Total number of classes {len(unique_counts[0])}')
    prune_classes = unique_counts[0][unique_counts[1] < cutoff]
    prune_ind = []
    for clss in prune_classes:
        prune_ind.append((labels == clss).nonzero()[0])
    print(f'|   |   | Pruning {len(prune_classes)} classes with {len(np.concatenate(prune_ind))} data points')
    if len(prune_ind) == 0:
        return features, labels, metadata
    prune_ind = np.concatenate(prune_ind)
    pruned_ind = np.ones(labels.shape[0]) == 1
    pruned_ind[prune_ind] = False
    return features[pruned_ind], labels[pruned_ind], metadata[pruned_ind]

In [10]:
def balanced_sample_ind(labels, batch = 5):
    unique_classes = np.unique(labels)
    #print(unique_classes)
    ret_ind = None
    for clss in unique_classes:
        class_ind = np.where(labels == clss)[0]
        #print(clss, class_ind)
        rand_ind = rng.choice(class_ind,batch)
        if ret_ind is None:
            ret_ind = rand_ind
        else:
            ret_ind = np.concatenate((ret_ind, rand_ind))
    return ret_ind

In [22]:
def get_prediction_accuracy(num_cams=1, largest=True, cam_id = None, cutoff = 25, batch = 5):
    f,l,m = cam_flm(num_cams, cam_id)
    f,l,m = prune_flm(f,l,m, cutoff)
    sampled_ind = balanced_sample_ind(l,batch)
    nonsampled_ind = np.ones(l.shape[0]) == 1
    nonsampled_ind[sampled_ind] = False
    try:
        clf = LogisticRegression(random_state=0,max_iter=2000).fit(f[sampled_ind], l[sampled_ind])
        predictions = clf.predict(f[nonsampled_ind])
    except Exception as e:
        #print(e)
        return -1
    
    return balanced_accuracy_score(l[nonsampled_ind], predictions)#, f1_score(l[nonsampled_ind], predictions,average='macro')

In [23]:
def get_original_accuracy(num_cams=1, largest=True, cam_id = None, cutoff = 25):
    f,l,m = cam_flm(num_cams, cam_id)
    f,l,m = prune_flm(f,l,m, cutoff)
    weight = np.load('pseudo_classifier_weight.npy')
    bias = np.load('pseudo_classifier_bias.npy')
    pred_logits = f @ weight.T + bias
    pred = np.argmax(pred_logits,axis=1)
    return np.sum(pred == l)/len(pred)

In [None]:
cam_ids = prune_cam_id(800)
#print(f'Total {len(cam_ids)} to check')
cam_dict = {}
orig_dict = {}
cutoff = 25
for cam_id in tqdm(cam_ids):
    print(f'| Cam ID {cam_id}')
    cam_dict[cam_id] = []
    orig_dict[cam_id] = get_original_accuracy(cam_id=[cam_id], cutoff=cutoff)
    print(f'|   | {orig_dict[cam_id]}')
    for batch in range(1,cutoff):
        print(f'|   | {batch}')
        prediction_acc = 0
        for i in range(3):
            prediction_acc += get_prediction_accuracy(cam_id = [cam_id], cutoff=cutoff, batch=batch)
        prediction_acc /= 3
        print(f'|   | {prediction_acc}')
        cam_dict[cam_id].append(prediction_acc)

  0%|                                                                                                                                                                                                                                  | 0/11 [00:00<?, ?it/s]

| Cam ID 73
Selecting cameras with ids [73]
|   | Total number of classes 12
|   |   | Pruning 7 classes with 51 data points
|   | 0.3537234042553192
|   | 1
Selecting cameras with ids [73]
|   | Total number of classes 12
|   |   | Pruning 7 classes with 51 data points
Selecting cameras with ids [73]
|   | Total number of classes 12
|   |   | Pruning 7 classes with 51 data points
Selecting cameras with ids [73]
|   | Total number of classes 12
|   |   | Pruning 7 classes with 51 data points
|   | 0.3055385984181029
|   | 2
Selecting cameras with ids [73]
|   | Total number of classes 12
|   |   | Pruning 7 classes with 51 data points
Selecting cameras with ids [73]
|   | Total number of classes 12
|   |   | Pruning 7 classes with 51 data points
Selecting cameras with ids [73]
|   | Total number of classes 12
|   |   | Pruning 7 classes with 51 data points
|   | 0.35317205969941917
|   | 3
Selecting cameras with ids [73]
|   | Total number of classes 12
|   |   | Pruning 7 classes with

  9%|███████████████████▋                                                                                                                                                                                                     | 1/11 [04:11<41:58, 251.85s/it]

|   | 0.5273947035094705
| Cam ID 95
Selecting cameras with ids [95]
|   | Total number of classes 6
|   |   | Pruning 3 classes with 36 data points
|   | 0.0
|   | 1
Selecting cameras with ids [95]
|   | Total number of classes 6
|   |   | Pruning 3 classes with 36 data points
Selecting cameras with ids [95]
|   | Total number of classes 6
|   |   | Pruning 3 classes with 36 data points
Selecting cameras with ids [95]
|   | Total number of classes 6
|   |   | Pruning 3 classes with 36 data points
|   | 0.4267199951970389
|   | 2
Selecting cameras with ids [95]
|   | Total number of classes 6
|   |   | Pruning 3 classes with 36 data points
Selecting cameras with ids [95]
|   | Total number of classes 6
|   |   | Pruning 3 classes with 36 data points
Selecting cameras with ids [95]
|   | Total number of classes 6
|   |   | Pruning 3 classes with 36 data points
|   | 0.48605976718742766
|   | 3
Selecting cameras with ids [95]
|   | Total number of classes 6
|   |   | Pruning 3 classes wi

 18%|███████████████████████████████████████▍                                                                                                                                                                                 | 2/11 [05:14<21:06, 140.71s/it]

|   | 0.7112745897773745
| Cam ID 101
Selecting cameras with ids [101]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 50 data points
|   | 0.01818181818181818
|   | 1
Selecting cameras with ids [101]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 50 data points
Selecting cameras with ids [101]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 50 data points
Selecting cameras with ids [101]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 50 data points
|   | 0.21372075095188459
|   | 2
Selecting cameras with ids [101]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 50 data points
Selecting cameras with ids [101]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 50 data points
Selecting cameras with ids [101]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 50 data points
|   | 0.2981001784279724
|   | 3
Selecting cameras with ids [101]
|   | Total number of classes 

 27%|███████████████████████████████████████████████████████████▏                                                                                                                                                             | 3/11 [15:16<46:48, 351.01s/it]

|   | 0.596594569573378
| Cam ID 120
Selecting cameras with ids [120]
|   | Total number of classes 21
|   |   | Pruning 4 classes with 50 data points
|   | 0.05218904030153668
|   | 1
Selecting cameras with ids [120]
|   | Total number of classes 21
|   |   | Pruning 4 classes with 50 data points
Selecting cameras with ids [120]
|   | Total number of classes 21
|   |   | Pruning 4 classes with 50 data points
Selecting cameras with ids [120]
|   | Total number of classes 21
|   |   | Pruning 4 classes with 50 data points
|   | 0.1421450477948111
|   | 2
Selecting cameras with ids [120]
|   | Total number of classes 21
|   |   | Pruning 4 classes with 50 data points
Selecting cameras with ids [120]
|   | Total number of classes 21
|   |   | Pruning 4 classes with 50 data points
Selecting cameras with ids [120]
|   | Total number of classes 21
|   |   | Pruning 4 classes with 50 data points
|   | 0.17730825373319245
|   | 3
Selecting cameras with ids [120]
|   | Total number of classes 2

 36%|██████████████████████████████████████████████████████████████████████████████▏                                                                                                                                        | 4/11 [37:12<1:25:25, 732.17s/it]

|   | 0.4538802169789178
| Cam ID 187
Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data points
|   | 0.9980255363959458
|   | 1
Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data points
Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data points
Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data points
|   | -1.0
|   | 2
Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data points
Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data points
Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data points
|   | -1.0
|   | 3
Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data po

 45%|██████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                      | 5/11 [37:35<47:39, 476.55s/it]

Selecting cameras with ids [187]
|   | Total number of classes 2
|   |   | Pruning 1 classes with 3 data points
|   | -1.0
| Cam ID 188
Selecting cameras with ids [188]
|   | Total number of classes 3
|   |   | Pruning 2 classes with 4 data points
|   | 0.9997090485888857
|   | 1
Selecting cameras with ids [188]
|   | Total number of classes 3
|   |   | Pruning 2 classes with 4 data points
Selecting cameras with ids [188]
|   | Total number of classes 3
|   |   | Pruning 2 classes with 4 data points
Selecting cameras with ids [188]
|   | Total number of classes 3
|   |   | Pruning 2 classes with 4 data points
|   | -1.0
|   | 2
Selecting cameras with ids [188]
|   | Total number of classes 3
|   |   | Pruning 2 classes with 4 data points
Selecting cameras with ids [188]
|   | Total number of classes 3
|   |   | Pruning 2 classes with 4 data points
Selecting cameras with ids [188]
|   | Total number of classes 3
|   |   | Pruning 2 classes with 4 data points
|   | -1.0
|   | 3
Selecting

 55%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                  | 6/11 [38:03<26:59, 323.94s/it]

Selecting cameras with ids [188]
|   | Total number of classes 3
|   |   | Pruning 2 classes with 4 data points
|   | -1.0
| Cam ID 270
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
|   | 0.015327102803738318
|   | 1
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
|   | 0.21497441319850555
|   | 2
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data poin

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


|   | 0.5343080682062187
|   | 17
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
|   | 0.5264865629158028
|   | 18
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
|   | 0.5356803273516672
|   | 19
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 63 data points
Selecting cameras with ids [270]
|   | Total number of classes 13
|   |   | Pruning 4 classes with 6

 64%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                               | 7/11 [51:06<31:35, 473.94s/it]

|   | 0.5532292796972517
| Cam ID 287
Selecting cameras with ids [287]
|   | Total number of classes 16
|   |   | Pruning 5 classes with 44 data points
|   | 0.04101326899879373
|   | 1
Selecting cameras with ids [287]
|   | Total number of classes 16
|   |   | Pruning 5 classes with 44 data points
Selecting cameras with ids [287]
|   | Total number of classes 16
|   |   | Pruning 5 classes with 44 data points
Selecting cameras with ids [287]
|   | Total number of classes 16
|   |   | Pruning 5 classes with 44 data points
|   | 0.17302909568496325
|   | 2
Selecting cameras with ids [287]
|   | Total number of classes 16
|   |   | Pruning 5 classes with 44 data points
Selecting cameras with ids [287]
|   | Total number of classes 16
|   |   | Pruning 5 classes with 44 data points
Selecting cameras with ids [287]
|   | Total number of classes 16
|   |   | Pruning 5 classes with 44 data points
|   | 0.21270410319944708
|   | 3
Selecting cameras with ids [287]
|   | Total number of classes

 73%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                          | 8/11 [1:04:03<28:31, 570.57s/it]

|   | 0.5005451420530768
| Cam ID 288
Selecting cameras with ids [288]
|   | Total number of classes 14
|   |   | Pruning 3 classes with 24 data points
|   | 0.011688311688311689
|   | 1
Selecting cameras with ids [288]
|   | Total number of classes 14
|   |   | Pruning 3 classes with 24 data points
Selecting cameras with ids [288]
|   | Total number of classes 14
|   |   | Pruning 3 classes with 24 data points
Selecting cameras with ids [288]
|   | Total number of classes 14
|   |   | Pruning 3 classes with 24 data points
|   | 0.18870439991646568
|   | 2
Selecting cameras with ids [288]
|   | Total number of classes 14
|   |   | Pruning 3 classes with 24 data points
Selecting cameras with ids [288]
|   | Total number of classes 14
|   |   | Pruning 3 classes with 24 data points
Selecting cameras with ids [288]
|   | Total number of classes 14
|   |   | Pruning 3 classes with 24 data points
|   | 0.25538355445333794
|   | 3
Selecting cameras with ids [288]
|   | Total number of classe

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual

def plot(cam_ind):
    predictions = cam_dict[cam_ids[cam_ind]]
    plt.plot(range(1,len(predictions)+2), predictions)
    
interact(plot, cam_ind=(0,len(cam_ids));

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import matplotlib.pyplot as plt

good_inds = []
for i in range(len(cam_ids)):
    predictions = cam_dict[cam_ids[i]]
    if predictions[-1] > 0:
        good_inds.append(i)
        
def plot_2(cam_ind):
    cam_ind = good_inds[cam_ind]
    print(f'Camera id {cam_ids[cam_ind]}')
    predictions = cam_dict[cam_ids[cam_ind]]
    print(f'Original {orig_dict[cam_ids[cam_ind]]}')
    print(f'Max {max(predictions)}')
    metadata = np.load(f'{path_base}/resnet50_test_metadata.npy')
    unique_counts = np.unique(metadata[:,0],return_counts=True)
    ind = np.where(unique_counts[0] == cam_ids[cam_ind])
    print(f'With {unique_counts[1][ind]} data points pre-pruning')
    predictions = [0] + orig_dict[cam_ids[cam_ind]]
    plt.plot(range(0,len(predictions)), predictions)


interact(plot_2, cam_ind=(0,len(good_inds)))

In [None]:
from sklearn.manifold import TSNE

f = np.load(f'{path_base}/resnet50_test_features.npy')
embedded = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(f)

In [None]:
f.shape

In [None]:
embedded.shape

In [None]:
import matplotlib.colors as mcolors

f,l,m = load_flm()
colors = mcolors.CSS4_COLORS
color_names = list(colors.keys())

c = 0
for i in np.unique(m[:,0]):
    ind = m[:,0] == i
    plt.plot(embedded[:,0][ind], embedded[:,1][ind],c=color_names[c],marker='.',linestyle = 'None')
    c += 10
    c %= len(color_names)

In [None]:
def plot_3(i):
    c = 10#np.random.randint(len(color_names))
    ind = m[:,0] == i
    plt.plot(embedded[:,0][ind], embedded[:,1][ind],c=color_names[c],marker='o',linestyle = 'None')

interact(plot_3, i=(0,len(np.unique(m[:,0]))));

In [None]:
f,l,m = load_flm()
colors = mcolors.CSS4_COLORS
color_names = list(colors.keys())

c = 0
for i in np.unique(m[:,0]):
    ind = m[:,0] == i
    for j in np.unique(l[ind]):
        ind_2 = l[ind] == j
        plt.plot(embedded[:,0][ind][ind_2], embedded[:,1][ind][ind_2],c=color_names[c],marker='.',linestyle = 'None')
        c += 10
        c %= len(color_names)

In [None]:
def plot_3(i):
    c = 10#np.random.randint(len(color_names))
    ind = m[:,0] == i
    for j in np.unique(l[ind]):
        ind_2 = l[ind] == j
        plt.plot(embedded[:,0][ind][ind_2], embedded[:,1][ind][ind_2],c=color_names[c],marker='.',linestyle = 'None')
        c += 10
        c %= len(color_names)

interact(plot_3, i=(0,len(np.unique(m[:,0]))));