In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter



In [2]:
from torchvision import datasets, transforms

train_dataset = datasets.MNIST(root='./data', train=True, download=True,
                               transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True,
                              transform=transforms.ToTensor(), )

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [3]:
import torch.nn as nn

class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        x = x.view(-1, self.linear.in_features)
        outputs = torch.sigmoid(self.linear(x))
        return outputs

In [4]:
writer = SummaryWriter('runs/logistic_regression_10_mnist') 

In [5]:
model = LogisticRegression(28*28, 10)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

batch_num = 0
for epoch in range(5):
    model.train() # Set the model to training mode
    for batch_data, batch_labels in train_loader:
        optimizer.zero_grad()
        batch_data = batch_data.view(-1, 28*28)
        output = model(batch_data)
        loss = criterion(output, batch_labels)
        writer.add_scalar('training loss', loss, batch_num)
        loss.backward()
        optimizer.step()
        batch_num += 1

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

val_loss = 0
val_correct = 0
with torch.no_grad():  # Disable gradient calculation during validation
    for data, labels in test_loader:
        # Transfer data to the appropriate device (CPU or GPU)
        data, labels = data.to(device), labels.to(device)

        # Forward pass
        outputs = model(data.view(-1, 28*28))

        # Calculate loss
        loss = criterion(outputs, labels)

        # Update validation metrics (e.g., accuracy)
        val_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        val_correct += (predicted == labels).sum().item()

In [7]:
val_loss /= len(test_loader)
val_accuracy = 100 * val_correct / len(test_dataset)

print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

Validation Loss: 1.5635, Validation Accuracy: 91.06%


In [8]:
def uniform_sample_dataset(dataset, t):
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    
    # Shuffle the indices
    sampler = torch.utils.data.sampler.RandomSampler(dataset, num_samples=t)
    sampled_dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=sampler)
    return sampled_dataloader

In [9]:
sample_loader = uniform_sample_dataset(train_dataset, 10)
for batch_data, batch_labels in sample_loader:
    print(batch_data.shape, batch_labels)

torch.Size([1, 1, 28, 28]) tensor([7])
torch.Size([1, 1, 28, 28]) tensor([1])
torch.Size([1, 1, 28, 28]) tensor([4])
torch.Size([1, 1, 28, 28]) tensor([8])
torch.Size([1, 1, 28, 28]) tensor([9])
torch.Size([1, 1, 28, 28]) tensor([7])
torch.Size([1, 1, 28, 28]) tensor([6])
torch.Size([1, 1, 28, 28]) tensor([2])
torch.Size([1, 1, 28, 28]) tensor([3])
torch.Size([1, 1, 28, 28]) tensor([3])


In [10]:
def calc_criterion_first_order_derivative(data_tensor, label_tensor, criterion, model):
    # Set requires_grad to True for the data tensor to enable gradient computation
    # data_tensor.requires_grad = True
    # label_tensor.requires_grad = True
    model.zero_grad()
    output = model(data_tensor)
    loss = criterion(output, label_tensor)
    # Compute the first-order gradient
    loss.backward(create_graph=True)
    param_grads = [ p.grad.flatten() for p in model.parameters() if p.requires_grad ]
    param_grads = torch.cat(param_grads)
    # print(param_grads)
    # return torch.concat([data_tensor.grad.reshape(1, -1), label_tensor.grad.reshape(1, -1)], dim=1)
    return param_grads

In [11]:
def calc_criterion_second_order_derivative(data_tensor, label_tensor, criterion, model):
    # Set requires_grad to True for the data tensor to enable gradient computation
    # data_tensor.requires_grad = True
    # label_tensor.requires_grad = True
    # loss = criterion(output, label_tensor)
    model.zero_grad()
    output = model(data_tensor)
    loss = criterion(output, label_tensor)
    # Compute the first-order gradient
    loss.backward(create_graph=True)
    first_grads = [ p.grad.flatten() for p in model.parameters() if p.requires_grad ]
    # Compute the second-order gradient
    # first_order_derivative.requires_grad = True
    second_grads = []
    for first_grad in first_grads:
        row = []
        for p in model.parameters():
            sub_matrix = []
            # print(first_grad.shape, p.shape)
            for i in range(first_grad.shape[0]):
                sub_matrix.append(torch.autograd.grad(first_grad[i], p, create_graph=True)[0].flatten())
            sub_matrix = torch.stack(sub_matrix)
            # print(sub_matrix)
            row.append(sub_matrix)
        row = torch.cat(row, dim=1)
        # print(row)
        second_grads.append(row)
    second_grads = torch.cat(second_grads, dim=0)
    # print("sec: ", second_grads)
    # hessian_matrix = torch.autograd.functional.hessian(lambda x: criterion(model(x), label_tensor), data_tensor).reshape(data_tensor.shape[1], -1)
    # hessian_matrix = torch.autograd.functional.hessian(lambda x, y: criterion(model(x), y), (data_tensor, label_tensor))
    return second_grads
    # matrix_list = []
    # for row, var in zip(hessian_matrix, [data_tensor, label_tensor]):
    #     list_row = []
    #     for tensor in row:
    #         if len(var.shape) == 1:
    #             list_row.append(tensor.reshape(1, -1))
    #         else:
    #             list_row.append(tensor.reshape(var.shape[1], -1))
    #     matrix_list.append(list_row)
    # # Concatenate along the first dimension
    # concatenated_hessian = torch.cat([torch.cat(row, dim=1) for row in matrix_list], dim=0)

    # return concatenated_hessian

In [12]:
test_model = torch.nn.Linear(2, 1)
print(test_model.weight, test_model.bias)

Parameter containing:
tensor([[-0.1205,  0.2650]], requires_grad=True) Parameter containing:
tensor([0.2224], requires_grad=True)


In [13]:
data_tensor = torch.tensor([[1.0, 2.0]], requires_grad=True)
label_tensor = torch.tensor([1.0])

2 * (test_model(data_tensor) - label_tensor).item() * data_tensor

tensor([[-0.7360, -1.4720]], grad_fn=<MulBackward0>)

In [14]:
print(calc_criterion_first_order_derivative(data_tensor, label_tensor, torch.nn.MSELoss(), test_model))
print(calc_criterion_second_order_derivative(data_tensor, label_tensor, torch.nn.MSELoss(), test_model))


tensor([-0.7360, -1.4720, -0.7360], grad_fn=<CatBackward0>)
tensor([[2., 4., 2.],
        [4., 8., 4.],
        [2., 4., 2.]], grad_fn=<CatBackward0>)


  return F.mse_loss(input, target, reduction=self.reduction)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [15]:
one_train_dataloader = uniform_sample_dataset(train_dataset, 1)
for batch_data, batch_labels in one_train_dataloader:
    print(calc_criterion_first_order_derivative(batch_data, batch_labels, criterion, model).shape)
    print(calc_criterion_second_order_derivative(batch_data.view(-1, 28*28), batch_labels, criterion, model).shape)

torch.Size([7850])
torch.Size([7850, 7850])


In [16]:
def hvp(y, w, v):
    # First backprop
    first_grads = torch.autograd.grad(y, w, retain_graph=True, create_graph=True)

    first_grads = [g.flatten() for g in first_grads]
    first_grads = torch.cat(first_grads, dim=0)
    # print(first_grads[0], first_grads[1].sum())
    
    # Elementwise products
    # print(f"first_grads: {first_grads}, v: {v}, first_grads * v: {first_grads * v}")
    elemwise_products = torch.sum(first_grads * v)
    # elemwise_products = 0
    # for grad_elem, v_elem in zip(first_grads, v):
    #     grad_elem = grad_elem.flatten()
    #     elemwise_products += torch.sum(grad_elem * v_elem)

    # Second backprop
    return_grads = torch.autograd.grad(elemwise_products, w, create_graph=True)
    return_grads = [g.flatten() for g in return_grads]
    # print(return_grads[0].sum(), return_grads[1].sum())
    return_grads = torch.cat(return_grads, dim=0)

    return return_grads

In [17]:
params = [p for p in test_model.parameters()]
vector = calc_criterion_first_order_derivative(data_tensor, label_tensor, torch.nn.MSELoss(), test_model)
print(vector)
# vector = torch.tensor([-3.0305, -6.0611, -3.0305])
expected = torch.matmul(calc_criterion_second_order_derivative(data_tensor, label_tensor, torch.nn.MSELoss(), test_model), vector.T)
vector._grad_fn = None
# print(vector)
test_model.zero_grad()
actual = hvp(torch.nn.MSELoss()(test_model(data_tensor), label_tensor), params, vector)
print(expected, actual)
assert torch.equal(expected, actual)

tensor([-0.7360, -1.4720, -0.7360], grad_fn=<CatBackward0>)
tensor([ -8.8321, -17.6643,  -8.8321], grad_fn=<MvBackward0>) tensor([ -8.8321, -17.6643,  -8.8321], grad_fn=<CatBackward0>)


  expected = torch.matmul(calc_criterion_second_order_derivative(data_tensor, label_tensor, torch.nn.MSELoss(), test_model), vector.T)


In [18]:
hvp_summary_writer = SummaryWriter('runs/hvp_sum_summary') 

In [19]:
import gc

def ihvp(train_dataset, test_data, test_label, model, criterion, t, r):
    hvp_eval_avg = 0
    test_data = test_data
    vector = calc_criterion_first_order_derivative(test_data, test_label, criterion, model)
    # vector = torch.ones((28*28+1)*10)
    print(vector)
    for i in range(r):
        sampled_train_loader = uniform_sample_dataset(train_dataset, t)
        # Step 1. Initialize the evaluation of the Hessian-vector product
        hvp_eval = vector
        data_number = 0
        for data, label in sampled_train_loader:
            print(f"round={i}, data number={data_number}")
            # data, labels = data.to(device), labels.to(device)
            # Step 2. Compute the second order gradient of the loss w.r.t. the model parameters
            model.zero_grad()
            data_tensor = data.view(-1, 28*28)
            params = [p for p in model.parameters()]
            # params_vector = torch.cat(params)
            hvp_eval._grad_fn = None
            print(f"hvp_eval: {hvp_eval}")
            return_grads = hvp(criterion(model(data_tensor), label), params, hvp_eval)
            print(f"return_grads: {return_grads}")
            # print("return grads: ", return_grads.shape)
            # second_order_grad = calc_criterion_second_order_derivative(data_tensor, label, criterion, model)
            # Step 3. Compute the inner product between the gradient and the Hessian-vector product
            # print(torch.matmul((torch.eye(second_order_grad.shape[0]) - second_order_grad), hvp_eval.T).T)
            # print(torch.eye(second_order_grad.shape[0]), second_order_grad.shape, hvp_eval.shape)
            # print(torch.matmul((torch.eye(second_order_grad.shape[0]) - second_order_grad), hvp_eval.T).shape)
            # print(hvp_eval.sum())
            # product = torch.matmul((torch.eye(second_order_grad.shape[0]) - second_order_grad), hvp_eval.T).T
            # print(product)
            # for p in range(len(hvp_evals)):
            #     print(vectors[p].shape, hvp_evals[p], return_grads[p].shape)
            #     hvp_evals[p] = vectors[p] + hvp_evals[p] - return_grads[p]
            #     hvp_summary_writer.add_scalar(f'hvp_eval_sum_{i}', hvp_evals[i].sum(), data_number + i * t)
            hvp_eval = hvp_eval + vector - return_grads
            print(f"sum: {hvp_eval.sum()}")
            hvp_summary_writer.add_scalar(f'hvp_eval_sum', hvp_eval.sum(), data_number + i * t)
            data_number += 1
            gc.collect()
        hvp_eval_avg = i / (i + 1) * hvp_eval_avg + 1 / (i + 1) * hvp_eval
            
    return hvp_eval_avg
    

In [20]:
test_dataloader = uniform_sample_dataset(test_dataset, 1)

In [21]:
hvp_summary_writer = SummaryWriter('runs/hvp_sum_summary') 
for test_data, test_label in test_dataloader:
    hvp_eval = ihvp(train_dataset, test_data, test_label, model, criterion, 5000, 1)

tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  6.7742e-11,
        -2.9883e-02,  7.0085e-07], grad_fn=<CatBackward0>)
round=0, data number=0
hvp_eval: tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  6.7742e-11,
        -2.9883e-02,  7.0085e-07])
return_grads: tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  4.0013e-12,
        -2.1928e-05,  3.3511e-09], grad_fn=<CatBackward0>)
sum: -11.522053718566895
round=0, data number=1
hvp_eval: tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  1.3148e-10,
        -5.9745e-02,  1.3984e-06])
return_grads: tensor([-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -2.9845e-17,
        -3.0997e-11,  8.1086e-21], grad_fn=<CatBackward0>)
sum: -17.28265380859375
round=0, data number=2
hvp_eval: tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  1.9923e-10,
        -8.9628e-02,  2.0992e-06])
return_grads: tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  1.5726e-10,
        -1.2828e-04,  1.0043e-08], grad_fn=<CatBackward0>)
sum: -2

KeyboardInterrupt: 

In [None]:
def upweighting_loss_influence_function(train_dataset, upweighted_data, upweighted_label, test_data, test_label, model, criterion):
    # Step 1. Compute the Hessian-vector product
    hvp_eval = ihvp(train_dataset, test_data, test_label, model, criterion)
    # Step 2. Compute the influence function
    influence = torch.dot(-hvp_eval, calc_criterion_first_order_derivative(upweighted_data, upweighted_label, criterion, model))
    return influence