In [1]:
import os
import torch
import json
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import grad
from models.vit import ViT
import utils
from PIL import Image
import my_dataset
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import tqdm

In [2]:
size = 32
num_classes = 10
root = "./weights/model-"
ration = 0.3
weight_pth = 80

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
with open("../cifar10/classification.json", 'r', encoding='UTF-8') as f:
    classification = json.loads(f.read())

In [5]:
model = ViT(
    image_size = size,
    patch_size = 4,
    num_classes = num_classes,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)

In [6]:
weight_path = utils.get_weight_path(root, weight_pth)

In [7]:
model.load_state_dict(torch.load(weight_path, map_location=device))
model.eval()

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
    (1): Linear(in_features=48, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): GELU(approximate=none)
    

In [8]:
test_data = utils.read_file("../cifar10/test_data.txt")[:128]
train_data = utils.read_file("../cifar10/train_data.txt")

In [9]:
data_transform = {
        "train": transforms.Compose([
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.Resize(32),
                                    transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "test": transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
test_dataset = my_dataset.MyDataSet_CIFAR_Tracin(images_path=test_data,
                        transform=data_transform["test"])
train_dataset = my_dataset.MyDataSet_CIFAR_Tracin(images_path=train_data,
                        transform=data_transform["train"])

In [10]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size = 1,
                                            collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size = 128,
                                            shuffle=True,
                                            pin_memory=True,
                                            collate_fn=test_dataset.collate_fn)

In [11]:
loss_function = torch.nn.CrossEntropyLoss()
score_list = []
path_list = []

In [12]:
for step, data in enumerate(val_loader):
    images, labels, path = data
    result_train = model(images.to(device))
    loss_train = loss_function(result_train, labels.to(device))
    grad_z_test = grad(loss_train, model.parameters())
    grad_z_test = utils.get_gradient(grad_z_test, model)

In [13]:
train_loader = tqdm(train_loader)
for step, data in enumerate(train_loader):
    images, labels, path = data
    result = model(images.to(device))
    loss_test = loss_function(result, labels.to(device))
    grad_z_train = grad(loss_test, model.parameters())
    grad_z_train = utils.get_gradient(grad_z_train, model)
    score = utils.tracin_get(grad_z_test, grad_z_train)

    score_list.append(float(score))
    path_list.append(path)

100%|██████████| 35000/35000 [08:55<00:00, 65.32it/s]


In [14]:
path_lists = []
score_lists = []
for index in utils.get_sorted_index(score_list, ration):
    path_lists.append(path_list[index])
    score_lists.append(score_list[index])

In [15]:
score_lists

[62.346473693847656,
 55.578590393066406,
 49.598262786865234,
 48.45658493041992,
 47.651119232177734,
 47.41008758544922,
 47.248756408691406,
 47.06611633300781,
 47.03180694580078,
 43.83526611328125,
 42.88144302368164,
 42.845314025878906,
 42.44669723510742,
 41.78238296508789,
 41.76268005371094,
 40.4931526184082,
 40.263790130615234,
 40.099815368652344,
 39.86412048339844,
 39.50143814086914,
 38.440486907958984,
 37.879878997802734,
 37.590110778808594,
 36.8826789855957,
 36.39841079711914,
 36.18220520019531,
 36.16845703125,
 35.90224838256836,
 35.715003967285156,
 35.48558807373047,
 35.29536437988281,
 35.15685272216797,
 35.142372131347656,
 34.97336959838867,
 34.96332931518555,
 34.107872009277344,
 33.93324279785156,
 33.89389419555664,
 33.16020965576172,
 33.06019592285156,
 32.93285369873047,
 32.709877014160156,
 32.4927978515625,
 32.40046691894531,
 31.983518600463867,
 31.96714210510254,
 31.91322898864746,
 31.802387237548828,
 31.796987533569336,
 31.1945

In [16]:
txt_path = "./tracin_file/checkpoint" +  str(weight_pth) + "_" + str(ration) + ".txt"

In [18]:
# with open(txt_path, 'w') as file:
#         file.truncate(0)
#         for data in path_lists:
#             file.write(str(data[0]) + '\n')
#         print("[Success] Write {0} lines of data in file {1}".format(len(path_lists), txt_path))

[Success] Write 10500 lines of data in file ./tracin_file/checkpoint80_0.3.txt
