In [None]:
import os
import time
import random
from collections import namedtuple

from tqdm import tnrange
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from parse import parse

import torch
from torch.utils import data
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms

from sdcdup.utils import get_datetime_now
from sdcdup.utils import channel_shift
from sdcdup.utils import to_hls
from sdcdup.utils import to_bgr
from sdcdup.utils import create_dataset_from_tiles
from sdcdup.utils import create_dataset_from_tiles_and_truth
from sdcdup.utils import even_split
from sdcdup.utils import CSVLogger

sdcdup.features.image_features import SDCImageContainer
# from datasets import create_dataset_from_truth
from sdcdup.models.dupnet import save_checkpoint
from sdcdup.models.dupnet import DupCNN

%matplotlib inline
%reload_ext autoreload
%autoreload 2

# SENDTOENV
train_768_dir = 'data/raw/train_768/'
train_256_dir = 'data/processed/train_256/'
image_md5hash_grids_file = 'data/interim/image_md5hash_grids.pkl'
image_bm0hash_grids_file = 'data/interim/image_bm0hash_grids.pkl'
image_cm0hash_grids_file = 'data/interim/image_cm0hash_grids.pkl'
image_greycop_grids_file = 'data/interim/image_greycop_grids.pkl'
image_entropy_grids_file = 'data/interim/image_entropy_grids.pkl'
image_issolid_grids_file = 'data/interim/image_issolid_grids.pkl'
duplicate_truth_file = os.path.join('data/processed/duplicate_truth.txt')

In [None]:
# SENDTOMODULE
class RandomHorizontalFlip:
    """Horizontally flip the given numpy array randomly with a given probability.

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (Image): Image to be flipped.

        Returns:
            Image: Randomly flipped image.
        """
        if np.random.random() < self.p:
            return cv2.flip(img, 1)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomTransformC4:
    """Rotate a n-D tensor by 90 degrees in the H x W plane.

    Args:
        with_identity (bool): whether or not to include 0 degrees as a probable rotation
    """

    def __init__(self, with_identity=True):
        self.with_identity = with_identity
        self.n90s = (0, 1, 2, 3) if self.with_identity else (1, 2, 3)

    def __call__(self, img):
        """
        Args:
            img (Image): Image to be rotated.

        Returns:
            Image: Randomly rotated image but in 90 degree increments.
        """
        k = random.choice(self.n90s)
        return torch.rot90(img, k, (1, 2))

    def __repr__(self):
        return self.__class__.__name__ + '(with_identity={})'.format(self.with_identity)


In [None]:
idx_chan_map = {0: 'H', 1: 'L', 2: 'S'}

# SENDTOMODULE
# ImgAugs = namedtuple('ImgAugs', 'idx3 chan gain')
ImgAugs = namedtuple('ImgAugs', 'flip_img_order first_from_large second_from_large second_augment_hls hls_chan hls_gain flip_stacking_order')
class Dataset(data.Dataset):
    """Characterizes a dataset for PyTorch"""
    def __init__(self, img_overlaps, train_or_valid, image_transform,
                 in_shape=(6, 256, 256), 
                 out_shape=(1,)):
        
        """Initialization"""
        self.img_overlaps = img_overlaps
        # TODO: handle case if train_or_valid == 'test'
        self.valid = train_or_valid == 'valid'
        self.image_transform = image_transform
        self.ij = ((0, 0), (0, 1), (0, 2),
                   (1, 0), (1, 1), (1, 2),
                   (2, 0), (2, 1), (2, 2))
        
        self.in_shape = in_shape
        self.out_shape = out_shape
        self.hls_limits = {'H': 10, 'L': 20, 'S': 20}
        if self.valid:
            self.img_augs = [self.get_random_augmentation() for _ in self.img_overlaps]

    def __len__(self):
        """Denotes the total number of samples"""
        return len(self.img_overlaps)

    def __getitem__(self, index):
        """Generates one sample of data"""
        if self.valid:
            img_aug = self.img_augs[index]
        else:
            img_aug = self.get_random_augmentation()
        return self.get_data_pair(self.img_overlaps[index], img_aug)  # X, y
    
    def get_random_augmentation(self):

        # So, we aren't always biasing the second image with hls shifting...
        flip_img_order = np.random.random() > 0.5
        # The first tile will always come from either a slice of the image or from the saved slice.
        first_from_large = np.random.random() > 0.5
        second_from_large = np.random.random() > 0.5
        second_augment_hls = np.random.random() > 0.25
        flip_stacking_order = np.random.random() > 0.5
        
        hls_idx = np.random.choice(3)
        hls_chan = idx_chan_map[hls_idx]
        hls_gain = np.random.choice(self.hls_limits[hls_chan]) + 1
        hls_gain = hls_gain if np.random.random() > 0.5 else -1 * hls_gain
        
        return ImgAugs(flip_img_order, first_from_large, second_from_large, second_augment_hls, hls_chan, hls_gain, flip_stacking_order)
    
    def color_shift(self, img, chan, gain):
        hls = to_hls(img)
        hls_shifted = channel_shift(hls, chan, gain)
        return to_bgr(hls_shifted)
    
    def get_tile(self, img, idx, sz=256):
        i, j = self.ij[idx]
        return img[i * sz:(i + 1) * sz, j * sz:(j + 1) * sz, :]
    
    def read_from_large(self, img_id, idx):
        img = cv2.imread(os.path.join(train_768_dir, img_id))
        return self.get_tile(img, idx)
    
    def read_from_small(self, img_id, idx):
        filebase, fileext = img_id.split('.')
        tile_id = f'{filebase}_{idx}.{fileext}'
        return cv2.imread(os.path.join(train_256_dir, tile_id))
    
    def get_data_pair(self, img_overlap, img_aug):

        # diff img_id (img1_id != img2_id), random tile from overlap, where is_dup == 1 (from duplicate_truth.txt)
            # img1_[i,j], img2_[k,l], 1, exact or fuzzy
            # img1_[i,j], tile2_kl, 1, exact or fuzzy
            # tile1_ij, img2_[k,l], 1, exact or fuzzy
            # tile1_ij, tile2_kl, 1, exact or fuzzy
        
        # same img_id (img1_id == img2_id), same tile (ij == kl)
            # img1_[i,j], img1_[i,j], 1, exact
            # img1_[i,j], tile1_ij, 1, fuzzy
            # tile1_ij, img1_[i,j], 1, fuzzy
            # tile1_ij, tile1_ij, 1, exact
            
        # same img_id (img1_id == img2_id), diff tile (ij != kl)
            # img1_[i,j], img1_[k,l], 0, similar but different
            # img1_[i,j], tile1_kl, 0, similar but different
            # tile1_ij, img1_[k,l], 0, similar but different
            # tile1_ij, tile1_kl, 0, similar but different
            
        # diff img_id (img1_id != img2_id), same tile (ij == kl)
            # img1_[i,j], img2_[i,j], 0, very different
            # img1_[i,j], tile2_ij, 0, very different
            # tile1_ij, img2_[i,j], 0, very different
            # tile1_ij, tile2_ij, 0, very different
            
        # diff img_id (img1_id != img2_id), diff tile (ij != kl)
            # img1_[i,j], img2_[k,l], 0, very different
            # img1_[i,j], tile2_kl, 0, very different
            # tile1_ij, img2_[k,l], 0, very different
            # tile1_ij, tile2_kl, 0, very different
        
        # use image_md5hash_grids.pkl for equal image id pairs (img1_id == img2_id)
        #--------------------------------------------------------------------
        # ij == kl? | tile1? | tile2? | shift? | is_dup?
        #--------------------------------------------------------------------
        #   yes     |  768   |  768   |   yes  |    yes agro color shift 
        #   yes     |  768   |  768   |    no  |    yes
        #   yes     |  768   |  256   |    no  |    yes 
        #   yes     |  256   |  768   |    no  |    yes 
        #   yes     |  256   |  256   |   yes  |    yes agro color shift 
        #   yes     |  256   |  256   |    no  |    yes 
        #    no     |  768   |  768   |   yes  |     no 
        #    no     |  768   |  768   |    no  |     no 
        #    no     |  256   |  256   |   yes  |     no 
        #    no     |  256   |  256   |    no  |     no 
        
        # use duplicate_truth.txt for unequal image id pairs (img1_id != img2_id)
        # NOTE: Be sure to use the overlap_map when comparing ij and kl
        #--------------------------------------------------------------------
        # ij == kl? | tile1? | tile2? | shift? | is_dup?
        #--------------------------------------------------------------------
        #   yes     |  768   |  768   |   yes  |    yes small color shift 
        #   yes     |  768   |  768   |    no  |    yes
        #   yes     |  768   |  256   |    no  |    yes
        #   yes     |  256   |  768   |    no  |    yes
        #   yes     |  256   |  256   |   yes  |    yes
        #   yes     |  256   |  256   |    no  |    yes
        #    no     |  768   |  768   |   yes  |     no
        #    no     |  768   |  768   |    no  |     no
        #    no     |  256   |  256   |   yes  |     no 
        #    no     |  256   |  256   |    no  |     no

        flip_img_order, first_from_large, second_from_large, aug_hls, chan, gain, flip_stacking_order = img_aug
        if flip_img_order:
            img2_id, img1_id, idx2, idx1, is_dup = img_overlap
        else:
            img1_id, img2_id, idx1, idx2, is_dup = img_overlap
        
        read1 = self.read_from_large if first_from_large else self.read_from_small
        read2 = self.read_from_large if second_from_large else self.read_from_small
        same_image = img1_id == img2_id
        
        if same_image:  # img1_id == img2_id
            if is_dup:  # idx1 == idx2
                tile1 = read1(img1_id, idx1)
                if aug_hls:
                    tile2 = self.color_shift(tile1, chan, gain)
                else:
                    tile2 = read2(img2_id, idx2)
            else:  # idx1 != idx2
                if first_from_large and second_from_large:
                    img = cv2.imread(os.path.join(train_768_dir, img1_id))
                    tile1 = self.get_tile(img, idx1)
                    tile2 = self.get_tile(img, idx2)
                else:
                    tile1 = read1(img1_id, idx1)
                    tile2 = read2(img2_id, idx2)
        else: # img1_id != img2_id
            tile1 = read1(img1_id, idx1)
            tile2 = read2(img2_id, idx2)

#         if is_dup == 0 and sdcic.tile_md5hash_grids[img1_id][idx1] == sdcic.tile_md5hash_grids[img2_id][idx2]:
#             i, j = self.ij[idx1]
#             k, l = self.ij[idx2]
#             print(f'algo={idx3}; {img1_id} {idx1} -> ({i},{j}); {img2_id}, {idx2} -> ({k},{l}); correcting... {is_dup} -> 1')
#             is_dup = 1
            
        tile1 = cv2.cvtColor(tile1, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        tile2 = cv2.cvtColor(tile2, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        
        X = np.dstack([tile2, tile1]) if flip_stacking_order else np.dstack([tile1, tile2])        
        X = self.image_transform(X)
        y = np.array([is_dup], dtype=np.float32)
        return X, y

In [None]:
idx_chan_map = {0: 'H', 1: 'L', 2: 'S'}

# SENDTOMODULE
ImgAugs = namedtuple('ImgAugs', 'idx3 chan gain')
class Dataset(data.Dataset):
    """Characterizes a dataset for PyTorch"""
    def __init__(self, img_overlaps, train_or_valid, image_transform,
                 in_shape=(6, 256, 256), 
                 out_shape=(1,)):
        
        """Initialization"""
        self.img_overlaps = img_overlaps
        # TODO: handle case if train_or_valid == 'test'
        self.valid = train_or_valid == 'valid'
        self.image_transform = image_transform
        self.ij = ((0, 0), (0, 1), (0, 2),
                   (1, 0), (1, 1), (1, 2),
                   (2, 0), (2, 1), (2, 2))
        
        self.in_shape = in_shape
        self.out_shape = out_shape
        self.hls_limits = {'H': 5, 'L': 10, 'S': 10}
        if self.valid:
            self.img_augs = [self.get_random_augmentation() for _ in self.img_overlaps]

    def __len__(self):
        """Denotes the total number of samples"""
        return len(self.img_overlaps)

    def __getitem__(self, index):
        """Generates one sample of data"""
        if self.valid:
            img_aug = self.img_augs[index]
        else:
            img_aug = self.get_random_augmentation()
        X, y = self.get_data_pair(self.img_overlaps[index], img_aug)
        return X, y
    
    def get_random_augmentation(self):

        p = [0.3, 0.2, 0.2, 0.3]
        idx3 = np.random.choice(4, p=p)
        
        hls_idx = np.random.choice(3)
        hls_chan = idx_chan_map[hls_idx]
        hls_gain = np.random.choice(self.hls_limits[hls_chan]) + 1
        hls_gain = hls_gain if np.random.random() > 0.5 else -1 * hls_gain
        
        return ImgAugs(idx3, hls_chan, hls_gain)
    
    def color_shift(self, img, chan, gain):
        hls = to_hls(img)
        hls_shifted = channel_shift(hls, chan, gain)
        return to_bgr(hls_shifted)
    
    def get_tile(self, img, idx, sz=256):
        i, j = self.ij[idx]
        return img[i * sz:(i + 1) * sz, j * sz:(j + 1) * sz, :]
    
    def read_from_large(self, img_id, idx):
        img = cv2.imread(img_id)
        return self.get_tile(img, idx)
    
    def read_from_small(self, img_id, idx):
        dup_truth_path, img_filename = img_id.rsplit('/images_768/')
        row, col = parse('r{:3d}_c{:3d}.jpg', img_filename)
        i, j = self.ij[idx]
        tile_id = os.path.join(dup_truth_path, 'images_256', f'r{row + i:03d}_c{col + j:03d}.jpg')
        return cv2.imread(tile_id)
    
    def get_data_pair(self, img_overlap, img_aug):

        img1_id, img2_id, idx1, idx2, is_dup = img_overlap
        idx3, chan, gain = img_aug
        same_image = img1_id == img2_id
        
        if same_image:  # img1_id == img2_id
            if is_dup:  # idx1 == idx2
                if idx3 == 0:
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                elif idx3 == 1:
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2:
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3:
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                else:
                    raise ValueError
            else:  # idx1 != idx2
                # idx3 = 3
                if idx3 == 0: # fast
                    img = cv2.imread(img1_id)
                    tile1 = self.get_tile(img, idx1)
                    tile2 = self.get_tile(img, idx2)
                elif idx3 == 1: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slowest
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fastest
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                else:
                    raise ValueError
        else:  # img1_id != img2_id
            if is_dup:
                if idx3 == 0: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 1: # slow
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slow
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fast
                    # These end up being the same tile.
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                else:
                    raise ValueError
            else:
                if idx3 == 0: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 1: # slow
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slow
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fast
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                else:
                    raise ValueError

#         i, j = self.ij[idx1]
#         k, l = self.ij[idx2]
#         print(f'same_image, is_dup, idx3: {same_image*1}, {is_dup}, {idx3}\n{img1_id} {idx1} -> ({i},{j})\n{img2_id} {idx2} -> ({k},{l})\n')
        tile1 = cv2.cvtColor(tile1, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        tile2 = cv2.cvtColor(tile2, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        
        X = np.dstack([tile1, tile2]) if np.random.random() < 0.5 else np.dstack([tile2, tile1])
        X = self.image_transform(X)
        y = np.array([is_dup], dtype=np.float32)
        return X, y

In [None]:
sdcic = SDCImageContainer()
sdcic.preprocess_image_properties(
    image_md5hash_grids_file,
    image_bm0hash_grids_file,
    image_cm0hash_grids_file,
    image_greycop_grids_file,
    image_entropy_grids_file,
    image_issolid_grids_file)

In [None]:
# Datasets
# full_dataset = create_dataset_from_tiles_and_truth(sdcic)
full_dataset = create_dataset_from_tiles(sdcic)
# full_dataset = create_dataset_from_truth()
print(len(full_dataset))

In [None]:
df = pd.DataFrame().append(full_dataset)
df.to_csv('data/processed/full_SDC_dataset_from_tiles.csv', index=False)
print(len(full_dataset))

In [None]:
df = pd.read_csv('data/processed/full_SDC_dataset_from_tiles.csv')
full_dataset = list(zip(*[df[c].values.tolist() for c in df]))
print(len(full_dataset))

In [None]:
from torch._six import int_classes as _int_classes

# SENDTOMODULE
class SubsetSampler(data.Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        indices (sequence): a sequence of indices
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


class ImportanceSampler(data.Sampler):
    r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

    Arguments:
        num_records (int): Total number of samples in the dataset.
        num_samples (int): Number of samples to draw from the dataset.
        batch_size (int): Size of mini-batch.

    """

    def __init__(self, num_records, num_samples, batch_size):
        
        if not isinstance(num_records, _int_classes) or isinstance(num_records, bool) or num_records <= 0:
            raise ValueError('num_records should be a positive integeral value, but got num_records={}'.format(num_records))
        if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or num_samples <= 0:
            raise ValueError('num_samples should be a positive integeral value, but got num_samples={}'.format(num_samples))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or batch_size <= 0:
            raise ValueError('batch_size should be a positive integeral value, but got batch_size={}'.format(batch_size))
        if num_records < num_samples < batch_size:
            raise ValueError('num_samples must be less than num_records and greater than batch_size')
        if num_samples % batch_size != 0:
            raise ValueError(f'batch_size ({batch_size}) must divide num_samples ({num_samples}) evenly.')
            
        self.num_steps = 0
        self.num_epochs = 0
        self.num_records = num_records
        self.num_samples = num_samples
        self.num_batches = num_samples // batch_size
        self.batch_size = batch_size
        self.drop_last = True
        
        self.ages = np.zeros(num_records, dtype=int)
        self.visits = np.zeros(num_records, dtype=int)
#         self.losses = np.zeros(num_records) - np.log(0.5)  # dup or non-dup
        self.losses = np.ones(num_records)
        
        self.epoch_losses = np.ones(num_samples) * -1.0
        self._epoch_ages = None
        
        self.indices = np.random.choice(self.num_records, self.num_samples, replace=False)
        self.sampler = SubsetSampler(self.indices)
        
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

    @property
    def epoch_ages(self):
        if self._epoch_ages is None:
            # plus 1 since we're always lagging behind by 1 gradient step.
            x = np.arange(self.num_batches)[::-1] + 1
            self._epoch_ages = np.repeat(x, self.batch_size)
            assert len(self._epoch_ages) == self.num_samples
        return self._epoch_ages
    
    def update(self, batch_losses):
        idx = self.num_steps * self.batch_size
        self.epoch_losses[idx:idx + self.batch_size] = batch_losses[:, 0]
        self.num_steps += 1

    def on_epoch_end(self):
        """Use losses, visits and ages to update weights for samples"""
        
        assert np.min(self.epoch_losses) >= 0, np.min(self.epoch_losses)
        # age all records by the number of batches seen this epoch.
        self.ages += self.num_batches
        # only update the sampled records since their ages got reset.
        self.ages[self.indices] = self.epoch_ages
        # increment visits for samples by one.
        self.visits[self.indices] += 1
        # update losses
        self.losses[self.indices] = self.epoch_losses
        self.num_epochs += 1

        # normalize
#         log_ages = np.log(self.ages)
        norm_ages = self.ages / np.sum(self.ages)
        
        non_visits = self.num_epochs - self.visits
        norm_visits = non_visits / np.sum(non_visits)

        norm_losses = self.losses / np.sum(self.losses)

        weights = norm_ages + norm_visits + norm_losses
#         weights = log_ages * (np.sum(self.losses) / np.sum(log_ages)) + self.losses

#         norm_weights = weights / np.sum(weights)
#         self.indices = np.random.choice(self.num_records, self.num_samples, replace=False, p=self.norm_weights)
        self.indices = np.argsort(weights)[::-1][:self.num_samples]
        np.random.shuffle(self.indices)
        
        self.sampler = SubsetSampler(self.indices)
        self.num_steps = 0
        self.epoch_losses *= -1.0

In [None]:
input_shape = (6, 256, 256)
conv_layers = (16, 32, 64, 128, 256)
fc_layers = (128,)
output_size = 1

model_basename = 'dup_model'
date_time = get_datetime_now()

# Parameters
trainval_split = 0.9
sample_rate = 0.05
batch_size = 256
max_epochs = 200
num_workers = 18
learning_rate = 0.0001
best_loss = 9999.0
max_datapoints = 200*2**15
# lr_step_size = int(10 / sample_rate)
# lr_gamma = 0.1
print(max_datapoints)

In [None]:
np.random.shuffle(full_dataset)
trainval_dataset = full_dataset[:max_datapoints] if max_datapoints < len(full_dataset) else full_dataset
print(len(trainval_dataset))

In [None]:
n_train, n_valid = even_split(len(trainval_dataset), batch_size, trainval_split)
partition = {'train': trainval_dataset[:n_train], 'valid': trainval_dataset[-n_valid:]}
n_samples = batch_size * (int(round(n_train * sample_rate)) // batch_size)
print(n_train, n_valid, n_samples)

In [None]:
df = pd.DataFrame()
df['img_id'] = pd.Series(partition['train'])
df.to_csv(os.path.join('models', f'{model_basename}.{date_time}.avl.csv'), index=False)

In [None]:
sampler = ImportanceSampler(n_train, n_samples, batch_size)

loader_params = {
    'train': {'batch_sampler': sampler, 'num_workers': num_workers},
#     'train': {'batch_size': batch_size, 'shuffle': True, 'num_workers': num_workers},
    'valid': {'batch_size': batch_size, 'shuffle': False, 'num_workers': num_workers}}

image_transforms = {
    'train': transforms.Compose([
#         RandomHorizontalFlip(),
        transforms.ToTensor(),
        RandomTransformC4(with_identity=False),
    ]),
    'valid': transforms.Compose([
        RandomHorizontalFlip(p=1),
        transforms.ToTensor(),
#         RandomTransformC4(with_identity=False),
    ]),
}

image_datasets = {x: Dataset(partition[x], x, image_transforms[x]) for x in ['train', 'valid']}

# Generators
generators = {x: data.DataLoader(image_datasets[x], **loader_params[x]) for x in ['train', 'valid']}
print(len(generators['train']), len(generators['valid']))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark=True

In [None]:
model = DupCNN(input_shape, output_size, conv_layers, fc_layers)
model.cuda()
model.to(device)

# loss = torch.nn.MSELoss()
loss = torch.nn.BCELoss()
# loss = torch.nn.BCEWithLogitsLoss()
sample_loss = torch.nn.BCELoss(reduction='none')

optimizer = Adam(model.parameters(), lr=learning_rate)

In [None]:
# SENDTOMODULE
class ReduceLROnPlateau2(ReduceLROnPlateau):
    def __init__(self, *args, **kwargs):
        super(ReduceLROnPlateau2, self).__init__(*args, **kwargs)
    
    def get_lr(self):
        return [pg['lr'] for pg in self.optimizer.param_groups]

# scheduler = StepLR(optimizer, step_size=lr_step_size, gamma=lr_gamma)
scheduler = ReduceLROnPlateau2(optimizer, verbose=True)

In [None]:
header = ["epoch", "time", "lr", "train_loss", "train_acc", "val_loss", "val_acc", "train_time", "val_time"]
Stats = namedtuple('Stats', header)
csv_filename = os.path.join('models', f'{model_basename}.{date_time}.metrics.csv')

logger = CSVLogger(csv_filename, header)

# Start Training!

In [None]:
start_time = time.time()

# Loop over epochs
for epoch in range(max_epochs):
    
#     scheduler.step()
    
    # Training
    t0 = time.time()
    total_train_loss = 0
    total_train_acc = 0
    model.train()
    t = tnrange(len(generators['train']))
    train_iterator = iter(generators['train'])
    for i in t:
        t.set_description(f'Epoch {epoch + 1:>02d}')
        # Get next batch and push to GPU
        inputs, labels = train_iterator.next()
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        train_loss = loss(outputs, labels)
        train_loss.backward()
        optimizer.step()
        
        #Print statistics
        sampler.update(sample_loss(outputs, labels).data.cpu().numpy())
        total_train_loss += train_loss.data.item()
        y_pred = outputs > 0.5
        y_pred = y_pred.type_as(torch.cuda.FloatTensor())
        equality = labels == y_pred
        total_train_acc += equality.type_as(torch.FloatTensor()).numpy().mean()
        
        loss_str = f'{total_train_loss/(i + 1):.6f}'
        acc_str = f'{total_train_acc/(i + 1):.5f}'
        t.set_postfix(loss=loss_str, acc=acc_str)
    
    train_loss = total_train_loss/(i + 1)
    train_acc = total_train_acc/(i + 1)
    train_time = time.time() - t0
    
    # Validation
    t1 = time.time()
    total_val_loss = 0
    total_val_acc = 0
    t = tnrange(len(generators['valid']))
    valid_iterator = iter(generators['valid'])
    with torch.no_grad():
        model.eval()
        for i in t:
            t.set_description(f'Validation')
            # Get next batch and push to GPU
            inputs, labels = valid_iterator.next()
            inputs, labels = inputs.to(device), labels.to(device)

            #Forward pass
            val_outputs = model(inputs)
            val_loss = loss(val_outputs, labels)
            
            total_val_loss += val_loss.data.item()
            y_pred = val_outputs > 0.5
            y_pred = y_pred.type_as(torch.cuda.FloatTensor())
            equality = labels == y_pred
            total_val_acc += equality.type_as(torch.FloatTensor()).numpy().mean()
        
            loss_str = f'{total_val_loss/(i + 1):.6f}'
            acc_str = f'{total_val_acc/(i + 1):.5f}'
            t.set_postfix(loss=loss_str, acc=acc_str)

    val_loss = total_val_loss/(i + 1)
    val_acc = total_val_acc/(i + 1)
    val_time = time.time() - t1
    
    sampler.on_epoch_end()
    total_time = time.time() - start_time
    
    df = pd.DataFrame()
    df['ages'] = pd.Series(sampler.ages)
    df['visits'] = pd.Series(sampler.visits)
    df['losses'] = pd.Series(sampler.losses)
    df.to_csv(os.path.join('models', f'{model_basename}.{date_time}.{epoch + 1:02d}.{val_loss:.6f}.avl.csv'), index=False)

    if val_loss < best_loss:
        save_checkpoint(os.path.join('models', f'{model_basename}.{date_time}.{epoch + 1:02d}.{val_loss:.6f}.pth'), model)
        save_checkpoint(os.path.join('models', f'{model_basename}.{date_time}.best.pth'), model)
        save_checkpoint(os.path.join('models', f'{model_basename}.best.pth'), model)
        best_loss = val_loss
    
    stats = Stats(epoch+1, total_time, scheduler.get_lr()[0], train_loss, train_acc, val_loss, val_acc, train_time, val_time)
    logger.on_epoch_end(stats)
    
    scheduler.step(val_loss)

## Immediate post-processing section here in case we want to look at model before shutting down notebook.

In [None]:
from collections import Counter

from sdcdup.utils import fuzzy_diff
from sdcdup.utils import fuzzy_join
from sdcdup.utils import get_hamming_distance

In [None]:
visits = Counter(sampler.visits)

np.min(sampler.losses), np.max(sampler.losses)

In [None]:
norm_ages = sampler.ages / np.sum(sampler.ages)
norm_losses = sampler.losses / np.sum(sampler.losses)
log_ages = np.log(sampler.ages)
scaled_ages = log_ages * (np.sum(sampler.losses) / np.sum(log_ages))
weights = scaled_ages + sampler.losses
norm_weights = weights / np.sum(weights)

df = pd.DataFrame()
df['ages'] = pd.Series(sampler.ages, dtype=int)
df['visits'] = pd.Series(sampler.visits, dtype=int)
df['losses'] = pd.Series(sampler.losses)
df['norm_ages'] = pd.Series(norm_ages)
df['norm_losses'] = pd.Series(norm_losses)
df['log_ages'] = pd.Series(log_ages)
df['scaled_ages'] = pd.Series(scaled_ages)
df['weights'] = pd.Series(weights)
df['norm_weights'] = pd.Series(norm_weights)
df.describe()

In [None]:
df.sort_values(by=['losses', 'scaled_ages'])

In [None]:
bad_loss_indices = np.where(sampler.losses > 0.16)[0]
sampler.losses[bad_loss_indices]

In [None]:
bad_overlaps = []
for i in np.where(sampler.losses > 0.16)[0]:
    bad_overlaps.append(partition['train'][i])
len(bad_overlaps)

In [None]:
for i in bad_loss_indices:
    bol = partition['train'][i]
    
    bmh1 = sdcic.tile_bm0hash_grids[bol[0]][bol[2]]
    bmh2 = sdcic.tile_bm0hash_grids[bol[1]][bol[3]]
    score = get_hamming_distance(bmh1, bmh2, as_score=True)

    tile1 = sdcic.get_tile(sdcic.get_img(bol[0]), bol[2])
    tile2 = sdcic.get_tile(sdcic.get_img(bol[1]), bol[3])
    tile3 = fuzzy_join(tile1, tile2)
    pix3, cts3 = np.unique(tile3.flatten(), return_counts=True)

    print(bol, f'{np.max(cts3 / (256*256*3)):>.4f}', f'{sampler.losses[i]:>.6f}', np.sum(tile1 != tile2), fuzzy_diff(tile1, tile2))
    for chan in range(3):
        pix1, cts1 = np.unique(tile1[:, :, chan].flatten(), return_counts=True)
        pix2, cts2 = np.unique(tile2[:, :, chan].flatten(), return_counts=True)
        pix3, cts3 = np.unique(tile3[:, :, chan].flatten(), return_counts=True)

        max_idx1 = np.argmax(cts1)
        max_pix1 = pix1[max_idx1]
        max_cts1 = cts1[max_idx1]
#         print(f'{max_pix1:>3}', max_cts1, f'{max_cts1/65536:.4f}')
        max_idx2 = np.argmax(cts2)
        max_pix2 = pix2[max_idx2]
        max_cts2 = cts2[max_idx2]
#         print(f'{max_pix2:>3}', max_cts2, f'{max_cts2/65536:.4f}')
        max_idx3 = np.argmax(cts3)
        max_pix3 = pix3[max_idx3]
        max_cts3 = cts3[max_idx3]
#         print(f'{max_pix3:>3}', max_cts3, f'{max_cts3/65536:.4f}')

In [None]:
ii = 0
for bol in sorted(full_dataset):
    bmh1 = sdcic.tile_bm0hash_grids[bol[0]][bol[2]]
    bmh2 = sdcic.tile_bm0hash_grids[bol[1]][bol[3]]
    score = get_hamming_distance(bmh1, bmh2, as_score=True)

    if not (score == 256 and bol[4] == 0):
        continue
        
    tile1 = sdcic.get_tile(sdcic.get_img(bol[0]), bol[2])
    tile2 = sdcic.get_tile(sdcic.get_img(bol[1]), bol[3])
    tile3 = fuzzy_join(tile1, tile2)
    pix3, cts3 = np.unique(tile3.flatten(), return_counts=True)
    
    if np.max(cts3 / (256*256*3)) > 0.97:
        ii += 1
        print(ii, bol, f'{np.max(cts3 / (256*256*3)):>.4f}', np.sum(tile1 != tile2), fuzzy_diff(tile1, tile2))
    
    continue
    
    for chan in range(3):
        pix1, cts1 = np.unique(tile1[:, :, chan].flatten(), return_counts=True)
        pix2, cts2 = np.unique(tile2[:, :, chan].flatten(), return_counts=True)
        pix3, cts3 = np.unique(tile3[:, :, chan].flatten(), return_counts=True)

        max_idx1 = np.argmax(cts1)
        max_idx2 = np.argmax(cts2)
        max_idx3 = np.argmax(cts3)

        max_pix1 = pix1[max_idx1]
        max_pix2 = pix2[max_idx2]
        max_pix3 = pix3[max_idx3]
        
        max_cts1 = cts1[max_idx1]
        max_cts2 = cts2[max_idx2]
        max_cts3 = cts3[max_idx3]
        
        if min([max_cts1, max_cts2, max_cts3])/65536 >= 0.95:
            continue
        
        ii += 1
        print(ii, bol)
        print(f'{max_pix1:>3}', max_cts1, f'{max_cts1/65536:.4f}')
        print(f'{max_pix2:>3}', max_cts2, f'{max_cts2/65536:.4f}')
        print(f'{max_pix3:>3}', max_cts3, f'{max_cts3/65536:.4f}')    