In [1]:
import argparse
import logging
import os
from typing import Tuple

import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from torchvision import transforms
import torchvision


#from examples.cifar.pipeline import get_cifar10_dataset
from kronfluence.analyzer import Analyzer, prepare_model
from kronfluence.arguments import FactorArguments
from kronfluence.task import Task
from kronfluence.utils.dataset import DataLoaderKwargs

BATCH_TYPE = Tuple[torch.Tensor, torch.Tensor]

In [2]:
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])
#采用自带的Cifar100
trainset = torchvision.datasets.CIFAR100(root='./data_CIFAR100', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
 
testset = torchvision.datasets.CIFAR100(root='./data_CIFAR100', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
def get_dataset_sample_ids_per_class(class_id, num_samples, test_dataset,
                                     start_index=0):
    sample_list = []
    img_count = 0
    for i in range(len(test_dataset)):
        _, t = test_dataset[i]
        if class_id == t:
            img_count += 1
            if (img_count > start_index) and \
                    (img_count <= start_index + num_samples):
                sample_list.append(i)
            elif img_count > start_index + num_samples:
                break
    return sample_list

def get_dataset_sample_ids(num_samples, test_dataset, num_classes=10,
                           start_index=0):

    sample_dict = {}
    sample_list = []
    if start_index > len(test_dataset) / num_classes:
        logging.warn(f"The variable test_start_index={start_index} is "
                     f"larger than the number of available samples per class.")
    for i in range(num_classes):
        sample_dict[str(i)] = get_dataset_sample_ids_per_class(
            i, num_samples, test_dataset, start_index)
        sample_list[len(sample_list):len(sample_list)] = sample_dict[str(i)]
    return sample_dict, sample_list

sample_dict, sample_list = get_dataset_sample_ids(10, testset, num_classes=10,
                           start_index=0)

all_values = []
for value in sample_dict.values():
    all_values.extend(value)

print(all_values)
print(len(all_values))

[9, 113, 226, 235, 377, 469, 484, 614, 623, 655, 132, 267, 572, 598, 661, 674, 682, 713, 725, 795, 54, 216, 325, 414, 434, 463, 504, 542, 573, 887, 396, 413, 493, 596, 633, 737, 965, 1028, 1055, 1176, 50, 140, 159, 172, 248, 478, 607, 754, 1105, 1312, 334, 441, 459, 657, 1142, 1206, 1219, 1272, 1274, 1314, 51, 83, 200, 204, 258, 268, 381, 404, 468, 535, 64, 198, 303, 329, 347, 376, 458, 586, 648, 686, 27, 28, 116, 161, 319, 417, 476, 532, 554, 644, 52, 114, 170, 340, 455, 575, 594, 864, 932, 997]
100


In [4]:
corrupt_percentage = None
dataset_dir = "./data_CIFAR100"
checkpoint_dir = "./"
query_batch_size = 100
factor_strategy = "ekfac"

In [5]:
class ClassificationTask(Task):
    def compute_train_loss(
        self,
        batch: BATCH_TYPE,
        model: nn.Module,
        sample: bool = False,
    ) -> torch.Tensor:
        inputs, labels = batch
        logits = model(inputs)
        if not sample:
            return F.cross_entropy(logits, labels, reduction="sum")
        with torch.no_grad():
            probs = torch.nn.functional.softmax(logits, dim=-1)
            sampled_labels = torch.multinomial(
                probs,
                num_samples=1,
            ).flatten()
        return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum")

    def compute_measurement(
        self,
        batch: BATCH_TYPE,
        model: nn.Module,
    ) -> torch.Tensor:
        # Copied from: https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py.
        inputs, labels = batch
        logits = model(inputs)

        bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False)
        logits_correct = logits[bindex, labels]

        cloned_logits = logits.clone()
        cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype)

        margins = logits_correct - cloned_logits.logsumexp(dim=-1)
        return -margins.sum()


In [6]:
model=models.mobilenet_v2(pretrained=False)

model.classifier = nn.Sequential( 
    #重新定义特征层，根据需要可以添加自己想要的Linear层
    nn.Dropout(p=0.2, inplace=False),
    nn.Linear(in_features=1280, out_features=100),  #多加几层都没关系
    #nn.LogSoftmax(dim=1)
)

student_save_path='./Dis_resnet50(T)_mobilebetv2(S)_cifar100_epoch10_withif_1.pkl'
model.load_state_dict(torch.load(student_save_path))



<All keys matched successfully>

In [7]:
# Define task and prepare model.
task = ClassificationTask()
model = prepare_model(model, task)

analyzer = Analyzer(
    analysis_name="cifar100",
    model=model,
    task=task,
)
# Configure parameters for DataLoader.
dataloader_kwargs = DataLoaderKwargs(num_workers=1)
analyzer.set_dataloader_kwargs(dataloader_kwargs)

In [8]:

# Compute influence factors.
factor_args = FactorArguments(strategy=factor_strategy)
analyzer.fit_all_factors(
    factors_name=factor_strategy,
    dataset=trainset,
    per_device_batch_size=None,
    factor_args=factor_args,
    overwrite_output_dir=False,
)

In [9]:

# Compute pairwise scores.
analyzer.compute_pairwise_scores(
    scores_name=factor_strategy,
    factors_name=factor_strategy,
    query_dataset=testset,
    #query_indices=list(range(2000)),
    query_indices=all_values,
    train_dataset=trainset,
    per_device_query_batch_size=query_batch_size,
    overwrite_output_dir=False,
)

{'all_modules': tensor([[-8.9190e+02, -8.7845e-04,  7.8836e+05,  ..., -1.4713e+04,
           2.4061e+04, -9.9291e+01],
         [-3.4495e+02,  3.0169e-04,  3.1101e+05,  ...,  5.6217e+03,
           1.2414e+04,  9.1736e+01],
         [ 3.7183e+03, -1.8812e-03,  5.0622e+05,  ...,  2.0623e+04,
           2.0881e+04,  5.2695e+01],
         ...,
         [-5.9469e+02,  7.2639e-04, -2.7661e+04,  ...,  9.9768e+03,
           2.7745e+04, -1.0639e+02],
         [ 3.7473e+03, -1.0781e-03, -2.6129e+04,  ...,  4.4032e+03,
          -4.5703e+04,  1.4691e+02],
         [-6.3104e+03,  5.0711e-03,  7.4715e+03,  ..., -8.9762e+02,
          -1.2028e+04, -1.2531e+03]])}

In [10]:
scores = analyzer.load_pairwise_scores(factor_strategy)["all_modules"]
scores.shape

torch.Size([100, 50000])

In [12]:
import numpy as np
score_test_mean_1 = [] #每个training data将对应的100个test data原始分数的平均分数

for i in range(scores.shape[1]):
    score_test_mean_1.append(np.mean(scores[:, i]))

print(len(score_test_mean_1)) #长度为50000的list

TypeError: mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=NoneType, ), but expected one of:
 * (*, torch.dtype dtype)
 * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)
 * (tuple of names dim, bool keepdim, *, torch.dtype dtype)
