# Init

In [1]:
import torch
import timm, tome
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
imagenet_data = datasets.CIFAR100('/tmp', train = False, download = True,
                                  transform=transforms.Compose([
                                      transforms.Resize(256),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                  ]))
subset = torch.utils.data.Subset(imagenet_data, [i for i in range(10000)])

Files already downloaded and verified


In [3]:
data_loader = DataLoader(subset, batch_size=64, shuffle=False)

In [4]:
# Hàm đánh giá mô hình
def evaluate(model, data_loader, flag):
    total = 0
    correct = 0
    count = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader):
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            count = count + 1
            if flag and count == 1:
                return 'debug'
    return 100 * correct / total

# Default Model

In [5]:
model = timm.create_model("deit_base_patch16_224.fb_in1k", pretrained=True, num_classes= 100)
model.load_state_dict(torch.load('checkpoints/deit_base_patch16_224.fb_in1k_cifar100.bin'))

<All keys matched successfully>

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

print('Done')

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [33]:
evaluate(model, data_loader, 0)

100%|██████████| 10000/10000 [02:21<00:00, 70.44it/s]


83.81

# Tome Model

In [7]:
tome.patch.timm(model)
model.r = [0, 0, 5, 14, 20, 13, 18, 16, 17, 44, 15, 32]


In [8]:
evaluate(model, data_loader, 0)

100%|██████████| 157/157 [00:56<00:00,  2.78it/s]


82.94

# New Tome (using x)

In [9]:
import tome_x_attn

In [10]:
model = timm.create_model("deit_base_patch16_224.fb_in1k", pretrained=True, num_classes= 100)
model.load_state_dict(torch.load('checkpoints/deit_base_patch16_224.fb_in1k_cifar100.bin'))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [41]:
tome_x_attn.patch.timm(model, local_ratio = 1) #[0, 0, 6, 12, 24, 11, 18, 12, 18, 46, 12, 33] - 83.08
model.r = [0, 0, 6, 12, 24, 11, 18, 12, 18, 46, 13, 33]
evaluate(model, data_loader, 0)

100%|██████████| 157/157 [00:44<00:00,  3.51it/s]


83.08

In [8]:
tome_x_attn.patch.timm(model, local_ratio = 2)
model.r = 8
evaluate(model, data_loader, 0)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 2000/2000 [01:35<00:00, 20.91it/s]


83.25

In [10]:
tome_x_attn.patch.timm(model, local_ratio = 2)
model.r = 16
evaluate(model, data_loader, 0)

100%|██████████| 2000/2000 [01:12<00:00, 27.56it/s]


80.5

In [17]:
tome_x_attn.patch.timm(model, local_ratio = 2)
model.r = [0, 0, 7, 6, 13, 12, 19, 18, 24, 22, 20, 18] 
evaluate(model, data_loader, 0)

100%|██████████| 2000/2000 [00:55<00:00, 36.34it/s]


83.65

In [17]:
tome_x_attn.patch.timm(model, local_ratio = 1)
model.r = [0, 0, 5, 14, 20, 13, 18, 16, 17, 44, 15, 32]
evaluate(model, data_loader, 0)

 22%|██▏       | 2232/10000 [01:00<03:30, 36.97it/s]


KeyboardInterrupt: 

In [10]:
from torchpq.clustering import MultiKMeans
import torch
from time import time
start_log_time = time()
n_data = 182
n_kmeans = 1
d_vector = 512
x = torch.randn(n_kmeans, d_vector, n_data, device="cuda:0")
kmeans = MultiKMeans(n_clusters=n_data - 8, distance="euclidean")
labels = kmeans.fit(x)
print("Cluster time:", (time() - start_log_time) * 1000, 'ms')

Cluster time: 4.414796829223633 ms


In [4]:
labels

tensor([[ 30,  41,  38,  49, 143, 157,  33, 119,  61,  92, 121,  16,  89,  36,
         102,  71, 135, 160, 164,  47, 163,  26,  55, 137, 169,  84,  37,  42,
         118,  35,  88,  76,  18,   8, 174, 156,  90,  77, 177, 161,  57,  14,
         122, 147, 168,  67,  34,  60, 130,  98, 152, 120,  79, 150, 133,  43,
          83,  24, 159, 179, 116,  52,  95,  13,  45,   4,  72,   0, 182, 129,
         186,  93,  68, 178,   7,  78, 175,  15,   0, 154,  25, 153,  29,  66,
         129,   5, 113,   2,  23, 189, 183, 109,  17,  51, 105, 180,   9, 117,
          80, 124,  53, 158, 145, 136,  97, 138,  39,  58,  54, 151, 167,  65,
         128,  48,  85, 127, 131,  46, 104,  21,  73,  27, 102,  81,  82, 110,
         148,  69, 106, 117,  22, 126,  32,  19,  99, 108, 142,  64,  28,  86,
         133,   3, 114, 103, 185,  75, 107,  56, 173, 123,  62, 172, 188, 171,
         115,  91, 125, 134,  70, 140, 101, 162,  74,  20, 149,  44, 166, 141,
         139,  12, 132,  87,  59,  78, 170,  94, 187

In [None]:
import torch
from torch_cluster import nearest
from torch_cluster import graclus_cluster
import time
metric = torch.load('metric.pt')

In [None]:
metric.shape

In [None]:
st_time = time.time()
metric = metric.squeeze(0)
with torch.no_grad():
    metric = metric / metric.norm(dim=-1, keepdim=True)
    a, b = metric[..., ::2, :], metric[..., 1::2, :]
scores = a @ b.transpose(-1, -2)
indices = torch.nonzero(scores, as_tuple=False)

# Trích xuất row và col từ indices
row = indices[:, 0]
col = indices[:, 1]

# Lấy trọng số tương ứng từ scores
weight = scores[row, col]
cluster = graclus_cluster(row, col, weight)
print('Time process: ', time.time() - st_time)

In [None]:
metric = metric[1].unsqueeze(0)

In [None]:
st_time = time.time()
r = 25
with torch.no_grad():
    metric = metric / metric.norm(dim=-1, keepdim=True)
    even_indices_a = torch.arange(0, 2*r + 3, 2)
    a = metric[:, even_indices_a, :]
    mask = torch.ones(metric.shape[1], dtype=bool)
    mask[even_indices_a] = False
    b = metric[:, mask, :]

print(a.shape)
print(b.shape)

reshaped_a = a.reshape(a.shape[0] * a.shape[1], a.shape[2])
reshaped_b = b.reshape(b.shape[0] * b.shape[1], b.shape[2])

batch_a = []
batch_b = []
for i in range(a.shape[0]):
    batch_a = batch_a + [i] * a.shape[1]
    batch_b = batch_b + [i] * b.shape[1]

batch_x = torch.tensor(batch_a).cuda()
batch_y = torch.tensor(batch_b).cuda()
cluster = nearest(reshaped_a, reshaped_b, batch_x, batch_y)
scr_idx = torch.vstack([torch.arange(2, a.shape[1]) for _ in range(a.shape[0])]).unsqueeze(-1)
dst_idx = torch.remainder(cluster, b.shape[1]).reshape(b.shape[0], a.shape[1])[:, 2:].unsqueeze(-1)
print('Time process: ', time.time() - st_time)

In [None]:
torch.remainder(cluster, b.shape[1]).reshape(b.shape[0], a.shape[1])

In [None]:
cluster.reshape(b.shape[0], a.shape[1])

In [None]:
cluster

# Kmeans

In [None]:
import kmeans

In [None]:
model = timm.create_model("deit_base_patch16_224.fb_in1k", pretrained=True, num_classes= 100)
model.load_state_dict(torch.load('/content/drive/MyDrive/ViT-pytorch/checkpoints/deit_base_patch16_224.fb_in1k_1.bin'))

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

<All keys matched successfully>

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [None]:
kmeans.patch.timm(model)
model.r = 16

In [None]:
evaluate(model, data_loader, 1)

  0%|          | 0/10000 [00:01<?, ?it/s]

X:  torch.Size([1, 197, 768])
Metric: torch.Size([1, 197, 64])
Attn scores:  torch.Size([1, 197, 197])





RuntimeError: ignored