In [11]:
import torch
from types import SimpleNamespace
import os, sys

from transformers import logging as hf_logging
hf_logging.set_verbosity_error()   # or .set_verbosity_warning()

# Go one level up: src/
project_root = os.path.abspath("..")
sys.path.append(project_root)
from models.docfusion_model import DocFusionModel

In [12]:
def make_dummy_batch(
    B=2,
    T=16,
    H=224,
    W=224,
    vocab_size=30522,
    num_labels=16,
    device="cpu",
):
    # fake token ids + mask
    input_ids = torch.randint(0, vocab_size, (B, T), device=device)
    attention_mask = torch.ones(B, T, device=device)

    # fake images
    images = torch.randn(B, 3, H, W, device=device)

    # fake token boxes in [0,1], (x1,y1,x2,y2)
    # ensure x2>x1, y2>y1
    x1y1 = torch.rand(B, T, 2, device=device) * 0.8
    wh   = torch.rand(B, T, 2, device=device) * 0.2
    x2y2 = (x1y1 + wh).clamp(max=1.0)
    token_boxes = torch.cat([x1y1, x2y2], dim=-1)

    batch = SimpleNamespace(
        input_ids=input_ids,
        attention_mask=attention_mask,
        images=images,
        token_boxes=token_boxes,
    )
    return batch

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
for mode in ["3.1", "3.3"]:
    print(f"\nTesting mode={mode}")
    model = DocFusionModel(
        mode=mode,
        num_fusion_layers=2,   # only used for 3.3
        num_labels=16,
    ).to(device)
    batch = make_dummy_batch(device=device)

    with torch.no_grad():
        logits = model(batch)

    print("logits.shape:", logits.shape)
    assert logits.shape[0] == batch.input_ids.shape[0]
    assert logits.shape[1] == batch.input_ids.shape[1]
    assert logits.shape[2] == 16

    # basic NaN / inf check
    assert torch.isfinite(logits).all(), "Found NaN/inf in logits"



Testing mode=3.1
logits.shape: torch.Size([2, 16, 16])

Testing mode=3.3
logits.shape: torch.Size([2, 16, 16])


In [21]:
model = DocFusionModel(mode="3.3", num_fusion_layers=1, num_labels=num_labels).to(device)
model.train()

logits = model(batch)                          # (B,T,num_labels)
labels = torch.randint(0, num_labels, (B,T), device=device)

loss = torch.nn.functional.cross_entropy(
    logits.view(-1, num_labels),
    labels.view(-1),
)

loss.backward()
count = 0
for name, p in model.named_parameters():
    if p.requires_grad:
        if p.grad is None:
            print(name, "NO GRAD")
        else:
            print(name, "grad mean:", p.grad.abs().mean().item())
        count += 1
        if count == 20:
            break




text_encoder.layout_gate grad mean: 0.0
text_encoder.layout_encoder.layout_scale grad mean: 0.01526162400841713
text_encoder.layout_encoder.coord_embed.weight grad mean: 0.0
text_encoder.layout_encoder.mlp.0.weight grad mean: 0.0
text_encoder.layout_encoder.mlp.0.bias grad mean: 0.0
text_encoder.layout_encoder.mlp.2.weight grad mean: 0.0
text_encoder.layout_encoder.mlp.2.bias grad mean: 0.0
text_encoder.layout_encoder.norm.weight grad mean: 0.0
text_encoder.layout_encoder.norm.bias grad mean: 0.0
text_encoder.norm.weight grad mean: 0.0016844046767801046
text_encoder.norm.bias grad mean: 0.0025796485133469105
region_pooler.scorer.weight grad mean: 0.008219465613365173
region_pooler.scorer.bias grad mean: 0.0
region_pooler.projection.weight NO GRAD
region_pooler.projection.bias NO GRAD
region_proj.weight grad mean: 0.0009427659679204226
region_proj.bias grad mean: 5.6880744523368776e-05
fusion_stem.ln_in.weight grad mean: 0.00026219047140330076
fusion_stem.ln_in.bias grad mean: 0.0003745