In [1]:
from pathlib import Path

import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
from hydra.utils import instantiate
from omegaconf import OmegaConf

from opr.testing import test
from opr.utils import parse_device
from opr.modules import Concat

from src.models import LateFusionModel


In [5]:
DEVICE = parse_device("cuda")
BATCH_SIZE = 32

DATASET = "nclt"

DATASET_DIR = {
    "oxford": "/home/docker_mssplace/Datasets/pnvlad_oxford_robotcar_full",
    "nclt": "/home/docker_mssplace/Datasets/NCLT_preprocessed",
}

MODALITIES_LIST = [
    # "all_camera_lidar",
    # "all_camera_semantic_lidar",
    # "all_camera_semantic_text_lidar",
    # "all_camera_semantic_text",
    # "all_camera_semantic",
    # "all_camera_text_lidar",
    # "all_camera_text",
    "camera1_lidar",
]

for MODALITIES in MODALITIES_LIST:
    i = 4 if DATASET == "oxford" else 5

    IMAGE_MODEL = f"{DATASET}_camera1_exp.pth" if "camera" in MODALITIES else None
    SEMANTIC_MODEL = None # f"{DATASET}_camera1_add.pth" if "semantic" in MODALITIES else None
    TEXT_MODEL = None # f"{DATASET}_text{i}_clip-base-mlp-add.pth" if "text" in MODALITIES else None
    LIDAR_MODEL = f"{DATASET}_lidar_exp.pth" if "lidar" in MODALITIES else None

    def load_model(checkpoint_name: str) -> nn.Module:
        checkpoint_name = Path(checkpoint_name)
        checkpoint = torch.load("./checkpoints" / checkpoint_name)
        model_cfg = OmegaConf.create(checkpoint["config"]["model"])
        model = instantiate(model_cfg)
        model.load_state_dict(checkpoint["model_state_dict"])
        model.eval()
        return model

    def load_dataloader(dataset_name: str, modalities: str) -> DataLoader:
        dataset_cfg = OmegaConf.load(Path("configs") / "dataset" / dataset_name / (modalities + ".yaml"))
        dataset_cfg.dataset_root = DATASET_DIR[dataset_name]
        dataset = instantiate(dataset_cfg, subset="test")
        dataloader = DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=0,
            drop_last=False,
            collate_fn=dataset.collate_fn,
        )
        return dataloader

    dataloader = load_dataloader(DATASET, MODALITIES)
    image_model = load_model(IMAGE_MODEL) if IMAGE_MODEL else None
    semantic_model = load_model(SEMANTIC_MODEL) if SEMANTIC_MODEL else None
    text_model = load_model(TEXT_MODEL) if TEXT_MODEL else None
    lidar_model = load_model(LIDAR_MODEL) if LIDAR_MODEL else None

    concat_model = LateFusionModel(
        image_model,
        semantic_model,
        text_model,
        lidar_model,
        fusion_module=Concat()
    )
    concat_model.eval()
    concat_model = concat_model.to(DEVICE)

    print(f"Expreriment: {DATASET}_{MODALITIES}")

    mean_recall_at_n, mean_recall_at_one_percent, mean_top1_distance = test(
        concat_model, dataloader
    )
    metrics = {
        'dataset': DATASET,
        "modality": MODALITIES,
        'exp_name': "concat",
        'R@1': mean_recall_at_n[0],
        'R@3': mean_recall_at_n[2],
        'R@5': mean_recall_at_n[4],
        'R@10': mean_recall_at_n[9],
        'R@1%': mean_recall_at_one_percent,
        'mean_top1_distance': mean_top1_distance,
    }

    metrics_df = pd.DataFrame(metrics, index=[0])

    if not Path(f"{DATASET}_{MODALITIES}_metrics.csv").exists():
        metrics_df.to_csv(f"{DATASET}_{MODALITIES}_metrics.csv", index=False)
    else:
        print("Metrics csv exists.")


Expreriment: nclt_camera1_lidar


Calculating test set descriptors:   0%|          | 0/86 [00:00<?, ?it/s]

                                                                                 

In [None]:
metrics_df
