In [None]:
import torch
import torch.fft
import numpy as np
from itertools import product

# Returns one for now, be used later

def generate_mask(size, distance=None, wavelength=None, **kwargs):
    
    # Placeholder implementation: Return a mask function filled with ones.
    # Later we could convert it in a tensor and make it trainable.
    # Other parameters can be as well, but don't know what we need and how it actually would perform
    mask = np.ones(size)

    return torch.tensor(mask, dtype=torch.float32, requires_grad=True)

def apply_mask(input_image, mask):
    # 1. Perform 2D FFT
    fourier_transformed = torch.fft.fftn(input_image, dim=(0, 1))

    # 2. Apply the Mask
    modulated = fourier_transformed * mask

    # 3. Inverse 2D FFT
    output_image = torch.fft.ifftn(modulated, dim=(0, 1)).real  # only take the real part

    return output_image

def generate_paths(array_size, num_channels):
    height, width = array_size
    
    # Generate all possible paths
    all_paths = list(product(range(width), repeat=height))
    
    # Check if the number of channels is valid
    if num_channels > len(all_paths):
        raise ValueError("Invalid input: Number of channels exceeds possible paths.")
    
    # Assign a unique path to each channel
    channel_paths = all_paths[:num_channels]
    
    return channel_paths

def generate_metalens_array(size, input_image):

    metalens_array = []
    
    for i in range(size[0]):
        row = []
        for j in range(size[1]):
            
            # Generate a unique mask for each metalens
            # For now, only array of one will be used
            metalens_mask = generate_metalens_mask(input_image.shape)
            
            # Apply the metalens to the input image
            modulated_image = apply_metalens(input_image, metalens_mask)
            
            row.append(modulated_image)
        
        metalens_array.append(row)

    return metalens_array


class MetalensWrapper:
    def __init__(self, metalens_array):
        self.metalens_array = metalens_array  # A 2D list of metalens functions

    def apply(self, input_tensor, paths):
        """
        Process the input tensor through the metalens array based on paths.
        :param input_tensor: A tensor of shape [C, H, W] where C is the number of channels.
        :param paths: A list of paths where each path is a list of integers indicating the metalens for each layer.
        :return: Processed tensor
        """
        # List to hold the processed channels
        processed_channels = []

        # Iterate over each channel in the input tensor
        for channel_idx, channel in enumerate(input_tensor):
            path = paths[channel_idx]
            processed_channel = channel

            # For each layer/row in the path, apply the selected metalens
            for row, col in enumerate(path):
                metalens = self.metalens_array[row][col]
                processed_channel = metalens(processed_channel)

            processed_channels.append(processed_channel)

        # Stack the processed channels to get the output tensor
        output_tensor = torch.stack(processed_channels, dim=0)
        return output_tensor

    
# Overall, the loops only used for generate path and metalens and tensor, will not be used in training