In [1]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
import cv2

import matplotlib.pyplot as plt
from mono.mono_head import InferDAM
from stereo.stereo_head import InferCREStereo

import torchvision.models as models

xFormers not available
xFormers not available


In [2]:
# gt_dir='/home/william/extdisk/data/realsense-D455_depth_image/sample/depth'
# raft_dir='/home/william/extdisk/data/realsense-D455_depth_image/sample/raft'
# gms_dir='/home/william/extdisk/data/realsense-D455_depth_image/sample/gmstereo'
# cres_dir='/home/william/extdisk/data/realsense-D455_depth_image/sample/crestero'
# cmp_dir='/home/william/extdisk/data/realsense-D455_depth_image/sample/research'
# left_dir='/home/william/extdisk/data/realsense-D455_depth_image/sample/infra1'
# right_dir='/home/william/extdisk/data/realsense-D455_depth_image/sample/infra2'
# mono_dir='/home/william/extdisk/data/realsense-D455_depth_image/sample/damv2'
# focal = 390.81134033203125
# baseline = 94.994
# cx = 320.83
# cy = 245.31
# # os.makedirs(cmp_dir, exist_ok=True)
# # os.makedirs(mono_dir, exist_ok=True)

In [3]:
# f_gts = sorted(Path('/home/william/extdisk/data/realsense-D455_depth_image/sample/depth').glob('*.png'))
# f_rafts = sorted(Path("/home/william/extdisk/data/realsense-D455_depth_image/sample/raft").glob('*.npy'))
# f_gms = sorted(Path("/home/william/extdisk/data/realsense-D455_depth_image/sample/gmstereo").glob('*_disp.pfm'))
# f_cres = sorted(Path("/home/william/extdisk/data/realsense-D455_depth_image/sample/crestereo").glob('*.npy'))
# f_left = sorted(Path(left_dir).glob('*.png'))
# f_right = sorted(Path(right_dir).glob('*.png'))
# f_mono = sorted(Path(mono_dir).glob('*.npy'))


In [4]:
import re

def read_pfm(file):
    file = open(file, 'rb')

    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().rstrip()
    if header.decode("ascii") == 'PF':
        color = True
    elif header.decode("ascii") == 'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
    if dim_match:
        width, height = list(map(int, dim_match.groups()))
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().decode("ascii").rstrip())
    if scale < 0:  # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>'  # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    return data*scale


In [5]:
from mono.depth_anything_v2.dinov2 import DINOv2

In [98]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math

class DinoV2FeatureExtractor(nn.Module):
    def __init__(self, dino_model):
        super().__init__()
        self.dino_model = dino_model
        for param in self.parameters():
            param.requires_grad = False
    def forward(self, x):
        f = self.dino_model.get_intermediate_layers(x, n=1)[0]  
        b, n, d = f.shape
        # Remove CLS token
        p = f[:, 1:, :]  
        num_patches = p.shape[1]  # N - 1
        s = int(math.sqrt(num_patches))
        if s * s != num_patches:
            raise ValueError(
                f"Number of patch tokens ({num_patches}) is not a perfect square."
                f" Image size or patch size might be incompatible."
            )
        # s = int(math.sqrt(n - 1))
        p = p.reshape(b, s, s, d).permute(0, 3, 1, 2)
        return p

class DisparityRefinementNet(nn.Module):
    def __init__(self, in_channels=2, feat_channels=384, out_channels=1):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels + feat_channels, 64, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(64, out_channels, 3, 1, 1)
        )
    def forward(self, dm, ds, dino_feat):
        b, _, h, w = dm.shape
        c = torch.cat([dm, ds], dim=1)
        u = F.interpolate(dino_feat, size=(h, w), mode='bilinear', align_corners=False)
        x = torch.cat([c, u], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

def warp_right_to_left(right_img, disp):
    b, c, h, w = right_img.shape
    y = torch.linspace(0, h - 1, h, device=right_img.device).view(1, h, 1).expand(b, h, w)
    x = torch.linspace(0, w - 1, w, device=right_img.device).view(1, 1, w).expand(b, h, w)
    x_warp = x - disp[:, 0, :, :]
    grid_x = 2.0 * x_warp / (w - 1) - 1.0
    grid_y = 2.0 * y / (h - 1) - 1.0
    grid = torch.stack([grid_x, grid_y], dim=3)
    return F.grid_sample(right_img, grid, mode='bilinear', padding_mode='border', align_corners=False)


In [99]:
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
import ast
import re
from PIL import Image

class Middlebury2014(Dataset):
    def __init__(self, data_list_path):
        with open(data_list_path, "r") as f:
            self.data_list = [line.strip() for line in f if line.strip()]
        self.data_list.sort()
        self.l_name = "im0.png"
        self.r_name = "im1.png"
        self.l_disp_name = "disp0.pfm"
        self.r_disp_name = "disp1.pfm"

        self.img_h = 518
        self.img_w = 518

        self.transform = transforms.Compose([
            transforms.Resize((self.img_h, self.img_w)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        scene_dir = self.data_list[index]
        left_img_path = os.path.join(scene_dir, self.l_name)
        right_img_path = os.path.join(scene_dir, self.r_name)
        left_disp_path = os.path.join(scene_dir, self.l_disp_name)
        right_disp_path = os.path.join(scene_dir, self.r_disp_name)

        left_img_bgr = cv2.imread(left_img_path, cv2.IMREAD_ANYCOLOR)
        right_img_bgr = cv2.imread(right_img_path, cv2.IMREAD_ANYCOLOR)

        left_img_rgb = cv2.cvtColor(left_img_bgr, cv2.COLOR_BGR2RGB)
        right_img_rgb = cv2.cvtColor(right_img_bgr, cv2.COLOR_BGR2RGB)

        left_disp = read_pfm(left_disp_path)
        right_disp = read_pfm(right_disp_path)

        left_img_pil = Image.fromarray(left_img_rgb)
        right_img_pil = Image.fromarray(right_img_rgb)

        left_img_tensor = self.transform(left_img_pil)
        right_img_tensor = self.transform(right_img_pil)

        left_disp_tensor = torch.from_numpy(left_disp).unsqueeze(0).float()
        right_disp_tensor = torch.from_numpy(right_disp).unsqueeze(0).float()

        sample = {
            "left_image": left_img_tensor,   # (3, 518, 518)
            "right_image": right_img_tensor, # (3, 518, 518)
            "left_disp": left_disp_tensor,   # (1, H, W)
            "right_disp": right_disp_tensor  # (1, H, W)
        }

        return sample


In [100]:
middlebury_data_list_path="/home/william/extdisk/data/middlebury/middlebury2014/middlebury2014_dataset.txt"
m2014_dataset = Middlebury2014(middlebury_data_list_path)
dino_model = DINOv2('vits')
mono_model = InferDAM()
mono_model.initialize(model_path='/home/william/extdisk/checkpoints/depth-anything/depth_anything_v2_vits.pth', encoder='vits')

stereo_model = InferCREStereo()
stereo_model.initialize(model_path='/home/william/extdisk/checkpoints/CREStereo/crestereo_eth3d.pth')

In [101]:
def train_one_epoch(dino_extractor, mono_net, stereo_net, refine_net, dataloader, optimizer, alpha=0.1, device="cuda"):
    dino_extractor.eval()
    refine_net.train()
    s = 0
    for batch in dataloader:
        l = batch["left_image"].to(device)
        r = batch["right_image"].to(device)
        with torch.no_grad():
            print(f"l shape is {l.shape}")
            fl = dino_extractor(l)
            dm = mono_net.predict(l)
            ds = stereo_net.predict(l, r)
            gt = ds
        pred = refine_net(dm, ds, fl)
        loss_gt = F.l1_loss(pred, gt)
        warped_r = warp_right_to_left(r, pred)
        loss_photo = F.l1_loss(l, warped_r)
        loss = loss_gt + alpha * loss_photo
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        s += loss.item()
    return s / len(dataloader)


In [102]:

dino_extractor = DinoV2FeatureExtractor(dino_model).cuda()
refine_net = DisparityRefinementNet(in_channels=2, feat_channels=768, out_channels=1).cuda()
dataset = m2014_dataset
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True)
optimizer = optim.AdamW(refine_net.parameters(), lr=1e-4, weight_decay=0.01)

for e in range(10):
    l = train_one_epoch(dino_extractor, mono_model, stereo_model, refine_net, dataloader, optimizer, alpha=0.1, device="cuda")
    print(e, l)

l shape is torch.Size([1, 3, 504, 504])


ValueError: Number of patch tokens (1295) is not a perfect square. Image size or patch size might be incompatible.