In [None]:
import os
import time
import random
from collections import namedtuple

from tqdm import tnrange
import numpy as np
import cv2

import torch
from torch.utils import data
from torch.optim.lr_scheduler import StepLR

from utils import channel_shift
from utils import to_hls
from utils import to_bgr
from utils import read_duplicate_truth
from utils import read_image_duplicate_tiles
from utils import even_split
from utils import create_dataset_from_tiles_and_truth
from dupnet import create_loss_and_optimizer
from dupnet import save_checkpoint
from dupnet import DupCNN

%matplotlib inline
%reload_ext autoreload
%autoreload 2

ship_dir = "data/input/"
train_768_dir = os.path.join("data", "train_768")
train_256_dir = os.path.join(ship_dir, "train_256")
image_duplicate_tiles_file = os.path.join("data", "image_duplicate_tiles.txt")
duplicate_truth_file = os.path.join("data", "duplicate_truth.txt")
os.makedirs(train_256_dir, exist_ok=True)

img_ids = os.listdir(train_768_dir)
len(img_ids)

## Speed Test

In [None]:
def read_from_large(img, ij):
    i, j = ij_pairs[ij]
    tile = img[i * 256:(i + 1) * 256, j * 256:(j + 1) * 256, :]
    tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    return tile

def read_from_small(img_id, ij):
    filebase, fileext = img_id.split('.')
    tile_id = f'{filebase}_{ij}.{fileext}'
    tile = cv2.imread(os.path.join(train_256_dir, tile_id))
    tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    return tile

def from_large(img_id, ij, kl):
    img = cv2.imread(os.path.join(train_768_dir, img_id))
    tile1 = read_from_large(img, ij)
    tile2 = read_from_large(img, kl)
    return np.dstack([tile1, tile2])

def from_small(img_id, ij, kl):
    tile1 = read_from_small(img_id, ij)
    tile2 = read_from_small(img_id, kl)
    return np.dstack([tile1, tile2])

def from_both(img_id, ij, kl):
    img = cv2.imread(os.path.join(train_768_dir, img_id))
    tile1 = read_from_large(img, ij)
    tile2 = read_from_small(img_id, kl)
    return np.dstack([tile1, tile2])

n_steps = 500

ij_pairs = ((0, 0), (0, 1), (0, 2),
            (1, 0), (1, 1), (1, 2),
            (2, 0), (2, 1), (2, 2))
ijkl = []
for i in range(n_steps):
    ij, kl = np.random.choice(len(ij_pairs), 2, replace=False)
    ijkl.append((ij, kl))

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (ij, kl), img_id in zip(ijkl, img_id_list):
    _ = from_large(img_id, ij, kl)
print('from_large:', time.time() - t0)

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (ij, kl), img_id in zip(ijkl, img_id_list):
    _ = from_small(img_id, ij, kl)
print('from_small:', time.time() - t0)

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (ij, kl), img_id in zip(ijkl, img_id_list):
    _ = from_both(img_id, ij, kl)
print('from_both: ', time.time() - t0)

In [None]:
class RandomHorizontalFlip:
    """Horizontally flip the given numpy array randomly with a given probability.

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (Image): Image to be flipped.

        Returns:
            Image: Randomly flipped image.
        """
        if np.random.random() < self.p:
            return cv2.flip(img, 1)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomTransformC4:
    """Rotate a n-D tensor by 90 degrees in the H x W plane.

    Args:
        with_identity (bool): whether or not to include 0 degrees as a probable rotation
    """

    def __init__(self, with_identity=True):
        self.with_identity = with_identity
        self.n90s = (0, 1, 2, 3) if self.with_identity else (1, 2, 3)

    def __call__(self, img):
        """
        Args:
            img (Image): Image to be rotated.

        Returns:
            Image: Randomly rotated image but in 90 degree increments.
        """
        k = random.choice(self.n90s)
        return torch.rot90(img, k, (1, 2))

    def __repr__(self):
        return self.__class__.__name__ + '(with_identity={})'.format(self.with_identity)


In [None]:
# dup_tiles = read_image_duplicate_tiles(image_duplicate_tiles_file)

idx_chan_map = {0: 'H', 1: 'L', 2: 'S'}

ImgAugs = namedtuple('ImgAugs', 'idx1 idx2 idx3 is_dup chan gain')

class Dataset(data.Dataset):
    """Characterizes a dataset for PyTorch"""
    def __init__(self, img_overlaps, train_768_dir, train_256_dir, 
                 in_shape=(6, 256, 256), 
                 out_shape=(1,), 
                 valid=False):
        
        """Initialization"""
        self.img_overlaps = img_overlaps
        self.train_768_dir = train_768_dir
        self.train_256_dir = train_256_dir
        self.ij = ((0, 0), (0, 1), (0, 2),
                   (1, 0), (1, 1), (1, 2),
                   (2, 0), (2, 1), (2, 2))
        
        self.in_shape = in_shape
        self.out_shape = out_shape
        self.hls_limits = {'H': 5, 'L': 10, 'S': 10}
        
        self.valid = valid
        if self.valid:
            self.img_augs = {}
            for img_overlap in self.img_overlaps:
                self.img_augs[img_overlap] = self.get_random_mapping(img_overlap)

    def __len__(self):
        """Denotes the total number of samples"""
        return len(self.img_overlaps)

    def __getitem__(self, index):
        """Generates one sample of data"""
        img_overlap = self.img_overlaps[index]
        if self.valid:
            img_aug = self.img_augs[img_overlap]
        else:
            img_aug = self.get_random_mapping(img_overlap)
        X, y = self.get_data_pair(img_overlap, img_aug)
        return X, y
    
    def get_random_mapping(self, img_overlap):
        
        img1_id, img2_id, img1_overlap_tag = img_overlap
        # get a list of 2-tuples for the indices of the overlapping tiles.
        # e.g. [(1, 0), (2, 1), (4, 3), (5, 4), (7, 6), (8, 7)]
        overlap_index_pairs = img_overlap_index_maps[img_overlap]
        # get a random index into the pairs list. e.g. 3
        overlap_index = np.random.choice(len(overlap_index_pairs))
        # now use the overlap_index to retreive the actual values from the index pairs. e.g. (5, 4)
        idx1, idx2 = overlap_index_pairs[overlap_index]

        is_dup = 1 if img1_id != img2_id else (np.random.random() > 0.5) * 1
        if not is_dup:
            temp = np.random.choice(len(self.ij), 2, replace=False) # grab 2 in case we randomly pick up kl again.
            idx2 = temp[0] if idx2 != temp[0] else temp[1]
            
        idx3 = np.random.choice(4, p=[0.4, 0.1, 0.1, 0.4])
        
        hls_idx = np.random.choice(3)
        hls_chan = idx_chan_map[hls_idx]
        hls_gain = np.random.choice(self.hls_limits[hls_chan]) + 1
        hls_gain = hls_gain if np.random.random() > 0.5 else -1 * hls_gain
        
        return ImgAugs(idx1, idx2, idx3, is_dup, hls_chan, hls_gain)
    
    def color_shift(self, img, chan, gain):
        hls = to_hls(img)
        hls_shifted = channel_shift(hls, chan, gain)
        return to_bgr(hls_shifted)
    
    def get_tile(self, img, ij, sz=256):
        i, j = self.ij[ij]
        return img[i * sz:(i + 1) * sz, j * sz:(j + 1) * sz, :]
    
    def read_from_large(self, img_id, ij):
        img = cv2.imread(os.path.join(self.train_768_dir, img_id))
        return self.get_tile(img, ij)
    
    def read_from_small(self, img_id, ij):
        filebase, fileext = img_id.split('.')
        tile_id = f'{filebase}_{ij}.{fileext}'
        return cv2.imread(os.path.join(self.train_256_dir, tile_id))
    
    def get_data_pair(self, img_overlap, img_aug):

        # diff img_id (img1_id != img2_id), random tile from overlap, where is_dup == 1 (from duplicate_truth.txt)
            # img1_[i,j], img2_[k,l], 1, exact or fuzzy
            # img1_[i,j], tile2_kl, 1, exact or fuzzy
            # tile1_ij, img2_[k,l], 1, exact or fuzzy
            # tile1_ij, tile2_kl, 1, exact or fuzzy
        
        # same img_id (img1_id == img2_id), same tile (ij == kl)
            # img1_[i,j], img1_[i,j], 1, exact
            # img1_[i,j], tile1_ij, 1, fuzzy
            # tile1_ij, img1_[i,j], 1, fuzzy
            # tile1_ij, tile1_ij, 1, exact
            
        # same img_id (img1_id == img2_id), diff tile (ij != kl)
            # img1_[i,j], img1_[k,l], 0, similar but different
            # img1_[i,j], tile1_kl, 0, similar but different
            # tile1_ij, img1_[k,l], 0, similar but different
            # tile1_ij, tile1_kl, 0, similar but different
            
        # diff img_id (img1_id != img2_id), same tile (ij == kl)
            # img1_[i,j], img2_[i,j], 0, very different
            # img1_[i,j], tile2_ij, 0, very different
            # tile1_ij, img2_[i,j], 0, very different
            # tile1_ij, tile2_ij, 0, very different
            
        # diff img_id (img1_id != img2_id), diff tile (ij != kl)
            # img1_[i,j], img2_[k,l], 0, very different
            # img1_[i,j], tile2_kl, 0, very different
            # tile1_ij, img2_[k,l], 0, very different
            # tile1_ij, tile2_kl, 0, very different
        
        # use image_duplicate_tiles.txt for equal image id pairs (img1_id == img2_id)
        #--------------------------------------------------------------------
        # ij == kl? | tile1? | tile2? | shift? | is_dup?
        #--------------------------------------------------------------------
        #   yes     |  768   |  768   |   yes  |    yes agro color shift 
        #   yes     |  768   |  768   |    no  |    yes
        #   yes     |  768   |  256   |    no  |    yes 
        #   yes     |  256   |  768   |    no  |    yes 
        #   yes     |  256   |  256   |   yes  |    yes agro color shift 
        #   yes     |  256   |  256   |    no  |    yes 
        #    no     |  768   |  768   |   yes  |     no 
        #    no     |  768   |  768   |    no  |     no 
        #    no     |  256   |  256   |   yes  |     no 
        #    no     |  256   |  256   |    no  |     no 
        
        # use duplicate_truth.txt for unequal image id pairs (img1_id != img2_id)
        # NOTE: Be sure to use the overlap_map when comparing ij and kl
        #--------------------------------------------------------------------
        # ij == kl? | tile1? | tile2? | shift? | is_dup?
        #--------------------------------------------------------------------
        #   yes     |  768   |  768   |   yes  |    yes small color shift 
        #   yes     |  768   |  768   |    no  |    yes
        #   yes     |  768   |  256   |    no  |    yes
        #   yes     |  256   |  768   |    no  |    yes
        #   yes     |  256   |  256   |   yes  |    yes
        #   yes     |  256   |  256   |    no  |    yes
        #    no     |  768   |  768   |   yes  |     no
        #    no     |  768   |  768   |    no  |     no
        #    no     |  256   |  256   |   yes  |     no 
        #    no     |  256   |  256   |    no  |     no

        img1_id, img2_id, img1_overlap_tag = img_overlap
        idx1, idx2, idx3, is_dup, chan, gain = img_aug

        if img1_id == img2_id:
            if is_dup:  # ij == kl
                if idx3 == 0:
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                elif idx3 == 1:
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2:
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3:
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                else:
                    raise ValueError
            else:  # ij != kl
                # These 4 have pretty much the same effect.
                # The last one (2 256x256 tiles vs 1 768x768 tile) is the fastest.
                # idx3 = np.random.choice(4, p=[0.1, 0.1, 0.1, 1.7])
                # idx3 = 3
                if idx3 == 0: # fast
                    img = cv2.imread(os.path.join(self.train_768_dir, img1_id))
                    tile1 = self.get_tile(img, idx1)
                    tile2 = self.get_tile(img, idx2)
                elif idx3 == 1: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slowest
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fastest
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                else:
                    raise ValueError
        else: # img1_id != img2_id
            if is_dup:  # ij == kl
                if idx3 == 0: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 1: # slow
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slow
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fast
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                else:
                    raise ValueError
            else:  # ij != kl
                tile1 = self.read_from_large(img1_id, idx1)
                tile2 = self.read_from_large(img2_id, idx2)
#                 if idx3 == 0: # slowest
#                     tile1 = self.read_from_large(img1_id, i, j)
#                     tile2 = self.read_from_large(img2_id, k, l)
#                 elif idx3 == 1: # slow
#                     tile1 = self.read_from_large(img1_id, i, j)
#                     tile2 = self.read_from_small(img2_id, k, l)
#                 elif idx3 == 2: # slow
#                     tile1 = self.read_from_small(img1_id, i, j)
#                     tile2 = self.read_from_large(img2_id, k, l)
#                 elif idx3 == 3: # fast
#                     tile1 = self.read_from_small(img1_id, i, j)
#                     tile2 = self.read_from_small(img2_id, k, l)
#                 else:
#                     raise ValueError

        if is_dup == 0 and np.all(tile1 == tile2):
            i, j = self.ij[idx1]
            k, l = self.ij[idx2]
            print(f'algo={idx3}; {img1_id} {idx1} -> ({i},{j}); {img2_id}, {idx2} -> ({k},{l}); {img1_overlap_tag}; correcting... {is_dup} -> 1')
            is_dup = 1
            
        tile1 = cv2.cvtColor(tile1, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        tile2 = cv2.cvtColor(tile2, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        
        X = np.dstack([tile1, tile2])
        X = X.transpose((2, 0, 1))
        y = np.array([is_dup], dtype=np.float32)
        return X, y

In [None]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark=True

input_shape = (6, 256, 256)
conv_layers = (16, 32, 64, 128, 256)
fc_layers = (128,)
output_size = 1

# Parameters
split = 80
batch_size = 256
max_epochs = 30
num_workers = 12
learning_rate = 0.0001

train_params = {'batch_size': batch_size,
                'shuffle': True,
                'num_workers': num_workers}

valid_params = {'batch_size': batch_size,
                'shuffle': False,
                'num_workers': num_workers}

# Datasets
dup_tiles = read_image_duplicate_tiles(image_duplicate_tiles_file)
dup_truth = read_duplicate_truth(duplicate_truth_file)
img_overlap_index_maps = create_dataset_from_tiles_and_truth(dup_tiles, dup_truth)
img_overlap_index_keys = list(img_overlap_index_maps)
np.random.shuffle(img_overlap_index_keys)
n_train, n_valid = even_split(len(img_overlap_index_keys), batch_size, split)
partition = {'train': img_overlap_index_keys[:n_train], 'valid': img_overlap_index_keys[-n_valid:]}
print(n_train, n_valid)


# Generators
train_set = Dataset(partition['train'], train_768_dir, train_256_dir)
train_generator = data.DataLoader(train_set, **train_params)

valid_set = Dataset(partition['valid'], train_768_dir, train_256_dir, valid=True)
valid_generator = data.DataLoader(valid_set, **valid_params)
print(len(train_generator), len(valid_generator))

In [None]:
model = DupCNN(input_shape, output_size, conv_layers, fc_layers)
model.cuda()

loss, optimizer = create_loss_and_optimizer(model, learning_rate)
scheduler = StepLR(optimizer, step_size=10, gamma=0.05)

n_batches = len(train_generator)
best_loss = 9999.0

# Loop over epochs
for epoch in range(max_epochs):
    
    start_time = time.time()
    scheduler.step()
    
    # Training
    total_train_loss = 0
    total_train_acc = 0
    model.train()
    t = tnrange(len(train_generator))
    train_iterator = iter(train_generator)
    for i in t:
        t.set_description(f'Epoch {epoch + 1:>3}')
        inputs, labels = train_iterator.next()
        # Transfer to GPU
        inputs, labels = inputs.to(device), labels.to(device)

        #Set the parameter gradients to zero
        optimizer.zero_grad()
        #Forward pass, backward pass, optimize
        outputs = model(inputs)
        train_loss = loss(outputs, labels)
        train_loss.backward()
        optimizer.step()
        #Print statistics
        total_train_loss += train_loss.data.item()
        
        y_pred = outputs > 0.5
        y_pred = y_pred.type_as(torch.cuda.FloatTensor())
        equality = labels == y_pred
        total_train_acc += equality.type_as(torch.FloatTensor()).mean()
        
        loss_str = f'{total_train_loss/(i + 1):.6f}'
        acc_str = f'{total_train_acc/(i + 1):.5f}'
        t.set_postfix(loss=loss_str, acc=acc_str)
    
    # Validation
    total_val_loss = 0
    total_val_acc = 0
    t = tnrange(len(valid_generator))
    valid_iterator = iter(valid_generator)
    with torch.set_grad_enabled(False):
        for i in t:
            t.set_description(f'Validation')
            inputs, labels = valid_iterator.next()
            # Transfer to GPU
            inputs, labels = inputs.to(device), labels.to(device)

            #Forward pass
            val_outputs = model(inputs)
            val_loss = loss(val_outputs, labels)
            total_val_loss += val_loss.data.item()

            y_pred = val_outputs > 0.5
            y_pred = y_pred.type_as(torch.cuda.FloatTensor())
            equality = labels == y_pred
            total_val_acc += equality.type_as(torch.FloatTensor()).mean()
        
            loss_str = f'{total_val_loss/(i + 1):.6f}'
            acc_str = f'{total_val_acc/(i + 1):.5f}'
            t.set_postfix(loss=loss_str, acc=acc_str)

        val_loss = total_val_loss/(i + 1)
        if val_loss < best_loss:
            save_checkpoint(os.path.join("out", f"dup_model.{epoch + 1:03d}-{val_loss:.6f}.pth"), model)
            save_checkpoint(os.path.join("out", "dup_model.last.pth"), model)
#             save_checkpoint(''.join(['out/checkpoint_', acc_str, '.pth']), model)
            best_loss = val_loss