In [4]:
import sys
import torch
sys.path.append('../')

from src.linear_nn import get_model, load_model, test, load_data
from src.model_eval import train_dataset
from src.model_eval import train_loader
from torch.cpu.amp import autocast
net, criterion, optimizer = get_model()

In [5]:

print(net)
model = load_model(net, filepath='../models/linear_trained_model.pth')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

SimpleNN(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)
cuda:0


SimpleNN(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)

In [6]:
from torch.optim.optimizer import Optimizer

class EKFACDistilled(Optimizer):
    def __init__(self, net, eps):
        self.eps = eps
        self.params = []
        self._fwd_handles = []
        self._bwd_handles = []
        self.net = net
        self.calc_act = True

        for mod in net.modules():
            mod_class = mod.__class__.__name__
            if mod_class in ['Linear']:
                handle = mod.register_forward_pre_hook(self._save_input)
                self._fwd_handles.append(handle)
                handle = mod.register_full_backward_hook(self._save_grad_output)
                self._bwd_handles.append(handle)
                params = [mod.weight]
                if mod.bias is not None:
                    params.append(mod.bias)
                d = {'params': params, 'mod': mod, 'layer_type': mod_class, 'A': [], 'S': []}
                self.params.append(d)
        super(EKFACDistilled, self).__init__(self.params, {})

    def step(self):
        for group in self.param_groups:
            mod = group['mod']
            x = self.state[mod]['x']
            gy = self.state[mod]['gy']

            # Computation of activation cov matrix for batch
            x = x.data.t()

            # Append column of ones to x if bias is not None
            if mod.bias is not None:
                ones = torch.ones_like(x[:1])
                x = torch.cat([x, ones], dim=0)
            
            if self.calc_act:
                # Calculate covariance matrix for activations (A_{l-1})
                group['A'].append(torch.mm(x, x.t()) / float(x.shape[1]))

            # Computation of psuedograd of layer output cov matrix for batch
            gy = gy.data.t()

            # Calculate covariance matrix for layer outputs (S_{l})
            group['S'].append(torch.mm(gy, gy.t()) / float(gy.shape[1]))

    def _save_input(self, mod, i):
        """Saves input of layer to compute covariance."""
        self.state[mod]['x'] = i[0]

    def _save_grad_output(self, mod, grad_input, grad_output):
        """Saves grad on output of layer to compute covariance."""
        self.state[mod]['gy'] = grad_output[0] * grad_output[0].size(0)

In [7]:
import captum._utils.common as common
from captum.influence._core.influence import DataInfluence
from torch.nn import Module
from typing import Any, Dict, List, Union
from torch import Tensor
import torch.distributions as dist
from torch.utils.data import DataLoader, Dataset
import tqdm


class EKFACInfluence(DataInfluence):
    def __init__(
        self,
        module: Module,
        layers: Union[str, List[str]],
        influence_src_dataset: Dataset,
        activation_dir: str,
        model_id: str = "",
        batch_size: int = 1,
        query_batch_size: int = 1,
        cov_batch_size: int = 1,
        **kwargs: Any,
    ) -> None:
        r"""
        Args:
            module (Module): An instance of pytorch model. This model should define all of its
                layers as attributes of the model. The output of the model must be logits for the
                classification task.
            layers (Union[str, List[str]]): A list of layer names for which the influence will
                be computed.
            influence_src_dataset (torch.utils.data.Dataset): Pytorch dataset that is used to create
                a pytorch dataloader to iterate over the dataset. This is the dataset for which we will
                be seeking for influential instances. In most cases this is the training dataset.
            activation_dir (str): Path to the directory where the activation computations will be stored.
            model_id (str): The name/version of the model for which layer activations are being computed.
                Activations will be stored and loaded under the subdirectory with this name if provided.
            batch_size (int): Batch size for the dataloader used to iterate over the influence_src_dataset.
            **kwargs: Any additional arguments that are necessary for specific implementations of the
                'DataInfluence' abstract class.
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.module = module
        self.module.to(self.device)
        self.layers = [layers] if isinstance(layers, str) else layers
        self.influence_src_dataset = influence_src_dataset
        self.activation_dir = activation_dir
        self.model_id = model_id
        self.query_batch_size = query_batch_size

        self.influence_src_dataloader = DataLoader(
            self.influence_src_dataset, batch_size=batch_size, shuffle=False
        )
        self.cov_src_dataloader = DataLoader(
            self.influence_src_dataset, batch_size=cov_batch_size, shuffle=True
        )
            
    def influence(
            self,
            query_dataset: Dataset,
            topk: int = 1,
            eps: float = 1e-5,
            **kwargs: Any,
        ) -> Dict:

        influences: Dict[str, Any] = {}
        query_grads: Dict[str, List[Tensor]] = {}
        influence_src_grads: Dict[str, List[Tensor]] = {}

        query_dataloader = DataLoader(
            query_dataset, batch_size=self.query_batch_size, shuffle=False
        )

        layer_modules = [
            common._get_module_from_name(self.module, layer) for layer in self.layers
        ]

        G_list = self._compute_EKFAC_params()

        criterion = torch.nn.NLLLoss(reduction='sum')
        print(f'Cacultating query gradients on trained model')
        for layer in layer_modules:
            query_grads[layer] = []
            influence_src_grads[layer] = []

        for _, (inputs, targets) in tqdm.tqdm(enumerate(query_dataloader), total=len(query_dataloader)):
            self.module.zero_grad()
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            outputs = self.module(inputs)

            loss = criterion(outputs, targets.view(-1))
            loss.backward()

            for layer in layer_modules:
                Qa = G_list[layer]['Qa']
                Qs = G_list[layer]['Qs']
                eigenval_diag = G_list[layer]['lambda']
                if layer.bias is not None:
                    grad_bias = layer.bias.grad
                    grad_weights = layer.weight.grad
                    grad_bias = grad_bias.reshape(-1, 1)
                    grads = torch.cat((grad_weights, grad_bias), dim=1)
                else:
                    grads = layer.weight.grad

                p1 = torch.matmul(Qs, torch.matmul(grads, torch.t(Qa)))
                p2 = torch.reciprocal(eigenval_diag+eps).reshape(p1.shape[0], -1)
                ihvp = torch.flatten(torch.matmul(torch.t(Qs), torch.matmul((p1/p2), Qa)))
                query_grads[layer].append(ihvp)

        criterion = torch.nn.CrossEntropyLoss(reduction='none')
        print(f'Cacultating training src gradients on trained model')
        for i, (inputs, targets) in tqdm.tqdm(enumerate(self.influence_src_dataloader), total=len(self.influence_src_dataloader)):
            self.module.zero_grad()
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            outputs = self.module(inputs)
            loss = criterion(outputs, targets.view(-1))
            for single_loss in loss:
                single_loss.backward(retain_graph=True)

                for layer in layer_modules:
                    if layer.bias is not None:
                        grad_bias = layer.bias.grad
                        grad_weights = layer.weight.grad
                        grad_bias = grad_bias.reshape(-1, 1)
                        grads = torch.cat([grad_weights, grad_bias], dim=1)
                    else:
                        grads = layer.weight.grad
                    influence_src_grads[layer].append(torch.flatten(grads))

            # Calculate influences by batch to save memory
            for layer in layer_modules:
                query_grad_matrix = torch.stack(query_grads[layer], dim=0)
                influence_src_grad_matrix = torch.stack(influence_src_grads[layer], dim=0)
                tinf = torch.matmul(query_grad_matrix, torch.t(influence_src_grad_matrix))
                tinf = tinf.detach().cpu()
                if layer not in influences:
                    influences[layer] = tinf
                else:
                    influences[layer] = torch.cat((influences[layer], tinf), dim=1)
                influence_src_grads[layer] = []
                
        return influences
            

    def _compute_EKFAC_params(self, n_samples: int = 2):
        ekfac = EKFACDistilled(self.module, 1e-5)
        loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
        for _, (input, _) in tqdm.tqdm(enumerate(self.cov_src_dataloader), total=len(self.cov_src_dataloader)):
            input = input.to(self.device)
            outputs = self.module(input)
            output_probs = torch.softmax(outputs, dim=-1)
            distribution = dist.Categorical(output_probs)
            for _ in range(n_samples):
                samples = distribution.sample()
                loss = loss_fn(outputs, samples)
                loss.backward(retain_graph=True)
                ekfac.step()
                self.module.zero_grad()
                optimizer.zero_grad()
        
        G_list = {}
        # Compute average A and S
        for group in ekfac.param_groups:
            G_list[group['mod']] = {}
            with autocast():
                A = torch.stack(group['A']).mean(dim=0)
                S = torch.stack(group['S']).mean(dim=0)

                print(f'Activation cov matrix shape {A.shape}')
                print(f'Layer output cov matrix shape {S.shape}')
            
                # Compute eigenvalues and eigenvectors of A and S
                la, Qa = torch.linalg.eigh(A)
                ls, Qs = torch.linalg.eigh(S)

                eigenval_diags = torch.outer(la, ls).flatten(start_dim=0)

            G_list[group['mod']]['Qa'] = Qa
            G_list[group['mod']]['Qs'] = Qs
            G_list[group['mod']]['lambda'] = eigenval_diags
            
        return G_list

In [8]:
precond = EKFACDistilled(net, eps=0.001)
influence = EKFACInfluence(net, layers=['fc1', 'fc2'], influence_src_dataset=train_dataset, activation_dir='activations', model_id='test', batch_size=64, cov_batch_size=64)
criterion = torch.nn.CrossEntropyLoss()

for mod in net.modules():
  mod_class = mod.__class__.__name__
  print(mod_class)
  print("**********************")

SimpleNN
**********************
Linear
**********************
Linear
**********************


In [9]:
# G_list = influence._compute_EKFAC_params(n_samples=3)
# print(G_list)

In [10]:
from torch.utils.data import Subset

test_dataset = Subset(train_dataset, range(500))
influences = influence.influence(test_dataset)


100%|██████████| 750/750 [00:12<00:00, 62.08it/s]


Activation cov matrix shape torch.Size([785, 785])
Layer output cov matrix shape torch.Size([256, 256])
Activation cov matrix shape torch.Size([257, 257])
Layer output cov matrix shape torch.Size([10, 10])
Cacultating query gradients on trained model


100%|██████████| 500/500 [00:03<00:00, 144.34it/s]


Cacultating training src gradients on trained model


100%|██████████| 750/750 [05:10<00:00,  2.41it/s]


In [11]:
for layer in influences:
    print(layer)
    print(influences[layer].shape)
    print(influences[layer][0].shape)
    print(torch.argmax(influences[layer][0]))
    print(torch.argmax(influences[layer][1]))


Linear(in_features=784, out_features=256, bias=True)
torch.Size([500, 48000])
torch.Size([48000])
tensor(42815)
tensor(14462)
Linear(in_features=256, out_features=10, bias=True)
torch.Size([500, 48000])
torch.Size([48000])
tensor(20606)
tensor(32383)


In [12]:
from torchvision import datasets
store_mnist='../data'
train_dataset = datasets.MNIST(root=store_mnist, train=True, download=True)
import matplotlib.pyplot as plt

for layer in influences:
    test_influences = influences[layer].detach().clone()
    for i, influence in enumerate(test_influences):
        print(influence[:10])
        print(torch.max(influence))
        top = torch.argmax(influence)
        influence[top] = 0
        count = 0
        while top != i:
            influence[top] = 0
            count += 1
            top = torch.argmax(influence)
        print(f"top influence found in {count} steps")
    break


tensor([ 92.3462, 102.3696, 123.3151, 138.2845, 136.4028, 121.2249, 125.1984,
        132.8768, 138.6667, 141.0410])
tensor(1094.2822)
top influence found in 39560 steps
tensor([ 6.8856, 25.6878, 31.0755, 35.0516, 36.1526, 32.7021, 32.3929, 34.3370,
        34.1743, 38.2346])
tensor(309.4397)
top influence found in 39432 steps
tensor([17.4726, 23.1118, 59.5723, 61.5336, 61.4453, 52.9404, 54.4008, 55.4892,
        57.8857, 58.9593])
tensor(417.6404)
top influence found in 33857 steps
tensor([11.0698, 14.8960, 16.8871, 41.4456, 42.2462, 36.1389, 36.4174, 36.4267,
        38.9356, 48.7795])
tensor(339.2998)
top influence found in 34559 steps
tensor([-1.7399, -0.9295, -1.7136, -2.1980, 18.2154, 20.8756, 20.6055, 19.1498,
        17.8205, 19.0587])
tensor(82.7157)
top influence found in 12321 steps
tensor([-21.4126, -27.4272, -45.2204, -58.9710, -54.0380, -12.9528, -14.3238,
        -16.1053, -21.8222, -22.7923])
tensor(218.2907)


KeyboardInterrupt: 