In [1]:
import os
import random
import argparse
import json
from warnings import warn
from typing import List, Dict
from pathlib import Path
from functools import partial
from textwrap import wrap
from contextlib import suppress
from statistics import mean, stdev

import numpy as np
from tqdm import tqdm
import wandb
import matplotlib.pyplot as plt
from mpl_toolkits import axes_grid1
from einops import rearrange, reduce
from timm.models import create_model
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import math
from typing import Optional, List, Union

import models
from datasets import build_dataset
from train import get_args_parser, adjust_config, set_seed, set_run_name, count_params

  def vit_tiny_patch16_224(pretrained: bool = False, **kwargs):
  def vit_small_patch16_224(pretrained: bool = False, **kwargs):
  def vit_small_patch8_224(pretrained: bool = False, **kwargs):
  def vit_base_patch16_224(pretrained: bool = False, **kwargs):
  def vit_base_patch8_224(pretrained: bool = False, **kwargs):
  def vit_large_patch16_224(pretrained: bool = False, **kwargs):
  def vit_large_patch14_224(pretrained: bool = False, **kwargs):
  def vit_huge_patch14_224(pretrained: bool = False, **kwargs):
  def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs):
  def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs):
  def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs):
  def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs):
  def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs):
  def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs):
  def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs):
  d

In [2]:
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
parser.add_argument('--compute_attention_average', action='store_true')
parser.add_argument('--compute_attention_cka', action='store_true')
parser.set_defaults(output_dir='results_inference')
args = parser.parse_args(args=[])

args.model = 'topk_deit_tiny_patch16_224.fb_in1k'
args.cfg = 'configs/cub_test3.4p_ft_weakaugs.yaml'
# args.cfg = 'configs/cotton_ft_weakaugs.yaml'
args.device = 'cpu'
args.keep_rate = [0.5]
# args.reduction_loc = [3, 6, 9]
args.train_trainval = True
args.input_size = 224
args.model_depth = 12
# clca
# args.ifa_head = True
# args.clc = True
# args.num_clr = 1
adjust_config(args)
# args.finetune = './results_tiny/{}_topk_deit_tiny_patch16_224.fb_in1k_61.pth'.format(args.dataset_name)
args.finetune = './results_tiny/{}_topk_deit_tiny_patch16_224.fb_in1k_{}_61.pth'.format(args.dataset_name, args.keep_rate[0])
# args.finetune = None

{'dataset_name': 'cub', 'dataset_root_path': '../../data/cub/CUB_200_2011', 'df_train': 'train.csv', 'df_trainval': 'train_val.csv', 'df_val': 'val.csv', 'df_test': 'test_100.csv', 'folder_train': 'images', 'folder_val': 'images', 'folder_test': 'images', 'df_classid_classname': 'classid_classname.csv'}
{'pretrained': True}
{'horizontal_flip': True}


In [3]:
set_seed(args.seed)

dataset_train, args.num_classes = build_dataset(is_train=True, args=args)
dataset_val, _ = build_dataset(is_train=False, args=args)

sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

train_loader = torch.utils.data.DataLoader(
    dataset_train, sampler=sampler_train,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=False,
)
test_loader = torch.utils.data.DataLoader(
    dataset_val, sampler=sampler_val,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=False
)

if args.finetune and args.ifa_head and args.clc:
    args.setting = 'ft_clca'
elif args.finetune and args.ifa_head:
    args.setting = 'ft_cla'
elif args.finetune:
    args.setting = 'ft_bl'
else:
    args.setting = 'fz_bl'

print(f"Creating model: {args.model}")
model = create_model(
    args.model,
    pretrained=True,
    pretrained_cfg=None,
    pretrained_cfg_overlay=None,
    num_classes=1000,
    drop_rate=args.drop,
    drop_path_rate=args.drop_path,
    drop_block_rate=None,
    img_size=args.input_size,
    args = args
)
if args.dataset_name.lower() != "imagenet":
    model.reset_classifier(args.num_classes)
if args.num_clr:
    model.add_clr(args.num_clr)
print(model)

model.to(args.device)

model.eval()

if args.finetune:
    checkpoint = torch.load(args.finetune, map_location='cpu')
    model.load_state_dict(checkpoint['model'], strict=True)

Compose(
    Resize(size=(256, 256), interpolation=bicubic, max_size=None, antialias=True)
    RandomCrop(size=(224, 224), padding=None)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
)
Compose(
    Resize(size=(256, 256), interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
)
Creating model: topk_deit_tiny_patch16_224.fb_in1k
[] []
TopK(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identi

  checkpoint = torch.load(args.finetune, map_location='cpu')


In [None]:
class CLCATransformerEncoder(nn.Module):
    """
    支援 TopK Token Reduction 和 CLC 的 TransformerEncoder
    保持與 nn.TransformerEncoder 相容的介面
    """
    def __init__(
        self,
        encoder_layer: nn.TransformerEncoderLayer,
        num_layers: int,
        norm: Optional[nn.Module] = None,
        # CLCA 特定參數
        keep_rate: Optional[Union[float, List[float]]] = None,
        reduction_loc: Optional[List[int]] = None,
        use_clc: bool = False
    ):
        super().__init__()
        
        # 基本 Transformer 參數
        self.num_layers = num_layers
        self.norm = norm
        
        # CLCA 參數處理
        self.reduction_loc = reduction_loc or []
        self.use_clc = use_clc
        
        # 處理 keep_rate 格式
        if keep_rate is not None:
            if isinstance(keep_rate, float):
                # 單一值，應用到所有 reduction 位置
                self.keep_rates = [keep_rate] * len(self.reduction_loc)
            elif isinstance(keep_rate, list):
                if len(keep_rate) == 1:
                    # [0.5] -> 應用到所有 reduction 位置
                    self.keep_rates = keep_rate * len(self.reduction_loc)
                else:
                    # [0.5, 0.5, 0.5] -> 直接使用
                    assert len(keep_rate) == len(self.reduction_loc), \
                        f"keep_rate 長度 {len(keep_rate)} 必須與 reduction_loc 長度 {len(self.reduction_loc)} 相同"
                    self.keep_rates = keep_rate
        else:
            self.keep_rates = []
        
        # 建立層列表（複製 encoder_layer）
        self.layers = nn.ModuleList([
            self._copy_encoder_layer(encoder_layer) for _ in range(num_layers)
        ])
        
        # 標記哪些層需要 TopK（注意：reduction_loc 是 1-indexed）
        self.layer_configs = []
        keep_rate_idx = 0
        for i in range(num_layers):
            layer_config = {
                'index': i,
                'use_topk': False,
                'keep_rate': 1.0,
                'use_clc': use_clc
            }
            
            # 檢查是否是 reduction 層（轉換為 0-indexed）
            if (i + 1) in self.reduction_loc:
                layer_config['use_topk'] = True
                layer_config['keep_rate'] = self.keep_rates[keep_rate_idx]
                keep_rate_idx += 1
            
            self.layer_configs.append(layer_config)
        
        # 計算 CLC groups（如果啟用）
        if use_clc:
            self.clc_groups = self._compute_clc_groups()
        else:
            self.clc_groups = []
    
    def _copy_encoder_layer(self, layer: nn.TransformerEncoderLayer):
        """複製一個 encoder layer"""
        # 這裡需要深度複製層的參數
        import copy
        return copy.deepcopy(layer)
    
    def _compute_clc_groups(self):
        """
        根據 reduction_loc 計算 CLC groups
        reduction_loc=[3, 6, 9] -> groups=[[0,1,2,3], [4,5,6], [7,8,9], [10,11]]
        """
        groups = []
        start = 0
        
        for end in self.reduction_loc:
            # end 是 0-indexed，包含該層
            groups.append(list(range(start, end + 1)))
            start = end + 1
        
        # 最後一組（如果還有剩餘的層）
        if start < self.num_layers:
            groups.append(list(range(start, self.num_layers)))
        
        return groups
    
    def forward(self, src, mask=None, src_key_padding_mask=None):
        """
        Forward pass
        在 PyTorch 端保持標準行為，TopK 和 CLC 邏輯會在 HLS 轉換時處理
        """
        output = src
        
        for i, layer in enumerate(self.layers):
            output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
        
        if self.norm is not None:
            output = self.norm(output)
        
        return output
    
    def get_config(self):
        """返回 CLCA 配置，供 hls4ml 使用"""
        return {
            'num_layers': self.num_layers,
            'reduction_loc': self.reduction_loc,
            'keep_rates': self.keep_rates,
            'use_clc': self.use_clc,
            'clc_groups': self.clc_groups,
            'layer_configs': self.layer_configs
        }

class Transformer4HLS_CLCA(nn.Module):
    """
    包裝 CLCA TransformerEncoder 的完整模型
    相容原本的 Transformer4HLS 介面
    """
    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_encoder_layers: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "relu",
        norm_first: bool = False,
        device: Optional[str] = None,
        # CLCA 特定參數
        keep_rate: Optional[Union[float, List[float]]] = None,
        reduction_loc: Optional[List[int]] = None,
        use_clc: bool = False
    ):
        super().__init__()
        
        # 儲存所有參數
        self.d_model = d_model
        self.nhead = nhead
        self.num_encoder_layers = num_encoder_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation = activation
        self.norm_first = norm_first
        self.device = device
        
        # CLCA 參數
        self.keep_rate = keep_rate
        self.reduction_loc = reduction_loc
        self.use_clc = use_clc
        
        # 初始化 transformer
        self._init_transformer()
    
    def _init_transformer(self):
        # 創建基本的 encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation=self.activation,
            norm_first=self.norm_first,
            device=self.device
        )
        
        # 創建 norm layer
        norm = nn.LayerNorm(self.d_model)
        
        # 根據是否使用 CLCA 功能決定使用哪種 encoder
        if self.keep_rate is not None or self.use_clc:
            # 使用 CLCA encoder
            self.transformer_encoder = CLCATransformerEncoder(
                encoder_layer=encoder_layer,
                num_layers=self.num_encoder_layers,
                norm=norm,
                keep_rate=self.keep_rate,
                reduction_loc=self.reduction_loc,
                use_clc=self.use_clc
            )
            self.is_clca = True
        else:
            # 使用標準 encoder（向後相容）
            self.transformer_encoder = nn.TransformerEncoder(
                encoder_layer=encoder_layer,
                num_layers=self.num_encoder_layers,
                norm=norm
            )
            self.is_clca = False
    
    def forward(self, src):
        return self.transformer_encoder(src)
    
    def get_clca_config(self):
        """獲取 CLCA 配置（如果有的話）"""
        if self.is_clca:
            return self.transformer_encoder.get_config()
        return None

torch.manual_seed(0)
# 1. 原本的使用方式（不啟用 CLCA）
# model4hls = Transformer4HLS_CLCA(
#     d_model=192,
#     nhead=3,
#     num_encoder_layers=12,
#     dim_feedforward=768,
#     dropout=0,
#     activation='gelu',
#     norm_first=True,
#     device='cpu'
#     # 沒有指定 keep_rate, reduction_loc, use_clc
# )

# 2. 啟用 TopK，使用單一 keep_rate
# model4hls = Transformer4HLS_CLCA(
#     d_model=192,
#     nhead=3,
#     num_encoder_layers=12,
#     dim_feedforward=768,
#     dropout=0,
#     activation='gelu',
#     norm_first=True,
#     device='cpu',
#     keep_rate=0.5,  # 單一值，應用到所有 reduction 位置
#     reduction_loc=[3, 6, 9]  # 在第 4, 7, 10 層後 reduction
# )

# 4. 啟用 TopK + CLC
model4hls = Transformer4HLS_CLCA(
    d_model=192,
    nhead=3,
    num_encoder_layers=12,
    dim_feedforward=768,
    dropout=0,
    activation='gelu',
    norm_first=True,
    device='cpu',
    keep_rate=[0.5],  # 或 0.5
    reduction_loc=[3, 6, 9],
    use_clc=True  # 啟用 CLC
)

model4hls.eval()

for i in range(args.model_depth):
    model4hls.transformer_encoder.layers[i].self_attn.in_proj_weight    = model.blocks[i].attn.qkv.weight
    model4hls.transformer_encoder.layers[i].self_attn.in_proj_bias      = model.blocks[i].attn.qkv.bias
    model4hls.transformer_encoder.layers[i].self_attn.out_proj.weight   = model.blocks[i].attn.proj.weight
    model4hls.transformer_encoder.layers[i].self_attn.out_proj.bias     = model.blocks[i].attn.proj.bias
    model4hls.transformer_encoder.layers[i].linear1.weight              = model.blocks[i].mlp.fc1.weight
    model4hls.transformer_encoder.layers[i].linear1.bias                = model.blocks[i].mlp.fc1.bias
    model4hls.transformer_encoder.layers[i].linear2.weight              = model.blocks[i].mlp.fc2.weight
    model4hls.transformer_encoder.layers[i].linear2.bias                = model.blocks[i].mlp.fc2.bias
    model4hls.transformer_encoder.layers[i].norm1.weight                = model.blocks[i].norm1.weight
    model4hls.transformer_encoder.layers[i].norm1.bias                  = model.blocks[i].norm1.bias
    model4hls.transformer_encoder.layers[i].norm2.weight                = model.blocks[i].norm2.weight
    model4hls.transformer_encoder.layers[i].norm2.bias                  = model.blocks[i].norm2.bias
model4hls.transformer_encoder.norm.weight   = model.norm.weight
model4hls.transformer_encoder.norm.bias     = model.norm.bias

# torch.save(model4hls, './model4hls_{}.pth'.format(args.input_size))

In [11]:
from torch.fx import symbolic_trace

traced_model = symbolic_trace(model4hls)

for node in traced_model.graph.nodes:
    print(node)

src
transformer_encoder_layers_0
transformer_encoder_layers_1
transformer_encoder_layers_2
transformer_encoder_layers_3
transformer_encoder_layers_4
transformer_encoder_layers_5
transformer_encoder_layers_6
transformer_encoder_layers_7
transformer_encoder_layers_8
transformer_encoder_layers_9
transformer_encoder_layers_10
transformer_encoder_layers_11
transformer_encoder_norm
output


In [12]:
for idx in range(1):
    random_tensor = torch.randn(1, 3, args.input_size, args.input_size)
    # 將 random_tensor 移動到與模型相同的設備
    random_tensor = random_tensor.to(args.device)
    model4hls.to(args.device)
    # print(random_tensor)
    with torch.no_grad():
        x = model.patch_embed(random_tensor)
        x = model._pos_embed(x)
        x = model.patch_drop(x)
        x = model.norm_pre(x)
        print('Input shape of encoders = {}'.format(x.shape))
        out = x
        out2 = x
        # out, left_token, sample_idx, compl = model.blocks[0](x)
        # out2 = model4hls.transformer_encoder.layers[0](x.permute(1, 0, 2))
        for i, blk in enumerate(model.blocks):
            # print('Processing block {}'.format(i))
            # out, left_token, sample_idx, compl = blk(out) # for evit
            out, left_token, sample_idx = blk(out) # for topk 
        out = model.norm(out)
        out2 = model4hls(out2.permute(1, 0, 2))
        out2 = out2.permute(1, 0, 2)
        print(out.shape)
        print(out2.shape)
        print(out)
        print(out2)
        difference = (out - out2).max()
        
        print('Difference between pytorch model and model4hls = {}'.format(difference))

Input shape of encoders = torch.Size([1, 197, 192])
torch.Size([1, 197, 192])
torch.Size([1, 197, 192])
tensor([[[ 0.6368, -0.9250, -2.5562,  ...,  0.0086, -1.8008, -0.8929],
         [-0.7014,  1.5679,  1.9164,  ...,  2.0516,  1.2177,  0.7111],
         [-1.0806,  1.3600,  1.1683,  ...,  2.0420,  0.7527,  0.5741],
         ...,
         [-0.0988,  1.1558,  3.5726,  ...,  3.5264,  0.6754, -0.0927],
         [-0.7743,  0.0852,  2.0394,  ...,  1.2179,  3.1379,  0.9270],
         [-0.2831,  0.5343,  2.6471,  ...,  3.0796,  1.0631,  0.4195]]])
tensor([[[ 0.6368, -0.9249, -2.5562,  ...,  0.0086, -1.8008, -0.8929],
         [-0.7015,  1.5679,  1.9164,  ...,  2.0516,  1.2178,  0.7111],
         [-1.0807,  1.3600,  1.1682,  ...,  2.0419,  0.7527,  0.5741],
         ...,
         [-0.0989,  1.1559,  3.5725,  ...,  3.5264,  0.6754, -0.0927],
         [-0.7743,  0.0852,  2.0395,  ...,  1.2180,  3.1378,  0.9270],
         [-0.2833,  0.5343,  2.6471,  ...,  3.0796,  1.0632,  0.4194]]])
Difference b