In [14]:
from argparse import Namespace
from collections import OrderedDict
import copy

import torch
from torch import nn
from torch.ao.quantization import default_qconfig
from torch.ao.quantization.fuse_modules import fuse_modules


from model.DeepVSLNet import DeepVSLNet
from model.QuantizedDeepVSLNet import QuantizedDeepVSLNet
from model.layers import Conv1DReLU, Conv1D

In [15]:
def fuse_sequential_block(
    block: nn.Sequential,
    layers_to_fuse: list[str],
    inplace: bool = False
) -> nn.Sequential:
    """
    Fuse specified layers in a Sequential block.
    """
    return torch.quantization.fuse_modules(block, layers_to_fuse, inplace=inplace)

def fuse_modulelist_blocks(
    blocks: nn.ModuleList,
    fuse_map: list[str],
    inplace: bool = False
) -> nn.ModuleList:
    """
    Applies fusion to each nn.Sequential in a ModuleList using a fixed fuse pattern.
    """
    fused = []
    for block in blocks:
        block_copy = block if inplace else copy.deepcopy(block)
        fused_block = fuse_sequential_block(block_copy, fuse_map, inplace=inplace)
        fused.append(fused_block)
    return nn.ModuleList(fused)

def fuse_depthwise_separable_conv_block(conv_block, inplace=False):
    """
    Fuses the pointwise Conv1d + ReLU in a DepthwiseSeparableConvBlock.
    """
    conv_block_copy = conv_block if inplace else copy.deepcopy(conv_block)
    fuse_pattern = ['1', '2']  # pointwise conv + ReLU
    conv_block_copy.depthwise_separable_conv = fuse_modulelist_blocks(
        conv_block_copy.depthwise_separable_conv,
        fuse_map=fuse_pattern,
        inplace=inplace
    )
    return conv_block_copy

def fuse_feature_encoder(feature_encoder, inplace=False):
    """
    Fuses all submodules in the feature encoder.
    """
    encoder = feature_encoder if inplace else copy.deepcopy(feature_encoder)
    encoder.conv_block = fuse_depthwise_separable_conv_block(encoder.conv_block, inplace=inplace)
    return encoder

def fuse_conv1d_relu_in_sequential(seq: nn.Sequential) -> nn.Sequential:
    layers = []
    i = 0
    while i < len(seq):
        if (
            isinstance(seq[i], Conv1D)
            and i + 1 < len(seq)
            and isinstance(seq[i + 1], nn.ReLU)
        ):
            fused = Conv1DReLU(seq[i])
            layers.append(fused)
            i += 2  # skip next
        else:
            layers.append(seq[i])
            i += 1
    return nn.Sequential(*layers)

def fuse_predictor_head(predictor_head: nn.Sequential, inplace=False) -> nn.Sequential:
    block = predictor_head if inplace else copy.deepcopy(predictor_head)
    return fuse_conv1d_relu_in_sequential(block)

def fuse_conditioned_predictor(conditioned_predictor, inplace=False):
    """
    Fuses encoder and start/end heads in a conditioned predictor module.
    """
    predictor = conditioned_predictor if inplace else copy.deepcopy(conditioned_predictor)
    predictor.encoder = fuse_feature_encoder(predictor.encoder, inplace=inplace)
    predictor.start_block = fuse_predictor_head(predictor.start_block, inplace=inplace)
    predictor.end_block = fuse_predictor_head(predictor.end_block, inplace=inplace)
    return predictor

def fuse_model(model, inplace=False):
    """
    Top-level model fusion function.
    """
    model_copy = model if inplace else copy.deepcopy(model)
    model_copy.feature_encoder = fuse_feature_encoder(model_copy.feature_encoder, inplace=inplace)
    model_copy.predictor = fuse_conditioned_predictor(model_copy.predictor, inplace=inplace)
    return model_copy

In [20]:
configs = Namespace(
    video_feature_dim=256,
    dim=256,
    film_mode="inside_encoder:multi",
    drop_rate=0,
    word_size=300,
    char_size=1000,
    word_dim=300,
    char_dim=50,
    word_vectors=None,
    num_heads=8,
    max_pos_len=128,
    predictor="glove",
)

# 1. Load your original model
model = DeepVSLNet(configs=configs, word_vectors=None)
model.eval()

# 2. Apply fusion
fused = fuse_model(model, inplace=False)

# 3. Wrap for quantization
quant_model = QuantizedDeepVSLNet(fused)
quant_model.qconfig = torch.ao.quantization.default_qconfig

# 4. Prepare
torch.quantization.prepare(quant_model, inplace=True)

QuantizedDeepVSLNet(
  (query_quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (video_quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (dequant): DeQuantStub()
  (video_affine): VisualProjection(
    (drop): Dropout(p=0, inplace=False)
    (linear): Conv1D(
      (conv1d): Conv1d(
        256, 256, kernel_size=(1,), stride=(1,)
        (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
  )
  (cq_attention): CQAttention(
    (dropout): Dropout(p=0, inplace=False)
    (cqa_linear): Conv1D(
      (conv1d): Conv1d(
        1024, 256, kernel_size=(1,), stride=(1,)
        (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
  )
  (cq_concat): CQConcatenate(
    (weighted_pool): WeightedPool()
    (conv1d): Conv1D(
      (conv1d): Conv1d(
        512, 256, kernel_size=(1,), stride=(1,)
        (activation_post_process): MinMaxObserver