In [None]:
import sys
import os
sys.path.append('..')
import src.svm_utils as svm_utils
import src.visualization_utils as viz_utils
import src.ds_utils as ds_utils
import torch
import numpy as np
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()

In [None]:
BLUE = sns.color_palette("tab10")[0]
RED = sns.color_palette("tab10")[3]
ORANGE = sns.color_palette("tab10")[1]
BROWN = sns.color_palette("tab10")[5]
GRAY = sns.color_palette("tab10")[7]
GREEN = sns.color_palette("tab10")[2]

import matplotlib.pylab as pylab
params = {'legend.fontsize': 12,
          'figure.figsize': (5, 3),
         'axes.labelsize': 14,
         'axes.titlesize':16,
         'xtick.labelsize':14,
         'ytick.labelsize':14}
pylab.rcParams.update(params)




In [None]:
beton_root = "/mnt/cfs/projects/correlated_errors/betons"
experiment_root = "/mnt/cfs/projects/correlated_errors/experiments/spurious_cifar100/unlabeled_1_4_new_spurious_norm"

svm_name = "svm_spurious_unlabeled_normalized"
name = os.path.join(experiment_root, f"svm_checkpoints/{svm_name}.pt") # SVM output file
svm_model_name = os.path.join(experiment_root, f"svm_checkpoints/{svm_name}_model.pkl") # SVM output file
model_root = os.path.join(experiment_root, "models")
model_ckpt = os.path.join(model_root, "spurious_supercifar100_unlabeled/version_0/checkpoints/checkpoint_last.pt")
loss_upweight_root = os.path.join(experiment_root, "loss_vec_files")
subset_root = os.path.join(experiment_root, "subset_index_files")

In [None]:
processor = viz_utils.SVMProcessor(name, root=beton_root, checkpoint_path=model_ckpt, get_unlabeled=True)
classes_to_drop = torch.load(processor.metrics['args']['indices_file'])['classes_to_drop']

In [None]:
classes_to_drop

In [None]:
import src.pytorch_datasets as pytorch_datasets
ds = pytorch_datasets.SuperCIFAR100(root="/mnt/nfs/home/saachij/datasets/cifar100", train=False)
class_names = np.array(ds.classes)
subclass_names = []
# asterisk subclass_names
for c, n in enumerate(ds.subclasses):
    name = ' '.join(n.split('_'))
    if c in classes_to_drop:
        name += "*"
    subclass_names.append(name)
subclass_names = np.array(subclass_names)
singular_class_names = ['aquatic mammal', 'fish', 'flower', 'food container', 'fruit or vegetable', 'household electrical device', 'household furniture', 'insect', 'large carnivore', 'large man-made outdoor thing', 'large natural outdoor scene', 'large omnivores and herbivore', 'medium-sized mammal', 'non-insect invertebrate', 'person', 'reptile', 'small mammal', 'tree', 'standard vehicle', 'specialized vehicle']

In [None]:
# Get cross val scores 
import torch
import numpy as np
def check_efficacy(split='test', is_correct=None):
    superclass = processor.metrics[f'{split}_metrics']['classes'] # 0 if female, 1 if male
    subclass = processor.metrics[f'{split}_metrics']['spuriouses'] #1 if blond, 2 if black hair, 0 if neither
    if is_correct is None:
        is_correct = processor.metrics[f'{split}_metrics']['ytrue']
    min_classes = []
    for c in np.unique(superclass):
        mask = superclass == c
        print(f"---{c}---")
        class_accs = []
        subc_list = np.unique(subclass[mask])
        for c2 in subc_list:
            if c2 in classes_to_drop:
                suffix="*"
            else:
                suffix = ""
            mask2 = subclass == c2
            acc = is_correct[mask & mask2].mean()
            class_accs.append(acc)
            num = len(is_correct[mask & mask2])
            print(f"{c2}{suffix}, {acc:0.4f}, {num}")
        min_classes.append(subc_list[np.argmin(class_accs)])
    print(min_classes)

In [None]:
print("val")
test_masks = check_efficacy("val")
print("\ntest")
test_masks = check_efficacy("test")
print("\ntrain")
train_masks = check_efficacy("train")

In [None]:
# split = 'test'
# for f in [
#     "runs/spurious_soft/version_0/metrics.pt", 
#     "runs/spurious_overweight2/version_0/metrics.pt",
#     "runs/spurious_oracle2/version_0/metrics.pt",
# ]:
#     print(f)
#     fix = torch.load(f)
#     print(fix[split]['Accuracy'])
#     is_correct = fix[split]['preds'] == fix[split]['classes']
#     check_efficacy(split, is_correct.float())

In [None]:
split = 'test'
test_dv = processor.metrics[f'{split}_metrics']['decision_values']
test_confs = processor.run_dict[split]['confs']
test_superclass = processor.metrics[f'{split}_metrics']['classes'] # 0 if female, 1 if male
test_subclass = processor.metrics[f'{split}_metrics']['spuriouses'] #1 if blond, 2 if black hair, 0 if neither
test_problematic = np.in1d(test_subclass, classes_to_drop)
test_pred_correct = processor.metrics[f'{split}_metrics']['ypred']
test_correct = processor.metrics[f'{split}_metrics']['ytrue']

In [None]:
import sklearn.metrics as sklearn_metrics
import seaborn as sns
import matplotlib.pyplot as plt
conf_matrix = sklearn_metrics.confusion_matrix(y_true=(test_problematic == 0), y_pred=test_correct)
fig, ax = plt.subplots(1, 1, figsize=(5, 4))
sns.heatmap(conf_matrix/conf_matrix.sum(), fmt='.2%', annot=True, cmap='Blues')
ax.yaxis.set_ticklabels(['Problematic', 'Not Problematic'])
ax.xaxis.set_ticklabels(['Incorrect', 'Correct'])
perc_incorr_that_is_prob = (test_problematic & (test_correct == 0)).sum()/(test_correct == 0).sum()
perc_prob_that_is_incorr = (test_problematic & (test_correct == 0)).sum()/(test_problematic).sum()
print(f"percentage incorrect that is problematic {perc_incorr_that_is_prob:0.3%}")
print(f"percentage problematic that is incorrect {perc_prob_that_is_incorr:0.3%}")
plt.show()

conf_matrix = sklearn_metrics.confusion_matrix(y_true=(test_problematic == 0), y_pred=test_pred_correct)
fig, ax = plt.subplots(1, 1, figsize=(5, 4))
sns.heatmap(conf_matrix/conf_matrix.sum(), fmt='.2%', annot=True, cmap='Blues')
ax.yaxis.set_ticklabels(['Problematic', 'Not Problematic'])
ax.xaxis.set_ticklabels(['Predicted Incorrect', 'Predicted Correct'])
perc_incorr_that_is_prob = (test_problematic & (test_pred_correct == 0)).sum()/(test_pred_correct == 0).sum()
perc_prob_that_is_incorr = (test_problematic & (test_pred_correct == 0)).sum()/(test_problematic).sum()
print(f"percentage flagged that is problematic {perc_incorr_that_is_prob:0.3%}")
print(f"percentage problematic that is flagged {perc_prob_that_is_incorr:0.3%}")
plt.show()

In [None]:
import scipy.stats as scipy_stats
import pandas as pd
def compute_entropy(arr):
    _, counts = np.unique(arr, return_counts=True)
    return scipy_stats.entropy(counts/counts.sum())

df = []
for c in range(20):
    mask = test_superclass == c
    N = len(test_dv[mask])
    dv_order = np.argsort(test_dv[mask])
    conf_order = np.argsort(test_confs[mask])

    for K in range(10, N, 5):
        df.append([
            c,
            K,
            compute_entropy(test_subclass[mask][dv_order[:K]]),
            compute_entropy(test_subclass[mask][conf_order[:K]]),
            compute_entropy(test_subclass[mask]),
        ])
df = pd.DataFrame(df, columns=['class', 'Top K', 'SVM', 'Confidence', 'All'])
df = df.melt(['class', 'Top K'], var_name='Method', value_name='Entropy')
fig, ax =plt.subplots(1, 1, figsize=(5, 3))
sns.lineplot(data=df, x='Top K', y='Entropy', hue='Method', ax=ax, ci=None)
plt.show()

In [None]:
import scipy.stats as scipy_stats
import pandas as pd
def compute_fraction(arr):
    return arr.sum()/len(arr)


df = []
for c in range(20):
    mask = test_superclass == c
    N = len(test_dv[mask])
    dv_order = np.argsort(test_dv[mask])
    conf_order = np.argsort(test_confs[mask])

    for K in range(10, N, 5):
        df.append([
            c,
            K,
            compute_fraction(test_problematic[mask][dv_order[:K]]),
            compute_fraction(test_problematic[mask][conf_order[:K]]),
            compute_fraction(test_problematic[mask]),
        ])
df = pd.DataFrame(df, columns=['class', 'Top K Flagged', 'SVM Decision Value', 'Confidence', 'Base Population'])
df = df.melt(['class', 'Top K Flagged'], var_name='Order', value_name='Fraction Minority Subclass')
fig, ax =plt.subplots(1, 1, figsize=(8, 4))
sns.lineplot(data=df, x='Top K Flagged', y='Fraction Minority Subclass', hue='Order', ax=ax, ci=None,
             hue_order=['SVM Decision Value', 'Confidence', 'Base Population'], palette=[BLUE, RED, GRAY])
# handles, labels = ax.get_legend_handles_labels()
# ax.legend(handles=handles, labels=labels)
os.makedirs("figures/spurious_cifar100", exist_ok=True)
plt.savefig("figures/spurious_cifar100/frac_pop.pdf", bbox_inches='tight')
plt.show()


In [None]:
df = pd.DataFrame()
df['dv'] = test_dv
df['confs'] = test_confs
df['superclass'] = test_superclass
df['superclass_name'] = class_names[test_superclass]
df['subclass'] = test_subclass
df['subclass_name'] = subclass_names[test_subclass]
df['is_corrects'] = test_correct
df = df.sort_values('subclass_name')

In [None]:
fig, ax = plt.subplots(20, 3, figsize=(15, 40))
for c in range(20):
    mask = df['superclass'] == c
    sns.violinplot(data=df[mask], x='subclass_name', y='dv', ax=ax[c, 0])
    ax[c, 0].set_xlabel(None)
    ax[c, 0].set_ylabel("SVM Decision Value")
    sns.violinplot(data=df[mask], x='subclass_name', y='confs', ax=ax[c, 1])
    ax[c, 1].set_xlabel(None)
    ax[c, 1].set_ylabel("Confidence")
    sns.barplot(data=df[mask], x='subclass_name', y='is_corrects', ax=ax[c, 2])
    ax[c, 2].set_xlabel(None)
    ax[c, 2].set_ylabel("Accuracy")
    ax[c, 2].set_title(class_names[c])
plt.tight_layout()
plt.show()

 # CLIP

In [None]:
import src.clip_utils as clip_utils
clip_analyzer = clip_utils.ClipAnalyzer(
    processor=processor, svm_model_name=svm_model_name, class_names=singular_class_names,
    clip_config_name='CIFAR100', do_normalize=True)

In [None]:
def get_cdf(arr, K_range=None):
    out = []
    if K_range is None:
        K_range = np.arange(10, len(arr), 10)
    for K in K_range:
        out.append(arr[:K].mean())
    out = np.array(out)
    return out, K_range

test_class = test_superclass
saved_caption_and_most_relevant_imgs = {}
df_dict = {}
for METHOD in ['CLASSIFY']:
    all_dfs = []
    for target_class in range(20):
        print(processor.metrics['cv_scores'][target_class])

        if METHOD == 'CLOSEST':
            print("performing closest")
            result = clip_analyzer.perform_closest_to_top_K(target_class, 'all')
        else:
            print("performing classify captions on svm")
            result = clip_analyzer.get_svm_style_top_K(target_class, 'all')
        print("--------")

        cdfs = {}
        class_mask = test_class==target_class
        masked_indices = np.arange(len(test_class))[class_mask]
        # K_range = np.arange(10, len(masked_indices), 10)
        K_range=np.arange(20, 200, 5)
        for caption_index in range(2):
            for direction in ['pos', 'neg']:
                caption_text = result[f'{direction}_captions'][caption_index]
                print(f"{direction}: {caption_text}")
                top_caption_latent = torch.tensor(result[f'{direction}_latents'][caption_index]).cuda()

                image_latents = clip_analyzer.clip_features['test'][class_mask].cuda()
                image_angles = clip_utils.order_descriptions_angle(mean_point=top_caption_latent.unsqueeze(0), query_points=image_latents)
                image_order = np.argsort(image_angles)[::-1]
                saved_caption_and_most_relevant_imgs[(METHOD, target_class, caption_index, direction)] = (masked_indices[image_order], caption_text)
                cdfs[direction], _ = get_cdf(test_correct[masked_indices[image_order]], K_range)
                # uncomment this to display the images
#                 processor._display_images(taken_index=masked_indices[image_order], taken_scores=image_angles[image_order],
#                             taken_confs=image_angles[image_order], split="test")

            df = pd.DataFrame()
            df['K'] = K_range
            for d, v in cdfs.items():
                df[d] = v
            df = df.melt('K', var_name='Direction', value_name='Accuracy')
            all_dfs.append(df)
            sns.lineplot(data=df, x='K', y='Accuracy', hue='Direction')
            plt.axhline(y=test_correct[masked_indices].mean(), xmin=0, xmax=K_range[-1], color='gray')
            plt.show()
    df_dict[METHOD] = all_dfs
    

In [None]:
subclass_names[classes_to_drop]

In [None]:
for c in range(20):
    print("--")
    print(subclass_names[classes_to_drop][c])
    for i in range(2):
        print(saved_caption_and_most_relevant_imgs[('CLASSIFY', c, i, 'neg')][1])

In [None]:
# display extremes for original model
for c in range(20):
    processor.display_extremes(c, split='test')

## Interventions

In [None]:
split = 'train'
orders = processor.orders['train']['SVM'][True] # most negative first
dv = processor.metrics[f'{split}_metrics']['decision_values']
superclass = processor.metrics[f'{split}_metrics']['classes']
subclass = processor.metrics[f'{split}_metrics']['spuriouses'] #1 if blond, 2 if black hair, 0 if neither
problematic = np.in1d(subclass, classes_to_drop)
def visualize_loss_vec(loss_vec):
    for c in np.unique(superclass):
        mask = superclass == c
        print(f"---{c}---")
        for c2 in np.unique(subclass[mask]):
            if c2 in classes_to_drop:
                suffix="*"
            else:
                suffix = ""
            mask2 = subclass == c2
            print(f"{c2}{suffix}, {loss_vec[mask & mask2].mean():0.4f}, {len(loss_vec[mask & mask2])}")

In [None]:
import seaborn as sns
import torch
import numpy as np
def get_automatic_loss_vec(filename=None):
    overall_loss_vec = np.ones(len(superclass)) * -1.0
    for c in range(processor.hparams['num_classes']):
        mask = superclass == c
        dv_vals = -dv[mask]
        dv_vals = dv_vals - dv_vals.min()
        dv_vals = dv_vals/dv_vals.mean()
        overall_loss_vec[mask] = dv_vals
    visualize_loss_vec(overall_loss_vec)

    train_indices = torch.load(processor.metrics['args']['indices_file'])['train_indices']
    big_loss_vec = torch.ones(train_indices.max()+1) * -1
    big_loss_vec[train_indices] = torch.tensor(overall_loss_vec).float()
    if filename is not None:
        torch.save(big_loss_vec, filename)

def get_simple_loss_vec(upweight=2, filename=None):
    overall_loss_vec = np.ones(len(superclass)) * -1.0
    for c in range(processor.hparams['num_classes']):
        mask = superclass == c
        overall_loss_vec[mask & (dv <= 0)] = upweight
        overall_loss_vec[mask & (dv > 0)] = 1
    visualize_loss_vec(overall_loss_vec)

    train_indices = torch.load(processor.metrics['args']['indices_file'])['train_indices']
    big_loss_vec = torch.ones(train_indices.max()+1) * -1
    big_loss_vec[train_indices] = torch.tensor(overall_loss_vec).float()
    if filename is not None:
        torch.save(big_loss_vec, filename)
            
def get_oracle_loss_vec(upweight=2, filename=None, balance=False):
    overall_loss_vec = np.ones(len(superclass)) * -1.0
    overall_loss_vec[problematic] = upweight
    overall_loss_vec[~problematic] = 1
    visualize_loss_vec(overall_loss_vec)

    train_indices = torch.load(processor.metrics['args']['indices_file'])['train_indices']
    big_loss_vec = torch.ones(train_indices.max()+1) * -1
    big_loss_vec[train_indices] = torch.tensor(overall_loss_vec).float()
    if filename is not None:
        torch.save(big_loss_vec, filename)

In [None]:
filename=None
filename=os.path.join(loss_upweight_root, "soft.pt")
get_automatic_loss_vec(filename=filename) # get soft upweighting
print("----------------------")
filename=os.path.join(loss_upweight_root, "overweight_2.pt")
get_simple_loss_vec(upweight=2, filename=filename)
print("----------------------")
# filename=os.path.join("loss_upweight_root", "balanced.pt")
# get_simple_loss_vec(balance=True, filename=filename)
# print("----------------------")
filename=os.path.join(loss_upweight_root, "oracle_2.pt")
get_oracle_loss_vec(upweight=2, filename=filename) # oracle upweight to 2
# # filename=os.path.join("loss_upweight_root", "oracle_balance.pt")
# get_oracle_loss_vec(balance=True, filename="oracle_balance.pt") # oracle upweight balanced

# Subset Intervention

In [None]:
split = 'unlabeled'
unlabeled_dv = processor.metrics[f'{split}_metrics']['decision_values']
unlabeled_confs = processor.run_dict[split]['confs']
unlabeled_superclass = processor.metrics[f'{split}_metrics']['classes'] # 0 if female, 1 if male
unlabeled_subclass = processor.metrics[f'{split}_metrics']['spuriouses'] #1 if blond, 2 if black hair, 0 if neither
unlabeled_problematic = np.in1d(unlabeled_subclass, classes_to_drop)
unlabeled_pred_correct = processor.metrics[f'{split}_metrics']['ypred']
unlabeled_correct = processor.metrics[f'{split}_metrics']['ytrue']

In [None]:
unlabeled_problematic.mean()

In [None]:
import scipy.stats as scipy_stats
import pandas as pd
rand_perm = np.arange(len(unlabeled_dv))
np.random.shuffle(rand_perm)

all_dv_inds = []
all_conf_inds = []
all_random_inds = []
K_range = np.arange(25, 125, 25)
for K in K_range:
    dv_inds = []
    conf_inds = []
    random_inds = []
    for c in range(20):
        mask = unlabeled_superclass == c
        masked_indices = np.arange(len(mask))[mask]
        N = len(unlabeled_dv[mask])
        dv_order = masked_indices[np.argsort(unlabeled_dv[mask])]
        conf_order = masked_indices[np.argsort(unlabeled_confs[mask])]
        random_order = masked_indices[np.argsort(rand_perm[mask])]
        dv_inds.append(dv_order[:K])
        conf_inds.append(conf_order[:K])
        random_inds.append(random_order[:K])
        
    all_dv_inds.append(np.concatenate(dv_inds))
    all_conf_inds.append(np.concatenate(conf_inds))
    all_random_inds.append(np.concatenate(random_inds))

In [None]:
df = pd.DataFrame()
df['K'] = K_range
df['SVM'] = [unlabeled_problematic[all_dv_inds[i]].mean() for i in range(len(K_range))]
df['Confidence'] = [unlabeled_problematic[all_conf_inds[i]].mean() for i in range(len(K_range))]
df['Random'] = [unlabeled_problematic[all_random_inds[i]].mean() for i in range(len(K_range))]
df = df.melt('K', value_name='Unlabeled Fraction Problematic', var_name='Method')
display(df)
sns.lineplot(data=df, x='K', y='Unlabeled Fraction Problematic', hue='Method')

In [None]:
indices_file = torch.load(processor.metrics['args']['indices_file'])
u_indices = indices_file['unlabeled_indices']
for name, order in [("dv", all_dv_inds), ('confs', all_conf_inds), ('random', all_random_inds)]:
    os.makedirs(os.path.join(subset_root, name), exist_ok=True)
    for i in range(len(all_dv_inds)):
        subset_indices_dict = {
            'val_indices': indices_file['val_indices'],
            'train_indices': torch.cat([indices_file['train_indices'],u_indices[order[i]]]),
            'classes_to_drop': indices_file['classes_to_drop']
        }
        torch.save(subset_indices_dict, os.path.join(subset_root, name, f'{i}.pt'))


# Load intervention files

In [None]:
K_range = np.arange(25, 125, 25)

name_map = {
    'dv': 'SVM Decision Value',
    'confs': 'Confidence',
    'baseline': 'Base Population',
    'random': 'Random'
}

mask = np.ones(len(test_problematic)) == 1
df = []
for t in ['dv', 'confs', 'baseline', 'random']:
    for i in range(len(K_range)):
        for v in range(5):
            if t == 'baseline':
                is_corrects = torch.tensor(test_correct)
            else:
                path = os.path.join(model_root, f"spurious_supercifar100_subset_{t}_{i}/version_{v}/metrics.pt")
                out = torch.load(path)
                is_corrects = (out['test']['preds'] == out['test']['classes'])
            flagged_acc = is_corrects[(test_pred_correct == 0) & mask].float().mean().item()
            prob_acc = is_corrects[test_problematic & mask].float().mean().item()
            acc = is_corrects[mask].float().mean().item()
            df.append([name_map[t], K_range[i], flagged_acc, prob_acc, acc])
df = pd.DataFrame(df, columns=['Order', 'K', 'Flagged', "Problematic", 'Accuracy'])
display(df)
fig, ax = plt.subplots(1, 3, figsize=(15, 4))
for i, w in enumerate(['Flagged', "Problematic", 'Accuracy']):
    sns.lineplot(data=df, x='K', y=w, hue='Order', ax=ax[i], markers=True, 
                 hue_order=['SVM Score', 'Confidence', 'Random', 'Base Population'],
                 palette=[BLUE, RED, ORANGE, GRAY]
                )
#     sns.scatterplot(data=df, x='K', y=w, hue='Method', ax=ax[i], markers=True)
    handles, labels = ax[i].get_legend_handles_labels()
    ax[i].legend(handles[:4], labels[:4])
plt.tight_layout()
plt.show()


fig, ax = plt.subplots(1, 1, figsize=(8,4))
sns.lineplot(data=df, x='K', y='Problematic', hue='Order', ax=ax, markers=True, 
             hue_order=['SVM Decision Value', 'Confidence', 'Random', 'Base Population'],
             palette=[BLUE, RED, ORANGE, GRAY]
            )
ax.set_xticks(K_range)
ax.set_xlabel('K Added Images')
ax.set_ylabel('Accuracy on Minority Subclass')
# handles, labels = ax.get_legend_handles_labels()
# ax.legend(handles[:4], labels[:4])
plt.savefig('figures/spurious_cifar100/intervention.pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

In [None]:
labels[:4]

In [None]:
for i in np.arange(7):
    for n in ['dv', 'confs']:
        print(n)
        out = torch.load(os.path.join(model_root, f"add_30_{n}/version_{i}/metrics.pt"))
        print(out['args']['indices_file'])
        i_file = torch.load(os.path.join(subset_root, n, f"2.pt"))
        print(np.in1d(train_spuriouses[i_file['train_indices']], classes_to_drop).sum())
        print((out['test']['preds'] == out['test']['classes'])[test_problematic].float().mean().item())
        print((out['test']['preds'] == out['test']['classes'])[~test_problematic].float().mean().item())
        print((out['test']['preds'] == out['test']['classes']).float().mean().item())

In [None]:
ds = pytorch_datasets.SuperCIFAR100(root="/mnt/nfs/home/saachij/datasets/cifar100", train=True)
config = f"dataset_configs/supercifar100.yaml"
hparams, train_labels, train_spuriouses = ds_utils.get_all_beton_labels(config, 'train', "/mnt/cfs/projects/correlated_errors/betons", include_spurious=True)


In [None]:
np.in1d(train_spuriouses,  classes_to_drop).mean()

In [None]:
K_range

In [None]:
K_range