In [None]:
import torch
from lib.dataset import BlendedMVSDataModule, TartanairDataModule
from lib.metrics import compute_metrics
from torchvision import transforms as T
import pandas as pd
import matplotlib.pyplot as plt
from ramdepth import Model
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
# parameters
device = "cuda:0"
dataset = "tartanair"

In [None]:
# load the data
img_process = T.Compose(
    [
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)
depth_process = T.ToTensor()

def basic_transform(ex):
    for key, value in ex.items():
        if key.startswith("image"):
            ex[key] = img_process(value)
        elif key.startswith("gt"):
            ex[key] = depth_process(value)
    return ex

DMClass = {
    "blended": BlendedMVSDataModule,
    "tartanair": TartanairDataModule
}[dataset]

dm = DMClass(
    load_prevs=4,
    batch_size=1,
    eval_transform=basic_transform,
)
dm.prepare_data()
dm.setup("test")
dl = dm.test_dataloader()

In [None]:
model = Model(pretrained=dataset, device=device)

In [None]:
def prepare_input(ex):
    n_src = max(int(k.split("_")[-1]) for k in ex.keys() if "_prev" in k)
    target = ex["image"]
    sources = torch.stack([ex[f"image_prev_{i}"] for i in range(n_src)], 2)
    poses = torch.stack([
        ex["position"] @ torch.linalg.inv(ex[f"position_prev_{i}"])
        for i in range(n_src)
    ], 1)
    intrinsics = torch.stack(
        [ex["intrinsics"]] + [ex[f"intrinsics_prev_{i}"] for i in range(n_src)],
        1,
    )
    return {
        "target": target,
        "sources": sources,
        "poses": poses,
        "intrinsics": intrinsics
    }

metrics = []
for ex in tqdm(dl):
    inp = {k: v.to(device) for k, v in prepare_input(ex).items()}
    depth = model(**inp).cpu()
    metrics.append({k: v.item() for k, v in compute_metrics(depth, ex["gt"]).items()})
metrics = pd.DataFrame(metrics).mean(axis=0)

print(f"== metrics for {dataset} ==")
print(metrics)