In [18]:
import numpy as np
import torch
import timm

In [19]:
def uniform_element_selection(wt, s_shape):
    assert wt.dim() == len(s_shape), "Tensors have different number of dimensions"
    ws = wt.clone()
    for dim in range(wt.dim()):
        assert wt.shape[dim] >= s_shape[dim], "Teacher's dimension should not be smaller than student's dimension"  # determine whether teacher is larger than student on this dimension
        if wt.shape[dim] % s_shape[dim] == 0:
            step = wt.shape[dim] // s_shape[dim]
            indices = torch.arange(s_shape[dim]) * step
        else:
            indices = torch.round(torch.linspace(0, wt.shape[dim]-1, s_shape[dim])).long()
        ws = torch.index_select(ws, dim, indices)
    assert ws.shape == s_shape
    return ws

In [20]:
# ViT-T weight selection from ImageNet-21K pretrained ViT-S
teacher = timm.create_model('vit_small_patch16_224_in21k', pretrained=True)
teacher_weights = teacher.state_dict()
from models.vision_transformer import vit_tiny
student = vit_tiny()
student_weights = student.state_dict()
# ConvNeXt-F weight seletion from ImageNet-21K pretrained ConvNeXt-T
# Uncomment below for ConvNeXt
# teacher = timm.create_model('convnext_tiny_in22k', pretrained=True)
# teacher_weights = teacher.state_dict()
# from models.convnext import convnext_femto
# student = convnext_femto()
# student_weights = student.state_dict()

In [21]:
weight_selection = {}
for key in student_weights.keys():
    # We don't perform weight selection on classification head by default. Remove this constraint if target dataset is the same as teacher's.
    if "head" in key:
        continue
    # First-N layer selection is implicitly applied here
    weight_selection[key] = uniform_element_selection(teacher_weights[key], student_weights[key].shape)

In [22]:
torch.save({'model':weight_selection}, "weight_selection.pth")