In [34]:
import os
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf


from __future__ import annotations

import os
from datetime import datetime
from pathlib import Path

import hydra
import omegaconf
import torch
import torch.multiprocessing as mp
from hydra.utils import instantiate
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from torch.distributed import destroy_process_group, init_process_group


from torch.nn.functional import interpolate
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm

from evals.datasets.builder import build_loader
from evals.utils.losses import DepthLoss
from evals.utils.metrics import evaluate_depth, match_scale_and_shift
from evals.utils.optim import cosine_decay_linear_warmup
from evals.utils.models import get_latest_checkpoint
import wandb 

import cv2
import numpy as np

In [None]:
"""
MIT License

Copyright (c) 2024 Mohamed El Banani

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from __future__ import annotations

import os
from datetime import datetime
from pathlib import Path

import hydra
import omegaconf
import torch
import torch.multiprocessing as mp
from hydra.utils import instantiate
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from torch.distributed import destroy_process_group, init_process_group


from torch.nn.functional import interpolate
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm

from evals.datasets.builder import build_loader
from evals.utils.losses import DepthLoss
from evals.utils.metrics import evaluate_depth, match_scale_and_shift
from evals.utils.optim import cosine_decay_linear_warmup
from evals.utils.models import get_latest_checkpoint
import wandb 


def ddp_setup(rank: int, world_size: int, port: int):
    """
    Args:
        rank: Unique identifier of each process
       world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def validate(
    model, probe, loader, loss_fn, verbose=True, scale_invariant=False, aggregate=True, cfg = None, log_interval = 20
):
    total_loss = 0.0
    metrics = None
    with torch.inference_mode():
        pbar = tqdm(loader, desc="Evaluation") if verbose else loader
        for i, batch in enumerate(pbar):
            images = batch["image"].cuda().squeeze(0)
            target = batch["depth"].cuda().squeeze(0)

            feat = model(images)
            pred = probe(feat).detach()
            pred = interpolate(pred, size=target.shape[-2:], mode="bilinear")

            if i % log_interval == 0:
                wandb.log({"val_prediction": wandb.Image(pred[:8].cpu().detach())})
                wandb.log({"val_target": wandb.Image(target[:8].cpu().detach())})

            loss = loss_fn(pred, target)
            total_loss += loss.item()

            batch_metrics = evaluate_depth(
                pred, target, scale_invariant=scale_invariant
            )
            if metrics is None:
                metrics = {
                    key: [
                        value,
                    ]
                    for key, value in batch_metrics.items()
                }
            else:
                for key, value in batch_metrics.items():
                    metrics[key].append(value)

    # aggregate
    total_loss = total_loss / len(loader)


    for key in metrics:
        metric_key = torch.cat(metrics[key], dim=0)
        metrics[key] = metric_key.mean() if aggregate else metric_key

    return total_loss, metrics


def evaluate_model(rank, world_size, cfg):
    if "checkpoint_path" in cfg and cfg.checkpoint_path: 
        checkpoint_path = get_latest_checkpoint(cfg.checkpoint_path)
        print("loading checkpoint from ", checkpoint_path)
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        cfg = checkpoint["cfg"]

    if world_size > 1:
        ddp_setup(rank, world_size, cfg.system.port)
    
    # ===== GET DATA LOADERS =====
    # validate and test on single gpu
    # train_loader = build_loader(cfg.dataset, "train", 1, world_size)
    # test_loader = build_loader(cfg.dataset, "valid", 1, 1)
    train_loader = build_loader(cfg.dataset, "train", 1, world_size)
    test_loader = build_loader(cfg.dataset, "valid", 1, 1)
    train_loader.dataset.__getitem__(0)

    # ===== Get models =====
    model = instantiate(cfg.backbone)
    probe = instantiate(
        cfg.probe, feat_dim=model.feat_dim, max_depth=train_loader.dataset.max_depth
    )

    # setup experiment name
    # === job info
    timestamp = datetime.now().strftime("%d%m%Y-%H%M")
    train_dset = train_loader.dataset.name
    test_dset = test_loader.dataset.name
    model_info = [
        f"{model.checkpoint_name:40s}",
        f"{model.patch_size:2d}",
        f"{str(model.layer):5s}",
        f"{model.output:10s}",
    ]
    probe_info = [f"{probe.name:25s}"]
    batch_size = cfg.batch_size * cfg.system.num_gpus
    train_info = [
        f"{cfg.optimizer.n_epochs:3d}",
        f"{cfg.optimizer.warmup_epochs:4.2f}",
        f"{str(cfg.optimizer.probe_lr):>10s}",
        f"{str(cfg.optimizer.model_lr):>10s}",
        f"{batch_size:4d}",
        f"{train_dset:10s}",
        f"{test_dset:10s}",
    ]

    # define exp_name
    exp_name = "_".join([timestamp] + model_info + probe_info + train_info)
    exp_name = f"{exp_name}_{cfg.note}" if cfg.note != "" else exp_name
    exp_name = exp_name.replace(" ", "")  # remove spaces

    # ===== SETUP LOGGING =====
    if rank == 0:
        exp_path = Path(__file__).parent / f"depth_exps/{exp_name}"
        exp_path.mkdir(parents=True, exist_ok=True)
        logger.add(exp_path / "training.log")
        logger.info(f"Config: \n {OmegaConf.to_yaml(cfg)}")

    try: 
        probe.load_state_dict(checkpoint["probe"], strict = True)
        print('successfully loaded probe weights from checkpoint')
    except: 
        raise ValueError('failed to load probe weights from checkpoint')

    # move to cuda
    model = model.to(rank)
    probe = probe.to(rank)


    # very hacky ... SAM gets some issues with DDP finetuning
    model_name = model.checkpoint_name
    if "sam" in model_name or "vit-mae" in model_name:
        h, w = train_loader.dataset.__getitem__(0)["image"].shape[-2:]
        model.resize_pos_embed(image_size=(h, w))

    # move to DDP
    if world_size > 1:
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
        probe = DDP(probe, device_ids=[rank])

    if cfg.optimizer.model_lr == 0:
        optimizer = torch.optim.AdamW(
            [{"params": probe.parameters(), "lr": cfg.optimizer.probe_lr}]
        )
    else:
        optimizer = torch.optim.AdamW(
            [
                {"params": probe.parameters(), "lr": cfg.optimizer.probe_lr},
                {"params": model.parameters(), "lr": cfg.optimizer.model_lr},
            ]
        )

    # lambda_fn = lambda epoch: cosine_decay_linear_warmup(  # noqa: E731
    #     epoch,
    #     cfg.optimizer.n_epochs * len(train_loader),
    #     cfg.optimizer.warmup_epochs * len(train_loader),
    # )
    lambda_fn = lambda epoch: 1.0
    scheduler = LambdaLR(optimizer, lr_lambda=lambda_fn)
    if "loss" in cfg: 
        loss_fn = instantiate(cfg.loss)
    else: 
        # default loss before adding it to the config
        loss_fn = DepthLoss()

    
    
    if rank == 0:
        logger.info(f"Evaluating on test split of {test_dset}")

        test_sa_loss, test_sa_metrics = validate(model, probe, test_loader, loss_fn, cfg = cfg)
        logger.info(f"Scale-Aware Final test loss       | {test_sa_loss:.4f}")
        for metric in test_sa_metrics:
            logger.info(f"Final test SA {metric:10s} | {test_sa_metrics[metric]:.4f}")

        wandb.log({"test_loss": test_sa_loss})
        for metric in test_sa_metrics:
            wandb.log({"test_sa_" + metric: test_sa_metrics[metric]})
        results_sa = ", ".join([f"{test_sa_metrics[_m]:.4f}" for _m in test_sa_metrics])

        # get scale invariant
        test_si_loss, test_si_metrics = validate(
            model, probe, test_loader, loss_fn, scale_invariant=True, cfg = cfg
        )
        logger.info(f"Scale-Invariant Final test loss       | {test_si_loss:.4f}")
        for metric in test_si_metrics:
            logger.info(f"Final test SI {metric:10s} | {test_si_metrics[metric]:.4f}")
        
        wandb.log({"test_si_loss": test_si_loss})
        for metric in test_si_metrics:
            wandb.log({"test_si_" + metric: test_si_metrics[metric]})

        results_si = ", ".join([f"{test_si_metrics[_m]:.4f}" for _m in test_si_metrics])

        # log experiments
        exp_info = ", ".join(model_info + probe_info + train_info)
        log = f"{timestamp}, {exp_info}, {results_si} \n"
        # log = f"{timestamp}, {exp_info}, {results_sa}, {results_si} \n" #TODO: add sa results
        with open(f"depth_results_{test_dset}.log", "a") as f:
            f.write(log)

        # save final model
        ckpt_path = exp_path / "ckpt.pth"
        checkpoint = {
            "cfg": cfg,
            "probe": probe.state_dict(),
        }
        torch.save(checkpoint, ckpt_path)
        logger.info(f"Saved checkpoint at {ckpt_path}")

    if world_size > 1:
        destroy_process_group()


@hydra.main(config_name="depth_video_training", config_path="./configs", version_base=None)
def main(cfg: DictConfig):
    # initialize wandb 
    wandb.config = omegaconf.OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True
    )
    wandb.init(project="depth_video_probe", name=cfg.note)
    world_size = cfg.system.num_gpus
    if world_size > 1:
        mp.spawn(evaluate_model, args=(world_size, cfg), nprocs=world_size)
    else:
        evaluate_model(0, world_size, cfg)


if __name__ == "__main__":
    main()
    

In [4]:
# image original dift checkpoint 
# Initialize Hydra and compose the configuration 
with initialize(version_base=None, config_path="./configs"):
    cfg = compose(config_name="depth_video_training", overrides=['+checkpoint_path="/work/ececis_research/peace/dino-diffusion/probe3d/depth_exps/21042024-1543_stable-diffusion-2-1_noise-1_16_0-1-2-3_dense_bindepth_dpt_k3_10_1.50_0.0005_0.0_16_NYUv2_NYUv2_dift_video_probe_04_21/ckpt_4_20000.pth"'])


In [5]:
cfg.checkpoint_path

'/work/ececis_research/peace/dino-diffusion/probe3d/depth_exps/21042024-1543_stable-diffusion-2-1_noise-1_16_0-1-2-3_dense_bindepth_dpt_k3_10_1.50_0.0005_0.0_16_NYUv2_NYUv2_dift_video_probe_04_21/ckpt_4_20000.pth'

In [11]:
if "checkpoint_path" in cfg and cfg.checkpoint_path: 
    checkpoint_path = get_latest_checkpoint(cfg.checkpoint_path)
    print("loading checkpoint from ", checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    cfg = checkpoint["cfg"]


# ===== GET DATA LOADERS =====
# validate and test on single gpu
# train_loader = build_loader(cfg.dataset, "train", 1, world_size)
# test_loader = build_loader(cfg.dataset, "valid", 1, 1)
train_loader = build_loader(cfg.dataset, "train", 1, 1)
test_loader = build_loader(cfg.dataset, "valid", 1, 1)
train_loader.dataset.__getitem__(0)

# ===== Get models =====
model_image = instantiate(cfg.backbone)
probe_image = instantiate(
    cfg.probe, feat_dim=model_image.feat_dim, max_depth=train_loader.dataset.max_depth
)


try: 
    probe_image.load_state_dict(checkpoint["probe"], strict = True)
    print('successfully loaded probe weights from checkpoint')
except: 
    raise ValueError('failed to load probe weights from checkpoint')

# move to cuda
model_image = model_image.to("cuda")
probe_image = probe_image.to("cuda")



image_mean None
number of scenes: 348
NYU-GeoNet train: 20227 sets of frames.
image_mean None
number of scenes: 348
NYU-GeoNet valid: 5510 sets of frames.


  snorm = torch.tensor(snorm).float()


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

successfully loaded probe weights from checkpoint


In [13]:
# video original probe nyu checkpoint 
with initialize(version_base=None, config_path="./configs"):
    cfg = compose(config_name="depth_video_training", overrides=['+checkpoint_path="/work/ececis_research/peace/dino-diffusion/probe3d/depth_exps/21042024-0525_dynamicrafter_512_interp_v1_noise-1_16_0-4-8-11_dense_bindepth_dpt_k3_10_1.50_0.0005_0.0_1_NYUv2_NYUv2_dynamicrafter_video_probe_step_1_04_21"'])


In [14]:
if "checkpoint_path" in cfg and cfg.checkpoint_path: 
    checkpoint_path = get_latest_checkpoint(cfg.checkpoint_path)
    print("loading checkpoint from ", checkpoint_path)
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    cfg = checkpoint["cfg"]


# ===== GET DATA LOADERS =====
# validate and test on single gpu
# train_loader = build_loader(cfg.dataset, "train", 1, world_size)
# test_loader = build_loader(cfg.dataset, "valid", 1, 1)
train_loader = build_loader(cfg.dataset, "train", 1, 1)
test_loader = build_loader(cfg.dataset, "valid", 1, 1)
train_loader.dataset.__getitem__(0)

# ===== Get models =====
model_video = instantiate(cfg.backbone)
probe_video = instantiate(
    cfg.probe, feat_dim=model_video.feat_dim, max_depth=train_loader.dataset.max_depth
)


try: 
    probe_video.load_state_dict(checkpoint["probe"], strict = True)
    print('successfully loaded probe weights from checkpoint')
except: 
    raise ValueError('failed to load probe weights from checkpoint')

# move to cuda
model_video = model_video.to("cuda")
probe_video = probe_video.to("cuda")



loading checkpoint from  /work/ececis_research/peace/dino-diffusion/probe3d/depth_exps/21042024-0525_dynamicrafter_512_interp_v1_noise-1_16_0-4-8-11_dense_bindepth_dpt_k3_10_1.50_0.0005_0.0_1_NYUv2_NYUv2_dynamicrafter_video_probe_step_1_04_21/ckpt_4_20000.pth
image_mean None
number of scenes: 348
NYU-GeoNet train: 20227 sets of frames.
image_mean None
number of scenes: 348
NYU-GeoNet valid: 5510 sets of frames.


  snorm = torch.tensor(snorm).float()


is_single_image False
config file used /work/ececis_research/peace/DynamiCrafter/configs/inference_512_v1.0.yaml
DiffusionWrapper instantiating from {'target': 'lvdm.modules.networks.openaimodel3d.UNetModel', 'params': {'in_channels': 8, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'dropout': 0.1, 'num_head_channels': 64, 'transformer_depth': 1, 'context_dim': 1024, 'use_linear': True, 'use_checkpoint': False, 'temporal_conv': True, 'temporal_attention': True, 'temporal_selfatt_only': True, 'use_relative_position': False, 'use_causal_attention': False, 'temporal_length': 16, 'addition_attention': True, 'image_cross_attention': True, 'default_fs': 24, 'fs_condition': True}}
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
Instantiating LatentVisualDiffusion
embedding model config: {'target': 'lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2', 'params': {'freeze': True}}
using custo

In [68]:

def make_video_from_depths(model_image, model_video, probe_image, probe_video, loader, fps=2, verbose=True, scale_invariant=False, aggregate=True, cfg=None, log_interval=20):
    # Initialize the video writer with OpenCV
    

    with torch.inference_mode():
        pbar = tqdm(loader, desc="Evaluation") if verbose else loader
        for i, batch in enumerate(pbar):
            fourcc = cv2.VideoWriter_fourcc(*"MJPG")
            video_writer = cv2.VideoWriter(f"image-video-comparisons/depth_comparison_{i+1}.avi", fourcc, fps, (1920, 480), True)
            images = batch["image"].cuda().squeeze(0)
            targets = batch["depth"].cuda().squeeze(0)

            print("Getting image features")
            feat_image = model_image(images)
            pred_image = probe_image(feat_image).detach()
            pred_image = interpolate(pred_image, size=targets.shape[-2:], mode="bilinear")

            print("Getting video features")
            feat_video = model_video(images)
            pred_video = probe_video(feat_video).detach()
            pred_video = interpolate(pred_video, size=targets.shape[-2:], mode="bilinear")

            for j in tqdm(range(images.size(0)), desc="Writing video outputs"):
                pred_image_np = (pred_image[j].cpu().numpy()[0] / pred_image[j].max().item() * 255).astype(np.uint8)
                pred_video_np = (pred_video[j].cpu().numpy()[0] / pred_video[j].max().item() * 255).astype(np.uint8)
                target_np = (targets[j].cpu().numpy()[0] / targets[j].max().item() * 255).astype(np.uint8)
                image_np = (images[j].cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8)

                # Stack images horizontally in the new order and convert grayscale predictions to RGB
                combined_image = np.hstack([
                    image_np, 
                    cv2.cvtColor(target_np, cv2.COLOR_GRAY2RGB),
                    cv2.cvtColor(pred_image_np, cv2.COLOR_GRAY2RGB), 
                    cv2.cvtColor(pred_video_np, cv2.COLOR_GRAY2RGB)
                ])
                
                # Adding labels
                cv2.putText(combined_image, "Original", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
                cv2.putText(combined_image, "Target", (image_np.shape[1] + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
                cv2.putText(combined_image, "Pred Image", (2 * image_np.shape[1] + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
                cv2.putText(combined_image, "Pred Video", (3 * image_np.shape[1] + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)

                video_writer.write(combined_image)

            video_writer.release()
            # break
            print(f"Video saved to 'image-video-comparisons/depth_comparison_{i+1}.avi'")
            if i >= 19:
                print("All done!")
                break

In [69]:
make_video_from_depths(model_image, model_video, probe_image, probe_video, test_loader, fps = 6)

  snorm = torch.tensor(snorm).float()
  snorm = torch.tensor(snorm).float()


Getting image features


[rank: 0] Seed set to 123


Getting video features


DDIM Sampler:   2%|█▉                                                                                             | 1/50 [00:00<00:20,  2.36it/s]
Writing video outputs: 100%|█████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 37.44it/s]
Evaluation:   0%|                                                                                                       | 0/5510 [00:04<?, ?it/s]
