# Visual features extraction

In [1]:
from main_model import MainModel

In [2]:
from config import resnet_cfg, resnet_input_shape, args
from config import T, C, NUM_CLASSES, N_MFCC
from detectron2.layers import ShapeSpec
import torch

B = 5
T = 2

resnet_input_shape = ShapeSpec(
    channels=C,
    height=128,
    width=128,
)

model = MainModel(
    input_shape_resnet=resnet_input_shape,
    cfg_resnet=resnet_cfg,
    args=args,
    T=T,
    N_MFCC=N_MFCC,
    num_classes=NUM_CLASSES,
)

model.eval()

out_eval = model(
    torch.randn(B, 4*T, 13),  # audio
    torch.randn(B, T, C, 128, 128),  # visual
)

print("Eval output shape:", out_eval.shape)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Eval output shape: torch.Size([10, 25, 6])


In [3]:
model.train()

audio_input = torch.randn(B, 4*T, 13)
visual_input = torch.randn(B, T, C, 128, 128)

visual_features = model.forward_visual_encoder(visual_input)
audio_features = model.forward_audio_encoder(audio_input)
fused_features = model.forward_fusion(visual_features, audio_features)
head_output = model.forward_head(fused_features)

out_train = model.forward_head(
    fused_features
)

print("Train output shape:", [x.shape for x in out_train])

Train output shape: [torch.Size([10, 67, 32, 32]), torch.Size([10, 67, 16, 16]), torch.Size([10, 67, 8, 8]), torch.Size([10, 67, 4, 4]), torch.Size([10, 67, 2, 2])]


In [4]:
from loss import v8DetectionLoss

In [5]:
detection_loss = v8DetectionLoss(model)

In [6]:
targets_batch = {
    "batch_idx": torch.arange(B),
    "cls": torch.randint(0, NUM_CLASSES, (B,)),
    "bboxes": torch.rand(B, 4),
}

In [7]:
[x.shape for x in out_train]

[torch.Size([10, 67, 32, 32]),
 torch.Size([10, 67, 16, 16]),
 torch.Size([10, 67, 8, 8]),
 torch.Size([10, 67, 4, 4]),
 torch.Size([10, 67, 2, 2])]

In [8]:
detection_loss(out_train, targets_batch)

(tensor([1.4507e+00, 9.0152e+03, 1.4841e+01], grad_fn=<MulBackward0>),
 tensor([1.4507e-01, 9.0152e+02, 1.4841e+00]))

In [2]:
from config import resnet_cfg, resnet_input_shape, T, C, NUM_CLASSES, N_MFCC, H, W

In [None]:
from dataset import AVADataset, AVADataLoader


train_dataset = AVADataset(
    "train", N_MFCC,
    C, H, W, T
)
    
train_loader = AVADataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)

In [3]:
from main_model import MainModel

T = 2
H = 128
W = 128

model = MainModel(
    input_shape_resnet=resnet_input_shape,
    cfg_resnet=resnet_cfg,
    T=T,
    N_MFCC=N_MFCC,
    num_classes=NUM_CLASSES,
)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
# Count the number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

13428367

In [5]:
# for i, (audio, visual, targets, boxes) in enumerate(train_loader):
#     print("Batch", i)
#     print("Audio shape:", audio.shape)
#     print("Visual shape:", visual.shape)
#     print("Targets shape:", targets.shape)
#     print("Boxes shape:", boxes.shape)

#     if i == 1:
#         break
import torch

audio = torch.randn(1, 4*T, 13)
visual = torch.randn(1, T, C, H, W)
targets = torch.randint(0, NUM_CLASSES, (1, T, 2))
boxes = torch.rand(1, T, 2, 4)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=10, gamma=0.1)

In [7]:
audio.shape, visual.shape, targets.shape, boxes.shape

(torch.Size([1, 8, 13]),
 torch.Size([1, 2, 3, 128, 128]),
 torch.Size([1, 2, 2]),
 torch.Size([1, 2, 2, 4]))

In [8]:
audio_features = model.forward_audio_encoder(audio.to(device))
visual_features = model.forward_visual_encoder(visual.to(device))

audio_features.cpu().shape, [x.shape for x in visual_features.values()]

(torch.Size([1, 2, 128]),
 [torch.Size([1, 2, 128, 32, 32]),
  torch.Size([1, 2, 128, 16, 16]),
  torch.Size([1, 2, 128, 8, 8]),
  torch.Size([1, 2, 128, 4, 4]),
  torch.Size([1, 2, 128, 2, 2])])

In [9]:
xa = audio_features
xv = [x for x in visual_features.values()][0]

xa.shape, xv.shape

(torch.Size([1, 2, 128]), torch.Size([1, 2, 128, 32, 32]))

In [None]:
# Audio is (B, T, 128)
# Each visual feature map is (B, T, 128, h_i*w_i)

# Audio-visual fusion: 
# 1. Use convolution to project visual features to the same dimension as audio features
# 2. Apply cross attention
# 3. Upsample fused features to the same spatial resolution as visual features
# 4. Add a residual connection to the original visual features


# 1. Convolution to project visual features to the same dimension as audio features
xv = xv.view(xv.shape[0]*xv.shape[1], xv.shape[2], xv.shape[3], xv.shape[4])
print("xv shape after view:", xv.shape)
conv = torch.nn.Conv2d(
    in_channels=xv.shape[1],
    out_channels=xa.shape[1],
    kernel_size=1,
    stride=1,
    padding=0
)

