In [1]:
import os
import time
import h5py
from collections import Counter
from collections import OrderedDict
from collections import defaultdict
import numpy as np
import cv2
from tqdm import tnrange, tqdm_notebook

import torch
from torch.utils import data

from dupnet import create_loss_and_optimizer
from dupnet import even_split
from dupnet import save_checkpoint
from dupnet import DupCNN

%matplotlib inline
%reload_ext autoreload
%autoreload 2

ship_dir = ".data/input/"
train_768_dir = os.path.join('.data', 'train_768')
train_256_dir = os.path.join(ship_dir, 'train_256')

os.makedirs(train_256_dir, exist_ok=True)

img_ids = os.listdir(train_768_dir)
len(img_ids)

192555

In [4]:
def read_from_large(img, i, j):
    tile = img[i * 256:(i + 1) * 256, j * 256:(j + 1) * 256, :]
    tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    return tile

def read_from_small(img_id, i, j):
    filebase, fileext = img_id.split('.')
    tile_id = f'{filebase}_{i}{j}.{fileext}'
    tile = cv2.imread(os.path.join(train_256_dir, tile_id))
    tile = cv2.cvtColor(tile, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    return tile

def from_large(img_id, i, j, k, l):
    img = cv2.imread(os.path.join(train_768_dir, img_id))
    tile1 = read_from_large(img, i, j)
    tile2 = read_from_large(img, k, l)
    return np.sum(tile1 - tile2)

def from_small(img_id, i, j, k, l):
    tile1 = read_from_small(img_id, i, j)
    tile2 = read_from_small(img_id, k, l)
    return np.sum(tile1 - tile2)

def from_both(img_id, i, j, k, l):
    img = cv2.imread(os.path.join(train_768_dir, img_id))
    tile1 = read_from_large(img, i, j)
    tile2 = read_from_small(img_id, k, l)
    return np.sum(tile1 - tile2)

In [6]:
ij_pairs = [(0, 0), (0, 1), (0, 2),
            (1, 0), (1, 1), (1, 2),
            (2, 0), (2, 1), (2, 2)]

n_steps = 100

ijkl = []
for i in range(n_steps):
    ij, kl = np.random.choice(len(ij_pairs), 2, replace=False)
    i, j = ij_pairs[ij]
    k, l = ij_pairs[kl]
    ijkl.append((i, j, k, l))

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (i, j, k, l), img_id in zip(ijkl, img_id_list):
    _ = from_large(img_id, i, j, k, l)
print(time.time() - t0)

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (i, j, k, l), img_id in zip(ijkl, img_id_list):
    _ = from_small(img_id, i, j, k, l)
print(time.time() - t0)

img_id_list = np.random.choice(img_ids, n_steps)
t0 = time.time()
for (i, j, k, l), img_id in zip(ijkl, img_id_list):
    _ = from_both(img_id, i, j, k, l)
print(time.time() - t0)

0.9052104949951172
1.184962511062622
1.9070210456848145


In [5]:
idx_chan_map = {0: 'H', 1: 'L', 2: 'S'}
chan_gimp_scale_map = {'H': 360, 'L': 200, 'S': 100}

def to_hls(bgr):
    return cv2.cvtColor(bgr, cv2.COLOR_BGR2HLS_FULL)

def to_bgr(hls):
    return cv2.cvtColor(hls, cv2.COLOR_HLS2BGR_FULL)

def channel_shift(hls, idx, val):
    '''
    hls values must be uint8. [0, 255]
    '''
    chan = idx_chan_map[idx]
    gimp_scale = chan_gimp_scale_map[chan]
    
    if idx == 0:  # hue

        scaled_val = 255. * val / gimp_scale
        scaled_val = np.around(scaled_val).astype(np.uint8)
        scaled_hls = np.copy(hls)
        scaled_hls[:, :, idx] += scaled_val

    elif idx == 1:  # lightness

        l = hls[:, :, idx] * (1. / 255.)
        v2 = val / gimp_scale
        one_m_v2 = 1 - v2
        one_p_v2 = 1 + v2
        l_shifted = l * one_m_v2 + v2 if val > 0 else l * one_p_v2
        l_shifted = np.clip(l_shifted, 0, 1)
        scaled_hls = np.copy(hls)
        scaled_hls[:, :, idx] = np.around(255 * l_shifted).astype(np.uint8)

    elif idx == 2:  # saturation

        scaled_val = (val / gimp_scale) + 1.
        s_shifted = hls[:, :, idx] * scaled_val
        s_shifted = np.clip(s_shifted, 0, 255)
        scaled_hls = np.copy(hls)
        scaled_hls[:, :, idx] = np.around(s_shifted).astype(np.uint8)

    else:
        raise ValueError

    return scaled_hls


class Dataset(data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, img_ids, train_768_dir, train_256_dir, valid=False):
        'Initialization'
        self.img_ids = img_ids
        self.train_768_dir = train_768_dir
        self.train_256_dir = train_256_dir
        self.ij = [(0, 0), (0, 1), (0, 2),
                   (1, 0), (1, 1), (1, 2),
                   (2, 0), (2, 1), (2, 2)]
        
        self.valid = valid
        if self.valid:
            self.x_valid = np.empty((len(self.img_ids), 6, 256, 256), dtype=np.float32)
            self.y_valid = np.empty((len(self.img_ids), 1), dtype=np.float32)
            for i, img_id in enumerate(self.img_ids):
                X, y = self.get_data_pair(img_id)
                self.x_valid[i] = X
                self.y_valid[i] = y

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.img_ids)

    def __getitem__(self, index):
        'Generates one sample of data'
        if self.valid:
            X = self.x_valid[index]
            y = self.y_valid[index]
        else:
            img_id = self.img_ids[index]
            X, y = self.get_data_pair(img_id)
        return X, y
    
    def color_shift(self, img):
        idx = np.random.choice(3)
        gain = np.random.choice(15) + 1
        gain = gain if np.random.random() > 0.5 else gain * -1
        hls = to_hls(img)
        hls_shifted = channel_shift(hls, idx, gain)
        return to_bgr(hls_shifted)
    
    def slice_from_large(self, img, i, j):
        return img[i * 256:(i + 1) * 256, j * 256:(j + 1) * 256, :]
    
    def read_from_large(self, img_id, i, j):
        img = cv2.imread(os.path.join(self.train_768_dir, img_id))
        return self.slice_from_large(img, i, j)
    
    def read_from_small(self, img_id, i, j):
        filebase, fileext = img_id.split('.')
        tile_id = f'{filebase}_{i}{j}.{fileext}'
        return cv2.imread(os.path.join(self.train_256_dir, tile_id))
    
    def get_data_pair(self, img_id):

        # same img_id (img_id1 == img_id2), same tile (ij == kl)
            # img_m[i,j], img_m[i,j], 1, exact
            # img_m[i,j], tile_m_ij, 1, fuzzy
            # tile_m_ij, img_m[i,j], 1, fuzzy
            # tile_m_ij, tile_m_ij, 1, exact
        # same img_id (img_id1 == img_id2), diff tile (ij != kl)
            # img_m[i,j], img_m[k,l], 0, similar
            # img_m[i,j], tile_m_kl, 0, similar
            # tile_m_ij, img_m[k,l], 0, similar
            # tile_m_ij, tile_m_kl, 0, similar
        # diff img_id (img_id1 != img_id2), same tile (ij == kl)
            # img_m[i,j], img_n[i,j], 0, very different
            # img_m[i,j], tile_n_ij, 0, very different
            # tile_m_ij, img_n[i,j], 0, very different
            # tile_m_ij, tile_n_ij, 0, very different
        # diff img_id (img_id1 != img_id2), diff tile (ij != kl)
            # img_m[i,j], img_n[k,l], 0, very different
            # img_m[i,j], tile_n_kl, 0, very different
            # tile_m_ij, img_n[k,l], 0, very different
            # tile_m_ij, tile_n_kl, 0, very different
        
        ij, kl = np.random.choice(len(self.ij), 2, replace=False)
        i, j = self.ij[ij]
        k, l = self.ij[kl]
        
        idx0 = np.random.choice(3, p=[0.5, 0.5, 0.0])
        if idx0 == 0:
            
            idx1 = np.random.choice(4, p=[0.25, 0.25, 0.25, 0.25])
            if idx1 == 0:
                tile1 = self.read_from_large(img_id, i, j)
                tile2 = self.color_shift(tile1)
            elif idx1 == 1:
                tile1 = self.read_from_large(img_id, i, j)
                tile2 = self.read_from_small(img_id, i, j)
            elif idx1 == 2:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_large(img_id, i, j)
            elif idx1 == 3:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.color_shift(tile1)
            else:
                raise ValueError
            
            y = 1
            
        elif idx0 == 1:
        
#             idx1 = np.random.choice(4, p=[0.1, 0.1, 0.1, 1.7])
            idx1 = 0
            if idx1 == 0:
                img = cv2.imread(os.path.join(self.train_768_dir, img_id))
                tile1 = self.slice_from_large(img, i, j)
                tile2 = self.slice_from_large(img, k, l)
            elif idx1 == 1:
                tile1 = self.read_from_large(img_id, i, j)
                tile2 = self.read_from_small(img_id, k, l)
            elif idx1 == 2:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_large(img_id, k, l)
            elif idx1 == 3:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_small(img_id, k, l)
            else:
                raise ValueError

            y = 0

        elif idx0 == 2:

            img_id2 = np.random.choice(self.img_ids)
            
            idx1 = np.random.choice(2)
            if idx1 == 0:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_small(img_id2, i, j)
            elif idx1 == 1:
                tile1 = self.read_from_small(img_id, i, j)
                tile2 = self.read_from_small(img_id2, k, l)
            else:
                raise ValueError

            y = 0

        else:
            raise ValueError
        
        if np.all(tile1 == tile2) and y == 0:
            print(f'{img_id}, ({i}, {j}), ({k}, {l}) correcting...')
            y = 1
        
        tile1 = cv2.cvtColor(tile1, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        tile2 = cv2.cvtColor(tile2, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        
        X = np.dstack([tile1, tile2])
#         X = tile1 - tile2
        X = X.transpose((2, 0, 1))
        y = np.array([y], dtype=np.float32)
        return X, y

In [6]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark=True

input_shape = (6, 256, 256)
conv_layers = (16, 32, 64, 128, 256)
fc_layers = (128,)
output_size = 1

# Parameters
split = 97
batch_size = 128
max_epochs = 40
num_workers = 12
learning_rate = 0.00001

train_params = {'batch_size': batch_size,
                'shuffle': True,
                'num_workers': num_workers}
valid_params = {'batch_size': batch_size,
                'shuffle': False,
                'num_workers': num_workers}

# Datasets
np.random.shuffle(img_ids)
n_train, n_valid = even_split(len(img_ids), batch_size, split)
partition = {'train': img_ids[:n_train], 'valid': img_ids[-n_valid:]}

# Generators
train_set = Dataset(partition['train'], train_768_dir, train_256_dir)
train_generator = data.DataLoader(train_set, **train_params)

valid_set = Dataset(partition['valid'], train_768_dir, train_256_dir, valid=True)
valid_generator = data.DataLoader(valid_set, **valid_params)
print(len(train_generator), len(valid_generator))

1459 45


In [None]:
inputs, labels = next(iter(train_generator))
inputs.shape, labels.shape

In [13]:
inputs, labels = next(iter(valid_generator))
inputs.shape, labels.shape

(torch.Size([128, 6, 256, 256]), torch.Size([128, 1]))

In [7]:
model = DupCNN(input_shape, output_size, conv_layers, fc_layers)
model.cuda()

loss, optimizer = create_loss_and_optimizer(model, learning_rate)
n_batches = len(train_generator)
best_score = 0

# Loop over epochs
for epoch in range(max_epochs):
    
    start_time = time.time()
    total_train_loss = 0
    total_train_acc = 0
        
    # Training
    model.train()
    t = tnrange(len(train_generator))
    train_iterator = iter(train_generator)
    for i in t:
        t.set_description(f'Epoch {epoch + 1:>3}')
        inputs, labels = train_iterator.next()
        # Transfer to GPU
        inputs, labels = inputs.to(device), labels.to(device)

        #Set the parameter gradients to zero
        optimizer.zero_grad()

        #Forward pass, backward pass, optimize
        outputs = model(inputs)
        loss_size = loss(outputs, labels)
        loss_size.backward()
        optimizer.step()

        #Print statistics
        total_train_loss += loss_size.data.item()
        
        y_pred = outputs > 0.5
        y_pred = y_pred.type_as(torch.cuda.FloatTensor())
        equality = labels == y_pred
        total_train_acc += equality.type_as(torch.FloatTensor()).mean()
        
        loss_str = f'{total_train_loss/(i + 1):.6f}'
        acc_str = f'{total_train_acc/(i + 1):.5f}'
        t.set_postfix(loss=loss_str, acc=acc_str)
            
    # Validation
    total_val_loss = 0
    total_val_acc = 0
    t = tnrange(len(valid_generator))
    valid_iterator = iter(valid_generator)
    with torch.set_grad_enabled(False):
        for i in t:
            t.set_description(f'Epoch {epoch + 1:>3}')
            inputs, labels = valid_iterator.next()
            # Transfer to GPU
            inputs, labels = inputs.to(device), labels.to(device)

            #Forward pass
            val_outputs = model(inputs)
            val_loss_size = loss(val_outputs, labels)
            total_val_loss += val_loss_size.data.item()

            y_pred = val_outputs > 0.5
            y_pred = y_pred.type_as(torch.cuda.FloatTensor())
            equality = labels == y_pred
            total_val_acc += equality.type_as(torch.FloatTensor()).mean()
        
            loss_str = f'{total_val_loss/(i + 1):.6f}'
            acc_str = f'{total_val_acc/(i + 1):.5f}'
            t.set_postfix(loss=loss_str, acc=acc_str)

        if total_val_acc/(i + 1) > best_score:
            save_checkpoint(''.join(['out/checkpoint_', acc_str, '.pth']), model)
            best_score = total_val_acc/(i + 1)

HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1474), HTML(value='')))






HBox(children=(IntProgress(value=0, max=30), HTML(value='')))


