<h1 align="center"> Channel Adaptive Vision Transformer: How to Use </h1>

This notebook is a step-by-step guide on how to use the Channel Adaptive Vision Transformer (ChAdaViT) model for image classification. The ChAdaViT model is a vision transformer that can adaptively take as input images from different number of channels, and project them into the same embedding space. This is particularly useful when working with multi-channel images, such as medical microscopy or even geopspatial images with multiple modalities.

In [27]:
import torch
import torch.nn as nn
import numpy as np
import hashlib

from src.backbones.vit.chada_vit import ChAdaViT
# 会依次去执行src，backbones，vit中的__init__()文件，但是后两者的init文件在src的init文件中已经会执行过一次，因此就不会再执行了

In [28]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


## Download weights
You can download the model weights under this URL: https://drive.google.com/file/d/1SUfUwerHJlf0vo9mdgM0mRn9TNZkaqXl/view?usp=drive_link   
Make sure to download it on the same directory as this notebook, and give the right permissions.

Enter the path of the weights:

In [29]:
CKPT_PATH = "weights.ckpt"

You can check the hash of the downloaded file here:

In [30]:
def check_hash(file_path, expected_hash):
    md5 = hashlib.md5()
    with open(file_path, "rb") as f:
        while chunk := f.read(4096):
            md5.update(chunk)
    return md5.hexdigest() == expected_hash

In [31]:
check_hash(CKPT_PATH, "e8a24ac58b8e34bdce10e0024d507f2e")

True

## Params

In [32]:
# Params
PATCH_SIZE = 16
EMBED_DIM = 192
RETURN_ALL_TOKENS = False
MAX_NUMBER_CHANNELS = 10

## Load State Dict

In [33]:
model = ChAdaViT(
    patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM,
    return_all_tokens=RETURN_ALL_TOKENS,
    max_number_channels=MAX_NUMBER_CHANNELS,
)

In [36]:
assert (
    CKPT_PATH.endswith(".ckpt")
    or CKPT_PATH.endswith(".pth")
    or CKPT_PATH.endswith(".pt")
) # 确保checkpoint文件后缀名正确
state = torch.load(CKPT_PATH, map_location="cpu")["state_dict"]
# 从checkpoint文件中找出模型的预训练参数，然后保存到state中
for k in list(state.keys()):
    if "encoder" in k:
        state[k.replace("encoder", "backbone")] = state[k]
    if "backbone" in k:
        state[k.replace("backbone.", "")] = state[k]
    del state[k]
# 这里是把state中的键的名字从encoder.xxx或者backbone.xxx改成xxx，相当于把前缀去掉，使得其键的名字转换为能对应当前 model 中参数的名字。
model.load_state_dict(state, strict=False) # state中还有DINO训练时留下来的teacher model的参数，
# 即以momentum_开头的参数，它们是由训练时student的参数通过EMA训练得到的，但是我现在只需要用student model
# 的参数就好了，因此strict=False，这样可以根据model中的实际所需参数名来加载，无需加载teacher类的参数了。
model.to(device)
model.eval()
print(model)
# TODO: 是不是只有Moyen大小的模型，没有petite和grand大小的模型，如果我想调整token维数或者attn头数是不是不行？因为没有符合这样结构的预训练权重集？
# TODO: 预训练好的模型参数是不是只能适配十个channel的，如果我只想做两个channel的是不是不行 ---是的 不行

ChAdaViT(
  (token_learner): TokenLearner(
    (proj): Conv2d(1, 192, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
      )
      (linear1): Linear(in_features=192, out_features=2048, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (linear2): Linear(in_features=2048, out_features=192, bias=True)
      (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.0, inplace=False)
      (dropout2): Dropout(p=0.0, inplace=False)
    )
  )
  (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (head): Identity()
)


## Generate Random Images (Optional)
If you are here, you probably want to test the model with your own images :)      
But anyway, you can use the following code to generate random images with different number of channels to simply check if the model is working as expected.

In [42]:
def generate_data(num_images: int, max_num_channels=MAX_NUMBER_CHANNELS):
    imgs = []
    labels = []
    for i in range(num_images): # 对每一张图做个遍历
        num_channels = np.random.randint(1, max_num_channels + 1) # 随机给某张图赋上channel值
        imgs.append(torch.randn(num_channels, 224, 224)) # 随机给某张图赋上每一个channel的H和W上的pixel值
        labels.append(torch.randint(0, 1, (1,))) # 随机给此图赋上label，这个label在[0,1)中随机取（取整数），且这个label的形状是（1，）
    data = list(zip(imgs, labels)) # zip就是把每一个随机生成的imgs（C,H,W）和它对应的类（1，）打包生成一个元组，然后十个元组生成一个list为data
    return data

In [43]:
data = generate_data(num_images=10, max_num_channels=MAX_NUMBER_CHANNELS) # 假设生成一个图像集，里面有10张channel不一致的图像，channel数的最大值是10
imgs, labels = zip(*data) 
# 这里的意思是把data这个列表中的元素（是一个个元组）解包之后再传给zip即变成zip((img1, label1), (img2, label2), (img3, label3))，这个会把所有第 0 个元素打包成一组（即 imgs），第 1 个元素打包成一组（即 labels）：imgs = (img1, img2, img3)，labels = (label1, label2, label3)
distribution = {}
for img in imgs:
    num_channels = img.shape[0]
    distribution[num_channels] = distribution.get(num_channels, 0) + 1
print(
    f"Number of generated images: {len(imgs)} \n Distribution of number of channels: {distribution}"
)

Number of generated images: 10 
 Distribution of number of channels: {9: 1, 2: 1, 10: 2, 4: 1, 1: 1, 7: 1, 6: 2, 3: 1}


## Prepare Data

One of the key elements of the ChAdaViT model is the ability to adapt to different number of channels. In this section, we will prepare the data to be fed into the model. We will use the `torchvision` library to load the data, and then we will create a custom dataset that will adapt the images to the model.

In [44]:
# 把按照一个batch输入的image整理拆分成一个channel list，label list，以及一个映射表，即每个图分别有几个channel
def collate_images(batch: list):
    """
    Collate a batch of images into a list of channels, a list of labels and a mapping of the number of channels per image.
    
    Args:
        batch (list): A list of tuples of (img, label)

    Return:
        channels_list (torch.Tensor): A tensor of shape (X*num_channels, 1, height, width)
        labels_list (torch.Tensor): A tensor of shape (batch_size, )
        num_channels_list (list): A list of the number of channels per image
    """
    num_channels_list = [] # 一个list，告诉model每个图有几个channel，形状为(batch_size, )
    channels_list = []
    labels_list = [] # 一个list，告诉model每个图的label，形状为(batch_size, )

    # Iterate over the list of images and extract the channels
    for image, label in batch: 
        # batch是一个列表，其中每一个元素是一个元组（imgs（形状为（C,H,W)的tensor），labels（形状为（1，）的tensor））
        labels_list.append(label) # 提取每一个图的label，输入到label_list中
        num_channels = image.shape[0] #提取每一个图的channel大小，输入到num_channels_list中
        num_channels_list.append(num_channels) 

        for channel in range(num_channels):
            channel_image = image[channel, :, :].unsqueeze(0) 
            # 提取某张图的某一个channel，形状为（H,W），然后再前面加一个维度，变成（1，H,W),并加入到channels_list中
            channels_list.append(channel_image)

    channels_list = torch.cat(channels_list, dim=0).unsqueeze(
        1
    )  # Shape: (X*num_channels, 1, height, width)
    # 先按照第0维级联，变成（X*num_channels,height, width），X*num_channels为每张图中有几个channel，然后加和，即这整个batch输入中有多少个channel，然后在第一维上加一个维度，这个维度（值为1）应该是channel数量，即对于batch中的每一个sample而言，输入给模型的是一组一个channel的图的集合，模型初始的patch embedding也是对一个channel的图做的，只不过要对多个一channel的图做多次，而不是像传统vit一样，输入是一个多channel的图，对一个多channel的图做一次patch embedding。但前面的X*num_channels则表示为整个batch中有多少个channel，可以把如今的(X*num_channels, 1, height, width)类比为原始的（B,C,H,W），只不过以前是batch中一个sample对应B中的1个元素，如今是对应X*num_channels中的多个元素（即多个一channel图）
    batched_labels = torch.tensor(labels_list) # 把labels_list转成torch格式

    return channels_list, batched_labels, num_channels_list

In [45]:
collated_batch = collate_images(data) # 注意，data的数据格式是一个列表，里面的每个元素是一个元组，（img，label），img和label都是张量一个是（C,H,W），一个是（1，）

In [46]:
collated_batch[2]

[9, 2, 10, 4, 1, 7, 6, 3, 10, 6]

## Extract Features

In [47]:
@torch.no_grad()
def extract_features(
    model: nn.Module,
    batch: torch.Tensor,
    mixed_channels: bool,
    return_all_tokens: bool,
):
    """
    Forwards a batch of images X and extracts the features from the backbone.

    Args:
        model (nn.Module): The model to forward the images through.
        X (torch.Tensor): The input tensor of shape (batch_size, 1, height, width).
        list_num_channels (list): A list of the number of channels per image.
        index (int): The index of the image to extract the features from.
        mixed_channels (bool): Whether the images have mixed number of channels or not.
        return_all_tokens (bool): Whether to return all tokens or not.

    Returns:
        feats (Dict): A dictionary containing the extracted features.
    """
    model.eval()

    # Overwrite model "mixed_channels" parameter for evaluation on "normal" datasets with uniform channels size
    model.mixed_channels = mixed_channels # 这里的model是来自chada_vit.py的model，初始化时并没有mixed_channels这个属性，这里是从外部加了这个属性

    X, targets, list_num_channels = batch # 对应于上文的channels_list, batched_labels, num_channels_list
    # X = X.to(device, non_blocking=True) # jiabang‘s change,我的dataloader中不会设置pin_memory=True,这里也不要non_blocking=True
    X = X.to(device)
    targets = targets.to(device) # jiabang's change, 我是要做supervised fine-tune的，因此要把targets也放到gpu上
    feats = model(x=X, index=0, list_num_channels=[list_num_channels])
    # index应该是这个batch的起始索引，即从第几个开始是此batch的内容，一般为0
    if not mixed_channels:
        if return_all_tokens:
            # Concatenate feature embeddings per image
            chunks = feats.view(sum(list_num_channels), -1, feats.shape[-1])
            chunks = torch.split(chunks, list_num_channels, dim=0)
            # Concatenate the chunks along the batch dimension
            feats = torch.stack(chunks, dim=0)
        # Assuming tensor is of shape (batch_size, num_tokens, backbone_output_dim)
        feats = feats.flatten(start_dim=1) # 这里是保留batch这个维度，然后把后面维度拉成一个长的向量，比如若feats是（B,C,H,W），那么就会变成（B,CxHxW），这里是（B，1xbackbone_output_dim）

    return feats

In [48]:
extracted_features = extract_features(
    model=model,
    batch=collated_batch,
    mixed_channels=True, # 是否一个batch中的图的channel会存在不一样的情况，我的情况一般为false jiabang's alert
    return_all_tokens=RETURN_ALL_TOKENS, # 一般都不需要把所有token都return出来 我只要cls_token
)

In [49]:
assert extracted_features.shape[0] == len(
    collated_batch[2]
)  # num_embeddings == num_images, even with different number of channels，就是batch size不会变
print(
    f"{extracted_features.shape[0]} embeddings of dim {extracted_features.shape[1]} were extracted."
)

10 embeddings of dim 192 were extracted.
