In [1]:
# Pytorch
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50

# Others
import glob
import cv2
import numpy as np
from tqdm.notebook import tqdm
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
model = resnet50()
model.fc = nn.Linear(2048, 4)

checkpoint = torch.load('/home/vdd/MIPT/v2/checkpoints/epoch-205_loss_0.00018.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])

model = model.to(device)

In [3]:
class SATDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index: int):
        img_path = self.paths[index]
        roi = cv2.imread(img_path)
        roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
        roi = self.to_tensor(roi)
        roi = self.normalize(roi)
        return roi

img_paths = sorted(glob.glob('/home/vdd/main/vdd/sat_model/test/*.png'), key=lambda x: int(x.split('/')[-1].split('.')[0]))

BATCH_SIZE = 2
NUM_WORKERS = 4
dataset = SATDataset(img_paths)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

In [4]:
centers = []
angles = []

model.eval()
with torch.no_grad():
    for batch in tqdm(dataloader):
        batch = batch.to(device)
        result = model.forward(batch)
        result = result.detach().cpu().numpy()
        centers.append(result[:, :2] * 10496)
        angles.append(np.round(np.rad2deg(np.arctan2(result[:, 2], result[:, 3])) % 360).astype(int))

centers = np.concatenate(centers)
angles = np.concatenate(angles)

  0%|          | 0/200 [00:00<?, ?it/s]

In [5]:
import kornia as K
import kornia.feature as KF
from kornia_moons.feature import *

matcher = KF.LoFTR(pretrained='outdoor').cuda().eval()

def find_rigid_alignment(A, B):
    """
    See: https://en.wikipedia.org/wiki/Kabsch_algorithm
    2-D or 3-D registration with known correspondences.
    Registration occurs in the zero centered coordinate system, and then
    must be transported back.
        Args:
        -    A: Torch tensor of shape (N,D) -- Point Cloud to Align (source)
        -    B: Torch tensor of shape (N,D) -- Reference Point Cloud (target)
        Returns:
        -    R: optimal rotation
        -    t: optimal translation
    Test on rotation + translation and on rotation + translation + reflection
        >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
        >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
        >>> B = (R0.mm(A.T)).T
        >>> t0 = torch.tensor([3., 3.])
        >>> B += t0
        >>> R, t = find_rigid_alignment(A, B)
        >>> A_aligned = (R.mm(A.T)).T + t
        >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
        >>> rmsd
        tensor(3.7064e-07)
        >>> B *= torch.tensor([-1., 1.])
        >>> R, t = find_rigid_alignment(A, B)
        >>> A_aligned = (R.mm(A.T)).T + t
        >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
        >>> rmsd
        tensor(3.7064e-07)
    """
    a_mean = A.mean(axis=0)
    b_mean = B.mean(axis=0)
    A_c = A - a_mean
    B_c = B - b_mean
    # Covariance matrix
    H = A_c.T.mm(B_c)
    U, S, V = torch.svd(H)
    # Rotation matrix
    R = V.mm(U.T)
    # Translation vector
    t = b_mean[None, :] - R.mm(a_mean[None, :].T).T
    t = t.T
    return R, t.squeeze()

def extract_region(img, center, size, angle):
    # Extract region from image around the center
    radius = np.ceil(np.sqrt(size** 2 * 2) / 2).astype(int)
    assert min(center) >= radius, 'center is too close to the border'
    cx, cy = center
    roi = img[cy-radius:cy+radius, cx-radius:cx+radius]

    # Rotate this region
    h, w = roi.shape[:2]
    M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1.0)
    roi = cv2.warpAffine(roi, M, (w, h))

    # Center crop roi
    start_y = (h - size) // 2
    start_x = (w - size) // 2
    roi = roi[start_y:start_y+size, start_x:start_x+size]

    return roi

sat_map = cv2.imread('/home/vdd/main/vdd/sat_model/original.tiff', 0)

In [14]:
delta_rots = []
delta_trans = []
for center, angle, tgt_img_path in tqdm(zip(centers, angles, img_paths)):
    delta_angle = 0
    translation = np.array([0, 0])
    if (center > 726).all() & (center < 9770).all():
        try:
            sat_img = extract_region(sat_map, np.round(center).astype(int), 1024, angle)
            sat_img = transforms.ToTensor()(sat_img).to(device).unsqueeze(dim=0)
            tgt_img = cv2.imread(tgt_img_path, 0)
            tgt_img = transforms.ToTensor()(tgt_img).to(device).unsqueeze(dim=0)

            input_dict = {"image0": sat_img, "image1": tgt_img}
            with torch.no_grad():
                correspondences = matcher(input_dict)

            mkpts0 = correspondences['keypoints0'].cpu().numpy()
            mkpts1 = correspondences['keypoints1'].cpu().numpy()

            H, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.5, 0.999, 5000)
            inliers = inliers > 0
            
            if sum(inliers) > 200:
                R, t = find_rigid_alignment(correspondences['keypoints0'][inliers.squeeze()], correspondences['keypoints1'][inliers.squeeze()])
                delta_angle = torch.rad2deg(torch.atan2(R[0][1], R[0][0])).detach().cpu().item()
                translation = t.detach().cpu().numpy()
        except:
            pass

    delta_rots.append(delta_angle)
    delta_trans.append(translation)

0it [00:00, ?it/s]

  mkpts0_c = torch.stack([i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], dim=1) * scale0
  mkpts1_c = torch.stack([j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], dim=1) * scale1


In [16]:
def rotate(p, origin=(0, 0), degrees=0):
    angle = np.deg2rad(degrees)
    R = np.array([[np.cos(angle), -np.sin(angle)],
                  [np.sin(angle),  np.cos(angle)]])
    o = np.atleast_2d(origin)
    p = np.atleast_2d(p)
    return np.squeeze((R @ (p.T-o.T) + o.T).T)

results = []
HALF_SIZE = 512
for center, angle, trans, rot in zip(centers, angles, delta_trans, delta_rots):
    bbox = np.array([
        center + np.array([-HALF_SIZE, -HALF_SIZE]) - trans,
        center + np.array([HALF_SIZE, -HALF_SIZE]) - trans,
        center + np.array([-HALF_SIZE, HALF_SIZE]) - trans,
        center + np.array([HALF_SIZE, HALF_SIZE]) - trans,
    ])
    bbox = rotate(bbox, center, angle)
    bbox = np.round(bbox).astype(int)
    results.append({
        'left_top': bbox[0].tolist(),
        'right_top': bbox[1].tolist(),
        'left_bottom': bbox[2].tolist(),
        'right_bottom': bbox[3].tolist(),
        'angle': np.round(angle.item() - rot).astype(int).item()
    })

In [17]:
import json
for pred, path in zip(results, img_paths):
    name = path.split('/')[-1].split('.')[0]
    with open(f'pred/{name}.json', 'w') as f:
        json.dump(pred, f)