In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np

from train_utils import *
from ResNet import *

## Cifar 10

In [3]:
import torchvision
import torchvision.transforms as transforms


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

test_loader = torch.utils.data.DataLoader(testset, **test_kwargs)

Files already downloaded and verified
Files already downloaded and verified


## Load indexes

In [4]:
import numpy as np

scores = np.load('cifar_forg.npy')

In [5]:
from torch.utils.data import Dataset, DataLoader

class IndexData(Dataset):
    def __init__(self, dataset, scores_data, indices=None):
        self.data = dataset
        self.indices = (
            indices if indices is not None else np.arange(len(dataset))
        )
        self.scores_data = scores_data

    def __getitem__(self, index):
        index = self.indices[index]
        data, target = self.data[index]
        return data, target, self.scores_data[index]

    def __len__(self):
        return len(self.indices)

In [6]:
aver_random_acc_list = []
aver_max_ind_ch_T_acc_list = []
aver_random_ind_ch_T_acc_list = []
aver_max_ind_const_T_acc_list = []
aver_random_ind_const_T_acc_list = []

In [7]:
def train_with_special_idx(trainset, test_loader, scores, idx):
    teacher = ResNet18()
    teacher.load_state_dict(torch.load(f'teacher.pt'))
    teacher.to(device)
    freeze_model(teacher)
    
    model = ResNet18()
    model.load_state_dict(torch.load(f'resnet18.pt'))
    model.eval()
    model.to(device)
    
    indexData = IndexData(trainset, scores, idx)
    idx_train_loader = torch.utils.data.DataLoader(indexData,**train_kwargs)
    return train_with_teacher(model, teacher, idx_train_loader, test_loader, IS_CONST_T)

In [None]:
bs_list = [200, 1000, 2000, 3000, 4000, 5000]

num_repeat = 1
global IS_CONST_T
for bs in bs_list:
    print(bs)
    random_ind_ch_T_acc_list = []
    random_ind_const_T_acc_list = []
    for _ in range(num_repeat):
        random_indexes = np.random.choice(scores.argsort(), size=bs, replace=False)
    
        IS_CONST_T = True
        random_idx_loss_const_T, random_idx_acc_const_T = train_with_special_idx(trainset, test_loader, scores, random_indexes)
        
    
        IS_CONST_T = False
        random_idx_loss, random_idx_acc = train_with_special_idx(trainset, test_loader, scores, random_indexes) 
    
        random_ind_ch_T_acc_list.append(random_idx_acc[-1])
        random_ind_const_T_acc_list.append(random_idx_acc_const_T[-1])
        
    aver_random_ind_ch_T_acc_list.append(np.mean(random_ind_ch_T_acc_list))
    aver_random_ind_const_T_acc_list.append(np.mean(random_ind_const_T_acc_list))
    

200




In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.plot(bs_list, random_ind_ch_T_acc_list, label = 'random_ind, T changed')
plt.plot(bs_list, random_ind_const_T_acc_list, label = 'random_ind, T const')


plt.title(f'Зависимость accuracy от размера подвыборки')
plt.xlabel('bs')
plt.ylabel('accuracy')
plt.legend(loc = 'lower right')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.plot(bs_list, max_ind_ch_T_acc_list, label = 'max_ind, T changed')
plt.plot(bs_list, max_ind_const_T_acc_list, label = 'max_ind, T const')


plt.title(f'Зависимость accuracy от размера подвыборки')
plt.xlabel('bs')
plt.ylabel('accuracy')
plt.legend(loc = 'lower right')
plt.show()