In [1]:
from argparse import Namespace
from collections import OrderedDict
import copy

import torch
from torch import nn
from torch.ao.quantization import default_qconfig
from torch.ao.quantization.fuse_modules import fuse_modules
from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
from model.layers import Embedding

from model.DeepVSLNet import DeepVSLNet
from model.QuantizedDeepVSLNet import QuantizedDeepVSLNet
from model.layers import Conv1DReLU, Conv1D

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def fuse_sequential_block(
    block: nn.Sequential,
    layers_to_fuse: list[str],
    inplace: bool = False
) -> nn.Sequential:
    """
    Fuse specified layers in a Sequential block.
    """
    return torch.quantization.fuse_modules(block, layers_to_fuse, inplace=inplace)

def fuse_modulelist_blocks(
    blocks: nn.ModuleList,
    fuse_map: list[str],
    inplace: bool = False
) -> nn.ModuleList:
    """
    Applies fusion to each nn.Sequential in a ModuleList using a fixed fuse pattern.
    """
    fused = []
    for block in blocks:
        block_copy = block if inplace else copy.deepcopy(block)
        fused_block = fuse_sequential_block(block_copy, fuse_map, inplace=inplace)
        fused.append(fused_block)
    return nn.ModuleList(fused)

def fuse_depthwise_separable_conv_block(conv_block, inplace=False):
    """
    Fuses the pointwise Conv1d + ReLU in a DepthwiseSeparableConvBlock.
    """
    conv_block_copy = conv_block if inplace else copy.deepcopy(conv_block)
    fuse_pattern = ['1', '2']  # pointwise conv + ReLU
    conv_block_copy.depthwise_separable_conv = fuse_modulelist_blocks(
        conv_block_copy.depthwise_separable_conv,
        fuse_map=fuse_pattern,
        inplace=inplace
    )
    return conv_block_copy

def fuse_feature_encoder(feature_encoder, inplace=False):
    """
    Fuses all submodules in the feature encoder.
    """
    encoder = feature_encoder if inplace else copy.deepcopy(feature_encoder)
    encoder.conv_block = fuse_depthwise_separable_conv_block(encoder.conv_block, inplace=inplace)
    return encoder

def fuse_conv1d_relu_in_sequential(seq: nn.Sequential) -> nn.Sequential:
    layers = []
    i = 0
    while i < len(seq):
        if (
            isinstance(seq[i], Conv1D)
            and i + 1 < len(seq)
            and isinstance(seq[i + 1], nn.ReLU)
        ):
            fused = Conv1DReLU(seq[i])
            layers.append(fused)
            i += 2  # skip next
        else:
            layers.append(seq[i])
            i += 1
    return nn.Sequential(*layers)

def fuse_predictor_head(predictor_head: nn.Sequential, inplace=False) -> nn.Sequential:
    block = predictor_head if inplace else copy.deepcopy(predictor_head)
    return fuse_conv1d_relu_in_sequential(block)

def fuse_conditioned_predictor(conditioned_predictor, inplace=False):
    """
    Fuses encoder and start/end heads in a conditioned predictor module.
    """
    predictor = conditioned_predictor if inplace else copy.deepcopy(conditioned_predictor)
    predictor.encoder = fuse_feature_encoder(predictor.encoder, inplace=inplace)
    predictor.start_block = fuse_predictor_head(predictor.start_block, inplace=inplace)
    predictor.end_block = fuse_predictor_head(predictor.end_block, inplace=inplace)
    return predictor

def fuse_model(model, inplace=False):
    """
    Top-level model fusion function.
    """
    model_copy = model if inplace else copy.deepcopy(model)
    model_copy.feature_encoder = fuse_feature_encoder(model_copy.feature_encoder, inplace=inplace)
    model_copy.predictor = fuse_conditioned_predictor(model_copy.predictor, inplace=inplace)
    return model_copy

In [3]:
configs = Namespace(
    video_feature_dim=256,
    dim=256,
    film_mode="inside_encoder:multi",
    drop_rate=0,
    word_size=300,
    char_size=1000,
    word_dim=300,
    char_dim=50,
    word_vectors=None,
    num_heads=8,
    max_pos_len=128,
    predictor="glove",
)

In [4]:
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.observer import default_observer

In [5]:
float_model = QuantizedDeepVSLNet(configs=configs, word_vectors=None)

In [10]:
a.pop("linear_modulation.film_generator.weight", None)
a.pop("linear_modulation.film_generator.bias", None)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [11]:
a["linear_modulation.film_generator.bias"]

KeyError: 'linear_modulation.film_generator.bias'

In [8]:
a = torch.load("deepvslnet_11649.t7", map_location='cpu')
# a.remove("linear_modulation.film_generator.weight", "linear_modulation.film_generator.bias")

In [None]:
# 1. Load your original model
float_model = DeepVSLNet(configs=configs, word_vectors=None)
float_model.eval()

fused_model = fuse_model(float_model)

In [None]:
qconfig_global = QConfig(
    activation=MinMaxObserver.with_args(dtype=torch.qint8),
    weight=default_observer.with_args(dtype=torch.qint8)
)

qconfig_emb = float_qparams_weight_only_qconfig


def assign_qconfig(model, qconfig_global, qconfig_emb):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Embedding):
            module.qconfig = None
            print(name, module)
        else:
            # For other modules, assign global only if they don't have qconfig yet
            if not hasattr(module, 'qconfig') or module.qconfig is None:
                module.qconfig = qconfig_global

assign_qconfig(fused_model, qconfig_global, qconfig_emb)

In [None]:
# 2: insert observers
quant_ready_model = QuantizedDeepVSLNet(fused_model)

torch.ao.quantization.prepare(quant_ready_model, inplace=True)

In [None]:
# 3: calibration
# run_static_quantization_calibration(
#     quant_ready_model, calibration_loader, num_calibration_batches
# )

# 4: convert to quantized
quantized_model = torch.ao.quantization.convert(quant_ready_model, inplace=False)

In [None]:
torch.save(quantized_model.state_dict(), 'quantized_state_dict.pth')

In [None]:
a = torch.load('quantized_state_dict.pth')

In [None]:
a["video_affine.linear.conv1d.weight"].dtype