In [1]:
import os
import torch
from argparse import ArgumentParser

from torch import nn
from torch.utils.data import ConcatDataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import json
import wandb

from romatch.benchmarks import MegadepthDenseBenchmark
from romatch.datasets.megadepth import MegadepthBuilder
from romatch.losses.robust_loss import RobustLosses
from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark

from romatch.train.train import train_k_steps
from romatch.models.matcher import *
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
from romatch.models.encoders import *
from romatch.checkpointing import CheckPoint
import pdb
import numpy as np
import time


In [2]:
def measure_time(model, inputs, repeats=10):
    # Warm-up GPU
    for _ in range(3):
        _ = model(inputs)
        torch.cuda.synchronize()

    # Timing
    start_time = time.time()
    for _ in range(repeats):
        _ = model(inputs)
        torch.cuda.synchronize()
    end_time = time.time()

    avg_time = (end_time - start_time) / repeats
    return avg_time

In [7]:

diffusion_model = CNNandSD(
        cnn_kwargs = dict(
            pretrained=True,
            amp = True),
        amp = True,
        use_vgg = True,
    ).to('cuda')





In [3]:
dinov2_model = CNNandDinov2(
        cnn_kwargs = dict(
            pretrained=True,
            amp = True),
        amp = True,
        use_vgg = True,
    ).to('cuda')



In [4]:
B = 8
C = 3
H, W = 560, 560
inputs = torch.randn(B, C, H, W).half().cuda()  # Replace B, C, H, W with appropriate values
inputs.dtype



torch.float16

In [8]:
avg_time_diffusion = measure_time(diffusion_model.forward, inputs)
print(f"Average time for Diffusion: {avg_time_diffusion:.4f} seconds")

Average time for Diffusion: 0.4483 seconds


In [5]:
avg_time_dinov2 = measure_time(dinov2_model.forward, inputs)
print(f"Average time for DINOv2: {avg_time_dinov2:.4f} seconds")

Average time for DINOv2: 0.2274 seconds


In [16]:
import numpy as np
data_root = '/export/r24a/data/zshao/data/megadepth/megadepth_test_1500'
scene = '0022_0.1_0.3.npz'
scene = np.load(f"{data_root}/{scene}", allow_pickle=True)


In [19]:
import torch
sim = torch.randn(size=[2,4,4])

In [20]:
nn12 = torch.max(sim, dim=2)[1] 
nn21 = torch.max(sim, dim=1)[1]



In [29]:
ids1 = torch.arange(sim.shape[1], device=sim.device).unsqueeze(0).expand(sim.shape[0], -1)  # Shape: [B, N]

mutual_nn_mask = (ids1 == nn21.gather(1, nn12))

In [30]:
mutual_nn_mask

tensor([[False,  True, False,  True],
        [ True,  True,  True, False]])

In [32]:
matches = torch.stack([ids1[mutual_nn_mask], nn12[mutual_nn_mask]], dim=1)


In [None]:
matches = torch.stack([ids1[mask], nn12[mask]], dim=1).view(sim.shape[0], 2, -1)

In [28]:
ids1[mask]

IndexError: The shape of the mask [2, 4] at index 0 does not match the shape of the indexed tensor [1, 4] at index 0

In [18]:
len(scene["pair_infos"])

300

In [5]:
unet = MyUNet2DConditionModel.from_pretrained('stabilityai/stable-diffusion-2-1', subfolder="unet")
onestep_pipe = OneStepSDPipeline.from_pretrained('stabilityai/stable-diffusion-2-1', unet=unet, safety_checker=None)

In [6]:
onestep_pipe.vae.encoder

Encoder(
  (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (down_blocks): ModuleList(
    (0): DownEncoderBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): Si

In [9]:
onestep_pipe.vae.decoder

Decoder(
  (conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (up_blocks): ModuleList(
    (0): UpDecoderBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()

In [7]:
unet

MyUNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0): Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=320, out_features=320, bias=True)
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_featu