# Less-Referenced Mosaic Testing

This notebook provides a cross-section of the Less-Referenced Mosaic creation process.


# Setup


## Imports


In [None]:
import copy
import glob
import inspect
import os
import shutil

In [None]:
import cv2
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.utils import check_random_state
import yaml

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
sns.set_style('white')

In [None]:
from night_horizons import utils, preprocess, reference, mosaic, raster, pipelines, features

## Settings


In [None]:
with open('./config.yml', "r", encoding='UTF-8') as file:
    settings = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
local_settings = {
    'include_raw_images': True,
    'mosaic_filepath': 'mosaics/less_referenced_test.tiff',
    'random_state': 1682142,
    'train_size': {
        0: 0,
        1: 1,
        2: 0,
    },

    'show_images': False,

    'n_loops': 6,
    'checkpoint_freq': 4,
    'save_bad_images': True,
    'bad_images_dir': '../test_data/feature_matching/',

    'image_joiners_options': {
        'defaults': {
            'feature_detector': 'AKAZE',
            'feature_matcher': 'BFMatcher',
            'log_keys': ['dst_kp', 'dst_pts', 'src_kp', 'src_pts', 'mask', 'abs_det_M'],
            'debug_mode': True,
        },
        'variations': [
            {'n_matches_used': 100, },
        ],
    },

    'padding': 2.,
    'use_approximate_georeferencing': True,
    # This set of choices assumes we have really good starting positions.
    # This is useful for debugging.
    # 'padding': 0.1,
    # 'use_approximate_georeferencing': False,

    # The fraction of non-nan georeferencings required to claim success.
    # We only require a bare minimum here. Performance beyond approximate
    # functionality should be evaluated elsewhere
    'acceptance_fraction': 0.8,

    # These values will be logged and checked for consistency.
    'log_keys': ['x_off', 'y_off', 'x_size', 'y_size', 'dst_img', 'dst_kp', 'src_kp', 'abs_det_M', 'return_code'],
}
settings.update(local_settings)

## Parse Settings


In [None]:
settings['mosaic_filepath'] = os.path.join(settings['data_dir'], settings['mosaic_filepath'])

In [None]:
for key, relpath in settings['paths_relative_to_data_dir'].items():
    settings[key] = os.path.join(settings['data_dir'], relpath)

In [None]:
random_state = check_random_state(settings['random_state'])

In [None]:
palette = sns.color_palette(settings['color_palette'])

In [None]:
crs = settings['crs']

In [None]:
constructor_kwargs = dict(
    image_joiner=features.ImageJoinerQueue(**settings['image_joiners_options']),
    filepath=settings['mosaic_filepath'],
    padding=settings['padding'],
    file_exists='error',
    log_keys=settings['log_keys'],
    crs=crs,
    debug_mode=True,
    bad_images_dir=settings['bad_images_dir'],
    checkpoint_freq=settings['checkpoint_freq'],
)

In [None]:
# Remove any existing files
base_fp = settings['mosaic_filepath']
test_fps = {
    'mosaic': base_fp,
    'y_pred': base_fp.replace('.tiff', '_y_pred.csv'),
    'settings': base_fp.replace('.tiff', '_settings.yaml'), 
    'log': base_fp.replace('.tiff', '_log.csv'), 
    'checkpoint': base_fp.replace(
        '.tiff', f"_i{settings['checkpoint_freq']:06d}.tiff")
}
def clear_files():
    for key, fp in test_fps.items():
        if os.path.isfile(fp):
            os.remove(fp)
clear_files()

# Prepare Data

The first part is to prepare the data (AKA extract/transform/load).


## Get filepaths


In [None]:
# Get the referenced filepaths, divided according to camera number
referenced_fps = {i: utils.discover_data(settings['referenced_images_dir'], ['tif', 'tiff'], pattern=r'Geo\s\d+_' + f'{i}.tif') for i in range(3)}

In [None]:
raw_fps = utils.discover_data(settings['images_dir'], ['tif', 'tiff', 'raw'])

## Train-Test Split

We split the data into training data (data that is georeferenced) and test data (data that is not georeferenced, or for which we don't use the georeferencing information when we're building the models).

We set the train size to some small number, because ideally the user only needs to georeference a couple of images manually.


In [None]:
# Get the training sample for each camera
referenced_fps_train = []
referenced_fps_test = []
for camera_num, train_size_i in settings['train_size'].items():

    # When there's no training for this camera
    if train_size_i == 0:
        referenced_fps_test.append(referenced_fps[camera_num])
        continue
        
    referenced_fps_train_i, referenced_fps_test_i = train_test_split(
        referenced_fps[camera_num],
        train_size=train_size_i,
        random_state=settings['random_state'],
        shuffle=True,
    )
    referenced_fps_train.append(referenced_fps_train_i)
    referenced_fps_test.append(referenced_fps_test_i)


In [None]:
referenced_fps_train = pd.concat(referenced_fps_train, ignore_index=True)
referenced_fps_test = pd.concat(referenced_fps_test, ignore_index=True)

## Combine Referenced and Raw


In [None]:
# Adjust the index so we don't have duplicates
raw_fps.index += referenced_fps_test.size

In [None]:
# Actual combination
fps_train = referenced_fps_train
fps_test = referenced_fps_test
if settings['include_raw_images']:
    fps = pd.concat([referenced_fps_test, raw_fps])
else:
    fps = referenced_fps_test

In [None]:
# Expected number of training files
assert len(pd.unique(referenced_fps_train.index)) == np.sum(list(settings['train_size'].values()))
# Consistent indices for test set
assert len(pd.unique(referenced_fps_test.index)) == len(referenced_fps_test.index)

## Preprocessing


### y values

We get the y-values first because we use a model fitted to them to get the X values


In [None]:
preprocessing_pipeline_y = preprocess.GeoTIFFPreprocesser(crs=crs)

In [None]:
# Get the geo-transforms used for training
y_train = preprocessing_pipeline_y.fit_transform(fps_train)
y_test = preprocessing_pipeline_y.fit_transform(fps_test)

### X values

We use the sensor (high-altitude balloon) positions to provide approximate georeferencing, which will be useful for saving computational time when building the unreferenced mosaic.


In [None]:
# This is the pipeline for approximate georeferencing
preprocessing_steps = pipelines.PreprocessingPipelines.nitelite_preprocessing_steps(
    crs=crs,
    use_approximate_georeferencing=settings['use_approximate_georeferencing'],
)
preprocessing_pipeline = Pipeline(preprocessing_steps)
preprocessing_pipeline

In [None]:
# Fit the pipeline
preprocessing_pipeline = preprocessing_pipeline.fit(
    fps_train,
    y_train,
    metadata__img_log_fp=settings['img_log_fp'],
    metadata__imu_log_fp=settings['imu_log_fp'],
    metadata__gps_log_fp=settings['gps_log_fp'],
)

In [None]:
# Get out the X values we'll use for the mosaic
X_train = preprocessing_pipeline.transform(fps_train)
X = preprocessing_pipeline.transform(fps)

In [None]:
# Check the camera numbers
for camera_num in range(3):
    assert (X_train['camera_num'] == camera_num).sum() == settings['train_size'][camera_num], 'Camera numbers are not as expected'

In [None]:
# Check the order
np.testing.assert_allclose(X['order'], np.arange(len(X)))
assert X.loc[X['camera_num'] == 1, 'order'].max() < X.loc[X['camera_num'] == 0, 'order'].min(), 'Some camera 1 images come before camera 0 images.'
assert (np.diff(X.loc[X['camera_num'] == 1, 'd_to_center']) < 0).sum() == 0, 'Some smaller distances appear out of order.'

### Check consistency


In [None]:
# We don't want to drop more than a few files when we're working with the referenced dataset
if not settings['include_raw_images']:
    assert y_test.index.size - X.index.size < 3, 'Too many files dropped.'

In [None]:
fps.index.size

In [None]:
# Drop the y values and filepaths that were filtered out
fps = fps.loc[X.index]
X_test = X.loc[X.index.isin(y_test.index)]
y_test = y_test.loc[X_test.index]

In [None]:
# Check that our test Xs and ys align
n_bad = (y_test['filepath'] != X_test['filepath']).sum()
assert n_bad == 0, f'{n_bad} wrong filepaths'

### Look at Order


In [None]:
# Let's take a look.
if settings['show_images']:
    sp = sns.scatterplot(
        data=X,
        x='x_center',
        y='y_center',
        hue='camera_num',
    )
    sp.set_aspect('equal')

# Pieces of the Mosaic


### Initialization


#### Test

Check that initialization works, first with a mosaic that only uses the training data.


In [None]:
small_less_reffed_mosaic = mosaic.LessReferencedMosaic(
    dataset_padding=100.,
    **constructor_kwargs
)

In [None]:
small_less_reffed_mosaic.fit(
    X=y_train,
    approx_y=y_train,
)

In [None]:
# The full mosaic image that's saved
dataset = small_less_reffed_mosaic.open_dataset()
mosaic_img = dataset.ReadAsArray().transpose(1, 2, 0)
mosaic_image = raster.ReferencedImage(
    mosaic_img[:, :, :3],
    [small_less_reffed_mosaic.x_min_, small_less_reffed_mosaic.x_max_],
    [small_less_reffed_mosaic.y_min_, small_less_reffed_mosaic.y_max_]
)

In [None]:
# The actual image used to make it
original_image = raster.ReferencedImage.open(y_train.iloc[0]['filepath'])

In [None]:
if settings['show_images']:
    # Compare the mosaic to the actual
    mosaic_image.show(crs='cartesian', img='semitransparent_img')

    fig = plt.gcf()
    ax = plt.gca()

    original_image.show(crs='cartesian', img='semitransparent_img', ax=ax)   

In [None]:
# Check the centers
mosaic_center = np.array(mosaic_image.cart_bounds).mean(axis=1)
original_center = np.array(original_image.cart_bounds).mean(axis=1)
d_between_centers = np.linalg.norm(mosaic_center - original_center)
np.testing.assert_allclose(d_between_centers, 0.)

In [None]:
# Check the widths
mosaic_width, mosaic_height = np.diff(mosaic_image.cart_bounds, axis=1).flatten()
original_width, original_height = np.diff(original_image.cart_bounds, axis=1).flatten()
np.testing.assert_allclose(mosaic_width, original_width + 2. * small_less_reffed_mosaic.dataset_padding)
np.testing.assert_allclose(mosaic_height, original_height + 2. * small_less_reffed_mosaic.dataset_padding)

In [None]:
# Delete the temporary initialization
dataset.FlushCache()
dataset = None
clear_files()

#### Actual full initialization and fit


In [None]:
less_reffed_mosaic = mosaic.LessReferencedMosaic(**constructor_kwargs)

In [None]:
# This creates the dataset and adds the referenced mosaic.
less_reffed_mosaic.fit(
    X=y_train,
    approx_y=X,
)

#### Validate settings saving


In [None]:
# Open the file
with open(less_reffed_mosaic.settings_filepath_, 'r', encoding='UTF-8') as file:
    saved_settings = yaml.load(file, Loader=yaml.BaseLoader)

In [None]:
# Check that all the values exist
fullargspec = inspect.getfullargspec(mosaic.LessReferencedMosaic.__init__)
expected_args = fullargspec.args
for key in expected_args:
    if key == 'self':
        continue
    assert key in saved_settings, f'attr {key} not found in settings'

## Convert geotransforms to pixel offsets and counts


In [None]:
(
    X['x_off'], X['y_off'],
    X['x_size'], X['y_size']
) = less_reffed_mosaic.physical_to_pixel(
    X['x_min'], X['x_max'],
    X['y_min'], X['y_max'],
    padding = less_reffed_mosaic.padding * X['spatial_error']
)

In [None]:
(
    y_train['x_off'], y_train['y_off'],
    y_train['x_size'], y_train['y_size']
) = less_reffed_mosaic.physical_to_pixel(
    y_train['x_min'], y_train['x_max'],
    y_train['y_min'], y_train['y_max'],
)

## First Image

We'll test the first loop in greater detail than the others.


In [None]:
i = 0
row = X.iloc[i]
dataset = less_reffed_mosaic.open_dataset()

In [None]:
mosaic_img = dataset.ReadAsArray().transpose(1, 2, 0)

In [None]:
image_joiner = less_reffed_mosaic.image_joiner.image_joiners[0]

### Search Region in the Context of the Full Mosaic


In [None]:
# Expected bounds
x_off = row['x_off']
y_off = row['y_off']
x_size = row['x_size']
y_size = row['y_size']

In [None]:
(
x_off_nopad, y_off_nopad,
x_size_nopad, y_size_nopad,
) = less_reffed_mosaic.physical_to_pixel(
    row['x_min'], row['x_max'],
    row['y_min'], row['y_max'],
)

In [None]:
if settings['show_images']:
    fig = plt.figure(figsize=(20,10))
    ax = plt.gca()

    # Current mosaic
    ax.imshow(mosaic_img)

    # The first image location
    rect = patches.Rectangle(
        (x_off, y_off),
        x_size,
        y_size,
        linewidth = 3,
        facecolor = 'none',
        edgecolor = palette[0],
        label='with padding',
    )
    ax.add_patch(rect)

    # The non-padded first image location
    rect = patches.Rectangle(
        (x_off_nopad, y_off_nopad),
        x_size_nopad,
        y_size_nopad,
        linewidth = 3,
        facecolor = 'none',
        edgecolor = palette[1],
        label='no padding',
    )
    ax.add_patch(rect)

    ax.set_aspect('equal')

    ax.legend()

### Search Region Image


In [None]:
# The existing mosaic at this location
dst_img = less_reffed_mosaic.get_image(dataset, x_off, y_off, x_size, y_size)

This is plotted below with matched features.


In [None]:
# At this time we expect all data added to the mosaic to be within the bounds of the search region, if we're using approximate georeferencing
if settings['use_approximate_georeferencing']:
    assert dst_img.sum() == mosaic_img.sum()

In [None]:
row_train = y_train.iloc[0]
# Here's a zoomed in version, so we know what we're looking at
zoom_dst_img = less_reffed_mosaic.get_image(
    dataset,
    row_train['x_off'], row_train['y_off'],
    row_train['x_size'], row_train['y_size']
)
if settings['show_images']:
    plt.imshow(zoom_dst_img)

### Search Region KeyPoints

We get these for later.


In [None]:
# Get the features from the original mosaic
dst_kp, dst_des = image_joiner.detect_and_compute(dst_img)

In [None]:
# Transform the dst keypoints to mosaic frame
dst_pts = cv2.KeyPoint_convert(dst_kp)
dsframe_dst_pts = dst_pts + np.array([x_off, y_off])
dsframe_dst_des = copy.copy(dst_des)

In [None]:
if settings['show_images']:
    # Look at the image and its keypoints
    raster.Image(dst_img).show()

    fig = plt.gcf()
    ax = plt.gca()

    ax.scatter(
        dst_pts[:,0],
        dst_pts[:,1],
        color='none',
        edgecolor='w',
        linewidth=3,
        s=150,
    )

### New Image


In [None]:
src_img = utils.load_image(
    row['filepath'],
    dtype=less_reffed_mosaic.dtype,
)

In [None]:
src_kp, src_des = image_joiner.detect_and_compute(src_img)
src_pts = cv2.KeyPoint_convert(src_kp)

In [None]:
if settings['show_images']:
    # Look at the image and its keypoints
    raster.Image(src_img).show()

    fig = plt.gcf()
    ax = plt.gca()

    ax.scatter(
        src_pts[:,0],
        src_pts[:,1],
        color='none',
        edgecolor='w',
        linewidth=3,
        s=150,
    )

### Feature Matching


In [None]:
# Get and validate the transform predicted from feature matching
M = image_joiner.find_homography(src_kp, src_des, dst_kp, dst_des)

In [None]:
# Inspect relationship
mask = image_joiner.log['mask'].astype(bool)
valid_src_pts = image_joiner.log['src_pts'][mask].reshape((mask.sum(), 2))
valid_dst_pts = image_joiner.log['dst_pts'][mask].reshape((mask.sum(), 2))

In [None]:
if settings['show_images']:
    subplot_mosaic = [['dst_img', 'src_img']]
    fig = plt.figure(figsize=(20,10))
    ax_dict = fig.subplot_mosaic(subplot_mosaic)

    ax = ax_dict['dst_img']
    ax.imshow(dst_img)

    ax = ax_dict['src_img']
    ax.imshow(src_img)

    for i in range(valid_src_pts.shape[0]):

        con = patches.ConnectionPatch(
            xyA=valid_dst_pts[i],
            xyB=valid_src_pts[i],
            coordsA='data',
            coordsB='data',
            axesA=ax_dict['dst_img'],
            axesB=ax_dict['src_img'],
            color=palette[1],
            linewidth=3,
        )
        ax.add_artist(con)

In [None]:
# raster.Image(src_img).save('../test_data/feature_matching/src_0.tiff')
# raster.Image(dst_img).save('../test_data/feature_matching/dst_0.tiff')

In [None]:
image_joiner.validate_homography(M)

### Warp the Source Image


In [None]:
# Warp the image being fit
warped_img = image_joiner.warp(src_img, dst_img, M)

In [None]:
if settings['show_images']:
    raster.Image(warped_img[:, :, :3]).show(img='semitransparent_img')

In [None]:
# The warped image should have the same dimensions as the dst img
assert warped_img.shape[:2] == dst_img.shape[:2]

### Blend the images


In [None]:
blended_img = image_joiner.blend(
    src_img=warped_img,
    dst_img=dst_img,
)

In [None]:
# Show
if settings['show_images']:
    raster.Image(blended_img[:, :, :3]).show(img='semitransparent_img')

### Save and look at the mosaic


In [None]:
less_reffed_mosaic.save_image(dataset, blended_img, x_off, y_off)

In [None]:
# Get the region of just the first image for comparison from before
zoom_dst_img_after = less_reffed_mosaic.get_image(
    dataset,
    row_train['x_off'], row_train['y_off'],
    row_train['x_size'], row_train['y_size'],
)

In [None]:
# More content should have been added
assert zoom_dst_img_after.sum() > zoom_dst_img.sum()

In [None]:
# View
if settings['show_images']:
    subplot_mosaic = [['before', 'after']]
    fig = plt.figure(figsize=(20,10))
    ax_dict = fig.subplot_mosaic(subplot_mosaic)

    ax = ax_dict['before']
    raster.Image(zoom_dst_img[:, :, :3]).show(img='semitransparent_img', ax=ax)

    ax = ax_dict['after']
    raster.Image(zoom_dst_img_after[:, :, :3]).show(img='semitransparent_img', ax=ax)

### Warp the Keypoints


In [None]:
# Transform to local frame and then the full mosaic frame
src_pts = cv2.KeyPoint_convert(src_kp)
global_src_pts = cv2.perspectiveTransform(src_pts.reshape(-1, 1, 2), M).reshape(-1, 2)
global_src_pts += np.array([x_off, y_off])

In [None]:
# Store the transformed points for the next loop
dsframe_dst_pts = np.append(dsframe_dst_pts, global_src_pts, axis=0)
dsframe_dst_des = np.append(dsframe_dst_des, src_des, axis=0)

In [None]:
if settings['show_images']:
    fig = plt.figure()
    ax = plt.gca()

    sns.scatterplot(
        x=dsframe_dst_pts[:,0],
        y=dsframe_dst_pts[:,1],
        ax = ax,
    )

    rect = patches.Rectangle(
        (x_off, y_off),
        x_size,
        y_size,
        linewidth = 3,
        facecolor = 'none',
        edgecolor = palette[0],
    )
    ax.add_patch(rect)

    ax.set_xlim(0, less_reffed_mosaic.dataset_.RasterXSize)
    ax.set_ylim(less_reffed_mosaic.dataset_.RasterYSize, 0)
    ax.set_aspect('equal')

In [None]:
# Automated check that everything's in bounds
not_in_bounds = ~(
    (x_off <= dsframe_dst_pts[:,0] )
    & (dsframe_dst_pts[:,0] <= x_off + x_size)
    & (y_off <= dsframe_dst_pts[:,1] )
    & (dsframe_dst_pts[:,1] <= y_off + y_size)
)
assert not_in_bounds.sum() == 0

### Check the georeferencing


In [None]:
# Call the fn
warped_x_off, warped_y_off, warped_x_size, warped_y_size = image_joiner.warp_bounds(src_img, M)
warped_x_off += x_off
warped_y_off += y_off

In [None]:
# Convert to physical
warped_x_min, warped_x_max, warped_y_min, warped_y_max = less_reffed_mosaic.pixel_to_physical(
    warped_x_off, warped_y_off, warped_x_size, warped_y_size)

In [None]:
warped_center = np.array([
    0.5 * (warped_x_min + warped_x_max),
    0.5 * (warped_y_min + warped_y_max),
])

In [None]:
# Compare to recorded
if not settings['include_raw_images']:

    # Get the recorded bounds
    recorded_x_min, recorded_x_max, recorded_y_min, recorded_y_max = y_test.loc[row.name, ['x_min', 'x_max', 'y_min', 'y_max']]
    
    # Get the center
    recorded_center = np.array([
        0.5 * (recorded_x_min + recorded_x_max),
        0.5 * (recorded_y_min + recorded_y_max),
    ])
    
    # Check the centers
    assert np.linalg.norm(warped_center - recorded_center) < 500.

In [None]:
if settings['show_images']:
    fig = plt.figure(figsize=(20,10))
    ax = plt.gca()

    # The warped image location
    width = warped_x_max - warped_x_min
    height = warped_y_max - warped_y_min
    rect = patches.Rectangle(
        (warped_x_min, warped_y_min),
        width,
        height,
        linewidth = 3,
        facecolor = 'none',
        edgecolor = palette[0],
    )
    ax.add_patch(rect)
    ax.scatter(
        *warped_center,
        s=100,
        color=palette[0],
    )

    # The actual image location
    if not settings['include_raw_images']:
        rect = patches.Rectangle(
            (recorded_x_min, recorded_y_min),
            recorded_x_max - recorded_x_min,
            recorded_y_max - recorded_y_min,
            linewidth = 3,
            facecolor = 'none',
            edgecolor = palette[1],
        )
        ax.add_patch(rect)
        ax.scatter(
            *recorded_center,
            s=100,
            color=palette[1],
        )

    padding_for_this_plot = 0.1 * width
    ax.set_xlim(warped_x_min - padding_for_this_plot, warped_x_max + padding_for_this_plot)
    ax.set_ylim(warped_y_min - padding_for_this_plot, warped_y_max + padding_for_this_plot)

    ax.set_aspect('equal')

### Log Values


In [None]:
# Store for later comparison
less_reffed_mosaic.update_log(locals())
less_reffed_mosaic.log.update(image_joiner.log)
less_reffed_mosaic.log['return_code'] = 'success'
for log_key in less_reffed_mosaic.log_keys:
    assert not isinstance(less_reffed_mosaic.log[log_key], list)

In [None]:
fit_values = {
    key: getattr(less_reffed_mosaic, key)
    for key in less_reffed_mosaic.__dir__()
    if key[-1] == '_' if isinstance(getattr(less_reffed_mosaic, key), float)
}

## Next Image


In [None]:
i = 1
row1 = X.iloc[i]

### Preview keypoint selection


In [None]:
x_off1 = row1['x_off']
y_off1 = row1['y_off']
x_size1 = row1['x_size']
y_size1 = row1['y_size']

In [None]:
in_bounds1 = less_reffed_mosaic.check_bounds(
    dsframe_dst_pts,
    x_off1, y_off1, x_size1, y_size1,
)

In [None]:
assert in_bounds1.sum() > 0, \
    f'No image data in the search zone for index {row.name}'

In [None]:
dst_pts1 = dsframe_dst_pts[in_bounds1]
dst_des1 = dsframe_dst_des[in_bounds1]

In [None]:
# At this point in the loops, *all* the points should be in bounds, if we're doing approximate georeferencing
if settings['use_approximate_georeferencing']:
    assert (~in_bounds1).sum() == 0

### Call the typical function


In [None]:
if less_reffed_mosaic.feature_mode == 'recompute':
    dsframe_dst_pts = None
    dsframe_dst_des = None

In [None]:
# DEBUG
# src_image = raster.Image.open(row1['filepath'])
# image_joiner.validate_brightness(src_image.img_int)
# src_image.show()

In [None]:
return_code, results1 = less_reffed_mosaic.incorporate_image(
    dataset,
    row1,
    dsframe_dst_pts,
    dsframe_dst_des,
)

In [None]:
assert return_code == 'success', 'Image was not successfully combined.'

In [None]:
zoom_dst_img_after2 = less_reffed_mosaic.get_image(
    dataset,
    row_train['x_off'], row_train['y_off'],
    row_train['x_size'], row_train['y_size'],
)

In [None]:
if settings['show_images']:
    subplot_mosaic = [['before', 'after']]
    fig = plt.figure(figsize=(20,10))
    ax_dict = fig.subplot_mosaic(subplot_mosaic)

    ax = ax_dict['before']
    raster.Image(zoom_dst_img_after[:, :, :3]).show(img='semitransparent_img', ax=ax)

    ax = ax_dict['after']
    raster.Image(zoom_dst_img_after2[:, :, :3]).show(img='semitransparent_img', ax=ax)

In [None]:
# Clear out any existing files
clear_files()
dataset.FlushCache()
dataset = None

# Full Process

Now we'll check if it runs for a subset


In [None]:
import importlib
importlib.reload(mosaic)

## Run


In [None]:
i = 0
n_loops = settings['n_loops']

In [None]:
# The actual calls
lr_mosaic = mosaic.LessReferencedMosaic(**constructor_kwargs)

lr_mosaic.fit(
    X=y_train,
    approx_y=X,
)

y_pred = lr_mosaic.predict(
    X.iloc[:n_loops],
    i_start = 0,
)

## Check output


In [None]:
# Check for output files
for key, fp in test_fps.items():
    print(f'Checking for {key} ({fp})...')
    assert os.path.isfile(fp), f'File {key} ({fp}) not found.'

In [None]:
# Check for consistency for fit values
for key, value in fit_values.items():
    np.testing.assert_allclose(value, getattr(lr_mosaic, key))

In [None]:
# Check for consistency for calculated values
for log_key in lr_mosaic.log_keys:

    actual = lr_mosaic.log[log_key][:2]
    expected = less_reffed_mosaic.log[log_key][:2]
    
    print(f'Checking {log_key}...')
    
    try:
        np.testing.assert_allclose(actual, expected)
    except ValueError:
        for j, actual_j in enumerate(actual):
            try:
                np.testing.assert_allclose(actual_j, expected[j])
            except TypeError:
                actual_j = cv2.KeyPoint_convert(actual_j)
                expected_j = cv2.KeyPoint_convert(expected[j])
                np.testing.assert_allclose(actual_j, expected_j)   
    except np.exceptions.DTypePromotionError:
        for j, actual_j in enumerate(actual):
            assert actual_j == expected[j]

In [None]:
return_codes = pd.Series(lr_mosaic.log['return_code'])
assert len(return_codes) == n_loops, f'More return codes than loops. {n_loops} loops, {len(return_codes)} return codes'

In [None]:
# Check how many were successful
n_good = (return_codes == 'success').sum()
n_bad = n_loops - n_good
assert n_good / n_loops > settings['acceptance_fraction'], \
    (
        f"{n_bad} failures, success rate of {settings['acceptance_fraction']} not met. "
        f'Return codes are...\n{return_codes}'
    )

In [None]:
if y_pred.index.isin(y_test.index).sum() > 0:

    # Merge into a comparison dataframe
    y_pred_for_eval = y_pred.reindex(y_test.index)

    # Estimate the consistency with the manual geotransforms
    y_err = y_test - y_pred_for_eval
    err = np.sqrt(y_err['x_min']**2. + y_err['y_max']**2.)

    # Check how bad the errors are
    n_egregious = (err > 300.).sum()
    assert n_egregious == 0, f'Found {n_egregious} egregious errors.'

    if settings['show_images']:
        # Visualize the errors
        fig = plt.figure()
        ax = plt.gca()

        sns.scatterplot(
            x=np.arange(y_err.index.size),
            y=err,
            hue=np.arange(len(y_err)),
            ax=ax,
        )

        ax.set_ylim(0, ax.get_ylim()[1])

# Restarting From Checkpoint


In [None]:
# The actual calls
lr_mosaic = mosaic.LessReferencedMosaic(**constructor_kwargs)
# TODO: We want to make it the default for users to use the checkpoints.
# Requiring a specific choice of file_exists hinders that.
lr_mosaic.set_params(file_exists='pass')

lr_mosaic.fit(
    X=y_train,
    approx_y=X,
)



In [None]:
# DEBUG
import re
# Determine what to look for, for checkpoint files
checkpoint_dir, filename = os.path.split(lr_mosaic.filepath_)
base, ext = os.path.splitext(filename)
i_tag = r'_i(\d+)'
checkpoint_pattern = base + i_tag + ext
pattern = re.compile(checkpoint_pattern)

# Look for checkpoint files
i_start = -1
j_filename = None
possible_files = os.listdir(checkpoint_dir)
for j, filename in enumerate(possible_files):
    match = pattern.search(filename)
    if not match:
        continue

    number = int(match.group(1))
    if number > i_start:
        i_start = number
        j_filename = j

In [None]:
# Copy over dataset
checkpoint_filename = possible_files[j_filename]
checkpoint_filepath = os.path.join(checkpoint_dir, checkpoint_filename)
shutil.copy(checkpoint_filepath, lr_mosaic.filepath_)

In [None]:
y_pred = lr_mosaic.predict(
    X.iloc[:n_loops],
)

## Check output


In [None]:
# Check for output files
for key, fp in test_fps.items():
    print(f'Checking for {key} ({fp})...')
    assert os.path.isfile(fp), f'File {key} ({fp}) not found.'

In [None]:
# Check for consistency for fit values
for key, value in fit_values.items():
    np.testing.assert_allclose(value, getattr(lr_mosaic, key))

In [None]:
# Check for consistency for calculated values
for log_key in lr_mosaic.log_keys:

    actual = lr_mosaic.log[log_key][:2]
    expected = less_reffed_mosaic.log[log_key][:2]
    
    print(f'Checking {log_key}...')
    
    try:
        np.testing.assert_allclose(actual, expected)
    except ValueError:
        for j, actual_j in enumerate(actual):
            try:
                np.testing.assert_allclose(actual_j, expected[j])
            except TypeError:
                actual_j = cv2.KeyPoint_convert(actual_j)
                expected_j = cv2.KeyPoint_convert(expected[j])
                np.testing.assert_allclose(actual_j, expected_j)   
    except np.exceptions.DTypePromotionError:
        for j, actual_j in enumerate(actual):
            assert actual_j == expected[j]

In [None]:
return_codes = pd.Series(lr_mosaic.log['return_code'])
assert len(return_codes) == n_loops, f'More return codes than loops. {n_loops} loops, {len(return_codes)} return codes'

# Cleanup


In [None]:
clear_files()

# DEBUG


## Check georeferencing (DEBUG)

The below doesn't really make sense, because the x_off, y_off, x_size, y_size that are warped are really unreliable for the image itself.
Only some parts (e.g. the image center) are trustworthy-ish.


In [None]:
# # Get the warped bounds
# (
#     warped_x_min, warped_x_max,
#     warped_y_min, warped_y_max,
# ) = less_reffed_mosaic.pixel_to_physical(
#     info['x_off'], info['y_off'],
#     info['x_size'], info['y_size'],
# )

In [None]:
# # Get the centers
# warped_center = np.array([
#     0.5 * (warped_x_min + warped_x_max),
#     0.5 * (warped_y_min + warped_y_max),
# ])

In [None]:
# if not settings['include_raw_images']:

#     # Get the recorded bounds
#     recorded_x_min, recorded_x_max, recorded_y_min, recorded_y_max = y_test.loc[row.name, ['x_min', 'x_max', 'y_min', 'y_max']]
    
#     recorded_center = np.array([
#         0.5 * (recorded_x_min + recorded_x_max),
#         0.5 * (recorded_y_min + recorded_y_max),
#     ])
    
#     # Check the centers
#     d_between_centers = np.linalg.norm(warped_center - recorded_center)
#     assert d_between_centers < 100.

#     src_image = raster.ReferencedImage.open(row['filepath'])

In [None]:
# blended_img = less_reffed_mosaic.get_image(row['x_off'], row['y_off'], row['x_size'], row['y_size'])
# blended_image = raster.ReferencedImage(blended_img[:, :, :3], [warped_x_min, warped_x_max], [warped_y_min, warped_y_max])

In [None]:
# fig = plt.figure(figsize=(20,10))
# ax = plt.gca()

# blended_image.show(crs='cartesian', img='semitransparent_img', ax=ax)

# if not settings['include_raw_images']:
#     src_image.show(crs='cartesian', iimg='semitransparent_img', ax=ax)

# # The warped image location
# width = warped_x_max - warped_x_min
# height = warped_y_max - warped_y_min
# rect = patches.Rectangle(
#     (warped_x_min, warped_y_min),
#     width,
#     height,
#     linewidth = 3,
#     facecolor = 'none',
#     edgecolor = palette[0],
# )
# ax.add_patch(rect)
# ax.scatter(
#     *warped_center,
#     s=100,
#     color=palette[0],
#     label='found',
# )

# # The actual image location
# if not settings['include_raw_images']:
#     rect = patches.Rectangle(
#         (recorded_x_min, recorded_y_min),
#         recorded_x_max - recorded_x_min,
#         recorded_y_max - recorded_y_min,
#         linewidth = 3,
#         facecolor = 'none',
#         edgecolor = palette[1],
#     )
#     ax.add_patch(rect)
#     ax.scatter(
#         *recorded_center,
#         s=100,
#         color=palette[1],
#         label='recorded',
#     )

# padding_for_this_plot = 0.1 * width
# ax.set_xlim(warped_x_min - padding_for_this_plot, warped_x_max + padding_for_this_plot)
# ax.set_ylim(warped_y_min - padding_for_this_plot, warped_y_max + padding_for_this_plot)

# ax.set_aspect('equal')
# ax.legend()