# Visual features extraction

In [14]:
from detectron2.modeling.backbone.fpn import build_resnet_fpn_backbone

In [None]:
from detectron2.config import CfgNode as CN
from detectron2.layers import ShapeSpec


resnet_cfg = {
    "MODEL": {
        "FPN": {
            "IN_FEATURES": ["res2", "res3", "res4", "res5"],
            "OUT_CHANNELS": 128,
            "NORM": "BN",
            "FUSE_TYPE": "sum",
        },
        "BACKBONE": {
            "FREEZE_AT": 2,
        },
        "RESNETS": {
                "OUT_FEATURES": ["res2", "res3", "res4", "res5"],
                "DEPTH": 18,
                "NUM_GROUPS": 1,
                "WIDTH_PER_GROUP": 1,
                "STEM_OUT_CHANNELS": 1,
                "RES2_OUT_CHANNELS": 64,
                "STRIDE_IN_1X1": True,
                "RES5_DILATION": 1,
                "DEFORM_ON_PER_STAGE": [],
                "DEFORM_MODULATED": False,
                "DEFORM_NUM_GROUPS": [],
                "NORM": "BN"
            }
    }
}

# Convert the dictionary to a config object
resnet_cfg = CN(resnet_cfg)


B = 1
T = 5
C = 3
W = 512
H = 512

input_shape_resnet = ShapeSpec(
    channels=C,
    height=H,
    width=W,
)

In [None]:
fpn = build_resnet_fpn_backbone(resnet_cfg, input_shape_resnet)

In [118]:
# Print the number of parameters in the FPN
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of parameters in FPN: {count_parameters(fpn)}")

Number of parameters in FPN: 11733760


In [None]:
import torch


x = torch.randn(B*T, C, H, W)

In [120]:
res = fpn(x)

In [112]:
res.keys()

dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])

In [113]:
# Reshape everything to B, C, T, H, W
for k, v in res.items():
    res[k] = v.view(B, T, *v.shape[1:]).permute(0, 2, 1, 3, 4)

In [114]:
res['p2'].shape, res['p3'].shape, res['p4'].shape, res['p5'].shape

(torch.Size([1, 128, 5, 128, 128]),
 torch.Size([1, 128, 5, 64, 64]),
 torch.Size([1, 128, 5, 32, 32]),
 torch.Size([1, 128, 5, 16, 16]))

In [14]:
import torch
import torch.nn as nn

from detectron2.modeling.backbone.fpn import build_resnet_fpn_backbone

from audio_encoder import AudioEncoder

from tpavi import TPAVI


class MainModel(torch.nn.Module):
    def __init__(self,
                 input_shape_resnet,
                 resnet_cfg,
                 T
                 ):
        super(MainModel, self).__init__()
        
        self.T = T
        self.W = input_shape_resnet.width
        self.H = input_shape_resnet.height
        self.C = input_shape_resnet.channels
        
        # Initialize the visual encoder (ResNet18 + FPN model)
        self.visual_encoder = build_resnet_fpn_backbone(resnet_cfg, input_shape_resnet)

        # Initialize the audio encoder
        self.audio_encoder = AudioEncoder()
        self.dim_audio = 128  # Constant defined in the AudioEncoder block

        # Initialize the fusion modules (TPAVI) for each feature map
        # ouput of the visual encoder
        self.fusion_modules = nn.ModuleList()
        for feature_map in self.visual_encoder.output_shape().values():
            self.fusion_modules.append(
                TPAVI(
                    C=feature_map.channels,
                    T=self.T,
                    dim_audio=self.dim_audio,
                ))

    def forward_audio_encoder(self, audio):
        """
        Forward pass of the audio encoder.

        Args:
            audio (torch.Tensor): Audio input of shape (B, 4T, N_MFCC).

        Returns:
            torch.Tensor: Output feature tensor of shape (B, T, 128).
        """
        audio = audio.unsqueeze(1).transpose(-1, -2)
        return self.audio_encoder(audio)

    def forward_visual_encoder(self, video):
        """
        Forward pass of the visual encoder.

        Args:
            video (torch.Tensor): Video input of shape (B, T, C, W, H).

        Returns:
            dict: Dictionary containing feature maps from the visual encoder.
        """
        # Change shape to (B*T, C, W, H)
        print(video.shape)
        video = video.view(
            video.size(0)*video.size(1), *video.size()[2:])
        print(video.shape)
        return self.visual_encoder(video)


    def forward(self, audio, video):
        """
        Forward pass of the model.

        Args:
            audio (torch.Tensor): Audio input of shape (B, 4T, N_MFCC).
            video (torch.Tensor): Video input of shape (B, T, C, W, H).
        """
        visual_features = self.forward_visual_encoder(video)
        audio_features = self.forward_audio_encoder(audio)
        print(audio_features.shape)
        return visual_features, audio_features


In [13]:
audio = torch.randn(B, 4*T, 13)
audio = audio.unsqueeze(1).transpose(-1, -2)
print(audio.shape)
model.audio_encoder(audio).shape

torch.Size([1, 1, 13, 20])


torch.Size([1, 5, 128])

In [16]:
from config import cfg, input_shape
from utils import T, C
from detectron2.layers import ShapeSpec

B = 1
T = 5

input_shape = ShapeSpec(
    channels=C,
    height=256,
    width=256,
)

model = MainModel(
    input_shape_resnet=input_shape,
    resnet_cfg=cfg,
    T=T
)

visual_features, audio_features = model(
    audio=torch.randn(B, 4*T, 13),
    video=torch.randn(B, T, C, 256, 256),
)

torch.Size([1, 5, 3, 256, 256])
torch.Size([5, 3, 256, 256])
torch.Size([1, 5, 128])
