In [1]:
import torch

In [2]:
print(torch.__version__)

2.0.0


In [None]:
from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-tiny-distilled-patch16-224')
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-tiny-distilled-patch16-224')
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])


In [None]:
from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model('deit3_small_patch16_224.fb_in1k', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)


In [None]:
import timm
model = timm.create_model('deit3_small_patch16_224.fb_in1k', pretrained=True)

from qat.export.utils import replace_module_by_name, fetch_module_by_name
from operations.amm_linear import AMMLinear
import torch
from einops import rearrange
# timm.data

In [None]:
from torchinfo import summary
summary(model, (1, 3, 224, 224))

In [None]:

subvec_len=32
k=16
ncodebooks = { # NAME OF VARIABLE IS NOT GOOD,,,,,,,,   this is just C that is number of codebooks
            "attn.qkv": 384 // subvec_len,
            "attn.proj": 384 // subvec_len,
            "mlp.fc1": 384 // subvec_len,
            "mlp.fc2": 1536 // subvec_len
        }
for i in range(12):
    for name in ncodebooks:
        print(i, name)
        layer = model.blocks[i]
        module = fetch_module_by_name(layer, name)
        amm_linear = AMMLinear(
            ncodebooks[name],
            module.in_features,
            module.out_features,
            module.bias is not None,
            k=k
        )
        amm_linear.inverse_temperature_logit.data.copy_(
            torch.tensor(10)
        )
        print(amm_linear.weight.data.shape)
        print(module.weight.data.shape)
        weight = rearrange(module.weight.data, 'o i -> i o')
        weight = rearrange(weight, '(c v) o -> c v o', c=ncodebooks[name], v=subvec_len)
        amm_linear.weight.data.copy_(weight.data)
        amm_linear.bias.data.copy_(module.bias.data)
        replace_module_by_name(layer, name, amm_linear)

# For feature extractor

In [1]:
import timm
import torch
# model = timm.create_model('deit3_small_patch16_224.fb_in1k', pretrained=True) # deit_small_patch16_224
model = timm.create_model("deit3_base_patch16_224", pretrained=True)


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model.blocks[0].attn.num_heads

12

In [5]:
all_deit_models = timm.list_models('*deit*')
import pprint
pprint.pprint(all_deit_models)

['deit3_base_patch16_224',
 'deit3_base_patch16_384',
 'deit3_huge_patch14_224',
 'deit3_large_patch16_224',
 'deit3_large_patch16_384',
 'deit3_medium_patch16_224',
 'deit3_small_patch16_224',
 'deit3_small_patch16_384',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224']


In [None]:
from torchinfo import summary
summary(model, (1, 3, 224, 224))

In [2]:
collected_inputs = []
def hook_fn(block_idx, module_name):
    def actual_hook(module, input, output):
        B, N, C = input[0].shape
        print(f"Block {block_idx}, Module {module_name}: Input shape is {B, N, C}")
        collected_inputs.append({
            "block_idx": block_idx,
            "module_name": module_name,
            "input": input[0]  # input 是一个 只有一個元素的tuple
        })
    return actual_hook

In [3]:
# 為每個 block 內的 Linear 層添加 hook
for i, block in enumerate(model.blocks):
    block.attn.qkv.register_forward_hook(hook_fn(i, 'qkv'))
    block.attn.proj.register_forward_hook(hook_fn(i, 'proj'))
    block.mlp.fc1.register_forward_hook(hook_fn(i, 'fc1'))
    block.mlp.fc2.register_forward_hook(hook_fn(i, 'fc2'))

In [4]:
x = torch.randn(1, 3, 224, 224)

# 此時模型的前向傳播將觸發 hook，並印出每個 Linear 層的輸出形狀
outputs = model(x)

Block 0, Module qkv: Input shape is (1, 197, 384)
Block 0, Module proj: Input shape is (1, 197, 384)
Block 0, Module fc1: Input shape is (1, 197, 384)
Block 0, Module fc2: Input shape is (1, 197, 1536)
Block 1, Module qkv: Input shape is (1, 197, 384)
Block 1, Module proj: Input shape is (1, 197, 384)
Block 1, Module fc1: Input shape is (1, 197, 384)
Block 1, Module fc2: Input shape is (1, 197, 1536)
Block 2, Module qkv: Input shape is (1, 197, 384)
Block 2, Module proj: Input shape is (1, 197, 384)
Block 2, Module fc1: Input shape is (1, 197, 384)
Block 2, Module fc2: Input shape is (1, 197, 1536)
Block 3, Module qkv: Input shape is (1, 197, 384)
Block 3, Module proj: Input shape is (1, 197, 384)
Block 3, Module fc1: Input shape is (1, 197, 384)
Block 3, Module fc2: Input shape is (1, 197, 1536)
Block 4, Module qkv: Input shape is (1, 197, 384)
Block 4, Module proj: Input shape is (1, 197, 384)
Block 4, Module fc1: Input shape is (1, 197, 384)
Block 4, Module fc2: Input shape is (1, 1

In [6]:
len(collected_inputs)

48

In [8]:
collected_inputs[0]

{'block_idx': 0,
 'module_name': 'qkv',
 'input': tensor([[[ 7.9277e-01, -7.8314e-03, -1.5774e-04,  ..., -1.4133e-02,
           -3.2972e-02, -2.9692e-03],
          [-2.7694e-01,  1.6581e+00, -1.9204e-04,  ...,  2.7576e+00,
            2.2464e-02, -2.3143e-02],
          [-4.0685e-01,  1.4816e+00, -1.7134e-04,  ..., -2.3928e+00,
           -2.7203e-02, -5.5268e-03],
          ...,
          [-5.4404e-01,  3.4143e+00, -1.5433e-04,  ...,  4.9144e+00,
            5.4121e-02,  1.7036e-02],
          [-6.1456e-01,  1.8346e+00, -1.9132e-04,  ...,  7.9029e-01,
            1.0784e-01,  2.8620e-02],
          [-3.7956e-01, -2.8397e+00, -1.6308e-04,  ...,  2.5991e+00,
           -3.4529e-02,  2.1639e-02]]], grad_fn=<NativeLayerNormBackward0>)}

In [9]:
import torch
from networks.LUTDeiT import *
model = create_deit()
model.load_state_dict(torch.load("base.pt"))

<All keys matched successfully>

In [14]:
import numpy as np

# 線性函數: f(x) = 2x
def linear_func(x):
    return 2 * x

# Softmax 函數
def softmax(x):
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x)

# 定義兩個向量 a 和 b，以及一個標量 c
a = np.array([1.0, 2.0])
b = np.array([2.0, 1.0])
c = 0.5

# 檢查線性函數是否滿足線性性質
print("Linear function:")
print("f(c * a + b) =", linear_func(c * a + b))
print("c * f(a) + f(b) =", c * linear_func(a) + linear_func(b))

# 檢查 Softmax 是否滿足線性性質
print("\nSoftmax function:")
print("f(c * a + b) =", softmax(c * a + b))
print("c * f(a) + f(b) =", c * softmax(a) + softmax(b))


Linear function:
f(c * a + b) = [5. 4.]
c * f(a) + f(b) = [5. 4.]

Softmax function:
f(c * a + b) = [0.62245933 0.37754067]
c * f(a) + f(b) = [0.86552929 0.63447071]


In [4]:
# Standard library imports
from argparse import ArgumentParser
import os
# Third-party imports
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

def load_data(batchSize, num_workers):
    batch_size = batchSize
    traindir = os.path.join("/work/u1887834/imagenet/", 'train')
    valdir = os.path.join("/work/u1887834/imagenet/", 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(), 
            transforms.ToTensor(),
            normalize,
        ]))

    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
   
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True, sampler=None)

    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True, sampler=None)
    return train_loader, val_loader

In [5]:
train_loader, _ = load_data(2, 4)

In [10]:
for index, (x, y) in enumerate(train_loader):
    print(x.shape, y.shape)

torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size

KeyboardInterrupt: 