In [1]:
import os
import time
from collections import Counter
from collections import OrderedDict
from collections import defaultdict
from collections import namedtuple

from tqdm import tnrange, tqdm_notebook
import numpy as np
import cv2

import torch
from torch.utils import data
from torch.optim.lr_scheduler import StepLR

from utils import channel_shift
from utils import to_hls
from utils import to_bgr
from utils import read_image_duplicate_tiles
from utils import even_split
from dupnet import create_loss_and_optimizer
from dupnet import save_checkpoint
from dupnet import DupCNN

%matplotlib inline
%reload_ext autoreload
%autoreload 2

ship_dir = "data/input/"
train_768_dir = os.path.join("data", "train_768")
train_256_dir = os.path.join(ship_dir, "train_256")
image_duplicate_tiles_file = os.path.join("data", "image_duplicate_tiles.txt")

os.makedirs(train_256_dir, exist_ok=True)

img_ids = os.listdir(train_768_dir)
len(img_ids)

192555

## Speed Test

In [2]:
def read_from_large(img, i, j):
    tile = img[i * 256:(i + 1) * 256, j * 256:(j + 1) * 256, :]
    tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    return tile

def read_from_small(img_id, i, j):
    filebase, fileext = img_id.split('.')
    tile_id = f'{filebase}_{i}{j}.{fileext}'
    tile = cv2.imread(os.path.join(train_256_dir, tile_id))
    tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    return tile

def from_large(img_id, i, j, k, l):
    img = cv2.imread(os.path.join(train_768_dir, img_id))
    tile1 = read_from_large(img, i, j)
    tile2 = read_from_large(img, k, l)
    return np.dstack([tile1, tile2])

def from_small(img_id, i, j, k, l):
    tile1 = read_from_small(img_id, i, j)
    tile2 = read_from_small(img_id, k, l)
    return np.dstack([tile1, tile2])

def from_both(img_id, i, j, k, l):
    img = cv2.imread(os.path.join(train_768_dir, img_id))
    tile1 = read_from_large(img, i, j)
    tile2 = read_from_small(img_id, k, l)
    return np.dstack([tile1, tile2])

n_steps = 500

ij_pairs = [(0, 0), (0, 1), (0, 2),
            (1, 0), (1, 1), (1, 2),
            (2, 0), (2, 1), (2, 2)]
ijkl = []
for i in range(n_steps):
    ij, kl = np.random.choice(len(ij_pairs), 2, replace=False)
    i, j = ij_pairs[ij]
    k, l = ij_pairs[kl]
    ijkl.append((i, j, k, l))

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (i, j, k, l), img_id in zip(ijkl, img_id_list):
    _ = from_large(img_id, i, j, k, l)
print('from_large:', time.time() - t0)

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (i, j, k, l), img_id in zip(ijkl, img_id_list):
    _ = from_small(img_id, i, j, k, l)
print('from_small:', time.time() - t0)

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (i, j, k, l), img_id in zip(ijkl, img_id_list):
    _ = from_both(img_id, i, j, k, l)
print('from_both: ', time.time() - t0)

from_large: 3.5697109699249268
from_small: 1.8689689636230469
from_both:  4.278532028198242


In [5]:
dup_tiles = read_image_duplicate_tiles(image_duplicate_tiles_file)

idx_chan_map = {0: 'H', 1: 'L', 2: 'S'}

ImgAugs = namedtuple('ImgAugs', 'idx0 idx1 ij kl chan gain')

class Dataset(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, img_ids, train_768_dir, train_256_dir, 
                 in_shape=(6, 256, 256), 
                 out_shape=(1,), 
                 valid=False):
        
        'Initialization'
        self.img_ids = img_ids
        self.train_768_dir = train_768_dir
        self.train_256_dir = train_256_dir
        self.ij = [(0, 0), (0, 1), (0, 2),
                   (1, 0), (1, 1), (1, 2),
                   (2, 0), (2, 1), (2, 2)]
        
        self.in_shape = in_shape
        self.out_shape = out_shape
        self.hls_limits = {'H': 5, 'L': 15, 'S': 15}
        
        self.valid = valid
        if self.valid:
            self.img_augs = {}
            for i, img_id in enumerate(self.img_ids):
                self.img_augs[img_id] = self.get_random_mapping(img_id)

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.img_ids)

    def __getitem__(self, index):
        'Generates one sample of data'
        img_id = self.img_ids[index]
        if self.valid:
            img_aug = self.img_augs[img_id]
        else:
            img_aug = self.get_random_mapping(img_id)
        X, y = self.get_data_pair(img_id, img_aug)
        return X, y
    
    # 001bfb70a.jpg, 1, 3, (1, 0), (2, 0) 
    
    def get_random_mapping(self, img_id):
        ij, kl = np.random.choice(len(self.ij), 2, replace=False)
        if dup_tiles[img_id][ij] == dup_tiles[img_id][kl]:
            # Tiles are exact match
            idx0 = 0
        else:
            idx0 = np.random.choice(3, p=[0.5, 0.5, 0.0])
        idx1 = np.random.choice(4, p=[0.25, 0.25, 0.25, 0.25])
        hls_idx = np.random.choice(3)
        hls_chan = idx_chan_map[hls_idx]
        hls_gain = np.random.choice(self.hls_limits[hls_chan]) + 1
        hls_gain = hls_gain if np.random.random() > 0.5 else -1 * hls_gain
        
        return ImgAugs(idx0, idx1, ij, kl, hls_chan, hls_gain)
    
    def color_shift(self, img, chan, gain):
        hls = to_hls(img)
        hls_shifted = channel_shift(hls, chan, gain)
        return to_bgr(hls_shifted)
    
    def get_tile(self, img, i, j, sz=256):
        return img[i * sz:(i + 1) * sz, j * sz:(j + 1) * sz, :]
    
    def read_from_large(self, img_id, i, j):
        img = cv2.imread(os.path.join(self.train_768_dir, img_id))
        return self.get_tile(img, i, j)
    
    def read_from_small(self, img_id, i, j):
        filebase, fileext = img_id.split('.')
        tile_id = f'{filebase}_{i}{j}.{fileext}'
        return cv2.imread(os.path.join(self.train_256_dir, tile_id))
    
    def get_data_pair(self, img_id, img_aug):

        # same img_id (img_id1 == img_id2), same tile (ij == kl)
            # img_m[i,j], img_m[i,j], 1, exact
            # img_m[i,j], tile_m_ij, 1, fuzzy
            # tile_m_ij, img_m[i,j], 1, fuzzy
            # tile_m_ij, tile_m_ij, 1, exact
        # same img_id (img_id1 == img_id2), diff tile (ij != kl)
            # img_m[i,j], img_m[k,l], 0, similar but different
            # img_m[i,j], tile_m_kl, 0, similar but different
            # tile_m_ij, img_m[k,l], 0, similar but different
            # tile_m_ij, tile_m_kl, 0, similar but different
        # diff img_id (img_id1 != img_id2), same tile (ij == kl)
            # img_m[i,j], img_n[i,j], 0, very different
            # img_m[i,j], tile_n_ij, 0, very different
            # tile_m_ij, img_n[i,j], 0, very different
            # tile_m_ij, tile_n_ij, 0, very different
        # diff img_id (img_id1 != img_id2), diff tile (ij != kl)
            # img_m[i,j], img_n[k,l], 0, very different
            # img_m[i,j], tile_n_kl, 0, very different
            # tile_m_ij, img_n[k,l], 0, very different
            # tile_m_ij, tile_n_kl, 0, very different
        
        idx0, idx1, ij, kl, chan, gain = img_aug
        i, j = self.ij[ij]
        k, l = self.ij[kl]
        
        if idx0 == 0:
            
            if idx1 == 0:
                tile1 = self.read_from_large(img_id, i, j)
                tile2 = self.color_shift(tile1, chan, gain)
            elif idx1 == 1:
                tile1 = self.read_from_large(img_id, i, j)
                tile2 = self.read_from_small(img_id, i, j)
            elif idx1 == 2:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_large(img_id, i, j)
            elif idx1 == 3:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.color_shift(tile1, chan, gain)
            else:
                raise ValueError
            
            y = 1
            
        elif idx0 == 1:
            
            # These 4 have pretty much the same effect.
            # The last one is the fastest.
#             idx1 = np.random.choice(4, p=[0.1, 0.1, 0.1, 1.7])
            idx1 = 3
            if idx1 == 0:
                img = cv2.imread(os.path.join(self.train_768_dir, img_id))
                tile1 = self.get_tile(img, i, j)
                tile2 = self.get_tile(img, k, l)
            elif idx1 == 1:
                tile1 = self.read_from_large(img_id, i, j)
                tile2 = self.read_from_small(img_id, k, l)
            elif idx1 == 2:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_large(img_id, k, l)
            elif idx1 == 3:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_small(img_id, k, l)
            else:
                raise ValueError

            y = 0

        elif idx0 == 2:

            img_id2 = np.random.choice(self.img_ids)
            
            idx1 = np.random.choice(2)
            if idx1 == 0:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_small(img_id2, i, j)
            elif idx1 == 1:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_small(img_id2, k, l)
            else:
                raise ValueError

            y = 0

        else:
            raise ValueError
        
        if y == 0 and np.all(tile1 == tile2):
            print(f'{img_id}, {idx0}, {idx1}, ({ij} {kl}), (({i}, {j}), ({k}, {l})) correcting...')
            y = 1
        
        tile1 = cv2.cvtColor(tile1, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        tile2 = cv2.cvtColor(tile2, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        
        X = np.dstack([tile1, tile2])
#         X = tile1 - tile2
        X = X.transpose((2, 0, 1))
        y = np.array([y], dtype=np.float32)
        return X, y

In [6]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark=True

input_shape = (6, 256, 256)
conv_layers = (16, 32, 64, 128, 256)
fc_layers = (128,)
output_size = 1

# Parameters
split = 80
batch_size = 256
max_epochs = 50
num_workers = 12
learning_rate = 0.0001

train_params = {'batch_size': batch_size,
                'shuffle': True,
                'num_workers': num_workers}

valid_params = {'batch_size': batch_size,
                'shuffle': False,
                'num_workers': num_workers}
# Datasets
np.random.shuffle(img_ids)
n_train, n_valid = even_split(len(img_ids), batch_size, split)
partition = {'train': img_ids[:n_train], 'valid': img_ids[-n_valid:]}

# Generators
train_set = Dataset(partition['train'], train_768_dir, train_256_dir)
train_generator = data.DataLoader(train_set, **train_params)

valid_set = Dataset(partition['valid'], train_768_dir, train_256_dir, valid=True)
valid_generator = data.DataLoader(valid_set, **valid_params)
print(len(train_generator), len(valid_generator))

601 150


In [7]:
model = DupCNN(input_shape, output_size, conv_layers, fc_layers)
model.cuda()

loss, optimizer = create_loss_and_optimizer(model, learning_rate)
scheduler = StepLR(optimizer, step_size=10, gamma=0.05)

n_batches = len(train_generator)
best_loss = 9999.0

# Loop over epochs
for epoch in range(max_epochs):
    
    start_time = time.time()
    scheduler.step()
    
    # Training
    total_train_loss = 0
    total_train_acc = 0
    model.train()
    t = tnrange(len(train_generator))
    train_iterator = iter(train_generator)
    for i in t:
        t.set_description(f'Epoch {epoch + 1:>3}')
        inputs, labels = train_iterator.next()
        # Transfer to GPU
        inputs, labels = inputs.to(device), labels.to(device)

        #Set the parameter gradients to zero
        optimizer.zero_grad()
        #Forward pass, backward pass, optimize
        outputs = model(inputs)
        train_loss = loss(outputs, labels)
        train_loss.backward()
        optimizer.step()
        #Print statistics
        total_train_loss += train_loss.data.item()
        
        y_pred = outputs > 0.5
        y_pred = y_pred.type_as(torch.cuda.FloatTensor())
        equality = labels == y_pred
        total_train_acc += equality.type_as(torch.FloatTensor()).mean()
        
        loss_str = f'{total_train_loss/(i + 1):.6f}'
        acc_str = f'{total_train_acc/(i + 1):.5f}'
        t.set_postfix(loss=loss_str, acc=acc_str)
    
    # Validation
    total_val_loss = 0
    total_val_acc = 0
    t = tnrange(len(valid_generator))
    valid_iterator = iter(valid_generator)
    with torch.set_grad_enabled(False):
        for i in t:
            t.set_description(f'Validation')
            inputs, labels = valid_iterator.next()
            # Transfer to GPU
            inputs, labels = inputs.to(device), labels.to(device)

            #Forward pass
            val_outputs = model(inputs)
            val_loss = loss(val_outputs, labels)
            total_val_loss += val_loss.data.item()

            y_pred = val_outputs > 0.5
            y_pred = y_pred.type_as(torch.cuda.FloatTensor())
            equality = labels == y_pred
            total_val_acc += equality.type_as(torch.FloatTensor()).mean()
        
            loss_str = f'{total_val_loss/(i + 1):.6f}'
            acc_str = f'{total_val_acc/(i + 1):.5f}'
            t.set_postfix(loss=loss_str, acc=acc_str)

        val_loss = total_val_loss/(i + 1)
        if val_loss < best_loss:
            save_checkpoint(os.path.join("out", f"dup_model.{epoch + 1:03d}-{val_loss:.6f}.pth"), model)
            save_checkpoint(os.path.join("out", "dup_model.last.pth"), model)
#             save_checkpoint(''.join(['out/checkpoint_', acc_str, '.pth']), model)
            best_loss = val_loss

HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

001bfb70a.jpg, 1, 3, (1, 0), (2, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

fc3bf6f21.jpg, 1, 3, (2, 0), (1, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

04bb76f99.jpg, 1, 3, (1, 0), (2, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

04bb76f99.jpg, 1, 3, (2, 0), (1, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

ce341a4be.jpg, 1, 3, (0, 0), (0, 1) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

c1776befd.jpg, 1, 3, (0, 0), (1, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

fc3bf6f21.jpg, 1, 3, (2, 0), (0, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

66a5f04a7.jpg, 1, 3, (0, 0), (1, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

04bb76f99.jpg, 1, 3, (1, 0), (2, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

c1776befd.jpg, 1, 3, (2, 0), (0, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

001bfb70a.jpg, 1, 3, (1, 0), (2, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))

ce341a4be.jpg, 1, 3, (0, 1), (0, 0) correcting...



HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))




HBox(children=(IntProgress(value=0, max=601), HTML(value='')))




HBox(children=(IntProgress(value=0, max=150), HTML(value='')))


