In [None]:
## Test BertLayer inputs and Outputs

In [25]:
import sys
import os
notebook_path = os.getcwd()
project_path = os.path.join(notebook_path, '..')  
if project_path not in sys.path:
    sys.path.append(project_path)
import torch
from model.glip import BertEncoderLayer,VLFuse

In [8]:
batch_size = 2
seq_len = 10
hidden_size = 768

x = torch.randn(batch_size, seq_len, hidden_size)
mask = torch.ones(batch_size, seq_len)

encLayer = BertEncoderLayer(hidden_size=hidden_size)

input = {
    "visual": [torch.randn(batch_size, 256, 10, 10)],
    "visual_masks": [torch.ones(batch_size, 10, 10)],
    "lang": {
        "hidden": x.clone(),
        "masks": mask
    }
}

y=encLayer(input)
    

In [32]:
print("\nInput Shapes:")
print (f"Visual feature: {[f.shape for f in input['visual']]}")
print(f"Language feature: {input['lang']['hidden'].shape}")
    


Input Shapes:
Visual feature: [torch.Size([2, 256, 10, 10])]
Language feature: torch.Size([2, 10, 768])


In [14]:
y.keys()

dict_keys(['visual', 'visual_masks', 'lang'])

In [18]:
assert torch.allclose(y['visual'][0],input['visual'][0])
assert torch.allclose(y['visual_masks'][0],input['visual_masks'][0])

In [24]:
y['lang']['hidden'].shape

torch.Size([2, 10, 768])

In [23]:
assert not torch.allclose(y['lang']['hidden'],input['lang']['hidden'])

In [30]:

hidden_dim= 256
# Test VLFuse module
vlfuse = VLFuse(hidden_dim=hidden_dim)

# Forward pass
outputs = vlfuse(input)

VLFuse Test Results:

Input Shapes:
Visual feature: [torch.Size([2, 256, 10, 10])]
Language feature: torch.Size([2, 10, 768])


In [36]:
print("Output shapes")
print (f"Visual feature: {[f.shape for f in outputs['visual']]}")
print(f"Language feature: {outputs['lang']['hidden'].shape}")
    

Output shapes
Visual feature: [torch.Size([2, 256, 10, 10])]
Language feature: torch.Size([2, 10, 768])


In [37]:
visual_changes = [(out - inp).abs().mean().item() for out, inp in zip(outputs['visual'], input['visual'])]
lang_change = (outputs['lang']['hidden'] - input['lang']['hidden']).abs().mean().item()

In [38]:
print("\nFeature Changes (L1 norm):")
for i, change in enumerate(visual_changes):
    print(f"Visual level {i} change: {change:.3f}")
print(f"Language feature change: {lang_change:.3f}")


Feature Changes (L1 norm):
Visual level 0 change: 0.062
Language feature change: 0.015
