In [1]:
import pickle
import numpy as np
import torch
from scipy.stats import spearmanr
import os

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
dataset = 'cifar10'
proj_dim = 4096
method = 'trak'

In [4]:
train_index_path = f'./data/{dataset}/idx-train.pkl'

In [None]:
# Load train index
with open(train_index_path, 'rb')  as handle:
    idx_train = pickle.load(handle)
print("len(idx_train):", len(idx_train))

In [None]:
# load lds subset index
mask_array_list = []
for i in range(256):
    with open(f'./data/{dataset}/lds-val/sub-idx-{i}.pkl', 'rb')  as handle:
        sub_idx_train = pickle.load(handle)
    mask_array = np.in1d(idx_train, sub_idx_train)
    mask_array_list.append(mask_array)
lds_mask_array = np.stack(mask_array_list)
print("lds_mask_array.shape:", lds_mask_array.shape)

In [None]:
# load lds subset model output
loss_array_list = []
for i in range(256):
    
    for seed in [0,1,2]:
        with open(f'./saved/models/{dataset}/lds-val/index-{i}-seed-{seed}/test_CE.pkl', 'rb')  as handle:
            # -log(p/(1-p))
            loss_list = pickle.load(handle)
        margins = np.concatenate(loss_list, axis=-1)

        if (seed == 0):
            loss_array = margins
        else:
            loss_array += margins

        loss_array = loss_array/3

    loss_array_list.append(loss_array) 

lds_loss_array = np.stack(loss_array_list)
print("lds_loss_array.shape:", lds_loss_array.shape)

In [None]:
print(lds_loss_array[0])


In [None]:
# load grad
train_grad_list = []
for seed in [0,1,2]:
    train_grad_seed = np.memmap(
        f'./saved/grad/{dataset}/seed-{seed}/train-{proj_dim}.npy', 
        dtype=np.float32, 
        mode='r',
        shape=(50000, proj_dim)
    )
    train_grad_list.append(train_grad_seed)
train_grad = np.stack(train_grad_list)
print("train_grad.shape:", train_grad.shape)
train_grad = torch.from_numpy(train_grad).cuda()

In [None]:
print(train_grad[0])

In [None]:
test_grad_list = []
for seed in [0,1,2]:
    test_grad_seed = np.memmap(
        f'./saved/grad/{dataset}/seed-{seed}/test-{proj_dim}.npy', 
        dtype=np.float32, 
        mode='r',
        shape=(10000, proj_dim)
    )
    test_grad_list.append(test_grad_seed)
test_grad = np.stack(test_grad_list)
print("test_grad.shape:", test_grad.shape)
test_grad = torch.from_numpy(test_grad).cuda()

In [None]:
# load training set error
train_error_list = []

    
for seed in [0,1,2]:
    with open(f'./saved/models/{dataset}/origin/seed-{seed}/train_error.pkl', 'rb')  as handle:
        # 1-p
        error_list = pickle.load(handle)
    error_array = np.concatenate(error_list, axis=-1)

    if (seed == 0):
        train_error = error_array
    else:
        train_error += error_array

train_error = train_error/3
print("train_error.shape:", train_error.shape)

train_error_diag = np.diag(train_error)
print("train_error_diag.shape:", train_error_diag.shape)


In [12]:
# calculate the score
lds_list = []
lamb_list = [
        1e-2, 2e-2, 5e-2,
        1e-1, 2e-1, 5e-1,
        1e0, 2e0, 5e0,
        1e1, 2e1, 5e1,
        1e2, 2e2, 5e2,
        1e3, 2e3, 5e3, 
        1e4, 2e4, 5e4, 
        1e5, 2e5, 5e5, 
        1e6, 2e6, 5e6, 
    ]

rs_list = []
ps_list = []
best_scores = None
best_lds = -np.inf

In [None]:
from pathlib import Path
# corr = eval_correlations(scores, '.')
tmp_path = './tmp'
masks_path = Path(tmp_path).joinpath('mask.npy')
masks = torch.as_tensor(np.load(masks_path, mmap_mode='r')).float()
print(masks.shape)

margins_path = Path(tmp_path).joinpath('val_margins.npy')
margins = torch.as_tensor(np.load(margins_path, mmap_mode='r'))
print(margins.shape)

In [None]:
from tqdm import tqdm
for lamb in lamb_list:

    scores_list = []

    for seed in [0,1,2]:

        train_grad_seed = train_grad[seed]
        test_grad_seed = test_grad[seed]
        kernel = train_grad_seed.T@train_grad_seed
        
        kernel_ = kernel + lamb*torch.eye(kernel.shape[0]).cuda()
        kernel_ = torch.linalg.inv(kernel_)  

        if method == 'trak':
            scores_seed = test_grad_seed@((train_grad_seed@kernel_).T)
            
        scores_seed = scores_seed.cpu().numpy()
        scores_list.append(scores_seed)
    
    scores = np.stack(scores_list)
    scores = scores.mean(axis=0)
    # scores = scores@train_error_diag

    val_inds = np.arange(10000)
    preds = masks @ scores.T
    rs = []
    ps = []
    for ind, j in tqdm(enumerate(val_inds)):
        r, p = spearmanr(preds[:, ind], margins[:, j])
        rs.append(r)
        ps.append(p)
    rs, ps = np.array(rs), np.array(ps)
    print(f'Correlation: {rs.mean():.3f} (avg p value {ps.mean():.6f})')

In [None]:
for lamb in lamb_list:

    scores_list = []

    for seed in [0,1,2]:

        train_grad_seed = train_grad[seed]
        test_grad_seed = test_grad[seed]
        kernel = train_grad_seed.T@train_grad_seed
        
        kernel_ = kernel + lamb*torch.eye(kernel.shape[0]).cuda()
        kernel_ = torch.linalg.inv(kernel_)  

        if method == 'trak':
            scores_seed = test_grad_seed@((train_grad_seed@kernel_).T)
            
        scores_seed = scores_seed.cpu().numpy()
        scores_list.append(scores_seed)
    
    scores = np.stack(scores_list)
    scores = scores.mean(axis=0)
    scores = scores@train_error_diag

    margins = lds_loss_array
    infl_est_ = -scores
    preds = lds_mask_array @ infl_est_.T

    # compute lds
    rs = []
    ps = []
    for ind in range(2000):
        r, p = spearmanr(preds[:, ind], margins[:, ind])
        rs.append(r)
        ps.append(p)
    rs, ps = np.array(rs), np.array(ps)
    print(f'Correlation: {rs.mean():.3f} (avg p value {ps.mean():.6f})')
    
    rs_list.append(rs.mean())   
    ps_list.append(ps.mean())

    if rs.mean()>best_lds:
        best_scores = scores
        best_lds = rs.mean()

print(f'best_lds: {best_lds:.3f}')

