In [199]:
import cvxpy as cp
import numpy as np
from numpy.random import multivariate_normal
from sklearn.linear_model import LogisticRegression

In [299]:

def generate_data(mu_c, mu_spur, sigma_c, sigma_spur, n):
    # Generate n/2 negative examples, n/2 positive examples from core-spurious distribution.
    assert(n % 2 == 0)
    d_core, d_spur = len(mu_c), len(mu_spur)
    ys = [0] * (n // 2) + [1] * (n // 2)
    x_cores_0 = multivariate_normal(-mu_c, np.eye(d_core) * sigma_c, size=n//2)
    x_cores_1 = multivariate_normal(mu_c, np.eye(d_core) * sigma_c, size=n//2)
    x_spurs_0 = multivariate_normal(-mu_spur, np.eye(d_spur) * sigma_spur, size=n//2)
    x_spurs_1 = multivariate_normal(mu_spur, np.eye(d_spur) * sigma_spur, size=n//2)
    x0 = np.concatenate([x_cores_0, x_spurs_0], axis=1)
    x1 = np.concatenate([x_cores_1, x_spurs_1], axis=1)
    return np.concatenate([x0, x1], axis=0), np.array(ys)


def make_data(mu_c, Mu_spur, sigma_c, sigma_spur, n_per_domain_train, n_per_domain_val):
    # Generate train and val examples for each domain.
    assert(n_per_domain_train % 2 == 0 and n_per_domain_val % 2 == 0)
    train_xs, train_ys = [], []
    val_xs, val_ys = [], []
    num_domains, d_spur = Mu_spur.shape
    d_core = len(mu_c)
    for domain in range(num_domains):
        xs, ys = generate_data(mu_c, Mu_spur[domain], sigma_c, sigma_spur, n_per_domain_train)
        train_xs.append(xs)
        train_ys.append(ys)
        xs, ys = generate_data(mu_c, Mu_spur[domain], sigma_c, sigma_spur, n_per_domain_val)
        val_xs.append(xs)
        val_ys.append(ys)
    train_xs = np.concatenate(train_xs, axis=0)
    train_ys = np.concatenate(train_ys, axis=0)
    val_xs = np.concatenate(val_xs, axis=0)
    val_ys = np.concatenate(val_ys, axis=0)
    return train_xs, train_ys, val_xs, val_ys
    
def generate_mus(d_core, d_spur, D):
    mu_c = multivariate_normal(np.zeros(d_core), np.eye(d_core))
    mu_c = mu_c / np.linalg.norm(mu_c)
    Mu_spur = multivariate_normal(np.zeros(d_spur), np.eye(d_spur) * 1 / d_spur, size=D)
    return mu_c, Mu_spur


In [326]:
# Solve logistic regression.

def get_log_reg_problem(xs, ys, w0):
    n, d = xs.shape
    w = cp.Variable(d)
    lambd = cp.Parameter(nonneg=True)
    log_likelihood = cp.sum(
        cp.multiply(ys, xs @ w) - cp.logistic(xs @ w)
    )
    problem = cp.Problem(cp.Maximize(log_likelihood / n - lambd * cp.norm(w - w0, 2)))
    return problem, lambd, w

def log_reg(xs, ys, w0, reg, solver=cp.ECOS):
    # w_0 is the initialization to regularize towards, reg is the regularization strength (lambda)
    problem, lambd, w = get_log_reg_problem(xs, ys, w0)
    lambd.value = reg
    problem.solve(solver=solver)
    if problem.status != 'optimal':
        print(problem.status, reg)
    return w.value

def get_logits(w, xs):
    return xs @ w

def get_preds(w, xs):
    logits = xs @ w
    return (logits >= 0).astype(np.int32)

def get_acc(preds, ys):
    return np.mean(preds == ys)

In [354]:
# For each sigma_spur, D combination. Sample. Pre-train. Fine-tune or linear probe. Measure accuracy.
def compare_fine_tuning(sigma_c, sigma_spur, d_core, d_spur, D, n_per_domain_train, n_per_domain_val):
    mu_c, Mu_spur = generate_mus(d_core, d_spur, D)
    train_xs, train_ys, val_xs, val_ys = make_data(mu_c, Mu_spur, sigma_c, sigma_spur, n_per_domain_train, n_per_domain_val)
    id_tr_xs, id_tr_ys = train_xs[:n_per_domain_train], train_ys[:n_per_domain_train]
    id_val_xs, id_val_ys = val_xs[:n_per_domain_val], val_ys[:n_per_domain_val]
    xs = [train_xs, val_xs, id_tr_xs, id_val_xs]
    ys = [train_ys, val_ys, id_tr_ys, id_val_ys]
    # Pre-train classifier.
    w0 = log_reg(train_xs, train_ys, np.zeros(d_core + d_spur), reg=0.0)
    pretrain_ood_acc = get_acc(get_preds(w0, val_xs), val_ys)
    pretrain_id_acc = get_acc(get_preds(w0, id_val_xs), id_val_ys)
    print('pretrain_ood_acc: ', pretrain_ood_acc)
    print('pretrain_id_acc: ', pretrain_id_acc)
    # Fine-tune classifier. 
    ft_ood_accs, ft_id_accs = [], []
    prob, lambd, w_param = get_log_reg_problem(id_tr_xs, id_tr_ys, w0)
    regs = np.logspace(0, -3, 20)
    for reg in regs:
        lambd.value = reg
        prob.solve()
        if prob.status != 'optimal':
            print(prob.status, reg)
        ft_ood_accs.append(get_acc(get_preds(w_param.value, val_xs), val_ys))
        ft_id_accs.append(get_acc(get_preds(w_param.value, id_val_xs), id_val_ys))
    print(ft_ood_accs)
    print(ft_id_accs)
    # Linear-probe classifier. 
    lp_ood_accs, lp_id_accs = [], []
    logits = []
    for i in range(len(xs)):
        logits.append(get_logits(w0, xs[i]))
    prob, lambd, head_param = get_log_reg_problem(logits[2], id_tr_ys, np.array([1.0]))
    regs = np.logspace(0, -3, 20)
    for reg in regs:
        lambd.value = reg
        prob.solve()
        if prob.status != 'optimal':
            print(prob.status, reg)
        ft_ood_accs.append(get_acc(get_preds(w_param.value, val_xs), val_ys))
        ft_id_accs.append(get_acc(get_preds(w_param.value, id_val_xs), id_val_ys))
    print(ft_ood_accs)
    print(ft_id_accs)
    

In [301]:
data = make_data(np.array([2.0]), np.array([[5.0], [0.0]]), 1.0, 1.0, 10, 2)

In [355]:
sigma_c = 1.0
sigma_spur = 0.3
d_core, d_spur = 1, 10
D = 20
n_per_domain_train, n_per_domain_val = 1000, 1000
compare_fine_tuning(sigma_c, sigma_spur, d_core, d_spur, D, n_per_domain_train, n_per_domain_val)

pretrain_ood_acc:  0.8482
pretrain_id_acc:  0.798
[0.8482, 0.8482, 0.8482, 0.8483, 0.8428, 0.8312, 0.80775, 0.7823, 0.7537, 0.72395, 0.69655, 0.66915, 0.64355, 0.62255, 0.6058, 0.59385, 0.58515, 0.5778, 0.5747, 0.5731]
[0.798, 0.798, 0.798, 0.8, 0.881, 0.932, 0.963, 0.975, 0.983, 0.984, 0.986, 0.985, 0.985, 0.987, 0.988, 0.988, 0.987, 0.987, 0.985, 0.985]


In [310]:
w0 = np.zeros(d_core + d_spur)
w = log_reg(id_tr_xs, id_train_ys, w0, reg=0.0)

optimal


In [304]:
print(w)

[-1.7980729  -0.51111631  0.99871807  0.8447853   3.15868253 -1.84308014
  0.83643784 -1.5165847   1.99099672  0.49510643  1.26742756]


In [305]:
val_preds = get_preds(w, id_tr_xs)
val_acc = get_acc(val_preds, id_train_ys)
print(val_acc)

0.962


In [306]:
clf = LogisticRegression(random_state=0, C=1000.0, fit_intercept=False).fit(id_tr_xs, id_train_ys)

In [307]:
clf.coef_

array([[-1.85829487, -0.53834539,  1.03173298,  0.86021655,  3.3017722 ,
        -1.91698605,  0.8680179 , -1.57759646,  2.08355504,  0.50766487,
         1.31722613]])

In [308]:
get_acc(clf.predict(id_tr_xs), id_train_ys)

0.962