In [5]:
import pickle
import numpy as np
import torch
from scipy.stats import spearmanr

In [6]:
dataset = 'cifar2'
proj_dim = 4096
method = 'trak'


In [7]:
train_index_path = f'./data/{dataset}/idx-train.pkl'

In [8]:
# Load train index
with open(train_index_path, 'rb')  as handle:
    idx_train = pickle.load(handle)
print("len(idx_train):", len(idx_train))

len(idx_train): 10000


In [9]:
# load lds subset index
mask_array_list = []
for i in range(256):
    with open(f'./data/{dataset}/lds_val/sub-idx-{i}.pkl', 'rb')  as handle:
        sub_idx_train = pickle.load(handle)
    mask_array = np.in1d(idx_train, sub_idx_train)
    mask_array_list.append(mask_array)
lds_mask_array = np.stack(mask_array_list)
print("lds_mask_array.shape:", lds_mask_array.shape)

lds_mask_array.shape: (256, 10000)


In [10]:
# load lds subset model output
loss_array_list = []
for i in range(256):
    
    for seed in [0,1,2]:
        with open(f'./saved/models/{dataset}/lds-val/index-{i}-seed-{seed}/test_CE.pkl', 'rb')  as handle:
            # -log(p/(1-p))
            loss_list = pickle.load(handle)
        margins = np.concatenate(loss_list, axis=-1)

        if (seed == 0):
            loss_array = margins
        else:
            loss_array += margins

        loss_array = loss_array/3

    loss_array_list.append(loss_array) 

lds_loss_array = np.stack(loss_array_list)
print("lds_loss_array.shape:", lds_loss_array.shape)

lds_loss_array.shape: (256, 2000)


In [11]:
print(lds_loss_array[0])


[0.00688803 0.01962721 0.51917636 ... 0.01619437 0.26889518 0.02676756]


In [13]:
# load grad
train_grad_list = []
for seed in [0,1,2]:
    train_grad_seed = np.memmap(
        f'./saved/grad/{dataset}/seed-{seed}/train-{proj_dim}.npy', 
        dtype=np.float32, 
        mode='r',
        shape=(10000, proj_dim)
    )
    train_grad_list.append(train_grad_seed)
train_grad = np.stack(train_grad_list)
print("train_grad.shape:", train_grad.shape)
train_grad = torch.from_numpy(train_grad).cuda()

train_grad.shape: (3, 10000, 4096)


In [14]:
print(train_grad[0])

tensor([[ 1.7249e+01, -1.6733e+01, -1.1416e+01,  ..., -3.8947e+01,
          6.3580e+00, -5.0446e+00],
        [ 1.4643e+00,  2.5474e+00,  2.6933e-02,  ...,  1.9473e+00,
          1.7577e+00,  3.3745e+00],
        [-2.8790e-01, -1.4160e+00, -1.0547e+00,  ..., -3.0139e+00,
          8.6601e-01, -5.2559e-01],
        ...,
        [-5.3286e-01, -2.0897e+00, -2.1248e+00,  ..., -4.1665e+00,
          1.4464e+00, -1.4421e+00],
        [ 6.0728e-01, -1.3522e+01,  8.1734e-01,  ..., -2.7139e+01,
          8.4135e-02, -8.3998e+00],
        [ 1.3433e+00,  2.6064e+00,  2.2589e-01,  ...,  2.1558e+00,
         -3.9850e-01,  2.0823e+00]], device='cuda:0')


In [15]:
test_grad_list = []
for seed in [0,1,2]:
    test_grad_seed = np.memmap(
        f'./saved/grad/{dataset}/seed-{seed}/test-{proj_dim}.npy', 
        dtype=np.float32, 
        mode='r',
        shape=(2000, proj_dim)
    )
    test_grad_list.append(test_grad_seed)
test_grad = np.stack(test_grad_list)
print("test_grad.shape:", test_grad.shape)
test_grad = torch.from_numpy(test_grad).cuda()

test_grad.shape: (3, 2000, 4096)


In [16]:
# load training set error
train_error_list = []

    
for seed in [0,1,2]:
    with open(f'./saved/models/{dataset}/origin/seed-{seed}/train_error.pkl', 'rb')  as handle:
        # 1-p
        error_list = pickle.load(handle)
    error_array = np.concatenate(error_list, axis=-1)

    if (seed == 0):
        train_error = error_array
    else:
        train_error += error_array

train_error = train_error/3
print("train_error.shape:", train_error.shape)

train_error_diag = np.diag(train_error)
print("train_error_diag.shape:", train_error_diag.shape)


train_error.shape: (10000,)
train_error_diag.shape: (10000, 10000)


In [17]:
# calculate the score
lds_list = []
lamb_list = [
        1e-2, 2e-2, 5e-2,
        1e-1, 2e-1, 5e-1,
        1e0, 2e0, 5e0,
        1e1, 2e1, 5e1,
        1e2, 2e2, 5e2,
        1e3, 2e3, 5e3, 
        1e4, 2e4, 5e4, 
        1e5, 2e5, 5e5, 
        1e6, 2e6, 5e6, 
    ]

rs_list = []
ps_list = []
best_scores = None
best_lds = -np.inf

In [18]:


for lamb in lamb_list:

    scores_list = []

    for seed in [0,1,2]:

        train_grad_seed = train_grad[seed]
        test_grad_seed = test_grad[seed]
        kernel = train_grad_seed.T@train_grad_seed
        
        kernel_ = kernel + lamb*torch.eye(kernel.shape[0]).cuda()
        kernel_ = torch.linalg.inv(kernel_)  

        if method == 'trak':
            scores_seed = test_grad_seed@((train_grad_seed@kernel_).T)
            
        scores_seed = scores_seed.cpu().numpy()
        scores_list.append(scores_seed)
    
    scores = np.stack(scores_list)
    scores = scores.mean(axis=0)
    scores = scores@train_error_diag

    margins = lds_loss_array
    infl_est_ = -scores
    preds = lds_mask_array @ infl_est_.T

    # compute lds
    rs = []
    ps = []
    for ind in range(2000):
        r, p = spearmanr(preds[:, ind], margins[:, ind])
        rs.append(r)
        ps.append(p)
    rs, ps = np.array(rs), np.array(ps)
    print(f'Correlation: {rs.mean():.3f} (avg p value {ps.mean():.6f})')
    
    rs_list.append(rs.mean())   
    ps_list.append(ps.mean())

    if rs.mean()>best_lds:
        best_scores = scores
        best_lds = rs.mean()

print(f'best_lds: {best_lds:.3f}')



Correlation: -0.003 (avg p value 0.495328)
Correlation: -0.003 (avg p value 0.495328)
Correlation: -0.003 (avg p value 0.495328)
Correlation: -0.003 (avg p value 0.495208)
Correlation: -0.003 (avg p value 0.495177)
Correlation: -0.003 (avg p value 0.495132)
Correlation: -0.003 (avg p value 0.495458)
Correlation: -0.003 (avg p value 0.495260)
Correlation: -0.003 (avg p value 0.495212)
Correlation: -0.003 (avg p value 0.495535)
Correlation: -0.003 (avg p value 0.495224)
Correlation: -0.003 (avg p value 0.495238)
Correlation: -0.003 (avg p value 0.495324)
Correlation: -0.003 (avg p value 0.495252)
Correlation: -0.003 (avg p value 0.496218)
Correlation: -0.003 (avg p value 0.497177)
Correlation: -0.003 (avg p value 0.498108)
Correlation: -0.003 (avg p value 0.500973)
Correlation: -0.002 (avg p value 0.502486)
Correlation: -0.002 (avg p value 0.502870)
Correlation: -0.002 (avg p value 0.500195)
Correlation: -0.002 (avg p value 0.499914)
Correlation: -0.002 (avg p value 0.501630)
Correlation