In [None]:
import os
import time
from collections import namedtuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils import data
from torch.optim import Adam
from torchvision import transforms
from tqdm import tnrange

from sdcdup.utils import get_datetime_now
from sdcdup.utils import get_project_root
from sdcdup.utils import even_split
from sdcdup.utils import RandomHorizontalFlip
from sdcdup.utils import RandomTransformC4
from sdcdup.utils import CSVLogger
from sdcdup.utils import ReduceLROnPlateau2
from sdcdup.utils import ImportanceSampler
# from sdcdup.data import create_dataset_from_tiles_and_truth
from sdcdup.data import create_dataset_from_tiles
# from datasets import create_dataset_from_truth
from sdcdup.data import TrainDataset as Dataset
# from datasets import ExternalDataset as Dataset
from sdcdup.models import save_checkpoint
from sdcdup.models import DupCNN

%load_ext dotenv
%dotenv
%matplotlib inline
%reload_ext autoreload
%autoreload 2

project_root = get_project_root()
train_image_dir = os.path.join(project_root, os.getenv('RAW_DATA_DIR'), 'train_768')
train_tile_dir = os.path.join(project_root, os.getenv('PROCESSED_DATA_DIR'), 'train_256')
full_dataset_filename = os.path.join(project_root, os.getenv('PROCESSED_DATA_DIR'), 'full_SDC_dataset_from_tiles.csv')

In [None]:
# Datasets
if os.path.exists(full_dataset_filename):
    df = pd.read_csv(full_dataset_filename)
    full_dataset = list(zip(*[df[c].values.tolist() for c in df]))
else:
#     full_dataset = create_dataset_from_tiles_and_truth()
    full_dataset = create_dataset_from_tiles()
#     full_dataset = create_dataset_from_truth()
    df = pd.DataFrame().append(full_dataset)
    df.to_csv(full_dataset_filename, index=False)
print(len(full_dataset))

In [None]:
input_shape = (6, 256, 256)
conv_layers = (16, 32, 64, 128, 256)
fc_layers = (128,)
output_size = 1

model_basename = 'dup_model'
date_time = get_datetime_now()

# Parameters
trainval_split = 0.9
sample_rate = 0.05
batch_size = 256
max_epochs = 200
num_workers = 18
learning_rate = 0.0001
best_loss = 9999.0
max_datapoints = 200*2**15
print(max_datapoints)

In [None]:
np.random.shuffle(full_dataset)
trainval_dataset = full_dataset[:max_datapoints] if max_datapoints < len(full_dataset) else full_dataset
print(len(trainval_dataset))

In [None]:
n_train, n_valid = even_split(len(trainval_dataset), batch_size, trainval_split)
partition = {'train': trainval_dataset[:n_train], 'valid': trainval_dataset[-n_valid:]}
n_samples = batch_size * (int(round(n_train * sample_rate)) // batch_size)
print(n_train, n_valid, n_samples)

In [None]:
df = pd.DataFrame()
df['img_id'] = pd.Series(partition['train'])
df.to_csv(os.path.join(project_root, 'models', f'{model_basename}.{date_time}.avl.csv'), index=False)

In [None]:
sampler = ImportanceSampler(n_train, n_samples, batch_size)

loader_params = {
    'train': {'batch_sampler': sampler, 'num_workers': num_workers},
#     'train': {'batch_size': batch_size, 'shuffle': True, 'num_workers': num_workers},
    'valid': {'batch_size': batch_size, 'shuffle': False, 'num_workers': num_workers}}

image_transforms = {
    'train': transforms.Compose([
#         RandomHorizontalFlip(),
        transforms.ToTensor(),
        RandomTransformC4(with_identity=False),
    ]),
    'valid': transforms.Compose([
        RandomHorizontalFlip(p=1),
        transforms.ToTensor(),
#         RandomTransformC4(with_identity=False),
    ]),
}

image_datasets = {x: Dataset(partition[x], x, image_transforms[x]) for x in ['train', 'valid']}

# Generators
generators = {x: data.DataLoader(image_datasets[x], **loader_params[x]) for x in ['train', 'valid']}
print(len(generators['train']), len(generators['valid']))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark=True

In [None]:
model = DupCNN(input_shape, output_size, conv_layers, fc_layers)
model.cuda()
model.to(device)

# loss = torch.nn.MSELoss()
loss = torch.nn.BCELoss()
# loss = torch.nn.BCEWithLogitsLoss()
sample_loss = torch.nn.BCELoss(reduction='none')

optimizer = Adam(model.parameters(), lr=learning_rate)

scheduler = ReduceLROnPlateau2(optimizer, verbose=True)

In [None]:
header = ["epoch", "time", "lr", "train_loss", "train_acc", "val_loss", "val_acc", "train_time", "val_time"]
Stats = namedtuple('Stats', header)
csv_filename = os.path.join(project_root, 'models', f'{model_basename}.{date_time}.metrics.csv')

logger = CSVLogger(csv_filename, header)

# Start Training!

In [None]:
start_time = time.time()

# Loop over epochs
for epoch in range(max_epochs):
    
#     scheduler.step()
    
    # Training
    t0 = time.time()
    total_train_loss = 0
    total_train_acc = 0
    model.train()
    t = tnrange(len(generators['train']))
    train_iterator = iter(generators['train'])
    for i in t:
        t.set_description(f'Epoch {epoch + 1:>02d}')
        # Get next batch and push to GPU
        inputs, labels = train_iterator.next()
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        train_loss = loss(outputs, labels)
        train_loss.backward()
        optimizer.step()
        
        #Print statistics
        sampler.update(sample_loss(outputs, labels).data.cpu().numpy())
        total_train_loss += train_loss.data.item()
        y_pred = outputs > 0.5
        y_pred = y_pred.type_as(torch.cuda.FloatTensor())
        equality = labels == y_pred
        total_train_acc += equality.type_as(torch.FloatTensor()).numpy().mean()
        
        loss_str = f'{total_train_loss/(i + 1):.6f}'
        acc_str = f'{total_train_acc/(i + 1):.5f}'
        t.set_postfix(loss=loss_str, acc=acc_str)
    
    train_loss = total_train_loss/(i + 1)
    train_acc = total_train_acc/(i + 1)
    train_time = time.time() - t0
    
    # Validation
    t1 = time.time()
    total_val_loss = 0
    total_val_acc = 0
    t = tnrange(len(generators['valid']))
    valid_iterator = iter(generators['valid'])
    with torch.no_grad():
        model.eval()
        for i in t:
            t.set_description(f'Validation')
            # Get next batch and push to GPU
            inputs, labels = valid_iterator.next()
            inputs, labels = inputs.to(device), labels.to(device)

            #Forward pass
            val_outputs = model(inputs)
            val_loss = loss(val_outputs, labels)
            
            total_val_loss += val_loss.data.item()
            y_pred = val_outputs > 0.5
            y_pred = y_pred.type_as(torch.cuda.FloatTensor())
            equality = labels == y_pred
            total_val_acc += equality.type_as(torch.FloatTensor()).numpy().mean()
        
            loss_str = f'{total_val_loss/(i + 1):.6f}'
            acc_str = f'{total_val_acc/(i + 1):.5f}'
            t.set_postfix(loss=loss_str, acc=acc_str)

    val_loss = total_val_loss/(i + 1)
    val_acc = total_val_acc/(i + 1)
    val_time = time.time() - t1
    
    sampler.on_epoch_end()
    total_time = time.time() - start_time
    
    df = pd.DataFrame()
    df['ages'] = pd.Series(sampler.ages)
    df['visits'] = pd.Series(sampler.visits)
    df['losses'] = pd.Series(sampler.losses)
    df.to_csv(os.path.join(project_root, 'models', f'{model_basename}.{date_time}.{epoch + 1:02d}.{val_loss:.6f}.avl.csv'), index=False)

    if val_loss < best_loss:
        save_checkpoint(os.path.join(project_root, 'models', f'{model_basename}.{date_time}.{epoch + 1:02d}.{val_loss:.6f}.pth'), model)
        save_checkpoint(os.path.join(project_root, 'models', f'{model_basename}.{date_time}.best.pth'), model)
        save_checkpoint(os.path.join(project_root, 'models', f'{model_basename}.best.pth'), model)
        best_loss = val_loss
    
    stats = Stats(epoch+1, total_time, scheduler.get_lr()[0], train_loss, train_acc, val_loss, val_acc, train_time, val_time)
    logger.on_epoch_end(stats)
    
    scheduler.step(val_loss)

## (Optional) Immediate post-processing section here in case we want to analyze model before shutting down notebook.

In [None]:
from collections import Counter

from sdcdup.utils import fuzzy_diff
from sdcdup.utils import fuzzy_join
from sdcdup.utils import get_hamming_distance

In [None]:
visits = Counter(sampler.visits)

np.min(sampler.losses), np.max(sampler.losses)

In [None]:
norm_ages = sampler.ages / np.sum(sampler.ages)
norm_losses = sampler.losses / np.sum(sampler.losses)
log_ages = np.log(sampler.ages)
scaled_ages = log_ages * (np.sum(sampler.losses) / np.sum(log_ages))
weights = scaled_ages + sampler.losses
norm_weights = weights / np.sum(weights)

df = pd.DataFrame()
df['ages'] = pd.Series(sampler.ages, dtype=int)
df['visits'] = pd.Series(sampler.visits, dtype=int)
df['losses'] = pd.Series(sampler.losses)
df['norm_ages'] = pd.Series(norm_ages)
df['norm_losses'] = pd.Series(norm_losses)
df['log_ages'] = pd.Series(log_ages)
df['scaled_ages'] = pd.Series(scaled_ages)
df['weights'] = pd.Series(weights)
df['norm_weights'] = pd.Series(norm_weights)
df.describe()

In [None]:
df.sort_values(by=['losses', 'scaled_ages'])

In [None]:
bad_loss_indices = np.where(sampler.losses > 0.16)[0]
sampler.losses[bad_loss_indices]

In [None]:
bad_overlaps = []
for i in np.where(sampler.losses > 0.16)[0]:
    bad_overlaps.append(partition['train'][i])
len(bad_overlaps)

In [None]:
for i in bad_loss_indices:
    bol = partition['train'][i]
    
    bmh1 = sdcic.img_metrics['bmh'][bol[0]][bol[2]]
    bmh2 = sdcic.img_metrics['bmh'][bol[1]][bol[3]]
    score = get_hamming_distance(bmh1, bmh2, as_score=True)

    tile1 = sdcic.get_tile(sdcic.get_img(bol[0]), bol[2])
    tile2 = sdcic.get_tile(sdcic.get_img(bol[1]), bol[3])
    tile3 = fuzzy_join(tile1, tile2)
    pix3, cts3 = np.unique(tile3.flatten(), return_counts=True)

    print(bol, f'{np.max(cts3 / (256*256*3)):>.4f}', f'{sampler.losses[i]:>.6f}', np.sum(tile1 != tile2), fuzzy_diff(tile1, tile2))
    for chan in range(3):
        pix1, cts1 = np.unique(tile1[:, :, chan].flatten(), return_counts=True)
        pix2, cts2 = np.unique(tile2[:, :, chan].flatten(), return_counts=True)
        pix3, cts3 = np.unique(tile3[:, :, chan].flatten(), return_counts=True)

        max_idx1 = np.argmax(cts1)
        max_pix1 = pix1[max_idx1]
        max_cts1 = cts1[max_idx1]
#         print(f'{max_pix1:>3}', max_cts1, f'{max_cts1/65536:.4f}')
        max_idx2 = np.argmax(cts2)
        max_pix2 = pix2[max_idx2]
        max_cts2 = cts2[max_idx2]
#         print(f'{max_pix2:>3}', max_cts2, f'{max_cts2/65536:.4f}')
        max_idx3 = np.argmax(cts3)
        max_pix3 = pix3[max_idx3]
        max_cts3 = cts3[max_idx3]
#         print(f'{max_pix3:>3}', max_cts3, f'{max_cts3/65536:.4f}')

In [None]:
ii = 0
for bol in sorted(full_dataset):
    bmh1 = sdcic.img_metrics['bmh'][bol[0]][bol[2]]
    bmh2 = sdcic.img_metrics['bmh'][bol[1]][bol[3]]
    score = get_hamming_distance(bmh1, bmh2, as_score=True)

    if not (score == 256 and bol[4] == 0):
        continue
        
    tile1 = sdcic.get_tile(sdcic.get_img(bol[0]), bol[2])
    tile2 = sdcic.get_tile(sdcic.get_img(bol[1]), bol[3])
    tile3 = fuzzy_join(tile1, tile2)
    pix3, cts3 = np.unique(tile3.flatten(), return_counts=True)
    
    if np.max(cts3 / (256*256*3)) > 0.97:
        ii += 1
        print(ii, bol, f'{np.max(cts3 / (256*256*3)):>.4f}', np.sum(tile1 != tile2), fuzzy_diff(tile1, tile2))
    
    continue
    
    for chan in range(3):
        pix1, cts1 = np.unique(tile1[:, :, chan].flatten(), return_counts=True)
        pix2, cts2 = np.unique(tile2[:, :, chan].flatten(), return_counts=True)
        pix3, cts3 = np.unique(tile3[:, :, chan].flatten(), return_counts=True)

        max_idx1 = np.argmax(cts1)
        max_idx2 = np.argmax(cts2)
        max_idx3 = np.argmax(cts3)

        max_pix1 = pix1[max_idx1]
        max_pix2 = pix2[max_idx2]
        max_pix3 = pix3[max_idx3]
        
        max_cts1 = cts1[max_idx1]
        max_cts2 = cts2[max_idx2]
        max_cts3 = cts3[max_idx3]
        
        if min([max_cts1, max_cts2, max_cts3])/65536 >= 0.95:
            continue
        
        ii += 1
        print(ii, bol)
        print(f'{max_pix1:>3}', max_cts1, f'{max_cts1/65536:.4f}')
        print(f'{max_pix2:>3}', max_cts2, f'{max_cts2/65536:.4f}')
        print(f'{max_pix3:>3}', max_cts3, f'{max_cts3/65536:.4f}')    