In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import numpy as np 
from sklearn.metrics import accuracy_score
sys.path.append('../src')
import metainfo
from metainfo import img_dataset_list, txt_dataset_list, img_zs_model_list, txt_zs_model_list
from data_utils import load_emb, load_label_encoding, load_label, get_n_classes, get_class_balance
from data_utils import get_fewshot_samples, get_label_shifted_samples, compute_tv
from inference import zs_predict, ot_predict, pm_predict, get_pred_proba, ot_posthoc, pm_posthoc
from label_shift_utils import get_uniformly_resampled_indices, get_dirichlet_marginal, get_resampled_indices, BBSE
from sklearn.linear_model import LogisticRegression

train_emb_path = metainfo.train_emb_path
test_emb_path = metainfo.test_emb_path

model_name = 'RN50' # must be in img_fewshot_dataset_list
dataset_name = 'CIFAR10'
balanced_train = True
n_samples_per_class = 10
seed = 0

# Load data
train_emb = load_emb(dataset_name=dataset_name, model_name=model_name, base_path=train_emb_path)
train_labels = load_label(dataset_name=dataset_name, model_name=model_name, base_path=train_emb_path)

test_emb = load_emb(dataset_name=dataset_name, model_name=model_name, base_path=test_emb_path)
test_labels = load_label(dataset_name=dataset_name, model_name=model_name, base_path=test_emb_path)

label_encodings = load_label_encoding(dataset_name=dataset_name, model_name=model_name, base_path=test_emb_path)
n_classes = get_n_classes(test_labels)

# Sample data for experiment (and to simulate label shift)
train_emb_sampled, train_labels_sampled = get_fewshot_samples(train_emb, train_labels,
                                                               n_classes, n_samples_per_class, seed)
sample_train_cb = get_class_balance(train_labels_sampled, n_classes)
sample_test_cb = get_class_balance(test_labels, n_classes)

tv_naive_cb = compute_tv(sample_train_cb, sample_test_cb)

# Zeroshot prediction
zs_pred = zs_predict(test_emb, label_encodings)
zs_acc = accuracy_score(test_labels, zs_pred)

# OT prediction with naive sample cb
ot_naive_pred = ot_predict(test_emb, label_encodings, n_classes, sample_train_cb)
ot_naive_acc =  accuracy_score(test_labels, ot_naive_pred)

# BBSE Zeroshot
zs_pred_proba_train = get_pred_proba(train_emb_sampled, label_encodings)
zs_pred_proba_test = get_pred_proba(test_emb, label_encodings)
zs_bbse_cb = BBSE(zs_pred_proba_train, train_labels_sampled, zs_pred_proba_test, n_classes)
tv_zs_bbse_cb = compute_tv(zs_bbse_cb, sample_test_cb)

# BBSE + OT
ot_bbse_pred = ot_predict(test_emb, label_encodings, n_classes, zs_bbse_cb)
ot_bbse_acc = accuracy_score(test_labels, ot_bbse_pred)

# BBSE + Reweight
zs_bbse_weight = zs_bbse_cb / sample_train_cb
zs_bbse_pred = (zs_pred_proba_test * zs_bbse_weight).argmax(axis=1)
zs_bbse_acc = accuracy_score(test_labels, zs_bbse_pred)

# BBSE + Prior Matching
pm_bbse_pred = pm_predict(test_emb, label_encodings, n_classes, zs_bbse_cb)
pm_bbse_acc = accuracy_score(test_labels, pm_bbse_pred)

# Linear probing
lm = LogisticRegression(multi_class='multinomial', max_iter=10000)
lm.fit(train_emb_sampled, train_labels_sampled)

lp_pred_proba_train = lm.predict_proba(train_emb_sampled)
lp_pred_proba_test = lm.predict_proba(test_emb)
lp_pred= lp_pred_proba_test.argmax(axis=1)
lp_acc = accuracy_score(test_labels, lp_pred)

# Linear probing BBSE
lp_bbse_cb = BBSE(lp_pred_proba_train, train_labels_sampled, lp_pred_proba_test, n_classes)
tv_lp_bbse_cb = compute_tv(sample_test_cb, lp_bbse_cb)


# Linear probing + Rewieght with LP BBSE
lp_bbse_weight = lp_bbse_cb / sample_train_cb
lp_bbse_pred= (lp_pred_proba_test * lp_bbse_weight).argmax(axis=1)
lp_bbse_acc = accuracy_score(test_labels, lp_bbse_pred)

# Linear probing + Reweight with Prior Matching
lp_bbse_pm_pred = pm_posthoc(pred_proba=lp_pred_proba_test,
                             n_classes=n_classes, class_balance=zs_bbse_cb)
lp_bbse_pm_acc = accuracy_score(test_labels, lp_bbse_pm_pred)

# Linear probing + OT with LP BBSE
lp_bbse_ot_pred = ot_posthoc(lp_pred_proba_test,
                             n_classes=n_classes, class_balance=lp_bbse_cb)
lp_bbse_ot_acc = accuracy_score(test_labels, lp_bbse_ot_pred)

result = {'dataset_name': dataset_name,
        'model_name': model_name,
        'n_samples_per_class': n_samples_per_class,
        'seed': seed,
        'tv_naive_cb': tv_naive_cb,
        'tv_zs_bbse_cb': tv_zs_bbse_cb,
        'tv_lp_bbse_cb': tv_lp_bbse_cb,
        'zs_acc': zs_acc,
        'zs_bbse_acc': zs_bbse_acc,
        'pm_bbse_acc': pm_bbse_acc,
        'ot_naive_acc': ot_naive_acc,
        'ot_bbse_acc': ot_bbse_acc,
        'lp_acc': lp_acc,
        'lp_bbse_acc': lp_bbse_acc,
        'lp_bbse_pm_acc': lp_bbse_pm_acc,
        'lp_bbse_ot_acc': lp_bbse_ot_acc
        }

result



{'dataset_name': 'CIFAR10',
 'model_name': 'RN50',
 'n_samples_per_class': 10,
 'seed': 0,
 'tv_naive_cb': 0.0,
 'tv_zs_bbse_cb': 0.09564712393850108,
 'tv_lp_bbse_cb': 0.102685641067285,
 'zs_acc': 0.6651,
 'zs_bbse_acc': 0.1387,
 'pm_bbse_acc': 0.1957,
 'ot_naive_acc': 0.7678,
 'ot_bbse_acc': 0.7484,
 'lp_acc': 0.686,
 'lp_bbse_acc': 0.4791,
 'lp_bbse_pm_acc': 0.319,
 'lp_bbse_ot_acc': 0.6818}