In [1]:
import pickle
import numpy as np
import torch
from scipy.stats import spearmanr
import os

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [7]:
dataset = 'cifar2'
dataset_type = 'test'

method = 'das'
proj_dim = 4096
selected_timesteps = 10

In [5]:
train_index_path = f'./data/idx-train.pkl'
# Load train index
with open(train_index_path, 'rb')  as handle:
    idx_train = pickle.load(handle)
print("len(idx_train):", len(idx_train))

len(idx_train): 5000


In [6]:
# load lds subset index
mask_array_list = []
for i in range(256):
    with open(f'./data/lds-val/sub-idx-{i}.pkl', 'rb')  as handle:
        sub_idx_train = pickle.load(handle)
    mask_array = np.in1d(idx_train, sub_idx_train)
    mask_array_list.append(mask_array)
lds_mask_array = np.stack(mask_array_list)
print("lds_mask_array.shape:", lds_mask_array.shape)

lds_mask_array.shape: (256, 5000)


In [None]:
# load lds subset model output
loss_array_list = []
for i in range(256):
    
    for seed in [0,1,2]:
        for e_seed in [0,1,2]:
            with open(f'./saved/errors/lds-val/model-{i}-{seed}/{dataset_type}-es{e_seed}.pkl', 'rb')  as handle:
                # MSE loss
                loss_list = pickle.load(handle)
            margins = np.concatenate(loss_list, axis=-1)

            if (seed == 0):
                loss_array = margins
            else:
                loss_array += margins

    loss_array = loss_array/3*3

    loss_array_list.append(loss_array) 

lds_loss_array = np.stack(loss_array_list)
print("lds_loss_array.shape:", lds_loss_array.shape)

lds_testset_correctness = lds_loss_array.mean(axis=1)
print("lds_testset_correctness.shape:", lds_testset_correctness.shape)

In [None]:
# load grad
train_grad_list = []
for seed in [0,1,2]:
    train_grad_seed = np.memmap(
        f'./saved/grad/model-{seed}/ddpm-{method}-train-t{selected_timesteps}-d{proj_dim}-es0.npy', 
        dtype=np.float32, 
        mode='r',
        shape=(5000, proj_dim)
    )
    train_grad_list.append(train_grad_seed)
train_grad = np.stack(train_grad_list)
print("train_grad.shape:", train_grad.shape)
train_grad = torch.from_numpy(train_grad).cuda()

In [None]:
test_grad_list = []
for seed in [0,1,2]:
    test_grad_seed = np.memmap(
        f'./saved/grad/model-{seed}/ddpm-{method}-{dataset_type}-t{selected_timesteps}-d{proj_dim}-es0.npy', 
        dtype=np.float32, 
        mode='r',
        shape=(1000, proj_dim)
    )
    test_grad_list.append(test_grad_seed)
test_grad = np.stack(test_grad_list)
print("test_grad.shape:", test_grad.shape)
test_grad = torch.from_numpy(test_grad).cuda()

In [None]:
# load training set error
train_error_list = []

for seed in [0,1,2]:
    with open(f'./saved/errors/model-{seed}/train-es0.pkl', 'rb')  as handle:
        # MSE loss
        error_list = pickle.load(handle)
    error_array = np.concatenate(error_list, axis=-1)
    error_array = error_array.sqrt()
    
    if (seed == 0):
        train_error = error_array
    else:
        train_error += error_array

train_error = train_error/3
print("train_error.shape:", train_error.shape)

train_error_diag = np.diag(train_error)
print("train_error_diag.shape:", train_error_diag.shape)

In [None]:
# calculate the score
lds_list = []
lamb_list = [
        1e-2, 2e-2, 5e-2,
        1e-1, 2e-1, 5e-1,
        1e0, 2e0, 5e0,
        1e1, 2e1, 5e1,
        1e2, 2e2, 5e2,
        1e3, 2e3, 5e3, 
        1e4, 2e4, 5e4, 
        1e5, 2e5, 5e5, 
        1e6, 2e6, 5e6, 
    ]

rs_list = []
ps_list = []
best_scores = None
best_lds = -np.inf

In [None]:
def get_xtx(grads):
    proj_dim = grads.shape[1]
    result = torch.zeros(
        proj_dim, proj_dim, dtype=torch.float16, device='cuda'
    )
    blocks = torch.split(grads, split_size_or_sections=20000, dim=0)

    for block in blocks:
        result += block.T @ block

    return result


def get_xtx_inv(xtx, lambda_reg):
    xtx_reg = xtx + lambda_reg * torch.eye(
        xtx.size(0), device=xtx.device, dtype=xtx.dtype
    )
    xtx_inv = torch.linalg.inv(xtx_reg.to(torch.float32))

    xtx_inv /= xtx_inv.abs().mean()

    return xtx_inv.to(torch.float16)


def get_A_B(A, B, batch_size=20000):

    blocks = torch.split(A, split_size_or_sections=batch_size, dim=0)
    result = torch.empty(
        (A.shape[0], B.shape[1]), dtype=torch.float16, device=A.device
    )

    for i, block in enumerate(blocks):
        start = i * batch_size
        end = min(A.shape[0], (i + 1) * batch_size)
        result[start:end] = block @ B
    
    return result

In [None]:
from tqdm import tqdm

for lamb in lamb_list:

    scores_list = []

    for seed in [0,1,2]:

        train_grad_seed = train_grad[seed]
        test_grad_seed = test_grad[seed]
        kernel = get_xtx(train_grad_seed)
        kernel_inv = get_xtx_inv(kernel, lamb)

        if method == 'dtrak':
            features = get_A_B(train_grad_seed, kernel_inv)
            scores_seed = get_A_B(test_grad_seed, features.T)

        elif method == 'das1':
            features = get_A_B(train_grad_seed, kernel_inv)
            scores_seed = get_A_B(test_grad_seed, features.T)
            scores_seed = scores_seed@train_error_diag
            scores_seed = scores_seed**2

        elif method == 'das0':
            features = get_A_B(train_grad_seed, kernel_inv)
            hat_matrix = get_A_B(train_grad_seed, features.T)
            hat_value = torch.diag(hat_matrix)
            features = features/(1-hat_value)
            scores_seed = get_A_B(test_grad_seed, features.T)
            scores_seed = scores_seed@train_error_diag
            scores_seed = scores_seed**2

        scores_seed = scores_seed.cpu().numpy()
        scores_list.append(scores_seed)
    
    scores = np.stack(scores_list)
    scores = scores.mean(axis=0)

    margins = lds_testset_correctness
    infl_est_ = -scores
    preds = lds_mask_array @ infl_est_.T

    # compute lds score
    rs = []
    ps = []

    for ind in range(1000):
        r, p = spearmanr(preds[:, ind], margins[:, ind])
        rs.append(r)
        ps.append(p)
    rs, ps = np.array(rs), np.array(ps)
    print(f'Correlation: {rs.mean():.3f} (avg p value {ps.mean():.6f})')

    rs_list.append(rs.mean())   
    ps_list.append(ps.mean())

    if rs.mean()>best_lds:
        best_scores = scores
        best_lds = rs.mean()

print(f'Best score: {best_lds:.3f}')