In [1]:
import sys
sys.path.append('D:\\Compute Science\\Machine Learning\\论文\\项目\\FairSPL\\venv_torch')
sys.path.append('D:\\Compute Science\\Machine Learning\\论文\\项目\\FairSPL\\venv_torch\\lib\\site-packages')

In [2]:
import torch
from torch import nn
import numpy as np

In [3]:
# 定义一个简单的模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 3)
        self.fc2 = nn.Linear(3, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)

        return x
    
loss_fn = nn.MSELoss()

In [4]:
# 输出模型的参数
model = MLP()
for param in model.parameters():
    print(param.size())

torch.Size([3, 4])
torch.Size([3])
torch.Size([2, 3])
torch.Size([2])


In [5]:
# 定义数据
data = torch.tensor([1,2,3,4], dtype=torch.float)
label = torch.tensor([5,6], dtype=torch.float)
pred = model(data)
loss = loss_fn(pred, label)

计算雅克比矩阵（一阶导数）

In [6]:
def grad(model, y):
    """ 计算一阶导数.
    Returns: 
        grads, grads[i]: dy / dx_i
    """
    grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True, create_graph=True)
    # for grad in grads:
    #     print(grad.size()) # 可以发现一共 4 个 Tensor，分别为损失函数对四个参数 Tensor（两层，每层都有权重和偏置）的梯度。
        
    grads = torch.cat([x.flatten() for x in grads], dim=0)
    return grads

grads = grad(model, loss)
print(grads.size())

torch.Size([23])


计算 Hessian 矩阵（二阶导数）

In [7]:
# 如果直接传入 model.parameters()，会报错，目前不知道原因
def hess(model, y, grads=None):
    """ 计算二阶导数.
    Returns: 
        he, he[i,j]: d^2y / (dx_i dx_j)
    """
    if grads is None:
        grads = grad(model, y)
        
    total_params = sum(p.numel() for p in model.parameters())
    he = torch.zeros(total_params, total_params)
    
    for i, g in enumerate(grads):
        second_order_grad = grad(model, g)
        he[i, :] = second_order_grad

    return he

he = hess(model, loss)
print(he.size())

torch.Size([23, 23])


In [8]:
# # 计算 hessian 矩阵
# grad = torch.autograd.grad(outputs=loss, inputs=model.parameters(), create_graph=True)
# grad = torch.cat([x.flatten() for x in grad], dim=0)
# total_params = sum(p.numel() for p in model.parameters())
# he2 = torch.zeros(total_params, total_params)
    
# for i, g in enumerate(grad):
#     second_order_grad = torch.autograd.grad(outputs=g, inputs=model.parameters(), retain_graph=True)
#     second_order_grad = torch.cat([x.flatten() for x in second_order_grad], dim=0)
#     he2[i, :] = second_order_grad

实现影响函数（一次性返回所有点的影响函数值）
$$
\mathcal{I}_{\text {up,loss }}\left(z, z_{\text {test }}\right) =-\nabla_{\theta} L\left(z_{\text {test }}, \hat{\theta}\right)^{\top} H_{\hat{\theta}}^{-1} \nabla_{\theta} L(z, \hat{\theta})
$$

In [4]:
# 数据准备
train_z = torch.tensor([[1,2,3,4],[5,6,7,8]], dtype=torch.float)
train_t = torch.tensor([[9,10],[11,12]], dtype=torch.float)

train_set = [(z, t) for z, t in zip(train_z, train_t)]

test_z = torch.tensor([5,6,7,7], dtype=torch.float)
test_t = torch.tensor([7,8], dtype=torch.float)

model = MLP()

In [5]:
# Step1: 计算V第三项
def grad_z2(model, z, t):
    """ Calculates the gradient z. One grad_z should be computed for each
    training sample.
    
    Arguments:
        z: torch tensor, training data points
            e.g. an image sample (batch_size, 3, 256, 256)
        t: torch tensor, training data labels
        model: torch NN, model used to evaluate the dataset
    Returns:
        grad_z: list of torch tensor, containing the gradients
            from model parameters to loss
    """
    model.eval()
    y = model(z)
    loss = loss_fn(y, t)
    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    grads = torch.cat([x.flatten() for x in grads], dim=0)

    return grads

grad = grad_z(model, train_z, train_t)
grad

tensor([  5.8607,   7.7541,   9.6475,  11.5409,  -3.7036,  -5.1888,  -6.6740,
         -8.1592,  22.7662,  30.3566,  37.9470,  45.5374,   1.8934,  -1.4852,
          7.5904,  19.9988, -25.1778,  22.7012,  26.1847, -33.6491,  29.9575,
         -7.0601,  -8.8066], grad_fn=<CatBackward0>)

In [11]:
# 验证是否是每一个样本的loss的梯度的和的均值等于在batch上loss的计算结果
grad1, grad2 = [grad_z(model, z, t) for (z, t) in train_set]
grad.allclose((grad1+grad2)/2)

True

In [15]:
# Step2: 计算前两项，也就是 s_test=v^T H^{-1}
# 辅助函数：快速计算 Hv，其中 H 是 Hessian 矩阵
def hvp(y, model, v): # 计算 y 对 w 的二阶导 H，返回 Hv
    """Multiply the Hessians of y and w by v.
    Uses a backprop-like approach to compute the product between the Hessian
    and another vector efficiently, which even works for large Hessians.
    Example: if: y = 0.5 * w^T A x then hvp(y, w, v) returns and expression
    which evaluates to the same values as (A + A.t) v.

    Arguments:
        y: scalar/tensor, for example the output of the loss function
        w: list of torch tensors, tensors over which the Hessian
            should be constructed
        v: list of torch tensors, same shape as w,
            will be multiplied with the Hessian

    Returns:
        return_grads: list of torch tensors, contains product of Hessian and v.

    """
    # First backprop
    first_grads = torch.autograd.grad(y, model.parameters(), retain_graph=True, create_graph=True)

    # Elementwise products
    elemwise_products = 0
    for grad_elem, v_elem in zip(first_grads, v):
        elemwise_products += torch.sum(grad_elem * v_elem)

    # Second backprop
    return_grads = torch.autograd.grad(elemwise_products, model.parameters(), create_graph=True)

    return return_grads


def s_test(model, test_z, test_t, train_loader, damp=0.01, scale=25.0,
       recursion_depth=5000):
    """s_test can be precomputed for each test point of interest, and then
    multiplied with grad_z to get the desired value for each training point.
    Here, strochastic estimation is used to calculate s_test. s_test is the
    Inverse Hessian Vector Product.

    Arguments:
        test_z: torch tensor, test data points, such as test images
        test_t: torch tensor, contains all test data labels
        model: torch NN, model used to evaluate the dataset
        train_loader: torch Dataloader, can load the training dataset
        damp: float, dampening factor
        scale: float, scaling factor
        recursion_depth: int, number of iterations aka recursion depth
            should be enough so that the value stabilises.

    Returns:
        h_estimate: list of torch tensors, s_test
    """
    v = grad_z(model, test_z, test_t)
    h_estimate = v.clone().detach()
    for i in range(recursion_depth):
        for z, t in train_loader:
            y = model(z)
            loss = loss_fn(y, t)
            hv = hvp(loss, model, h_estimate) 
            h_estimate = [
                _v + (1 - damp) * _h_e - _hv / scale
                for _v, _h_e, _hv in zip(v, h_estimate, hv)]
            break
    return h_estimate

In [16]:
s_test(model, test_z, test_t, train_set, recursion_depth=1)

[tensor([[16.5433, 21.2170, 25.8908, 30.5645],
         [16.4665, 21.0635, 25.6605, 30.2575],
         [15.5208, 19.1720, 22.8233, 26.4746]], grad_fn=<SubBackward0>),
 tensor([18.9172, 18.8404, 17.8947], grad_fn=<SubBackward0>),
 tensor([[37.2072, 36.8563, 37.2114],
         [33.7050, 44.2962, 33.5768]], grad_fn=<SubBackward0>),
 tensor([16.5154, 19.6926], grad_fn=<SubBackward0>)]

In [17]:
0.7720*0.4651

0.3590572

In [None]:
# Step3: 合并所有的结果 计算每一个样本点对test_z的影响值
def calc_influence_function(model, train_z, train_t, test_z, test_t):
    """Calculates the influence function

    Arguments:
        grad_z_vecs: list of torch tensor, containing the gradients
            from model parameters to loss
        e_s_test: list of torch tensor, contains s_test vectors

    Returns:
        influence: list of float, influences of all training data samples
            for one test sample
        harmful: list of float, influences sorted by harmfulness
        helpful: list of float, influences sorted by helpfulness.
    """
    train_dataset_size = len(train_z)
    influences = []
    
    grad_z_vecs = [grad_z(model, z, t) for z, t in zip(train_z, train_t)] 
    e_s_test = s_test(model, test_z, test_t, zip(train_z, train_t))
    
#     grad_z_vecs = torch.cat([x.flatten() for x in grad_z_vecs], dim=0)
#     e_s_test = torch.cat([x.flatten() for x in e_s_test], dim=0)
    
    # 对第 i 个样本
    for i in range(train_dataset_size):
        influence_i = -sum(
            [
                torch.sum(k * j).data.cpu().numpy()
                for k, j in zip(grad_z_vecs[i], e_s_test)
            ]) / train_dataset_size
        influences.append(influence_i)

    harmful = np.argsort(influences)
    helpful = harmful[::-1]

    return influences, harmful.tolist(), helpful.tolist()

In [None]:
calc_influence_function(model, train_x, train_label, test_z, test_t)

In [6]:
# Step1: 计算V第三项
def grad_z2(model, z, t):
    """ Calculates the gradient z. One grad_z should be computed for each
    training sample.
    
    Arguments:
        z: torch tensor, training data points
            e.g. an image sample (batch_size, 3, 256, 256)
        t: torch tensor, training data labels
        model: torch NN, model used to evaluate the dataset
    Returns:
        grad_z: list of torch tensor, containing the gradients
            from model parameters to loss
    """
    model.eval()
    y = model(z)
    loss = loss_fn(y, t)
    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    grads = torch.cat([x.flatten() for x in grads], dim=0)

    return grads

grad = grad_z2(model, train_z, train_t)
grad

tensor([  5.8607,   7.7541,   9.6475,  11.5409,  -3.7036,  -5.1888,  -6.6740,
         -8.1592,  22.7662,  30.3566,  37.9470,  45.5374,   1.8934,  -1.4852,
          7.5904,  19.9988, -25.1778,  22.7012,  26.1847, -33.6491,  29.9575,
         -7.0601,  -8.8066], grad_fn=<CatBackward0>)

In [None]:
# Step1: 计算V第三项
def grad_z2(model, z, t):
    """ Calculates the gradient z. One grad_z should be computed for each
    training sample.
    
    Arguments:
        z: torch tensor, training data points
            e.g. an image sample (batch_size, 3, 256, 256)
        t: torch tensor, training data labels
        model: torch NN, model used to evaluate the dataset
    Returns:
        grad_z: list of torch tensor, containing the gradients
            from model parameters to loss
    """
    model.eval()
    y = model(z)
    loss = loss_fn(y, t)
    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    grads = torch.cat([x.flatten() for x in grads], dim=0)

    return grads

# grad = grad_z2(model, train_z, train_t)
# grad

In [None]:
# Step2: 计算前两项，也就是 s_test=v^T H^{-1}
# 辅助函数：快速计算 Hv，其中 H 是 Hessian 矩阵
def hvp2(y, model, v): # 计算 y 对 w 的二阶导 H，返回 Hv
    """Multiply the Hessians of y and w by v.
    Uses a backprop-like approach to compute the product between the Hessian
    and another vector efficiently, which even works for large Hessians.
    Example: if: y = 0.5 * w^T A x then hvp(y, w, v) returns and expression
    which evaluates to the same values as (A + A.t) v.

    Arguments:
        y: scalar/tensor, for example the output of the loss function
        w: list of torch tensors, tensors over which the Hessian
            should be constructed
        v: list of torch tensors, same shape as w,
            will be multiplied with the Hessian

    Returns:
        return_grads: list of torch tensors, contains product of Hessian and v.

    """
    # First backprop
    first_grads = torch.autograd.grad(y, model.parameters(), retain_graph=True, create_graph=True)
    first_grads = torch.cat([x.flatten() for x in first_grads], dim=0)
    # Elementwise products
    elemwise_products = 0
    for grad_elem, v_elem in zip(first_grads, v):
        elemwise_products += torch.sum(grad_elem * v_elem)
    
    # Second backprop
    return_grads = torch.autograd.grad(elemwise_products, model.parameters(), create_graph=True)

    return return_grads


def s_test2(model, test_z, test_t, train_loader, damp=0.01, scale=25.0,
       recursion_depth=5000):
    """s_test can be precomputed for each test point of interest, and then
    multiplied with grad_z to get the desired value for each training point.
    Here, strochastic estimation is used to calculate s_test. s_test is the
    Inverse Hessian Vector Product.

    Arguments:
        test_z: torch tensor, test data points, such as test images
        test_t: torch tensor, contains all test data labels
        model: torch NN, model used to evaluate the dataset
        train_loader: torch Dataloader, can load the training dataset
        damp: float, dampening factor
        scale: float, scaling factor
        recursion_depth: int, number of iterations aka recursion depth
            should be enough so that the value stabilises.

    Returns:
        h_estimate: list of torch tensors, s_test
    """
    v = grad_z2(model, test_z, test_t)
    h_estimate = v.clone().detach()
    for i in range(recursion_depth):
        for z, t in train_loader:
            y = model(z)
            loss = loss_fn(y, t)
            hv = hvp2(loss, model, h_estimate) 
            h_estimate = [
                _v + (1 - damp) * _h_e - _hv / scale
                for _v, _h_e, _hv in zip(v, h_estimate, hv)]
            break
    return h_estimate

In [None]:
# Step3: 合并所有的结果 计算每一个样本点对test_z的影响值
def calc_influence_function(model, train_set, test_z, test_t):
    """Calculates the influence function

    Arguments:
        grad_z_vecs: list of torch tensor, containing the gradients
            from model parameters to loss
        e_s_test: list of torch tensor, contains s_test vectors

    Returns:
        influence: list of float, influences of all training data samples
            for one test sample
        harmful: list of float, influences sorted by harmfulness
        helpful: list of float, influences sorted by helpfulness.
    """
    train_dataset_size = len(train_set)
    influences = []
    
    grad_z_vecs = [grad_z2(model, z, t) for z, t in train_set] 
    e_s_test = s_test2(model, test_z, test_t, train_set, recursion_depth=1)
    
#     grad_z_vecs = torch.cat([x.flatten() for x in grad_z_vecs], dim=0)
#     e_s_test = torch.cat([x.flatten() for x in e_s_test], dim=0)
    
    # 对第 i 个样本
    for i in range(train_dataset_size):
        influence_i = -sum(
            [
                torch.sum(k * j).data.cpu().numpy()
                for k, j in zip(grad_z_vecs[i], e_s_test)
            ]) / train_dataset_size
        influences.append(influence_i)

    harmful = np.argsort(influences)
    helpful = harmful[::-1]

    return influences, harmful.tolist(), helpful.tolist()