In [None]:
import h5py
import dask
import dask.array as da
import numpy as np
import plotly.express as px
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from src.load_scripts import load_h5_dataset
from src.visualization import plot_spectra

path = Path('data/Marsikov')

In [None]:
dask.config.set({"array.slicing.split_large_chunks": False})
Image.MAX_IMAGE_PIXELS = None

In [None]:
px.imshow(img[::15, ::15, :])

In [None]:
legend = {
    'albite':      ( 71, 213, 213),
    'quartz':      ( 85,   0, 255),
    'muscovite':   (251, 119, 255),
    'spessartine': (190,  95,  41),
    'orthoclase':  (255,  32, 103),
    'biotite':     (255, 170,   0)
}

In [None]:
def load_h5_dataset(dataset_path: Path):
    for file_path in dataset_path.glob('**/*.h5'):
        try:
            f = h5py.File(file_path, "r")
            f = f[list(f.keys())[0]]
            f = f[list(f.keys())[0]]
            f = f['libs']
            print('    Loading dimensions...', end='', flush=True)
            dim = [max(f['metadata']['X']) + 1, max(f['metadata']['Y']) + 1]
            print(' Done!', flush=True)

            print('    Loading spectra...', end='', flush=True)
            X = da.from_array(f['data'])
            print(' Done!', flush=True)

            print('    Loading wavelengths...', end='', flush=True)
            wavelengths = da.from_array(f['calibration'])
            print(' Done!', flush=True)

            print('    Reshaping spectra...', end='', flush=True)
            X = da.reshape(X, dim + [-1])
            X[::2, :] = X[::2, ::-1]
            print(' Done!', flush=True)

        except Exception as e:
            print('\n[WARNING] Failed to load file {} with error message: {}. Skipping!'.format(file_path, e), flush=True)
            continue

        print('    Loading true labels...', end='', flush=True)
        img = np.asarray(Image.open(dataset_path / 'y_true.png'))
        flat = img.reshape(-1, img.shape[2])
        y = np.zeros(flat.shape[0])
        for i, val in tqdm(enumerate(legend.values(), start=1)):
            y[(flat == val).all(axis=1)] = i
        y = y.reshape(img.shape[:-1])

        return X, y, wavelengths, dim
    raise RuntimeError('Failed to load! No valid file found!')

In [None]:
X, y, wavelengths, dim = load_h5_dataset(path)

In [None]:
intensities = X.sum(axis=2).compute()

In [None]:
import cv2
import numpy as np

# Read the images
image1 = np.array((y - y.min()) / y.max() * 256 , dtype='uint8')  # Higher-resolution image
image2 = np.array((intensities - intensities.min()) / intensities.max() * 256, dtype='uint8')  # Lower-resolution image

del X, y

# Detect SIFT keypoints and descriptors in both images
sift = cv2.SIFT_create()
keypoints1, descriptors1 = sift.detectAndCompute(image1, None)
keypoints2, descriptors2 = sift.detectAndCompute(image2, None)

# Match descriptors using BFMatcher
bf = cv2.BFMatcher()
matches = bf.knnMatch(descriptors1, descriptors2, k=2)

# Apply ratio test to filter good matches
good_matches = []
for m, n in matches:
    if m.distance < 0.75 * n.distance:
        good_matches.append(m)

# Find the transformation using the matched keypoints
src_pts = np.float32([keypoints1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
dst_pts = np.float32([keypoints2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)


# Warp the lower-resolution image to match the higher-resolution one
aligned_image = cv2.warpPerspective(image2, M, (image1.shape[1], image1.shape[0]))

In [None]:
image1 = np.array((y - y.min()) / y.max() * 256 , dtype='uint8')  # Higher-resolution image
image2 = np.array((intensities - intensities.min()) / intensities.max() * 256, dtype='uint8')  # Lower-resolution image

In [None]:
fig,ax = plt.subplots(ncols=2)
ax[0].imshow(image1)
ax[1].imshow(image2)


In [None]:
new_dims = []
for original_length, new_length in zip(
    libs_map.shape, 
    [int(x  * 1.5) for x in libs_map.shape]
):
    new_dims.append(np.linspace(0, original_length-1, new_length))

coords = np.meshgrid(*new_dims, indexing='ij')
upscaled_libs_map = map_coordinates(libs_map, coords)

upscaled_libs_map = quantile_process_map(
    input_map=upscaled_libs_map,
    rotate=True,
    transpose=False
)

In [None]:
# Find size of image1
sz = libs_match.shape

libs_image = libs_match.copy()
icp_image = template.copy()
 
# Define the motion model
warp_mode = cv2.MOTION_AFFINE
 
# Define 2x3 or 3x3 matrices and initialize the matrix to identity
if warp_mode == cv2.MOTION_HOMOGRAPHY :
    warp_matrix = np.eye(3, 3, dtype=np.float32)
else:
    warp_matrix = np.eye(2, 3, dtype=np.float32)
 
# Specify the number of iterations.
number_of_iterations = 5000
 
# Specify the threshold of the increment
# in the correlation coefficient between two iterations
termination_eps = 1e-10
 
# Define termination criteria
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations,  termination_eps)
 
# Run the ECC algorithm. The results are stored in warp_matrix.
(cc, warp_matrix) = cv2.findTransformECC (icp_image,libs_image,warp_matrix, warp_mode, criteria)
 
if warp_mode == cv2.MOTION_HOMOGRAPHY:
# Use warpPerspective for Homography
    libs_image_aligned = cv2.warpPerspective(
        libs_image, 
        warp_matrix, 
        (sz[1],sz[0]), 
        flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP
    )
else :
# Use warpAffine for Translation, Euclidean and Affine
    libs_image_aligned = cv2.warpAffine(
        libs_image, 
        warp_matrix, 
        (sz[1],sz[0]), 
        flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP
    )
 
# Show final results
fig,ax = plt.subplots(ncols=3)
ax[0].imshow(icp_image)
ax[1].imshow(libs_image)
ax[2].imshow(libs_image_aligned)

print(np.corrcoef(
    icp_image.reshape(-1),
    libs_image.reshape(-1)
)[0,1])
print(np.corrcoef(
    icp_image.reshape(-1),
    libs_image_aligned.reshape(-1)
)[0,1])