# SCOT

Notebook for SCOT (Semantic Correspondence as an Optimal Transport Problem).

References:

https://openaccess.thecvf.com/content_CVPR_2020/papers/Liu_Semantic_Correspondence_as_an_Optimal_Transport_Problem_CVPR_2020_paper.pdf

https://github.com/csyanbin/SCOT

## Setup and installation

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Move to SCOT folder under the project root directory.

In [None]:
%cd '/content/drive/MyDrive/SemesterProj/SCOT/'

Install required libraries

In [None]:
!pip install scikit-image
!pip install pandas
!pip install requests

## Inference

### Imports and helper functions

In [None]:
from numpy.compat import integer_types
import argparse
import datetime
import os
import logging
import time
import cv2
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from model import scot_CAM, geometry, evaluation, util
from data import dataset, download
from PIL import Image
import numpy as np
from itertools import product

In [None]:
# Class for accesing dict attributes with dot
class AttributeDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def threshold_img(img_hsv, int_val=False):
    """
    Create a categprical label based on HSV colorspace thresholding.
    """
    # define range of blue color in HSV
    lower_blue = np.array([115, 50, 50])
    upper_blue = np.array([125, 255, 255])
    mask_blue = cv2.inRange(img_hsv, lower_blue, upper_blue)

    # define range of green color in HSV
    lower_green = np.array([55, 50, 50])
    upper_green = np.array([65, 255, 255])
    mask_green = cv2.inRange(img_hsv, lower_green, upper_green)

    # lower mask (0-10)
    lower_red = np.array([0, 50, 50])
    upper_red = np.array([5, 255, 255])
    mask0_red = cv2.inRange(img_hsv, lower_red, upper_red)

    # upper mask (170-180)
    lower_red = np.array([175, 50, 50])
    upper_red = np.array([180, 255, 255])
    mask1_red = cv2.inRange(img_hsv, lower_red, upper_red)

    # join my masks
    mask_red = mask0_red + mask1_red

    if int_val:
        # set my output img to zero everywhere except my mask
        output_hsv = np.zeros((img_hsv.shape[0], img_hsv.shape[1]))
        output_hsv[np.where(mask_red == 255)] = 1
        output_hsv[np.where(mask_blue == 255)] = 2
        output_hsv[np.where(mask_green == 255)] = 3
    else:
        output_hsv = np.zeros((img_hsv.shape[0], img_hsv.shape[1], 4))
        # Set background one-hot encoding
        output_hsv[..., 0] = 1
        # Set classes one-hot encoding
        output_hsv[np.where(mask_red == 255)] = [0, 1, 0, 0]
        output_hsv[np.where(mask_blue == 255)] = [0, 0, 0, 1]
        output_hsv[np.where(mask_green == 255)] = [0, 0, 1, 0]

    return output_hsv.astype(np.uint8)


def search_rect_space(x, y, map_w, map_h, search_w, search_h):
    """
    Define a rectangular search space based on current point and window size.
    """
    # Return top, left, right, bottom coordinates of a patch centered in x,y
    search_min_r = max(x - search_h // 2, 0)
    search_max_r = min(x + search_h // 2, map_h)
    search_min_c = max(y - search_w // 2, 0)
    search_max_c = min(y + search_w // 2, map_w)

    r_range = list(range(search_min_r, search_max_r))
    c_range = list(range(search_min_c, search_max_c))

    return list(product(r_range, c_range))


def change_label_to_onehot(label, new_h, new_w):
    label_int = np.zeros((new_h, new_w, 4), dtype=np.uint8)
    label_int[..., 0] = 1
    label_int[np.where(label == 1)] = [0, 1, 0, 0]
    label_int[np.where(label == 2)] = [0, 0, 1, 0]
    label_int[np.where(label == 3)] = [0, 0, 0, 1]

    return label_int

### Inference


In [None]:
def run(datapath, ann_dir, img_dir, drone_dir, benchmark, backbone, thres, alpha, hyperpixel,
        logpath, args, beamsearch=False, model=None, dataloader=None,
        weight_type='average_channel', output_path='./', side_thres=300, show=False,
        interp='nn', search_dim_w='50'):
    
    """Runs Semantic Correspondence as an Optimal Transport Problem"""

    # Create output_path if does not exists
    output_path = os.path.expanduser(output_path)
    os.makedirs(output_path, mode=0o777, exist_ok=True)

    # Evaluation benchmark initialization
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if dataloader is None:
        if benchmark != 'drone':
            download.download_dataset(os.path.abspath(datapath), benchmark)
        split = args.split
        dset = download.load_dataset(benchmark, datapath, thres, device, split, args.cam, ann_dir, img_dir, drone_dir)
        dataloader = DataLoader(dset, batch_size=1, num_workers=0)

    # Model initialization
    if model is None:
        model = scot_CAM.SCOT_CAM(backbone, hyperpixel, benchmark, device, args.cam)
    else:
        model.hyperpixel_ids = util.parse_hyperpixel(hyperpixel)

    for idx, data in enumerate(dataloader):
        # Retrieve images and adjust their sizes to avoid large numbers of hyperpixels
        data['src_img'], data['src_kps'], data['src_intratio'] = util.resize(data['src_img'], data['src_kps'][0], side_thres=side_thres)
        data['trg_img'], data['trg_kps'], data['trg_intratio'] = util.resize(data['trg_img'], data['trg_kps'][0], side_thres=side_thres)
        src_size = data['src_img'].size()
        trg_size = data['trg_img'].size()
        data['src_mask'] = None
        data['trg_mask'] = None
        data['alpha'] = alpha

        # Feed a pair of images to Hyperpixel Flow model
        with torch.no_grad():
            # Confidence matrix size is img_size / 4
            confidence_ts, src_box, trg_box, hyper_h, hyper_w = model(data['src_img'], data['trg_img'], args.sim, args.exp1, args.exp2, args.eps, args.classmap, data['src_bbox'], data['trg_bbox'], data['src_mask'], data['trg_mask'], backbone)
            
            # Read satellite label and transform according to interpolation method
            if interp == 'nn':
                label = Image.open(os.path.join(dset.label_img_path, data['label_imname'][0]))
                orig_h, orig_w = label.size[1], label.size[0]
                label = np.array(label.resize((hyper_w, hyper_h), Image.NEAREST))
                label_int = change_label_to_onehot(label, hyper_h, hyper_w)
            else:
                label = np.array(Image.open(os.path.join(dset.label_img_path, data['label_imname'][0])).convert("RGB"))[:, :, ::-1]
                orig_h, orig_w = label.shape[0], label.shape[1]
                label = cv2.resize(label, (hyper_w, hyper_h), interpolation=cv2.INTER_AREA)
                label_hsv = cv2.cvtColor(label, cv2.COLOR_BGR2HSV)
                label_int = threshold_img(label_hsv)

            # Reshape label
            label_h, label_w = label_int.shape[0], label_int.shape[1]
            old_label = label_int.reshape((label_h * label_w, 4))

            # Build map of 2d indexes to 1d indexes for confidence matrix
            map_two_one = {}
            for i in range(confidence_ts.size()[0]):
                map_two_one[(i // label_w, i % label_w)] = i

            # Define windows size for filtered search space in confidence matrix
            if 'filtered' in weight_type:
                size_ratio = label_w / label_h
                search_w = search_dim_w
                search_h = int(search_w / size_ratio)

            # Create warped image by averaging
            if 'average' in weight_type:
                if 'average_filtered' in weight_type:
                    confidence_ts = confidence_ts.detach().cpu().numpy()
                    weighted_label = np.zeros((confidence_ts.shape[0], old_label.shape[1]))

                    for i in range(label_h):
                        for j in range(label_w):
                            # Define list of points to be searched and transform to 1d coordinates
                            search_points_2d = search_rect_space(i, j, map_w=label_w, map_h=label_h, search_w=search_w, search_h=search_h)
                            points_1d_coord = [map_two_one[p] for p in search_points_2d]

                            # Normalize confidence map row of interested points, to sum to 1
                            coef_conf = 1 / (confidence_ts[map_two_one[(i, j)], points_1d_coord].sum())
                            row_confidence_ts = (coef_conf * confidence_ts[map_two_one[(i, j)], points_1d_coord])
                            # Compute new weighted label
                            weighted_label[map_two_one[(i, j)]] = np.dot(row_confidence_ts[np.newaxis, ...], old_label[points_1d_coord, :])

                else:
                    # Normalize confidence map rows, to sum to 1
                    coef_conf = 1 / confidence_ts.sum(1, keepdim=True)
                    confidence_ts = (coef_conf * confidence_ts).detach().cpu().numpy()
                    # Compute new weighted label
                    weighted_label = np.dot(confidence_ts, old_label)

                # Take the maximum 
                new_label = np.argmax(weighted_label, axis=1).reshape((label_int.shape[0], label_int.shape[1]))

                # Change drone label format
                drone_label = change_label_to_onehot(new_label, new_label.shape[0], new_label.shape[1])

            # Create warped image by taking the maximum
            elif 'max' in weight_type:
                if 'max_filtered' in weight_type:
                    confidence_ts = confidence_ts.detach().cpu().numpy()
                    drone_label = np.zeros_like(label_int)

                    for i in range(label_h):
                        for j in range(label_w):
                            # Define list of points to be searched and transform to 1d coordinates
                            search_points_2d = search_rect_space(i, j, map_w=label_w, map_h=label_h, search_w=search_w, search_h=search_h)
                            points_1d_coord = [map_two_one[p] for p in search_points_2d]

                            # Take the maximum from possible points and compute the new label
                            trg_idx = np.argsort(confidence_ts[map_two_one[(i, j)], points_1d_coord])[::-1][0]
                            trg_idx = points_1d_coord[trg_idx]
                            drone_label[i, j] = label_int[trg_idx // label_w, trg_idx % label_w]

                else:
                    # Take the maximum and compute new label
                    conf, trg_indices = torch.max(confidence_ts, dim=1)
                    trg_indices = trg_indices.detach().cpu().numpy()
                    drone_label = np.zeros_like(label_int)
                    for i, index in enumerate(trg_indices):
                        y_coord_drone = i % label_w
                        x_coord_drone = i // label_w
                        y_coord_sat = index % label_w
                        x_coord_sat = index // label_w
                        drone_label[x_coord_drone, y_coord_drone] = label_int[x_coord_sat, y_coord_sat]

            drone_label = drone_label.astype(np.uint8)

            # Resize back to original size according to interpolation method
            if interp == 'nn':
                drone_label_int = np.zeros((drone_label.shape[0], drone_label.shape[1]))
                drone_label_int[np.where(drone_label[..., 1] == 1)] = 1
                drone_label_int[np.where(drone_label[..., 2] == 1)] = 2
                drone_label_int[np.where(drone_label[..., 3] == 1)] = 3
                drone_label = Image.fromarray(drone_label_int.astype(np.uint8))
                output_label = np.array(drone_label.resize((orig_w, orig_h), Image.NEAREST))
            else:
                drone_label_rgb = np.zeros((drone_label.shape[0], drone_label.shape[1], 3))
                drone_label_rgb[np.where(drone_label[..., 1] == 1)] = [255, 0, 0]
                drone_label_rgb[np.where(drone_label[..., 2] == 1)] = [0, 255, 0]
                drone_label_rgb[np.where(drone_label[..., 3] == 1)] = [0, 0, 255]
                drone_label_rgb = drone_label_rgb.astype(np.uint8)

                drone_label = cv2.resize(drone_label_rgb, (orig_w, orig_h), interpolation=cv2.INTER_CUBIC)
                drone_label = drone_label[:, :, ::-1]
                drone_label_hsv = cv2.cvtColor(drone_label, cv2.COLOR_BGR2HSV)
                output_label = threshold_img(drone_label_hsv, int_val=True)
            
            # Convert warped label to 'P' mode, put palette and save
            seg_img = Image.fromarray(output_label).convert('P')
            seg_img.putpalette(np.array([[0, 0, 0], [255, 0, 0], [0, 0, 255], [0, 255, 0]], dtype=np.uint8))
            seg_img.save(os.path.join(output_path, data['label_imname'][0]))

            if show:
                label =  np.array(Image.open(os.path.join(output_path, data['label_imname'][0])).convert("RGB"))
                drone_img =  np.array(Image.open(os.path.join(dset.src_img_path, data['label_imname'][0].split('.')[0] + '.jpg')))
                overlay = cv2.addWeighted(drone_img, 0.6, label, 0.7, 0)
                plt.imshow(overlay)

        print("IDX is {}".format(idx))


Configuration and running

In [None]:
# Set configuration dictionary

args = {# GPU id
        'gpu': '0',
        # Dataset path
        'datapath': '../data/120m/',
        # Satellite labels directory
        'ann_dir': 'ann_dir',
        # Satellite images directory
        'img_dir': 'img_dir',
        # Drone images directory
        'drone_dir': 'drone_dir',
        # Split to use: 'train' or 'test' - sequentially given for full dataset generation
        'split': 'test',
        # Dataset type
        'dataset': 'drone',
        # Backbone type
        'backbone': 'resnet101',
        # Thresholding type
        'thres': 'img',
        'alpha': 0.05,
        'hyperpixel': '(2,4,5,18,19,20,24,32)',
        'logpath': '',
        # Similarity type: OT, cos, OTGeo, cosGeo (the last 2 use RHM)
        'sim': 'OTGeo',
        # Exponential factor on initial cosine cost (default)
        'exp1': 1.0,
        # Exponential factor on final OT scores (default)
        'exp2': 0.5,
        # Epsilon for Sinkhorn Regularization (default)
        'eps': 0.05,
        # Whether to use(1) classmap or not(0) - we skip classmap here
        'classmap': 0,
        # Activation map folder, empty for end2end computation
        'cam': '',
        # Method to compute warped satellite labels: 'max', 'max_filtered', 'average' or 'average_filtered'
        'weight_type': 'max',
        # Target size of the biggest image side when used for training
        'side_thres': 480,
        # Output path for warped images
        'output_path': '',
        # Interpolation type when resizing: 'nn' or 'threshold_csv'
        'interp': 'threshold_csv',
        # If any 'filtered' method is used for 'weight_type', this is the window size for the search space,
        # when constraining the reconstruction
        'search_dim_w': 50}

args = AttributeDict(args)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.output_path = os.path.join(args.datapath, 'scot_warped_{}_{}'.format(args.side_thres, args.weight_type))

# Call SCOT function
run(datapath=args.datapath, ann_dir=args.ann_dir, img_dir=args.img_dir,
    drone_dir=args.drone_dir, benchmark=args.dataset, backbone=args.backbone,
    thres=args.thres, alpha=args.alpha, hyperpixel=args.hyperpixel,
    logpath=args.logpath, args=args, beamsearch=False, weight_type=args.weight_type, 
    output_path=args.output_path, side_thres=args.side_thres, show=True,
    interp=args.interp, search_dim_w=args.search_dim_w)