In [None]:
import h5py
import dask
import dask.array as da
import numpy as np
import plotly.express as px
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from src.load_scripts import load_h5_dataset
from src.visualization import plot_spectra

path = Path('data/Marsikov')

In [None]:
dask.config.set({"array.slicing.split_large_chunks": False})
Image.MAX_IMAGE_PIXELS = None

In [None]:
legend = {
    'albite':      ( 71, 213, 213),
    'quartz':      ( 85,   0, 255),
    'muscovite':   (251, 119, 255),
    'spessartine': (190,  95,  41),
    'orthoclase':  (255,  32, 103),
    'biotite':     (255, 170,   0)
}

In [None]:
def load_h5_dataset(dataset_path: Path):
    for file_path in dataset_path.glob('**/*.h5'):
        try:
            f = h5py.File(file_path, "r")
            f = f[list(f.keys())[0]]
            f = f[list(f.keys())[0]]
            f = f['libs']
            print('    Loading dimensions...', end='', flush=True)
            dim = [max(f['metadata']['X']) + 1, max(f['metadata']['Y']) + 1]
            print(' Done!', flush=True)

            print('    Loading spectra...', end='', flush=True)
            X = da.from_array(f['data'])
            print(' Done!', flush=True)

            print('    Loading wavelengths...', end='', flush=True)
            wavelengths = da.from_array(f['calibration'])
            print(' Done!', flush=True)

            print('    Reshaping spectra...', end='', flush=True)
            X = da.reshape(X, dim + [-1])
            X[::2, :] = X[::2, ::-1]
            print(' Done!', flush=True)

        except Exception as e:
            print('\n[WARNING] Failed to load file {} with error message: {}. Skipping!'.format(file_path, e), flush=True)
            continue

        print('    Loading true labels...', end='', flush=True)
        img = np.asarray(Image.open(dataset_path / 'y_true.png'))
        flat = img.reshape(-1, img.shape[2])
        y = np.zeros(flat.shape[0])
        for i, val in tqdm(enumerate(legend.values(), start=1)):
            y[(flat == val).all(axis=1)] = i
        y = y.reshape(img.shape[:-1])
        print(' Done!', flush=True)

        return X, y, wavelengths, dim
    raise RuntimeError('Failed to load! No valid file found!')

In [None]:
X, y, wavelengths, dim = load_h5_dataset(path)

In [None]:
intensities = X.sum(axis=2).compute()

In [None]:
from skimage.measure import block_reduce

down = block_reduce(y, (15, 15), np.max)

In [None]:
labels = np.array((down - down.min()) / down.max() * 256 , dtype='uint8')
libs = np.array((intensities - intensities.min()) / intensities.max() * 256, dtype='uint8')

In [None]:
methods = [
    'cv2.TM_CCOEFF', 'cv2.TM_CCOEFF_NORMED', 'cv2.TM_CCORR',
    'cv2.TM_CCORR_NORMED', #'cv2.TM_SQDIFF', #'cv2.TM_SQDIFF_NORMED'
]

image = labels
template = libs[:, :980]

libs_subsets = []
w, h = template.shape[::-1]
for meth in methods:
    img = image.copy()
    method = eval(meth)
    # Apply template Matching
    res = cv2.matchTemplate(img,template,method)
    min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
    # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum
    if method in [cv2.TM_SQDIFF, cv2.TM_SQDIFF_NORMED]:
        top_left = min_loc
    else:
        top_left = max_loc
        bottom_right = (top_left[0] + w, top_left[1] + h)

    libs_subsets.append(
        [top_left[1],bottom_right[1],top_left[0],bottom_right[0]]
    )
    
    cv2.rectangle(img,top_left, bottom_right, 255, 2)
    plt.subplot(121),plt.imshow(res,cmap=None)
    plt.title('Matching Result'), plt.xticks([]), plt.yticks([])
    plt.subplot(122),plt.imshow(img,cmap=None)
    plt.title('Detected Point'), plt.xticks([]), plt.yticks([])
    plt.suptitle(meth)
    plt.show()

In [None]:
labels_cropped = labels[libs_subsets[2][0]:libs_subsets[2][1],libs_subsets[2][2]:libs_subsets[2][3]]
libs_cropped = libs[:, :980]

sz = libs_cropped.shape
 
# Define the motion model
warp_mode = cv2.MOTION_AFFINE
 
# Define 2x3 or 3x3 matrices and initialize the matrix to identity
if warp_mode == cv2.MOTION_HOMOGRAPHY :
    warp_matrix = np.eye(3, 3, dtype=np.float32)
else:
    warp_matrix = np.eye(2, 3, dtype=np.float32)
 
# Specify the number of iterations.
number_of_iterations = 5000
 
# Specify the threshold of the increment
# in the correlation coefficient between two iterations
termination_eps = 1e-10
 
# Define termination criteria
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations,  termination_eps)
 
# Run the ECC algorithm. The results are stored in warp_matrix.
(cc, warp_matrix) = cv2.findTransformECC (libs_cropped,labels_cropped,warp_matrix, warp_mode, criteria)
 
if warp_mode == cv2.MOTION_HOMOGRAPHY:
# Use warpPerspective for Homography
    libs_image_aligned = cv2.warpPerspective(
        labels_cropped, 
        warp_matrix, 
        (sz[1],sz[0]), 
        flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP
    )
else :
# Use warpAffine for Translation, Euclidean and Affine
    libs_image_aligned = cv2.warpAffine(
        labels_cropped, 
        warp_matrix, 
        (sz[1],sz[0]), 
        flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP
    )
 
# Show final results
fig,ax = plt.subplots(ncols=3)
ax[0].imshow(libs_cropped)
ax[1].imshow(labels_cropped)
ax[2].imshow(libs_image_aligned)

print(np.corrcoef(
    libs_cropped.reshape(-1),
    labels_cropped.reshape(-1)
)[0,1])
print(np.corrcoef(
    libs_cropped.reshape(-1),
    libs_image_aligned.reshape(-1)
)[0,1])

In [None]:
plt.imshow(libs_cropped)

In [None]:
plt.imshow(libs_image_aligned)