In [1]:
import torch
import torch.nn as nn
from wav2lip import Wav2Lip
from conv import *

In [2]:
fuse_counter = 0

def fuse_conv_bn_eval(conv, bn):
    """
    Fuse a Conv2d (or ConvTranspose2d) and BatchNorm2d into a single layer.
    Both modules must be in eval() mode.
    """
    assert isinstance(conv, (nn.Conv2d, nn.ConvTranspose2d)), "Only Conv2d or ConvTranspose2d is supported!"
    assert isinstance(bn, nn.BatchNorm2d), "Only BatchNorm2d is supported!"

    # Выбираем правильный класс
    fused_cls = type(conv)

    # Подготовка общих аргументов
    common_kwargs = dict(
        in_channels=conv.in_channels,
        out_channels=conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        groups=conv.groups,
        bias=True,
    )

    # Для ConvTranspose2d нужно добавить output_padding
    if isinstance(conv, nn.ConvTranspose2d):
        common_kwargs["output_padding"] = conv.output_padding

    # padding_mode есть только у Conv2d
    if isinstance(conv, nn.Conv2d):
        common_kwargs["padding_mode"] = conv.padding_mode

    # Новый слой
    fused_conv = fused_cls(**common_kwargs)

    # Параметры
    w_conv = conv.weight.clone()
    if conv.bias is not None:
        b_conv = conv.bias.clone()
    else:
        b_conv = torch.zeros(conv.out_channels, device=w_conv.device)

    w_bn = bn.weight
    b_bn = bn.bias
    running_mean = bn.running_mean
    running_var = bn.running_var
    eps = bn.eps

    scale = w_bn / torch.sqrt(running_var + eps)

    if isinstance(conv, nn.Conv2d):
        # scale по out_channels
        w_conv_fused = w_conv * scale.reshape([-1, 1, 1, 1])
        b_conv_fused = b_bn + (b_conv - running_mean) * scale
    elif isinstance(conv, nn.ConvTranspose2d):
        # scale по in_channels
        w_conv_fused = w_conv * scale.reshape([1, -1, 1, 1])
        b_conv_fused = b_bn + (b_conv - running_mean) * scale

    # Обновляем веса
    fused_conv.weight.data.copy_(w_conv_fused)
    fused_conv.bias.data.copy_(b_conv_fused)

    return fused_conv


def fuse_model_eval(model):
    global fuse_counter
    """
    Рекурсивно проходит по модулям модели.
    Там, где видит Conv2d + BatchNorm2d, заменяет их на эквивалентный Conv2d.
    Работает только в eval() режиме.
    """
    for name, module in model.named_children():
        # Если в модуле есть атрибут conv_block
        if hasattr(module, "conv_block") and isinstance(module.conv_block, nn.Sequential):
            modules = list(module.conv_block.children())
            new_modules = []
            skip = False
            for i in range(len(modules)):
                if skip:
                    skip = False
                    continue
                if (
                    (isinstance(modules[i], nn.Conv2d) or isinstance(modules[i], nn.ConvTranspose2d))
                    and i + 1 < len(modules)
                    and isinstance(modules[i + 1], nn.BatchNorm2d)
                ):
                    # Fuse!
                    fused_conv = fuse_conv_bn_eval(modules[i], modules[i + 1])
                    new_modules.append(fused_conv)
                    skip = True  # Пропускаем BatchNorm2d
                    fuse_counter += 1
                else:
                    new_modules.append(modules[i])
            # Обновляем conv_block
            module.conv_block = nn.Sequential(*new_modules)
        
        # Рекурсивно вызываем для всех детей
        fuse_model_eval(module)

In [3]:

def one_step(model, mel, face, device):
    with torch.no_grad():
        torch.cuda.reset_peak_memory_stats(device)
        out = model(mel, face)
        print("Peak memory (MB):", torch.cuda.max_memory_allocated(device) / 1024**2)

def test_fusing(do_fuse, device, batch_size):
    global fuse_counter
    model = Wav2Lip()
    mel = torch.rand(batch_size, 1, 80, 16)
    face = torch.rand(batch_size, 6, 96, 96)
    model.eval()

    if do_fuse:
        fuse_model_eval(model)
        print("Fusing done. Number of fused layers: {}".format(fuse_counter))
    model = model.to(device)
    mel = mel.to(device)
    face = face.to(device)
    one_step(model, mel, face, device)

def test_time_for_fusing(do_fuse, device, batch_size):
    import time
    import numpy as np
    mean_time = []
    global fuse_counter
    model = Wav2Lip()
    mel = torch.rand(batch_size, 1, 80, 16)
    face = torch.rand(batch_size, 6, 96, 96)
    model.eval()

    if do_fuse:
        fuse_model_eval(model)
        print("Fusing done. Number of fused layers: {}".format(fuse_counter))
    model = model.to(device)
    mel = mel.to(device)
    face = face.to(device)

    with torch.no_grad():
        for _ in range(100):
            start = time.time()
            out = model(mel, face)
            mean_time.append(time.time() - start)
    print("Total time: {}, mean time: {}".format(np.sum(mean_time), np.mean(mean_time)))

In [None]:

# test_fusing(True, "cuda", 1)
# test_time_for_fusing(True, "cuda", 1)

# test_fusing(False, "cuda", 1)
test_time_for_fusing(False, "cuda", 1)

# Memory:
#    FUSE: Peak memory (MB): 161.12353515625 (Number of fused layers: 50)
# NO FUSE: Peak memory (MB): 161.34326171875

# Speed:
#    FUSE: Total time: 0.9142956733703613, mean time: 0.009142956733703612
# NO FUSE: Total time: 1.1565561294555664, mean time: 0.011565561294555665

Total time: 1.1565561294555664, mean time: 0.011565561294555665
