In [15]:
import cv2
import numpy as np
from sklearn.decomposition import SparseCoder


def sparse_representation(img, dictionary):
    img_flat = img.reshape(-1, img.shape[-1])
    coder = SparseCoder(
        dictionary=dictionary, transform_algorithm="omp", transform_n_nonzero_coefs=1
    )
    sparse_code = coder.transform(img_flat)
    return sparse_code


def reconstruct_image(sparse_code, dictionary, shape):
    img_flat_recon = np.dot(sparse_code, dictionary)
    img_recon = img_flat_recon.reshape(shape)
    return img_recon


def fuse_images(img1, img2, dictionary):
    sparse1 = sparse_representation(img1, dictionary)
    sparse2 = sparse_representation(img2, dictionary)

    fused_sparse = (sparse1 + sparse2) / 2

    fused_image = reconstruct_image(fused_sparse, dictionary, img1.shape)

    return fused_image


def main():
    img1 = cv2.imread("data/fusion/052 (2).png") / 255.0
    img2 = cv2.imread("data/fusion/052.png") / 255.0

    dictionary = np.random.rand(10, img1.shape[-1])

    fused_image = fuse_images(img1, img2, dictionary)

    cv2.imshow("Fused Image", fused_image)
    # cv2.waitKey(00)
    # cv2.destroyAllWindows()

    cv2.imwrite("fused_image_sparse_main.png", (fused_image * 255).astype(np.uint8))


if __name__ == "__main__":
    main()

  return func(*args, **kwargs)
  return func(*args, **kwargs)
