In [None]:
import h5py
import dask
import dask.array as da
import numpy as np
import plotly.express as px
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from src.load_scripts import load_h5_dataset
from src.visualization import plot_spectra

path = Path('data/Marsikov')

In [None]:
dask.config.set({"array.slicing.split_large_chunks": False})
Image.MAX_IMAGE_PIXELS = None

In [None]:
legend = {
    'albite': (71, 213, 213),
    'quartz': (85, 0, 255),
    'muscovite': (251, 119, 255),
    'spessartine': (190, 95, 41),
    'orthoclase': (255,  32, 103),
    'biotite': (255, 170,   0)
}

In [None]:
def load_h5_dataset(dataset_path: Path):
    for file_path in dataset_path.glob('**/*.h5'):
        try:
            f = h5py.File(file_path, "r")
            f = f[list(f.keys())[0]]
            f = f[list(f.keys())[0]]
            f = f['libs']
            print('    Loading dimensions...', end='', flush=True)
            dim = [max(f['metadata']['X']) + 1, max(f['metadata']['Y']) + 1]
            print(' Done!', flush=True)

            print('    Loading spectra...', end='', flush=True)
            X = da.from_array(f['data'])
            print(' Done!', flush=True)

            print('    Loading wavelengths...', end='', flush=True)
            wavelengths = da.from_array(f['calibration'])
            print(' Done!', flush=True)

            print('    Reshaping spectra...', end='', flush=True)
            X = da.reshape(X, dim + [-1])
            X[::2, :] = X[::2, ::-1]
            print(' Done!', flush=True)

        except Exception as e:
            print('\n[WARNING] Failed to load file {} with error message: {}. Skipping!'.format(file_path, e), flush=True)
            continue

        print('    Loading true labels...', end='', flush=True)
        img = np.asarray(Image.open(dataset_path / 'y_true.png'))
        flat = img.reshape(-1, img.shape[2])
        y = np.zeros(flat.shape) - 1
        for i, val in tqdm(enumerate(legend.values())):
            y[flat == val] = i
        y = y.max(axis=1).reshape(img.shape[:-1]) / (len(legend.keys()) + 1)
        print(' Done!', flush=True)

        return X, y, wavelengths, dim
    raise RuntimeError('Failed to load! No valid file found!')

In [None]:
X, y, wavelengths, dim = load_h5_dataset(path)

In [None]:
"""
vals = np.unique(arr.reshape(-1, arr.shape[2]), axis=0, return_counts=True)

unique, counts = vals
pairs = list(zip(unique, counts))
pairs.sort(key=lambda pair: pair[1], reverse=True)
pairs[:8]
"""

In [None]:
flat = arr.reshape(-1, arr.shape[2])
labels = np.zeros(flat.shape) - 1
for i, val in tqdm(enumerate(legend.values())):
    labels[flat == val] = i
labels = labels.max(axis=1).reshape(arr.shape[:-1]) / len(legend.keys() + 1)

In [None]:
labels.shape

In [None]:
px.imshow(labels[::4, ::4])

In [None]:
mean = (X + 1).mean(axis=(0, 1)).compute()

In [None]:
plot_spectra([mean])

In [None]:
intensities = X.sum(axis=2).compute()

In [None]:
px.imshow(intensities < 1e6)

In [None]:
intensities.shape