# 对齐网络

In [9]:
import torch.nn as nn
class AlignSubNet(nn.Module):
    def __init__(self, dst_len):
        """
        mode: the way of aligning avg_pool 这个模型并没有参数
        """
        super(AlignSubNet, self).__init__()

        # in_dim_t, in_dim_a, in_dim_v = args.feature_dims
        # seq_len_t, seq_len_a, seq_len_v = args.seq_lens
        self.dst_len = dst_len

    def get_seq_len(self):
        return self.dst_len
    
    def __avg_pool(self, text_x, audio_x, video_x):
        def align(x):
            raw_seq_len = x.size(1)
            if raw_seq_len == self.dst_len:
                return x
            if raw_seq_len // self.dst_len == raw_seq_len / self.dst_len:
                pad_len = 0
                pool_size = raw_seq_len // self.dst_len
            else:
                pad_len = self.dst_len - raw_seq_len % self.dst_len
                pool_size = raw_seq_len // self.dst_len + 1
            pad_x = x[:, -1, :].unsqueeze(1).expand([x.size(0), pad_len, x.size(-1)])
            x = torch.cat([x, pad_x], dim=1).view(x.size(0), pool_size, self.dst_len, -1)
            x = x.mean(dim=1)
            return x
        text_x = align(text_x)
        audio_x = align(audio_x)
        video_x = align(video_x)
        return text_x, audio_x, video_x
    
 
    def forward(self, text_x, audio_x, video_x):
        if text_x.size(1) == audio_x.size(1) == video_x.size(1):
            return text_x, audio_x, video_x
        return self.__avg_pool(text_x, audio_x, video_x)

import torch
text = torch.randn(2,24)
audio = torch.randn(2,33,33)
video = torch.randn(2,10,768)
align_subnet = AlignSubNet(text.size(1))
t,a,v = align_subnet(text,audio,video)
t.shape, a.shape, v.shape
list(align_subnet.named_modules())

[('', AlignSubNet())]