In [1]:
from argparse import Namespace
from torch.nn import ModuleDict
from model.DeepVSLNet_cbkd import TeacherVSLNetCBDK, build_optimizer_and_scheduler
from utils.cbkd_helpers import prune_block4, prune_block3, prune_block2
from utils.cbkd_config import CBKDConfig

from copy import deepcopy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def count_parameters(teacher, student):
    def num_params(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    teacher_total = num_params(teacher)
    student_total = num_params(student)

    print(f"Teacher total parameters: {teacher_total:,}")
    print(f"Student total parameters: {student_total:,}")

    reduction = 100 * (1 - student_total / teacher_total)
    print(f"Parameter reduction: {reduction:.2f}%")

In [3]:
configs = Namespace(
    video_feature_dim=256,
    dim=256,
    film_mode="inside_encoder:multi",
    drop_rate=0,
    word_size=300,
    char_size=1000,
    word_dim=300,
    char_dim=50,
    word_vectors=None,
    num_heads=8,
    max_pos_len=128,
    predictor="glove",
)

cbkd_config = CBKDConfig()

model = TeacherVSLNetCBDK(configs=configs, word_vectors=None)
student = deepcopy(model)

In [4]:
count_parameters(model.block1, model.block1)

Teacher total parameters: 323,548
Student total parameters: 323,548
Parameter reduction: 0.00%


In [5]:
student_2 = deepcopy(model.block2)
pruned_block2 = prune_block2(student_2["feature_encoder"], keep_ratio_ds=cbkd_config.keep_ratio_block2_ds, keep_ratio_attn=cbkd_config.keep_ratio_block2_attn)
pruned_block2 = ModuleDict({"feature_encoder": pruned_block2})
count_parameters(teacher=model.block2, student=pruned_block2)

Teacher total parameters: 964,096
Student total parameters: 723,564
Parameter reduction: 24.95%


In [6]:
student_3 = deepcopy(model.block3)
pruned_block3 = prune_block3(teacher_block3=student_3, keep_ratio_cqa=cbkd_config.keep_ratio_block3_cqa, keep_ratio_concat=cbkd_config.keep_ratio_block3_concat)
count_parameters(model.block3, pruned_block3)

Teacher total parameters: 395,009
Student total parameters: 198,613
Parameter reduction: 49.72%


In [7]:
student_4 = deepcopy(model.block4)
pruned_block4 = prune_block4(teacher_block4=student_4, keep_ratio_enc=cbkd_config.keep_ratio_block4_enc, keep_ratio_pred=cbkd_config.keep_ratio_block4_pred)
count_parameters(teacher=model.block4, student=pruned_block4)

Teacher total parameters: 1,228,290
Student total parameters: 502,436
Parameter reduction: 59.09%


In [8]:
pruned_block4

ModuleDict(
  (predictor): ConditionedPredictor(
    (encoder): FeatureEncoder(
      (pos_embedding): PositionalEmbedding(
        (position_embeddings): Embedding(128, 256)
      )
      (conv_block): PrunedDSWrapper(
        (down): Linear(in_features=256, out_features=12, bias=True)
        (inner): DepthwiseSeparableConvBlock(
          (depthwise_separable_conv): ModuleList(
            (0-3): 4 x Sequential(
              (0): Conv1d(12, 12, kernel_size=(7,), stride=(1,), padding=(3,), groups=12, bias=False)
              (1): Conv1d(12, 12, kernel_size=(1,), stride=(1,))
              (2): ReLU()
            )
          )
          (layer_norms): ModuleList(
            (0-3): 4 x LayerNorm((12,), eps=1e-06, elementwise_affine=True)
          )
          (dropout): Dropout(p=0, inplace=False)
        )
        (up): Linear(in_features=12, out_features=256, bias=True)
      )
      (attention_block): MultiHeadAttentionBlock(
        (dropout): Dropout(p=0, inplace=False)
       