In [11]:
# %% [markdown]
# Smoke test: dataset -> model -> forward pass (radar-only)

# %%
import os, sys, math, random, time
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

print("PyTorch:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


PyTorch: 2.5.1+cu121
Device: cuda


In [12]:
# %%
from core.options import Options

opt = Options()  # use your dataclass defaults

# --- make sure key knobs are aligned with your dataset & decoder ---
opt.plucker_ray = True              # if your dataset concatenates Plücker rays
opt.num_input_views = 4
opt.input_size = 448
opt.patch_size = 45             # 先不改（你说“先不动了”）
opt.upsampler_type = 'none'         # 先按默认 none

# dataset loader threads (Notebook 调小更稳)
opt.num_workers = 0
opt.batch_size = 4                 # smoke test => 1

# radar GT ranges (match your ground-truth)
opt.radar_xy_range = 50.0
opt.radar_z_value  = 1.0
opt.radar_min_pts  = 500
opt.radar_max_pts  = 2500
opt.radar_num_points = 1300         # 固定点数，方便 batch
opt.radar_rcs_low  = -50.0
opt.radar_rcs_high =  50.0
opt.radar_vmax     =  50.0


# training flags that are no longer relevant to rendering losses
opt.lambda_lpips = 0.0              # 关掉 LPIPS 以避免加载不必要的网络


In [13]:
# %%
from core.provider_ikun2 import ObjaverseDataset2
from torch.utils.data import DataLoader

ds = ObjaverseDataset2(opt=opt, training=True)
print("Dataset length:", len(ds))

loader = DataLoader(
    ds,
    batch_size=opt.batch_size,
    shuffle=False,
    num_workers=opt.num_workers,
    pin_memory=True,
    drop_last=False,
    # 如果你以后改成变长点数，可加自定义 collate_fn 做 padding
    # collate_fn=collate_radar,
)

sample = next(iter(loader))
print("Top-level keys:", list(sample.keys()))
print("input keys:", list(sample['input'].keys()))
print("radar_gt keys:", list(sample['radar_gt'].keys()))

x = sample['input']['images']
pts = sample['radar_gt']['points']
rcs = sample['radar_gt']['rcs']
vrel = sample['radar_gt']['vrel']

print("images:", tuple(x.shape), x.dtype)
print("radar points:", tuple(pts.shape), pts.dtype)
print("radar rcs:", tuple(rcs.shape), rcs.dtype)
print("radar vrel:", tuple(vrel.shape), vrel.dtype)


Dataset length: 82574


Top-level keys: ['mode', 'input', 'radar_gt']
input keys: ['images']
radar_gt keys: ['points', 'rcs', 'vrel']
images: (4, 4, 9, 448, 448) torch.float32
radar points: (4, 1300, 3) torch.float32
radar rcs: (4, 1300, 1) torch.float32
radar vrel: (4, 1300, 3) torch.float32


In [14]:
# %%
from core.mvgamba_models2 import MVGamba2

model = MVGamba2(opt).to(device)
model.train()  # or eval(); here we just want forward to work
print("Model ok.")


Model ok.


In [15]:
# %%
# move batch to device
def to_device(batch, device):
    out = {}
    for k, v in batch.items():
        if isinstance(v, dict):
            out[k] = to_device(v, device)
        elif torch.is_tensor(v):
            out[k] = v.to(device)
        else:
            out[k] = v
    return out

batch = to_device(sample, device)

with torch.no_grad():
    out = model(batch, epoch=0, step_ratio=0.0, vis=1)

print("Output keys:", list(out.keys()))
for k in ['loss','loss_cd','loss_rcs','loss_vrel','psnr']:
    if k in out:
        print(f"{k}: {float(out[k]):.6f}")

if 'pred_points' in out:
    print("pred_points:", tuple(out['pred_points'].shape))
else:
    print("No 'pred_points' returned (set vis=1 in forward to get it).")


Output keys: ['loss', 'loss_cd', 'loss_rcs', 'loss_vrel', 'psnr', 'pred_points']
loss: 44.556053
loss_cd: 39.656807
loss_rcs: 22.186756
loss_vrel: 26.805719
psnr: 0.000000
pred_points: (4, 1296, 3)
