In [None]:
# Standard library imports
import sys
import time
import gc
import random
import os

# Scientific computing and numerical operations
import numpy as np
import cupy as cp
import cupyx.scipy.linalg as cpx_la
from scipy.fft import fftn, ifftn, get_workers
import torch

import pandas as pd
import numba as nb
#from numba import njit

# Parallel processing
import joblib
from joblib import Parallel, delayed
import dask
#import dask.array as da
#from dask.distributed import Client, LocalCluster

# Spatial analysis tools
from scipy.spatial import KDTree
from scipy.ndimage import binary_erosion, binary_dilation

# Visualization and plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.path import Path
import plotly.graph_objects as go

# Interactive widgets and display
from ipywidgets import (
    interact, widgets, interactive, 
    fixed, HBox, VBox, Output, Layout
)
from IPython.display import display, clear_output

# GUI framework
from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import (
    QApplication, QMainWindow, QWidget,
    QPushButton, QSlider, QSpinBox, 
    QComboBox, QVBoxLayout, QHBoxLayout,
    QLabel, QFileDialog
)
from PyQt6.QtGui import QIcon

# System monitoring
import psutil

# Configuration
matplotlib.use('QtAgg')  # Set the matplotlib backend
plt.close('all')  # Close any existing plots
os.environ['CUPY_ACCELERATORS'] = 'cub,cutensor'  # Set CuPy accelerators

def free_gpu_memory(func):
    def wrapper_func(*args, **kwargs):
        retval = func(*args, **kwargs)
        cp._default_memory_pool.free_all_blocks()
        return retval
    return wrapper_func    


# Base Shape class
class Shape:
    def __init__(self, shape_id, n, group=None, priority=0):
        self.id = shape_id
        self.n = n
        self.group = group
        self.priority = priority
        self.value = None  # For calculated value (e.g., alpha)
        self.center = [0, 0, 0]  # Default center

    def to_tree_item(self):
        return f"{self.__class__.__name__.capitalize()} {self.id.split('_')[1]}"


# Sphere shape
class Sphere(Shape):
    def __init__(self, shape_id, center, radius, n, group=None, priority=0):
        super().__init__(shape_id, n, group, priority)
        self.center = center  # [x, y, z]
        self.radius = radius

# Cylinder shape
class Cylinder(Shape):
    def __init__(self, shape_id, center, radius, height, axis='z', n=1.5, group=None, priority=0):
        super().__init__(shape_id, n, group, priority)
        self.center = center
        self.radius = radius
        self.height = height
        self.axis = axis

# Ellipsoid shape
class Ellipsoid(Shape):
    def __init__(self, shape_id, center, semi_axes, n, axis='z', group=None, priority=0):
        super().__init__(shape_id, n, group, priority)
        self.center = center
        self.semi_axes = semi_axes  # [semi_axes_x, semi_axes_y, semi_axes_z]
        self.axis = axis

# Rectangle shape
class Rectangle(Shape):
    def __init__(self, shape_id, center, dimensions, n, group=None, priority=0):
        super().__init__(shape_id, n, group, priority)
        self.center = center
        self.dimensions = dimensions  # [dimensions_x, dimensions_y, dimensions_z]

# Prism shape
class Prism(Shape):
    def __init__(self, shape_id, center, radius, sides, height, n, axis='z', group=None, priority=0):
        super().__init__(shape_id, n, group, priority)
        self.center = center
        self.radius = radius
        self.sides = sides
        self.height = height
        self.axis = axis

# ShapeManager class
class ShapeManager:
    def __init__(self):
        self.shapes = {}  # {shape_id: shape_object}
        self.shape_counters = {}  # {shape_type: counter}
        self.available_ids = set()
        self.available_groups = set()
        self.next_group = 1

    def get_next_id(self, shape_type):
        if shape_type not in self.shape_counters:
            self.shape_counters[shape_type] = 1
        else:
            self.shape_counters[shape_type] += 1
        return f"{shape_type}_{self.shape_counters[shape_type]}"

    def get_next_group(self):
        if self.available_groups:
            return f"group_{self.available_groups.pop()}"
        else:
            group_num = self.next_group
            self.next_group += 1
            return f"group_{group_num}"

    def add_shape(self, shape_type, **params):
        shape_classes = {
            'sphere': Sphere,
            'cylinder': Cylinder,
            'ellipsoid': Ellipsoid,
            'rectangle': Rectangle,
            'prism': Prism
        }
        shape_class = shape_classes.get(shape_type)
        if not shape_class:
            raise ValueError(f"Unknown shape type: {shape_type}")
        shape_id = params.get('force_id', self.get_next_id(shape_type))
        shape = shape_class(shape_id, **params)
        self.shapes[shape_id] = shape
        return shape_id


    def remove_shape(self, shape_id):
        if shape_id in self.shapes:
            shape_type, shape_number = shape_id.rsplit('_', 1)
            del self.shapes[shape_id]
            self.available_ids.add(shape_id)
            self.shape_counters[shape_type].add(int(shape_number))
        else:
            raise ValueError(f"Shape ID {shape_id} not found")

            
    def remove_group(self, group_name):
        shapes_to_remove = [shape_id for shape_id, shape in self.shapes.items() if shape.group == group_name]
        
        for shape_id in shapes_to_remove:
            self.remove_shape(shape_id)
        
        if group_name in self.available_groups:
            self.available_groups.remove(group_name)
        
        # Reset group counter if needed
        group_number = int(group_name.split('_')[1])
        if group_number < self.next_group:
            self.next_group = group_number
        
        self.available_groups.add(group_number)
        
    def group_shapes(self, shape_ids, group_name):
        if group_name in self.available_groups:
            raise ValueError(f"Group {group_name} already exists")
        
        for shape_id in shape_ids:
            if shape_id not in self.shapes:
                raise ValueError(f"Shape ID {shape_id} not found")
            
            self.shapes[shape_id].group = group_name
        
        self.available_groups.add(group_name)
    
    def ungroup_shapes(self, group_name):
        if group_name not in self.available_groups:
            raise ValueError(f"Group {group_name} doesn't exist")
        
        for shape_id, shape in self.shapes.items():
            if shape.group == group_name:
                shape.group = None
        
        self.available_groups.remove(group_name)

    # Other methods like connect_shapes, disconnect_shapes can be added here

    def add_lattice(self, shape_type, shape_params, lattice_type='square',
                    spacing=100, size=(3,3,3),
                    x_offset=0, y_offset=0, z_offset=0, force_group=None, priority=0):
        if isinstance(spacing, (int, float)):
            spacing = (spacing, spacing, spacing)
        elif len(spacing) != 3:
            raise ValueError("Spacing must be single number or (dx,dy,dz) tuple")
        base_pos = shape_params.get('center', [0, 0, 0])
        group_name = force_group if force_group else self.get_next_group()
    
        # Get all coordinates from the helper function
        coords = self._generate_lattice_coords(
            lattice_type, spacing, size, base_pos, x_offset, y_offset, z_offset
        )
    
        # Create shapes at each coordinate
        shape_classes = {
            'sphere': Sphere,
            'cylinder': Cylinder,
            'ellipsoid': Ellipsoid,
            'rectangle': Rectangle,
            'prism': Prism
        }
        shape_class = shape_classes.get(shape_type)
        if not shape_class:
            raise ValueError(f"Unknown shape type: {shape_type}")
    
        for coord in coords:
            params = shape_params.copy()
            params['center'] = coord.tolist()
            params['group'] = group_name
            shape_id = self.get_next_id(shape_type)  # Pass shape_type here
            shape = shape_class(shape_id, **params)
            self.shapes[shape_id] = shape


    def _generate_lattice_coords(self, lattice_type, spacing, size,
                                 base_pos, x_offset, y_offset, z_offset):
        nx, ny, nz = size
        dx, dy, dz = spacing
        base_x, base_y, base_z = base_pos

        if lattice_type.lower() == 'square':
            x = np.arange(nx) * dx + base_x + x_offset
            y = np.arange(ny) * dy + base_y + y_offset
            z = np.arange(nz) * dz + base_z + z_offset
            x_grid, y_grid, z_grid = np.meshgrid(x, y, z, indexing='ij')
        else:  # hexagonal
            x = np.arange(nx) * dx + base_x + x_offset
            y = np.arange(ny) * dy * 0.866 + base_y + y_offset
            z = np.arange(nz) * dz + base_z + z_offset
            x_grid, y_grid, z_grid = np.meshgrid(x, y, z, indexing='ij')
            # Adjust x for each alternate y row
            y_indices = np.arange(ny)
            x_adjustments = ((y_indices % 2) * dx/2)
            x_grid += x_adjustments[np.newaxis, :, np.newaxis]

        coords = np.stack([x_grid, y_grid, z_grid], axis=-1).reshape(-1, 3)
        return coords

    def add_grating(self, base_length, base_width, base_height,
                    grate_length, grate_width, grate_height,
                    period, num_grates, n_material=1.5, force_group=None, priority=0):

        group_name = force_group if force_group else self.get_next_group()

        # Base platform
        shape_id = self.get_next_id()
        rectangle = Rectangle(
            shape_id=shape_id,
            center=[0, 0, base_height/2],
            dimensions=[base_length, base_width, base_height],
            n=n_material,
            group=group_name,
            priority=priority
        )
        self.shapes[shape_id] = rectangle

        # Grating bars
        start_y = -period*(num_grates-1)/2
        for i in range(num_grates):
            shape_id = self.get_next_id()
            rectangle = Rectangle(
                shape_id=shape_id,
                center=[0, start_y + i*period, base_height + grate_height/2],
                dimensions=[grate_length, grate_width, grate_height],
                n=n_material,
                group=group_name,
                priority=priority
            )
            self.shapes[shape_id] = rectangle

    def process_shapes(self, k, E_direction, E_polarization, lattice_spacing, adjust_all=True):
        if not self.shapes:
            print("No shapes exist to process.")
            return
        
        # Sort shapes by priority (higher priority comes last)
        sorted_shapes = sorted(self.shapes.values(), key=lambda s: s.priority)
        
        # Calculate unique ns
        unique_ns = set(shape.n for shape in sorted_shapes)
        alpha_map = {}
        for n in unique_ns:
            alpha_value = calculate_alpha(n, lattice_spacing, k, E_direction, E_polarization, method='LDR')
            print(f"n = {n:.3f}: α = {alpha_value:.3e}")
            alpha_map[n] = alpha_value

        
        # Update shapes
        for shape in sorted_shapes:
            shape.value = alpha_map[shape.n]
        
        # Adjust centers
        centers = np.array([shape.center for shape in sorted_shapes])
        if len(sorted_shapes) == 1 or adjust_all:
            # Use ceil instead of round to always go up
            adjusted_centers = lattice_spacing * np.ceil(centers / lattice_spacing)
        else:
            # Find largest shape
            max_size = 0
            largest_shape = None
            for shape in sorted_shapes:
                size = getattr(shape, 'radius', 0)
                if hasattr(shape, 'dimensions'):
                    size = max(shape.dimensions)
                elif hasattr(shape, 'semi_axes'):
                    size = max(shape.semi_axes)
                if size > max_size:
                    max_size = size
                    largest_shape = shape
            if largest_shape is not None:
                base_center = np.array(largest_shape.center)
                # Use ceil instead of round to always go up
                adjustment = lattice_spacing * np.ceil(base_center / lattice_spacing) - base_center
                adjusted_centers = centers + adjustment
            else:
                adjusted_centers = centers  # No adjustment
        
        # Update centers
        for shape, new_center in zip(sorted_shapes, adjusted_centers):
            old_center = np.array(shape.center)
            shape.center = new_center.tolist()
            #if not np.array_equal(old_center, new_center):
            #    print(f"Shape center adjusted: Old center: {old_center}, New center: {new_center}")


    def calculate_grid_extents(self, lattice_spacing, max_x=1e9, max_y=1e9, max_z=1e9):
        if not self.shapes:
            return 0, 0, 0
        
        max_coords = np.zeros(3)
        
        for shape in self.shapes.values():
            center = np.array(shape.center)
            
            if isinstance(shape, Sphere):
                extent = center + shape.radius
            elif isinstance(shape, Ellipsoid):
                extent = center + np.array(shape.semi_axes)
            elif isinstance(shape, Rectangle):
                extent = center + np.array(shape.dimensions) / 2
            elif isinstance(shape, Cylinder) or isinstance(shape, Prism):
                axis_index = {'x': 0, 'y': 1, 'z': 2}[shape.axis]
                extent = center.copy()
                extent[axis_index] += shape.height / 2
                extent[(axis_index + 1) % 3] += shape.radius
                extent[(axis_index + 2) % 3] += shape.radius
            else:
                raise ValueError(f"Unsupported shape type: {type(shape)}")
            
            max_coords = np.maximum(max_coords, extent)
        
        max_coords = np.minimum(max_coords, [max_x, max_y, max_z])
        
        nx = min(int(np.ceil(max_coords[0] / lattice_spacing)), int(np.ceil(max_x / lattice_spacing)))
        ny = min(int(np.ceil(max_coords[1] / lattice_spacing)), int(np.ceil(max_y / lattice_spacing)))
        nz = min(int(np.ceil(max_coords[2] / lattice_spacing)), int(np.ceil(max_z / lattice_spacing)))
        
        return nx, ny, nz



def calculate_alpha(refractive_index, lattice_spacing, wave_number, incident_direction=None, polarization_direction=None, method='LDR'):
    # Ensure proper types
    n = np.complex128(refractive_index)
    d = np.float64(lattice_spacing)
    k = np.float64(wave_number)
    i_dir = np.array(incident_direction, dtype=np.float64) if incident_direction is not None else None
    p_dir = np.array(polarization_direction, dtype=np.float64) if polarization_direction is not None else None
    
    # Calculate base CM polarizability
    eps = n**2
    alpha_cm = (3 * d**3 / (4 * np.pi)) * (eps - 1) / (eps + 2)
    
    if method == 'CM':
        return alpha_cm
        
    # RR term used in both RR and LDR
    rr_term = (2/3) * 1j * (k*d)**3
    
    if method == 'RR':
        return alpha_cm / (1 + (alpha_cm/d**3) * (-rr_term))
        
    if method == 'LDR':
        if i_dir is None or p_dir is None:
            raise ValueError("incident_direction and polarization_direction required for LDR method")
            
        b1 = -1.891531
        b2 = 0.1648469
        b3 = -1.7700004
        
        S = np.sum(np.dot(i_dir, p_dir)**2)
        ldr_term = (b1 + n**2*b2 + n**2*b3*S) * (k*d)**2
        
        return alpha_cm / (1 + (alpha_cm/d**3) * (ldr_term - rr_term))
        
    raise ValueError("method must be 'CM', 'RR', or 'LDR'")

@nb.jit(nopython=True, parallel=True, cache=True)
def sphere(grid_points, center_x, center_y, center_z, radius, epsilon):
    result = np.zeros(len(grid_points), dtype=np.bool_)
    for i in range(len(grid_points)):
        dx = grid_points[i,0] - center_x
        dy = grid_points[i,1] - center_y
        dz = grid_points[i,2] - center_z
        distance = np.sqrt(dx*dx + dy*dy + dz*dz)
        result[i] = (distance < radius) or (abs(distance - radius) < epsilon)
    return result

@nb.jit(nopython=True, parallel=True, cache=True)
def ellipsoid(grid_points, center_x, center_y, center_z, semi_axes_x, semi_axes_y, semi_axes_z, epsilon):
    result = np.zeros(len(grid_points), dtype=np.bool_)
    for i in range(len(grid_points)):
        dx = (grid_points[i,0] - center_x) / semi_axes_x
        dy = (grid_points[i,1] - center_y) / semi_axes_y
        dz = (grid_points[i,2] - center_z) / semi_axes_z
        distance = np.sqrt(dx*dx + dy*dy + dz*dz)
        result[i] = distance <= 1.0 + epsilon
    return result

@nb.jit(nopython=True, parallel=True, cache=True)
def rectangle(grid_points, center_x, center_y, center_z, dimensions_x, dimensions_y, dimensions_z, epsilon):
    result = np.zeros(len(grid_points), dtype=np.bool_)
    half_x, half_y, half_z = dimensions_x / 2, dimensions_y / 2, dimensions_z / 2
    for i in range(len(grid_points)):
        dx = abs(grid_points[i,0] - center_x)
        dy = abs(grid_points[i,1] - center_y)
        dz = abs(grid_points[i,2] - center_z)
        inside = (dx <= half_x) and (dy <= half_y) and (dz <= half_z)
        on_edge = (abs(dx - half_x) < epsilon and dy <= half_y and dz <= half_z) or \
                 (abs(dy - half_y) < epsilon and dx <= half_x and dz <= half_z) or \
                 (abs(dz - half_z) < epsilon and dx <= half_x and dy <= half_y)
        result[i] = inside or on_edge
    return result

@nb.jit(nopython=True, parallel=True, cache=True)
def cylinder(grid_points, center_x, center_y, center_z, radius, height, axis, epsilon):
    result = np.zeros(len(grid_points), dtype=np.bool_)
    for i in range(len(grid_points)):
        if axis == 'z':
            dx = grid_points[i,0] - center_x
            dy = grid_points[i,1] - center_y
            dz = abs(grid_points[i,2] - center_z)
        elif axis == 'y':
            dx = grid_points[i,0] - center_x
            dy = abs(grid_points[i,1] - center_y)
            dz = grid_points[i,2] - center_z
        else:  # x axis
            dx = abs(grid_points[i,0] - center_x)
            dy = grid_points[i,1] - center_y
            dz = grid_points[i,2] - center_z
        
        radius_dist = np.sqrt(dx*dx + dy*dy)
        half_height = height / 2
        
        inside_radius = radius_dist <= radius
        on_radius = abs(radius_dist - radius) < epsilon
        within_height = dz <= half_height
        on_cap = abs(dz - half_height) < epsilon
        
        result[i] = (inside_radius and within_height) or \
                   (on_radius and within_height) or \
                   (inside_radius and on_cap)
    return result

@nb.jit(nopython=True, parallel=True, cache=True)
def prism(grid_points, center_x, center_y, center_z, radius, height, sides, axis, epsilon):
    result = np.zeros(len(grid_points), dtype=np.bool_)
    
    vertices = np.zeros((int(sides), 2))
    for i in range(int(sides)):
        angle = 2 * np.pi * i / sides
        if axis == 'z':
            vertices[i,0] = center_x + radius * np.cos(angle)
            vertices[i,1] = center_y + radius * np.sin(angle)
        elif axis == 'y':
            vertices[i,0] = center_x + radius * np.cos(angle)
            vertices[i,1] = center_z + radius * np.sin(angle)
        else:  # x axis
            vertices[i,0] = center_y + radius * np.cos(angle)
            vertices[i,1] = center_z + radius * np.sin(angle)
    
    for i in range(len(grid_points)):
        if axis == 'z':
            point_2d = np.array([grid_points[i,0], grid_points[i,1]])
            height_coord = grid_points[i,2]
            center_height = center_z
        elif axis == 'y':
            point_2d = np.array([grid_points[i,0], grid_points[i,2]])
            height_coord = grid_points[i,1]
            center_height = center_y
        else:  # x axis
            point_2d = np.array([grid_points[i,1], grid_points[i,2]])
            height_coord = grid_points[i,0]
            center_height = center_x
        
        half_height = height / 2
        height_dist = abs(height_coord - center_height)
        within_height = height_dist <= half_height
        on_cap = abs(height_dist - half_height) < epsilon
    
        inside = False
        on_edge = False
        n_vertices = len(vertices)
        
        for j in range(n_vertices):
            j2 = (j + 1) % n_vertices
            xi, yi = vertices[j]
            xj, yj = vertices[j2]
            
            dx = xj - xi
            dy = yj - yi
            length = np.sqrt(dx*dx + dy*dy)
            
            if length > epsilon:
                dx = dx / length
                dy = dy / length
                vx = point_2d[0] - xi
                vy = point_2d[1] - yi
                proj = vx*dx + vy*dy
                dist = abs(vx*dy - vy*dx)
                if dist < epsilon and proj >= -epsilon and proj <= length + epsilon:
                    on_edge = True
                    break
            
            if ((yi > point_2d[1]) != (yj > point_2d[1])) and \
               (point_2d[0] < (xj - xi) * (point_2d[1] - yi) / (yj - yi) + xi):
                inside = not inside
    
        result[i] = (inside or on_edge) and (within_height or on_cap)
    
    return result

def create_shape_data(shape):
    return {
        'center_x': float(shape.center[0]),
        'center_y': float(shape.center[1]),
        'center_z': float(shape.center[2]),
        'radius': float(getattr(shape, 'radius', 0)),
        'height': float(getattr(shape, 'height', 0)),
        'sides': float(getattr(shape, 'sides', 0)),
        'semi_axes_x': float(getattr(shape, 'semi_axes', [0,0,0])[0]),
        'semi_axes_y': float(getattr(shape, 'semi_axes', [0,0,0])[1]),
        'semi_axes_z': float(getattr(shape, 'semi_axes', [0,0,0])[2]),
        'dimensions_x': float(getattr(shape, 'dimensions', [0,0,0])[0]),
        'dimensions_y': float(getattr(shape, 'dimensions', [0,0,0])[1]),
        'dimensions_z': float(getattr(shape, 'dimensions', [0,0,0])[2]),
        'axis': getattr(shape, 'axis', '')
    }


shape_calculators = {
    Sphere: lambda grid, **data: sphere(
        grid, data['center_x'], data['center_y'], data['center_z'], 
        data['radius'], 1e-7
    ),
    Ellipsoid: lambda grid, **data: ellipsoid(
        grid, data['center_x'], data['center_y'], data['center_z'],
        data['semi_axes_x'], data['semi_axes_y'], data['semi_axes_z'], 1e-7
    ),
    Rectangle: lambda grid, **data: rectangle(
        grid, data['center_x'], data['center_y'], data['center_z'],
        data['dimensions_x'], data['dimensions_y'], data['dimensions_z'], 1e-7
    ),
    Cylinder: lambda grid, **data: cylinder(
        grid, data['center_x'], data['center_y'], data['center_z'],
        data['radius'], data['height'], data['axis'], 1e-7
    ),
    Prism: lambda grid, **data: prism(
        grid, data['center_x'], data['center_y'], data['center_z'],
        data['radius'], data['height'], data['sides'], data['axis'], 1e-7
    )
}

def mark_shapes(grid_shape, lattice_spacing, shapes, voxel_center=True):
    nx, ny, nz = grid_shape
    offset = 0.5 if voxel_center else 0
    value_array = np.zeros(grid_shape, dtype=np.complex128)  # Changed from complex64
    epsilon = 1e-7
    buffer = 1

    for shape in shapes:
        shape_data = create_shape_data(shape)
        center = np.array([shape_data['center_x'], shape_data['center_y'], shape_data['center_z']], dtype=np.float64)
        
        max_dist = max(
            getattr(shape, 'radius', 0),
            getattr(shape, 'height', 0)/2,
            max(getattr(shape, 'semi_axes', [0,0,0])),
            max(getattr(shape, 'dimensions', [0,0,0]))/2
        )
        
        max_grid_dist = int(float(max_dist + epsilon) / lattice_spacing) + buffer
        center_idx = (center / lattice_spacing).astype(int)
        
        x_min = max(0, center_idx[0] - max_grid_dist)
        x_max = min(nx, center_idx[0] + max_grid_dist + 1)
        y_min = max(0, center_idx[1] - max_grid_dist)
        y_max = min(ny, center_idx[1] + max_grid_dist + 1)
        z_min = max(0, center_idx[2] - max_grid_dist)
        z_max = min(nz, center_idx[2] + max_grid_dist + 1)
        
        x = (np.arange(x_min, x_max) + offset) * lattice_spacing
        y = (np.arange(y_min, y_max) + offset) * lattice_spacing
        z = (np.arange(z_min, z_max) + offset) * lattice_spacing
        X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
        local_grid = np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=1)
        
        calculator = shape_calculators[type(shape)]
        inside = calculator(local_grid, **shape_data)
            
        local_value_array = np.zeros(len(local_grid), dtype=np.complex128)  # Changed from complex64
        local_value_array[inside] = shape.value
        
        local_value_array_reshaped = local_value_array.reshape(x_max-x_min, y_max-y_min, z_max-z_min)
        
        value_array_slice = value_array[x_min:x_max, y_min:y_max, z_min:z_max]
        mask = local_value_array_reshaped != 0
        value_array_slice[mask] = local_value_array_reshaped[mask]
    
    return value_array

def mark_shapes_gpu(grid_shape, lattice_spacing, shapes, voxel_center=True):
    nx, ny, nz = grid_shape
    offset = 0.5 if voxel_center else 0
    
    x = (cp.arange(nx) + offset) * lattice_spacing
    y = (cp.arange(ny) + offset) * lattice_spacing
    z = (cp.arange(nz) + offset) * lattice_spacing
    X, Y, Z = cp.meshgrid(x, y, z, indexing='ij')
    grid_points = cp.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=1)
    
    value_array = cp.zeros(nx * ny * nz, dtype=cp.complex64)
    epsilon = 1e-7
    
    for shape in shapes:
        if isinstance(shape, Sphere):
            displacement = grid_points - cp.asarray(shape.center)
            distances = cp.sqrt((displacement ** 2).sum(axis=1))
            inside_points = distances < shape.radius
            surface_points = cp.abs(distances - shape.radius) < epsilon
            inside = inside_points | surface_points
            
        elif isinstance(shape, Ellipsoid):
            scaled_points = (grid_points - cp.asarray(shape.center)) / cp.asarray(shape.radii)
            distances = cp.sqrt((scaled_points ** 2).sum(axis=1))
            inside = distances <= 1.0 + epsilon
            
        elif isinstance(shape, Rectangle):
            inside = cp.all((grid_points >= cp.asarray(shape.min_corner)) & 
                          (grid_points <= cp.asarray(shape.max_corner)), axis=1)
            on_edge = cp.zeros_like(inside)
            for dim in range(3):
                min_face = cp.abs(grid_points[:,dim] - shape.min_corner[dim]) < epsilon
                max_face = cp.abs(grid_points[:,dim] - shape.max_corner[dim]) < epsilon
                other_dims = cp.all((grid_points >= cp.asarray(shape.min_corner) - epsilon) & 
                                  (grid_points <= cp.asarray(shape.max_corner) + epsilon), axis=1)
                on_edge |= (min_face | max_face) & other_dims
            inside |= on_edge
            
        elif isinstance(shape, Cylinder):
            axis_index = {'x': 0, 'y': 1, 'z': 2}[shape.axis]
            other_axes = [i for i in range(3) if i != axis_index]
            points_2d = grid_points[:, other_axes]
            center_2d = cp.asarray([shape.base_center[i] for i in other_axes])
            vx = points_2d[:,0] - center_2d[0]
            vy = points_2d[:,1] - center_2d[1]
            distances = cp.sqrt(vx*vx + vy*vy)
            edge_points = cp.abs(distances - shape.radius) < epsilon
            inside_base = distances <= shape.radius
            inside_base |= edge_points
            axis_coordinate = grid_points[:, axis_index]
            axis_start = shape.base_center[axis_index]
            within_height = (axis_coordinate >= axis_start - epsilon) & (axis_coordinate <= axis_start + shape.height + epsilon)
            inside = inside_base & within_height
            
        elif isinstance(shape, ArbitraryPrism):
            axis_index = {'x': 0, 'y': 1, 'z': 2}[shape.axis]
            other_axes = [i for i in range(3) if i != axis_index]
            points_2d = grid_points[:, other_axes]
            vertices = cp.asarray(shape.generate_base_vertices())
            base_vertices_2d = vertices[:, other_axes]
            inside_base = cp.zeros(len(points_2d), dtype=bool)
            edge_points = cp.zeros(len(points_2d), dtype=bool)
            n_vertices = len(base_vertices_2d)
            
            for i in range(n_vertices):
                j = (i + 1) % n_vertices
                xi, yi = base_vertices_2d[i]
                xj, yj = base_vertices_2d[j]
                dx = xj - xi
                dy = yj - yi
                length = cp.sqrt(dx*dx + dy*dy)
                
                if length > epsilon:
                    dx = dx / length
                    dy = dy / length
                    vx = points_2d[:,0] - xi
                    vy = points_2d[:,1] - yi
                    proj = vx*dx + vy*dy
                    dist = cp.abs(vx*dy - vy*dx)
                    on_edge = (dist < epsilon) & (proj >= -epsilon) & (proj <= length + epsilon)
                    edge_points |= on_edge
                
                intersect = ((yi > points_2d[:,1]) != (yj > points_2d[:,1])) & (points_2d[:,0] < (xj - xi) * (points_2d[:,1] - yi) / (yj - yi) + xi)
                inside_base ^= intersect
            
            inside_base |= edge_points
            axis_coordinate = grid_points[:, axis_index]
            prism_start = shape.center[axis_index]
            prism_end = prism_start + shape.height
            within_height = (axis_coordinate >= prism_start - epsilon) & (axis_coordinate <= prism_end + epsilon)
            inside = inside_base & within_height
        
        else:
            raise ValueError(f"Unsupported shape type: {type(shape)}")
        
        value_array[inside] = shape.value
    
    cp.cuda.Stream.null.synchronize()
    return cp.asnumpy(value_array.reshape(grid_shape))


def create_alpha_array(shape_manager, lattice_spacing, voxel_center=True):
    nx, ny, nz = shape_manager.calculate_grid_extents(lattice_spacing)
    grid_shape = (nx, ny, nz)
    print('grid shape')
    print(grid_shape)
    
    value_array = mark_shapes(grid_shape, lattice_spacing, shape_manager.shapes.values(), voxel_center=voxel_center)
    value_array = trim_zero_faces(value_array)
    #value_array = optimize_array_size(value_array)
    
    # Print non-zero elements in each x-slice
    for i in range(value_array.shape[0]):
        non_zero_count = np.count_nonzero(value_array[i])
        print(f"Slice {i}: {non_zero_count} non-zero elements")
    
    return value_array, value_array.shape

def is_good_fft_size(n):
    """Check if size is composed ONLY of factors 2,3,5,7,11,13,17,19."""
    if n <= 0:
        return False
    
    for prime in [2, 3, 5, 7, 11, 13, 17, 19]:
        while n % prime == 0:
            n //= prime
    return n == 1

def find_next_even_good_size(n):
    """
    Find the next size that is:
    - At least n + 2
    - Even
    - Only composed of the prime factors [2,3,5,7,11,13,17,19]
    """
    # Must be at least n+2:
    n += 2
    # Ensure it is even:
    if n % 2 == 1:
        n += 1
    # Now increment in steps of 2 until it's a good FFT size:
    while not is_good_fft_size(n):
        n += 2
    return n

def optimize_array_size(array):
    """
    Takes a 3D NumPy array, finds the next 'friendly' FFT size for each dimension
    (which must be even and at least 2 greater), then pads the array with zeros
    around (both sides) without cutting off data.
    """
    nx, ny, nz = array.shape

    # Find the next even good size for each dimension
    next_x = find_next_even_good_size(nx)
    next_y = find_next_even_good_size(ny)
    next_z = find_next_even_good_size(nz)

    # Calculate how much we need to pad on each end
    diff_x = next_x - nx
    diff_y = next_y - ny
    diff_z = next_z - nz

    pad_x_left = diff_x // 2
    pad_x_right = diff_x - pad_x_left
    pad_y_left = diff_y // 2
    pad_y_right = diff_y - pad_y_left
    pad_z_left = diff_z // 2
    pad_z_right = diff_z - pad_z_left

    # Create padded array
    padded = np.pad(array, 
                    pad_width=((pad_x_left, pad_x_right),
                               (pad_y_left, pad_y_right),
                               (pad_z_left, pad_z_right)),
                    mode='constant', 
                    constant_values=0)

    return padded

def trim_zero_faces(array):
    x_min, x_max = 0, array.shape[0] - 1
    y_min, y_max = 0, array.shape[1] - 1
    z_min, z_max = 0, array.shape[2] - 1

    # Trim each face
    while x_min <= x_max and np.all(array[x_min, :, :] == 0): x_min += 1
    while x_max >= x_min and np.all(array[x_max, :, :] == 0): x_max -= 1
    while y_min <= y_max and np.all(array[:, y_min, :] == 0): y_min += 1
    while y_max >= y_min and np.all(array[:, y_max, :] == 0): y_max -= 1
    while z_min <= z_max and np.all(array[:, :, z_min] == 0): z_min += 1
    while z_max >= z_min and np.all(array[:, :, z_max] == 0): z_max -= 1

    return array[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1] if x_min <= x_max else np.array([])




def trim_zero_faces(array):
    x_min, x_max = 0, array.shape[0] - 1
    y_min, y_max = 0, array.shape[1] - 1
    z_min, z_max = 0, array.shape[2] - 1

    # Trim each face
    while x_min <= x_max and np.all(array[x_min, :, :] == 0): x_min += 1
    while x_max >= x_min and np.all(array[x_max, :, :] == 0): x_max -= 1
    while y_min <= y_max and np.all(array[:, y_min, :] == 0): y_min += 1
    while y_max >= y_min and np.all(array[:, y_max, :] == 0): y_max -= 1
    while z_min <= z_max and np.all(array[:, :, z_min] == 0): z_min += 1
    while z_max >= z_min and np.all(array[:, :, z_max] == 0): z_max -= 1

    return array[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1] if x_min <= x_max else np.array([])

def is_good_fft_size(n):
    """Check if size is composed ONLY of factors 2,3,5,7,11,13,17,19"""
    if n <= 0:
        return False
    
    for prime in [2,3,5,7,11,13,17,19]:
        while n % prime == 0:
            n //= prime
    return n == 1





# Single precision kernel
greens_kernel_single = cp.ElementwiseKernel(
    'int32 nx, int32 ny, int32 nz, float32 k, float32 lattice_spacing, int32 component',
    'raw complex64 interaction_matrix',
    '''
    int ix = i / (2 * ny * 2 * nz);
    int iy = (i / (2 * nz)) % (2 * ny);
    int iz = i % (2 * nz);
    float positions[3] = {
        ix * lattice_spacing,
        iy * lattice_spacing,
        iz * lattice_spacing
    };
    if (ix >= nx) positions[0] -= 2 * nx * lattice_spacing;
    if (iy >= ny) positions[1] -= 2 * ny * lattice_spacing;
    if (iz >= nz) positions[2] -= 2 * nz * lattice_spacing;
    
    if (ix == 0 && iy == 0 && iz == 0) {
        interaction_matrix[i] = complex<float>(0, 0);
        return;
    }
    
    const float r_squared = positions[0] * positions[0] +
                          positions[1] * positions[1] +
                          positions[2] * positions[2];
    const float r = sqrt(r_squared);
    const float inv_r = 1.0f / r;
    const float directions[3] = {
        positions[0] * inv_r,
        positions[1] * inv_r,
        positions[2] * inv_r
    };
    
    float normalized_component;
    if (component == 0) normalized_component = directions[0] * directions[0];      // xx
    else if (component == 1) normalized_component = directions[0] * directions[1]; // xy
    else if (component == 2) normalized_component = directions[0] * directions[2]; // xz
    else if (component == 3) normalized_component = directions[1] * directions[1]; // yy
    else if (component == 4) normalized_component = directions[1] * directions[2]; // yz
    else normalized_component = directions[2] * directions[2];                     // zz
    
    const complex<float> exp_term = exp(complex<float>(0, k * r)) * inv_r;
    const complex<float> term2_factor = (complex<float>(0, k * r) - complex<float>(1, 0)) / r_squared;
    const bool is_diagonal = (component == 0 || component == 3 || component == 5);
    float multiplier = 1.0f;
    if (ix == nx || iy == ny || iz == nz) multiplier = 0.0f;
    
    const float term1 = k * k * (normalized_component - (is_diagonal ? 1.0f : 0.0f));
    const complex<float> term2 = (3.0f * normalized_component - (is_diagonal ? 1.0f : 0.0f)) * term2_factor;
    interaction_matrix[i] = multiplier * exp_term * (term1 + term2);
    ''',
    'dipole_greens_single'
)

# Double precision kernel
greens_kernel_double = cp.ElementwiseKernel(
    'int32 nx, int32 ny, int32 nz, float64 k, float64 lattice_spacing, int32 component',
    'raw complex128 interaction_matrix',
    '''
    int ix = i / (2 * ny * 2 * nz);
    int iy = (i / (2 * nz)) % (2 * ny);
    int iz = i % (2 * nz);
    double positions[3] = {
        ix * lattice_spacing,
        iy * lattice_spacing,
        iz * lattice_spacing
    };
    if (ix >= nx) positions[0] -= 2 * nx * lattice_spacing;
    if (iy >= ny) positions[1] -= 2 * ny * lattice_spacing;
    if (iz >= nz) positions[2] -= 2 * nz * lattice_spacing;
    
    if (ix == 0 && iy == 0 && iz == 0) {
        interaction_matrix[i] = complex<double>(0, 0);
        return;
    }
    
    const double r_squared = positions[0] * positions[0] +
                           positions[1] * positions[1] +
                           positions[2] * positions[2];
    const double r = sqrt(r_squared);
    const double inv_r = 1.0 / r;
    const double directions[3] = {
        positions[0] * inv_r,
        positions[1] * inv_r,
        positions[2] * inv_r
    };
    
    double normalized_component;
    if (component == 0) normalized_component = directions[0] * directions[0];      // xx
    else if (component == 1) normalized_component = directions[0] * directions[1]; // xy
    else if (component == 2) normalized_component = directions[0] * directions[2]; // xz
    else if (component == 3) normalized_component = directions[1] * directions[1]; // yy
    else if (component == 4) normalized_component = directions[1] * directions[2]; // yz
    else normalized_component = directions[2] * directions[2];                     // zz
    
    const complex<double> exp_term = exp(complex<double>(0, k * r)) * inv_r;
    const complex<double> term2_factor = (complex<double>(0, k * r) - complex<double>(1, 0)) / r_squared;
    const bool is_diagonal = (component == 0 || component == 3 || component == 5);
    double multiplier = 1.0;
    if (ix == nx || iy == ny || iz == nz) multiplier = 0.0;
    
    const double term1 = k * k * (normalized_component - (is_diagonal ? 1.0 : 0.0));
    const complex<double> term2 = (3.0 * normalized_component - (is_diagonal ? 1.0 : 0.0)) * term2_factor;
    interaction_matrix[i] = multiplier * exp_term * (term1 + term2);
    ''',
    'dipole_greens_double'
)

@free_gpu_memory    
def generate_interaction_row(nx, ny, nz, k, lattice_spacing, reduced=True, double_precision=False):
    # Set precision-dependent variables
    if double_precision:
        dtype = cp.complex128
        kernel = greens_kernel_double
        k = float(k)
        lattice_spacing = float(lattice_spacing)
    else:
        dtype = cp.complex64
        kernel = greens_kernel_single
        k = np.float32(k)
        lattice_spacing = np.float32(lattice_spacing)

    # Initialize result array
    result = cp.zeros((2*nx, 2*ny, 2*nz, 6), dtype=dtype)
    
    # Compute for each component
    for component in range(6):
        kernel(nx, ny, nz, k, lattice_spacing, component,
               result[..., component], size=8*nx*ny*nz)
    
    # FFT transform
    if reduced:
        shape = (nx+1, ny+1, nz+1, 6)
    else:
        shape = result.shape
        
    interaction_matrix_fft = cp.zeros(shape, dtype=dtype)
    for i in range(6):
        interaction_matrix_fft[..., i] = cp.fft.fftn(-result[..., i], 
                                                    axes=(0,1,2))[:shape[0], :shape[1], :shape[2]]
    
    cp.get_default_memory_pool().free_all_blocks()
    return interaction_matrix_fft




precon_kernel_single = cp.ElementwiseKernel(
    'int32 nx, int32 ny, int32 nz, float32 k, float32 lattice_spacing, int32 component',
    'raw complex64 precon',
    '''
    int ix = i / (ny * nz);
    int iy = (i / nz) % ny;
    int iz = i % nz;

    float positions[3] = {
        ix * lattice_spacing,
        iy * lattice_spacing,
        iz * lattice_spacing
    };

    if (ix == 0 && iy == 0 && iz == 0) {
        precon[i] = complex<float>(0, 0);
        return;
    }

    const float r_squared = positions[0] * positions[0] + 
                          positions[1] * positions[1] + 
                          positions[2] * positions[2];
    const float r = sqrtf(r_squared);
    const float inv_r = 1.0f / r;

    const float directions[3] = {
        positions[0] * inv_r,
        positions[1] * inv_r,
        positions[2] * inv_r
    };

    float normalized_component;
    if (component == 0) normalized_component = directions[0] * directions[0];      // xx
    else if (component == 1) normalized_component = directions[0] * directions[1]; // xy
    else if (component == 2) normalized_component = directions[0] * directions[2]; // xz
    else if (component == 3) normalized_component = directions[1] * directions[1]; // yy
    else if (component == 4) normalized_component = directions[1] * directions[2]; // yz
    else normalized_component = directions[2] * directions[2];                     // zz

    const complex<float> exp_term = exp(complex<float>(0, k * r)) * inv_r;
    const complex<float> term2_factor = (complex<float>(0, k * r) - complex<float>(1, 0)) / r_squared;
    const bool is_diagonal = (component == 0 || component == 3 || component == 5);
    
    const float term1 = k * k * (normalized_component - (is_diagonal ? 1.0f : 0.0f));
    const complex<float> term2 = (3.0f * normalized_component - (is_diagonal ? 1.0f : 0.0f)) * term2_factor;
    
    precon[i] = exp_term * (term1 + term2);
    ''',
    'precon_gen_single'
)

# Double precision kernel
precon_kernel_double = cp.ElementwiseKernel(
    'int32 nx, int32 ny, int32 nz, float64 k, float64 lattice_spacing, int32 component',
    'raw complex128 precon',
    '''
    int ix = i / (ny * nz);
    int iy = (i / nz) % ny;
    int iz = i % nz;

    double positions[3] = {
        ix * lattice_spacing,
        iy * lattice_spacing,
        iz * lattice_spacing
    };

    if (ix == 0 && iy == 0 && iz == 0) {
        precon[i] = complex<double>(0, 0);
        return;
    }

    const double r_squared = positions[0] * positions[0] + 
                           positions[1] * positions[1] + 
                           positions[2] * positions[2];
    const double r = sqrt(r_squared);
    const double inv_r = 1.0 / r;

    const double directions[3] = {
        positions[0] * inv_r,
        positions[1] * inv_r,
        positions[2] * inv_r
    };

    double normalized_component;
    if (component == 0) normalized_component = directions[0] * directions[0];      // xx
    else if (component == 1) normalized_component = directions[0] * directions[1]; // xy
    else if (component == 2) normalized_component = directions[0] * directions[2]; // xz
    else if (component == 3) normalized_component = directions[1] * directions[1]; // yy
    else if (component == 4) normalized_component = directions[1] * directions[2]; // yz
    else normalized_component = directions[2] * directions[2];                     // zz

    const complex<double> exp_term = exp(complex<double>(0, k * r)) * inv_r;
    const complex<double> term2_factor = (complex<double>(0, k * r) - complex<double>(1, 0)) / r_squared;
    const bool is_diagonal = (component == 0 || component == 3 || component == 5);
    
    const double term1 = k * k * (normalized_component - (is_diagonal ? 1.0 : 0.0));
    const complex<double> term2 = (3.0 * normalized_component - (is_diagonal ? 1.0 : 0.0)) * term2_factor;
    
    precon[i] = exp_term * (term1 + term2);
    ''',
    'precon_gen_double'
)

@free_gpu_memory            
def pre_preconditioner(nx, ny, nz, k, lattice_spacing, x_expansion, y_expansion, z_expansion, double_precision=False):
    pad_x = x_expansion
    pad_y = y_expansion
    pad_z = z_expansion
    print(f"Original shape: ({nx}, {ny}, {nz})")
    print(f"Padded shape: ({pad_x}, {pad_y}, {pad_z})")

    if double_precision:
        dtype = cp.complex128
        kernel = precon_kernel_double
        k = float(k)
        lattice_spacing = float(lattice_spacing)
    else:
        dtype = cp.complex64
        kernel = precon_kernel_single
        k = np.float32(k)
        lattice_spacing = np.float32(lattice_spacing)

    precon = cp.zeros((pad_x, pad_y, pad_z, 6), dtype=dtype)
    
    for component in range(6):
        kernel(pad_x, pad_y, pad_z, k, lattice_spacing, component, 
               precon.reshape(-1, 6)[:, component], 
               size=pad_x*pad_y*pad_z)
        cp.get_default_memory_pool().free_all_blocks()        
    return precon



invert_3x3_kernel = cp.RawKernel(r'''
#include <cupy/complex.cuh>
extern "C" __global__
void invert_3x3_kernel(complex<float>* data, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        int matrix_idx = idx * 6;
        complex<float> a = data[matrix_idx];
        complex<float> b = data[matrix_idx+1];
        complex<float> c = data[matrix_idx+2];
        complex<float> d = data[matrix_idx+3];
        complex<float> e = data[matrix_idx+4];
        complex<float> f = data[matrix_idx+5];

        complex<float> det = a * (d * f - e * e) - b * (b * f - c * e) + c * (b * e - c * d);
        complex<float> invdet = 1.0f / det;

        complex<float> temp0 = (d * f - e * e) * invdet;
        complex<float> temp1 = (c * e - b * f) * invdet;
        complex<float> temp2 = (b * e - c * d) * invdet;
        complex<float> temp3 = (a * f - c * c) * invdet;
        complex<float> temp4 = (b * c - a * e) * invdet;
        complex<float> temp5 = (a * d - b * b) * invdet;

        data[matrix_idx]   = temp0;
        data[matrix_idx+1] = temp1;
        data[matrix_idx+2] = temp2;
        data[matrix_idx+3] = temp3;
        data[matrix_idx+4] = temp4;
        data[matrix_idx+5] = temp5;
    }
}
''', 'invert_3x3_kernel')

def invert_3x3_matrices_gpu(data):
    n = data.size // 6
    threads_per_block = 256
    blocks = (n + threads_per_block - 1) // threads_per_block
    invert_3x3_kernel((blocks,), (threads_per_block,), (data, n))

invert_3x3_kernel_double = cp.RawKernel(r'''
#include <cupy/complex.cuh>
extern "C" __global__
void invert_3x3_kernel_double(complex<double>* data, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) {
        int matrix_idx = idx * 6;
        complex<double> a = data[matrix_idx];
        complex<double> b = data[matrix_idx+1];
        complex<double> c = data[matrix_idx+2];
        complex<double> d = data[matrix_idx+3];
        complex<double> e = data[matrix_idx+4];
        complex<double> f = data[matrix_idx+5];

        complex<double> det = a * (d * f - e * e) - b * (b * f - c * e) + c * (b * e - c * d);
        complex<double> invdet = 1.0 / det;

        complex<double> temp0 = (d * f - e * e) * invdet;
        complex<double> temp1 = (c * e - b * f) * invdet;
        complex<double> temp2 = (b * e - c * d) * invdet;
        complex<double> temp3 = (a * f - c * c) * invdet;
        complex<double> temp4 = (b * c - a * e) * invdet;
        complex<double> temp5 = (a * d - b * b) * invdet;

        data[matrix_idx]   = temp0;
        data[matrix_idx+1] = temp1;
        data[matrix_idx+2] = temp2;
        data[matrix_idx+3] = temp3;
        data[matrix_idx+4] = temp4;
        data[matrix_idx+5] = temp5;
    }
}
''', 'invert_3x3_kernel_double')

def invert_3x3_matrices_gpu_double(data):
    n = data.size // 6
    threads_per_block = 256
    blocks = (n + threads_per_block - 1) // threads_per_block
    invert_3x3_kernel_double((blocks,), (threads_per_block,), (data, n))    

x_average_kernel = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void x_average(complex<float>* data, const float* sym_x, 
               const int nx, const int ny, const int nz, const int nv,
               const int stride_x, const int stride_y, const int stride_z) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = ny * nz * nv;
    
    if (idx < total_points) {
        int component = idx % nv;
        int z = (idx / nv) % nz;
        int y = (idx / (nv * nz));
        
        int points_to_average = (nx - 1) / 2;
        
        for (int x = 1; x <= points_to_average; x++) {
            float wx = float(x) / float(nx);
            float wnx = float(nx - x) / float(nx);
            
            int forward_idx = x * stride_x + y * stride_y + z * stride_z + component;
            int reverse_idx = (nx-x) * stride_x + y * stride_y + z * stride_z + component;
            
            complex<float> forward_val = data[forward_idx];
            complex<float> reverse_val = data[reverse_idx];
            complex<float> averaged = wnx * forward_val + sym_x[component] * wx * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_x[component] < 0 ? -averaged : averaged;
        }
        
        if (nx % 2 == 0) {
            int middle_x = nx / 2;
            int middle_idx = middle_x * stride_x + y * stride_y + z * stride_z + component;
            if (sym_x[component] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'x_average')

# Y-axis averaging kernel
y_average_kernel = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void y_average(complex<float>* data, const float* sym_y, 
               const int nx, const int ny, const int nz, const int nv,
               const int stride_x, const int stride_y, const int stride_z) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = nx * nz * nv;
    
    if (idx < total_points) {
        int component = idx % nv;
        int z = (idx / nv) % nz;
        int x = (idx / (nv * nz));
        
        int points_to_average = (ny - 1) / 2;
        
        for (int y = 1; y <= points_to_average; y++) {
            float wy = float(y) / float(ny);
            float wny = float(ny - y) / float(ny);
            
            int forward_idx = x * stride_x + y * stride_y + z * stride_z + component;
            int reverse_idx = x * stride_x + (ny-y) * stride_y + z * stride_z + component;
            
            complex<float> forward_val = data[forward_idx];
            complex<float> reverse_val = data[reverse_idx];
            complex<float> averaged = wny * forward_val + sym_y[component] * wy * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_y[component] < 0 ? -averaged : averaged;
        }
        
        if (ny % 2 == 0) {
            int middle_y = ny / 2;
            int middle_idx = x * stride_x + middle_y * stride_y + z * stride_z + component;
            if (sym_y[component] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'y_average')

# Z-axis averaging kernel
z_average_kernel = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void z_average(complex<float>* data, const float* sym_z, 
               const int nx, const int ny, const int nz, const int nv,
               const int stride_x, const int stride_y, const int stride_z) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = nx * ny * nv;
    
    if (idx < total_points) {
        int component = idx % nv;
        int y = (idx / nv) % ny;
        int x = (idx / (nv * ny));
        
        int points_to_average = (nz - 1) / 2;
        
        for (int z = 1; z <= points_to_average; z++) {
            float wz = float(z) / float(nz);
            float wnz = float(nz - z) / float(nz);
            
            int forward_idx = x * stride_x + y * stride_y + z * stride_z + component;
            int reverse_idx = x * stride_x + y * stride_y + (nz-z) * stride_z + component;
            
            complex<float> forward_val = data[forward_idx];
            complex<float> reverse_val = data[reverse_idx];
            complex<float> averaged = wnz * forward_val + sym_z[component] * wz * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_z[component] < 0 ? -averaged : averaged;
        }
        
        if (nz % 2 == 0) {
            int middle_z = nz / 2;
            int middle_idx = x * stride_x + y * stride_y + middle_z * stride_z + component;
            if (sym_z[component] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'z_average')

x_average_kernel_double = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void x_average_double(complex<double>* data, const double* sym_x, 
               const int nx, const int ny, const int nz, const int nv,
               const int stride_x, const int stride_y, const int stride_z) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = ny * nz * nv;
    
    if (idx < total_points) {
        int component = idx % nv;
        int z = (idx / nv) % nz;
        int y = (idx / (nv * nz));
        
        int points_to_average = (nx - 1) / 2;
        
        for (int x = 1; x <= points_to_average; x++) {
            double wx = double(x) / double(nx);
            double wnx = double(nx - x) / double(nx);
            
            int forward_idx = x * stride_x + y * stride_y + z * stride_z + component;
            int reverse_idx = (nx-x) * stride_x + y * stride_y + z * stride_z + component;
            
            complex<double> forward_val = data[forward_idx];
            complex<double> reverse_val = data[reverse_idx];
            complex<double> averaged = wnx * forward_val + sym_x[component] * wx * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_x[component] < 0 ? -averaged : averaged;
        }
        
        if (nx % 2 == 0) {
            int middle_x = nx / 2;
            int middle_idx = middle_x * stride_x + y * stride_y + z * stride_z + component;
            if (sym_x[component] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'x_average_double')

y_average_kernel_double = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void y_average_double(complex<double>* data, const double* sym_y, 
               const int nx, const int ny, const int nz, const int nv,
               const int stride_x, const int stride_y, const int stride_z) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = nx * nz * nv;
    
    if (idx < total_points) {
        int component = idx % nv;
        int z = (idx / nv) % nz;
        int x = (idx / (nv * nz));
        
        int points_to_average = (ny - 1) / 2;
        
        for (int y = 1; y <= points_to_average; y++) {
            double wy = double(y) / double(ny);
            double wny = double(ny - y) / double(ny);
            
            int forward_idx = x * stride_x + y * stride_y + z * stride_z + component;
            int reverse_idx = x * stride_x + (ny-y) * stride_y + z * stride_z + component;
            
            complex<double> forward_val = data[forward_idx];
            complex<double> reverse_val = data[reverse_idx];
            complex<double> averaged = wny * forward_val + sym_y[component] * wy * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_y[component] < 0 ? -averaged : averaged;
        }
        
        if (ny % 2 == 0) {
            int middle_y = ny / 2;
            int middle_idx = x * stride_x + middle_y * stride_y + z * stride_z + component;
            if (sym_y[component] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'y_average_double')

z_average_kernel_double = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void z_average_double(complex<double>* data, const double* sym_z, 
               const int nx, const int ny, const int nz, const int nv,
               const int stride_x, const int stride_y, const int stride_z) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = nx * ny * nv;
    
    if (idx < total_points) {
        int component = idx % nv;
        int y = (idx / nv) % ny;
        int x = (idx / (nv * ny));
        
        int points_to_average = (nz - 1) / 2;
        
        for (int z = 1; z <= points_to_average; z++) {
            double wz = double(z) / double(nz);
            double wnz = double(nz - z) / double(nz);
            
            int forward_idx = x * stride_x + y * stride_y + z * stride_z + component;
            int reverse_idx = x * stride_x + y * stride_y + (nz-z) * stride_z + component;
            
            complex<double> forward_val = data[forward_idx];
            complex<double> reverse_val = data[reverse_idx];
            complex<double> averaged = wnz * forward_val + sym_z[component] * wz * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_z[component] < 0 ? -averaged : averaged;
        }
        
        if (nz % 2 == 0) {
            int middle_z = nz / 2;
            int middle_idx = x * stride_x + y * stride_y + middle_z * stride_z + component;
            if (sym_z[component] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'z_average_double')


x_average_2d_kernel = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void x_average_2d(complex<float>* data, const float* sym_x, 
               const int nx, const int ny, const int nz3, const int nz3_2,
               const int stride_x, const int stride_y, const int stride_z1, const int stride_z2) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = ny * nz3 * nz3_2;
    
    if (idx < total_points) {
        int z2 = idx % nz3_2;
        int z1 = (idx / nz3_2) % nz3;
        int y = idx / (nz3 * nz3_2);
        
        int points_to_average = (nx - 1) / 2;
        
        for (int x = 1; x <= points_to_average; x++) {
            float wx = float(x) / float(nx);
            float wnx = float(nx - x) / float(nx);
            
            int forward_idx = x * stride_x + y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            int reverse_idx = (nx-x) * stride_x + y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            
            complex<float> forward_val = data[forward_idx];
            complex<float> reverse_val = data[reverse_idx];
            complex<float> averaged = wnx * forward_val + sym_x[z1/3] * wx * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_x[z1/3] < 0 ? -averaged : averaged;
        }
        
        if (nx % 2 == 0) {
            int middle_x = nx / 2;
            int middle_idx = middle_x * stride_x + y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            if (sym_x[z1/3] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'x_average_2d')

y_average_2d_kernel = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void y_average_2d(complex<float>* data, const float* sym_y, 
               const int nx, const int ny, const int nz3, const int nz3_2,
               const int stride_x, const int stride_y, const int stride_z1, const int stride_z2) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = nx * nz3 * nz3_2;
    
    if (idx < total_points) {
        int z2 = idx % nz3_2;
        int z1 = (idx / nz3_2) % nz3;
        int x = idx / (nz3 * nz3_2);
        
        int points_to_average = (ny - 1) / 2;
        
        for (int y = 1; y <= points_to_average; y++) {
            float wy = float(y) / float(ny);
            float wny = float(ny - y) / float(ny);
            
            int forward_idx = x * stride_x + y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            int reverse_idx = x * stride_x + (ny-y) * stride_y + z1 * stride_z1 + z2 * stride_z2;
            
            complex<float> forward_val = data[forward_idx];
            complex<float> reverse_val = data[reverse_idx];
            complex<float> averaged = wny * forward_val + sym_y[z1/3] * wy * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_y[z1/3] < 0 ? -averaged : averaged;
        }
        
        if (ny % 2 == 0) {
            int middle_y = ny / 2;
            int middle_idx = x * stride_x + middle_y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            if (sym_y[z1/3] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'y_average_2d')

x_average_2d_kernel_double = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void x_average_2d_double(complex<double>* data, const double* sym_x, 
               const int nx, const int ny, const int nz3, const int nz3_2,
               const int stride_x, const int stride_y, const int stride_z1, const int stride_z2) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = ny * nz3 * nz3_2;
    
    if (idx < total_points) {
        int z2 = idx % nz3_2;
        int z1 = (idx / nz3_2) % nz3;
        int y = idx / (nz3 * nz3_2);
        
        int points_to_average = (nx - 1) / 2;
        
        for (int x = 1; x <= points_to_average; x++) {
            double wx = double(x) / double(nx);
            double wnx = double(nx - x) / double(nx);
            
            int forward_idx = x * stride_x + y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            int reverse_idx = (nx-x) * stride_x + y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            
            complex<double> forward_val = data[forward_idx];
            complex<double> reverse_val = data[reverse_idx];
            complex<double> averaged = wnx * forward_val + sym_x[z1/3] * wx * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_x[z1/3] < 0 ? -averaged : averaged;
        }
        
        if (nx % 2 == 0) {
            int middle_x = nx / 2;
            int middle_idx = middle_x * stride_x + y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            if (sym_x[z1/3] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'x_average_2d_double')

y_average_2d_kernel_double = cp.RawKernel(r'''
#include <cupy/complex.cuh>

extern "C" __global__
void y_average_2d_double(complex<double>* data, const double* sym_y, 
               const int nx, const int ny, const int nz3, const int nz3_2,
               const int stride_x, const int stride_y, const int stride_z1, const int stride_z2) {
    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int total_points = nx * nz3 * nz3_2;
    
    if (idx < total_points) {
        int z2 = idx % nz3_2;
        int z1 = (idx / nz3_2) % nz3;
        int x = idx / (nz3 * nz3_2);
        
        int points_to_average = (ny - 1) / 2;
        
        for (int y = 1; y <= points_to_average; y++) {
            double wy = double(y) / double(ny);
            double wny = double(ny - y) / double(ny);
            
            int forward_idx = x * stride_x + y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            int reverse_idx = x * stride_x + (ny-y) * stride_y + z1 * stride_z1 + z2 * stride_z2;
            
            complex<double> forward_val = data[forward_idx];
            complex<double> reverse_val = data[reverse_idx];
            complex<double> averaged = wny * forward_val + sym_y[z1/3] * wy * reverse_val;
            
            data[forward_idx] = averaged;
            data[reverse_idx] = sym_y[z1/3] < 0 ? -averaged : averaged;
        }
        
        if (ny % 2 == 0) {
            int middle_y = ny / 2;
            int middle_idx = x * stride_x + middle_y * stride_y + z1 * stride_z1 + z2 * stride_z2;
            if (sym_y[z1/3] < 0) {
                data[middle_idx] = 0;
            }
        }
    }
}
''', 'y_average_2d_double')


@free_gpu_memory    
def circulant_approximation(original_data, alpha_array, refract_mult, reduced=True, double_precision=False):
    cp.get_default_memory_pool().free_all_blocks()
    
    # Set precision-dependent variables
    if double_precision:
        dtype = cp.complex128
        x_kernel = x_average_kernel_double
        y_kernel = y_average_kernel_double
        z_kernel = z_average_kernel_double
        invert_kernel = invert_3x3_matrices_gpu_double
    else:
        dtype = cp.complex64
        x_kernel = x_average_kernel
        y_kernel = y_average_kernel
        z_kernel = z_average_kernel
        invert_kernel = invert_3x3_matrices_gpu
    
    nx, ny, nz, nv = original_data.shape
    
    # Initial setup
    raw_mean = cp.mean(alpha_array[alpha_array != 0])
    alpha_avg = refract_mult * raw_mean.astype(dtype)
    #alpha_avg = 1/refract_mult
    #alpha_avg =  raw_mean.astype(dtype)
    print(f"Mean before multiplication: {raw_mean}")
    print(f"Mean after multiplication with {refract_mult}: {alpha_avg}")
    
    # Define symmetry arrays with appropriate precision
    if double_precision:
        sym_x = cp.array([1, -1, -1, 1, 1, 1], dtype=cp.float64)
        sym_y = cp.array([1, -1, 1, 1, -1, 1], dtype=cp.float64)
        sym_z = cp.array([1, 1, -1, 1, -1, 1], dtype=cp.float64)
    else:
        sym_x = cp.array([1, -1, -1, 1, 1, 1], dtype=cp.float32)
        sym_y = cp.array([1, -1, 1, 1, -1, 1], dtype=cp.float32)
        sym_z = cp.array([1, 1, -1, 1, -1, 1], dtype=cp.float32)
    
    indices = cp.array([0, 3, 5])
    original_data[0, 0, 0, indices] += alpha_avg
    del indices, alpha_avg, raw_mean  # Clean up unused variables
    cp.get_default_memory_pool().free_all_blocks()
    
    # Get strides for kernel indexing
    strides = original_data.strides
    stride_x = strides[0] // (16 if double_precision else 8)
    stride_y = strides[1] // (16 if double_precision else 8)
    stride_z = strides[2] // (16 if double_precision else 8)
    
    # Set up kernel parameters
    threads_per_block = 256
    
    # X averaging
    blocks_per_grid = (ny * nz * nv + threads_per_block - 1) // threads_per_block
    x_kernel((blocks_per_grid,), (threads_per_block,),
            (original_data, sym_x, nx, ny, nz, nv,
             stride_x, stride_y, stride_z))
    cp.get_default_memory_pool().free_all_blocks()
    
    # Y averaging
    blocks_per_grid = (nx * nz * nv + threads_per_block - 1) // threads_per_block
    y_kernel((blocks_per_grid,), (threads_per_block,),
            (original_data, sym_y, nx, ny, nz, nv,
             stride_x, stride_y, stride_z))
    cp.get_default_memory_pool().free_all_blocks()
    
    # Z averaging
    blocks_per_grid = (nx * ny * nv + threads_per_block - 1) // threads_per_block
    z_kernel((blocks_per_grid,), (threads_per_block,),
            (original_data, sym_z, nx, ny, nz, nv,
             stride_x, stride_y, stride_z))
    cp.get_default_memory_pool().free_all_blocks()

    # Forward FFT
    original_data = torch.from_dlpack(original_data)
    for i in range(nv):
        original_data[:, :, :, i] = torch.fft.fftn(original_data[:, :, :, i], dim=(0, 1, 2))
    original_data = cp.from_dlpack(original_data)
    cp.get_default_memory_pool().free_all_blocks()
    torch.cuda.empty_cache()
    
    # Matrix inversion
    invert_kernel(original_data)
    cp.get_default_memory_pool().free_all_blocks()
    torch.cuda.empty_cache()
    #original_data[0, 0, 0, 0] *= refract_mult  # First element
    #original_data[0, 0, 0, 3] *= refract_mult  # Fourth element
    #original_data[0, 0, 0, 5] *= refract_mult  # Sixth element
    # Inverse FFT
    original_data = torch.from_dlpack(original_data)
    for i in range(nv):
        original_data[:, :, :, i] = torch.fft.ifftn(original_data[:, :, :, i], dim=(0, 1, 2))
    original_data = cp.from_dlpack(original_data)
    cp.get_default_memory_pool().free_all_blocks()
    torch.cuda.empty_cache()
    
    # Dimension reduction
    current_nx, current_ny, current_nz = nx, ny, nz
    alpha_nx, alpha_ny, alpha_nz = alpha_array.shape
    
    while True:
        reduction_performed = False
        
        # X dimension reduction
        if current_nx >= 2 * alpha_nx:
            n2 = current_nx//2
            original_data = original_data[:n2, :, :]
            strides = original_data.strides
            stride_x = strides[0] // (16 if double_precision else 8)
            stride_y = strides[1] // (16 if double_precision else 8)
            stride_z = strides[2] // (16 if double_precision else 8)
            
            blocks_per_grid = (original_data.shape[1] * original_data.shape[2] * nv + threads_per_block - 1) // threads_per_block
            x_kernel((blocks_per_grid,), (threads_per_block,),
                    (original_data, sym_x, n2, original_data.shape[1], original_data.shape[2], nv,
                     stride_x, stride_y, stride_z))
            
            current_nx = original_data.shape[0]
            reduction_performed = True
        
        # Y dimension reduction
        if current_ny >= 2 * alpha_ny:
            n2 = current_ny//2
            original_data = original_data[:, :n2, :]
            strides = original_data.strides
            stride_x = strides[0] // (16 if double_precision else 8)
            stride_y = strides[1] // (16 if double_precision else 8)
            stride_z = strides[2] // (16 if double_precision else 8)
            
            blocks_per_grid = (original_data.shape[0] * original_data.shape[2] * nv + threads_per_block - 1) // threads_per_block
            y_kernel((blocks_per_grid,), (threads_per_block,),
                    (original_data, sym_y, original_data.shape[0], n2, original_data.shape[2], nv,
                     stride_x, stride_y, stride_z))
            
            current_ny = original_data.shape[1]
            reduction_performed = True
        
        # Z dimension reduction
        # Z dimension reduction
        if current_nz >= 2 * alpha_nz:
            n2 = current_nz//2
            original_data = original_data[:, :, :n2]
            strides = original_data.strides
            stride_x = strides[0] // (16 if double_precision else 8)
            stride_y = strides[1] // (16 if double_precision else 8)
            stride_z = strides[2] // (16 if double_precision else 8)
            
            blocks_per_grid = (original_data.shape[0] * original_data.shape[1] * nv + threads_per_block - 1) // threads_per_block
            z_kernel((blocks_per_grid,), (threads_per_block,),
                    (original_data, sym_z, original_data.shape[0], original_data.shape[1], n2, nv,
                     stride_x, stride_y, stride_z))
            
            current_nz = original_data.shape[2]
            reduction_performed = True
        
        # Free memory after each round of reductions
        cp.get_default_memory_pool().free_all_blocks()
        
        # If no reduction was performed in any dimension, we're done
        if not reduction_performed:
            break

    # Final FFT
    original_data = torch.from_dlpack(original_data)
    for i in range(nv):
        original_data[:, :, :, i] = torch.fft.fftn(original_data[:, :, :, i], dim=(0, 1, 2))
    original_data = cp.from_dlpack(original_data)
    cp.get_default_memory_pool().free_all_blocks()
    torch.cuda.empty_cache()
    
    # Clean up variables that won't be used anymore
    del sym_x, sym_y, sym_z, strides, stride_x, stride_y, stride_z
    del threads_per_block, blocks_per_grid
    torch.cuda.empty_cache()
    cp.get_default_memory_pool().free_all_blocks()
    gc.collect()

    if not reduced:
        print(f"Full size: {original_data.shape}")
        return original_data.astype(dtype)
    else:
        reduced_size = (current_nx//2+1, current_ny//2+1, current_nz//2+1)
        print(f"Full size: {original_data.shape}")
        print(f"Reduced size will be: {reduced_size}")
        result = original_data[:(current_nx//2+1), :(current_ny//2+1), :(current_nz//2+1)].astype(dtype)
        del original_data
        cp.get_default_memory_pool().free_all_blocks()
        torch.cuda.empty_cache()
        return result

        
@free_gpu_memory          
def circulant_approximationbig(original_data, alpha_array, refract_mult, reduced=True, double_precision=False):
    print('shape')
    print(alpha_array.shape)
    
    # Set precision-dependent variables
    if double_precision:
        dtype = cp.complex128
        x_kernel = x_average_kernel_double
        y_kernel = y_average_kernel_double
        sym_dtype = cp.float64
        stride_div = 16
    else:
        dtype = cp.complex64
        x_kernel = x_average_kernel
        y_kernel = y_average_kernel
        sym_dtype = cp.float32
        stride_div = 8

    data = cp.copy(original_data).astype(dtype)
    nx, ny, nz, nv = data.shape
    alpha_nx, alpha_ny, alpha_nz = alpha_array.shape
    
    alpha_avg = refract_mult * cp.max(alpha_array)
    print(alpha_avg)
    
    # Define symmetries with appropriate precision
    sym_x = cp.array([1, -1, -1, 1, 1, 1], dtype=sym_dtype)
    sym_y = cp.array([1, -1, 1, 1, -1, 1], dtype=sym_dtype)
    sym_z = cp.array([1, 1, -1, 1, -1, 1], dtype=sym_dtype)
    
    indices = cp.array([0, 3, 5])    
    data[0, 0, 0, indices] += alpha_avg
    del indices, alpha_avg
    cp.get_default_memory_pool().free_all_blocks()

    # Get strides for kernel indexing
    strides = data.strides
    stride_x = strides[0] // stride_div
    stride_y = strides[1] // stride_div
    stride_z = strides[2] // stride_div
    
    # Set up kernel parameters
    threads_per_block = 256
    
    # X averaging
    blocks_per_grid = (ny * nz * nv + threads_per_block - 1) // threads_per_block
    x_kernel((blocks_per_grid,), (threads_per_block,),
             (data, sym_x, nx, ny, nz, nv,
              stride_x, stride_y, stride_z))
    cp.get_default_memory_pool().free_all_blocks()
    
    # Y averaging
    blocks_per_grid = (nx * nz * nv + threads_per_block - 1) // threads_per_block
    y_kernel((blocks_per_grid,), (threads_per_block,),
             (data, sym_y, nx, ny, nz, nv,
              stride_x, stride_y, stride_z))
    cp.get_default_memory_pool().free_all_blocks()

    # 2D FFT using PyTorch
    data = torch.from_dlpack(data)
    for i in range(nv):
        data[:, :, :, i] = torch.fft.fft2(data[:, :, :, i], dim=(0, 1))
    data = cp.from_dlpack(data)
    cp.get_default_memory_pool().free_all_blocks()
    torch.cuda.empty_cache()

    data = data.astype(dtype)
    cp.get_default_memory_pool().free_all_blocks()

    # Matrix building section
    indices = [(0, 1, 2), (1, 3, 4), (2, 4, 5)]
    big_toeplitz_matrices = cp.zeros((nx, ny, 3*nz, 3*nz), dtype=dtype)
    idx_i = cp.repeat(cp.arange(nz), nz).reshape(nz, nz)
    idx_j = idx_i.T
    diagonal_indices = cp.arange(3*nz)
    flat_indices = [idx for row in indices for idx in row]
    
    for block_pos, block_idx in enumerate(flat_indices):
        i, j = block_pos // 3, block_pos % 3
        start_i, start_j = i*nz, j*nz
        
        if block_idx in [0, 1, 3, 5]:
            big_toeplitz_matrices[:, :, start_i:start_i+nz, start_j:start_j+nz] = data[:, :, cp.abs(idx_i-idx_j), block_idx]
        else:
            big_toeplitz_matrices[:, :, start_i:start_i+nz, start_j:start_j+nz] = cp.where(
                idx_i >= idx_j,
                data[:, :, idx_i-idx_j, block_idx],
                -data[:, :, idx_j-idx_i, block_idx]
            )
        cp.get_default_memory_pool().free_all_blocks()
    
    del data, idx_i, idx_j, diagonal_indices, flat_indices, indices
    
    big_toeplitz_matrices *= -1
    cp.get_default_memory_pool().free_all_blocks()

    big_toeplitz_matrices = cp.linalg.inv(big_toeplitz_matrices.reshape(-1, 3*nz, 3*nz)).reshape(nx, ny, 3*nz, 3*nz)
    cp.get_default_memory_pool().free_all_blocks()

    # Dimension reduction section
    current_nx, current_ny = nx, ny
    
    if current_nx >= 2 * alpha_nx or current_ny >= 2 * alpha_ny:
        big_toeplitz_matrices = torch.from_dlpack(big_toeplitz_matrices)
        big_toeplitz_matrices = torch.fft.ifft2(big_toeplitz_matrices, dim=(0,1))
        big_toeplitz_matrices = cp.from_dlpack(big_toeplitz_matrices)
        cp.get_default_memory_pool().free_all_blocks()
        torch.cuda.empty_cache()
        
        while current_nx >= 2 * alpha_nx:
            nx2 = current_nx//2
            wx = cp.arange(1, nx2, dtype=dtype)[:, None, None, None] / nx2
            
            if current_nx % 2:
                big_toeplitz_matrices[1:nx2] = (1 - wx) * big_toeplitz_matrices[1:nx2] + wx * big_toeplitz_matrices[nx2+2:2*nx2+1]
            else:
                big_toeplitz_matrices[1:nx2] = (1 - wx) * big_toeplitz_matrices[1:nx2] + wx * big_toeplitz_matrices[nx2+1:2*nx2]
            
            big_toeplitz_matrices = big_toeplitz_matrices[:nx2]
            current_nx = nx2
            cp.get_default_memory_pool().free_all_blocks()
        
        while current_ny >= 2 * alpha_ny:
            ny2 = current_ny//2
            wy = cp.arange(1, ny2, dtype=dtype)[None, :, None, None] / ny2
            
            if current_ny % 2:
                big_toeplitz_matrices[:, 1:ny2] = (1 - wy) * big_toeplitz_matrices[:, 1:ny2] + wy * big_toeplitz_matrices[:, ny2+2:2*ny2+1]
            else:
                big_toeplitz_matrices[:, 1:ny2] = (1 - wy) * big_toeplitz_matrices[:, 1:ny2] + wy * big_toeplitz_matrices[:, ny2+1:2*ny2]
            
            big_toeplitz_matrices = big_toeplitz_matrices[:, :ny2]
            current_ny = ny2
            cp.get_default_memory_pool().free_all_blocks()
        
        # FFT2 back using PyTorch
        big_toeplitz_matrices = torch.from_dlpack(big_toeplitz_matrices)
        big_toeplitz_matrices = torch.fft.fft2(big_toeplitz_matrices, dim=(0,1))
        big_toeplitz_matrices = cp.from_dlpack(big_toeplitz_matrices)
        cp.get_default_memory_pool().free_all_blocks()
        torch.cuda.empty_cache()
    
    # Final cleanup
    del sym_x, sym_y, sym_z, strides, stride_x, stride_y, stride_z
    del threads_per_block, blocks_per_grid
    if 'wx' in locals(): del wx
    if 'wy' in locals(): del wy
    cp.get_default_memory_pool().free_all_blocks()
    torch.cuda.empty_cache()
    gc.collect()
    return big_toeplitz_matrices

@free_gpu_memory    
def prepare_interaction_matrices(grid_size, k, refract_mult, x_expansion, y_expansion, z_expansion, 
                               alpha_array, lattice_spacing, reduced=True, is_2d=False, 
                               cutoff=False, double_precision=False):
    nx, ny, nz = grid_size
    
    # Set precision type
    dtype = cp.complex128 if double_precision else cp.complex64
    
    # Convert alpha_array to cupy with appropriate precision
    if isinstance(alpha_array, np.ndarray):
        alpha_array = cp.asarray(alpha_array, dtype=dtype)
    else:
        alpha_array = alpha_array.astype(dtype)
    
    if cutoff:
        alpha_array = cp.where(cp.abs(alpha_array) < 1.0, 0.0, alpha_array)
    
    print('different alphas')
    print(cp.unique(alpha_array).get())
    inv_alpha = cp.where(alpha_array != 0, 1.0 / alpha_array, 0)
    print('different inverse alphas')        
    print(cp.unique(inv_alpha).get())
    mask = alpha_array != 0
    
    torch.cuda.empty_cache()
    cp.get_default_memory_pool().free_all_blocks()        
    
    # Start timing
    start_event = cp.cuda.Event()
    end_event = cp.cuda.Event()
    start_event.record()
    
    preconditioner = pre_preconditioner(nx, ny, nz, k, lattice_spacing, 
                                      x_expansion, y_expansion, z_expansion,
                                      double_precision=double_precision)
    if is_2d:
        preconditioner = circulant_approximationbig(
            preconditioner,
            inv_alpha,
            refract_mult,
            reduced=True,
            double_precision=double_precision
            
        )
    else:
        preconditioner = circulant_approximation(
            preconditioner,
            inv_alpha,
            refract_mult,
            reduced=True,
            double_precision=double_precision
        )   
    
    # End timing
    end_event.record()
    end_event.synchronize()
    time_taken = cp.cuda.get_elapsed_time(start_event, end_event)
    print(f"Time to build preconditioner: {time_taken:.2f} ms")
    
    torch.cuda.empty_cache()
    cp.get_default_memory_pool().free_all_blocks()
    
    interaction_matrix = generate_interaction_row(
        grid_size[0], grid_size[1], grid_size[2], 
        k, lattice_spacing,
        reduced=reduced,
        double_precision=double_precision
    )

    # Cleanup
    del alpha_array, start_event, end_event, time_taken
    torch.cuda.empty_cache()
    cp.get_default_memory_pool().free_all_blocks()
    gc.collect()
    return inv_alpha, mask, interaction_matrix, preconditioner




In [None]:
import numpy as np
cp.get_default_memory_pool().free_all_blocks()
# Initialize shape manager
sm = ShapeManager()

# Add a 100nm radius gold sphere
#sphereR=.5*1250/(2*np.pi)
sphereR=500
'''
for i in range(20):
    x = i * 1650  # Center-to-center distance
    sm.add_shape('sphere', 
                center=[x, 800, 800],
                radius=800,
                n=1.5)
'''
sm.add_shape('sphere', center=[sphereR, sphereR,sphereR], radius=sphereR, n=0.022+1.4256j)
#sm.add_shape('sphere', center=[sphereR, sphereR,sphereR], radius=sphereR, n=0.23294+8.5243j)
#sm.add_shape('sphere', center=[300, 300, 300], radius=300, n=.5+2j)
#sm.add_shape('sphere', center=[2000, 2000, 2000], radius=2000, n=2.3)
#sm.add_shape('sphere', center=[40, 40, 40], radius=40, n=2.7)  # Shell
#sm.add_shape('sphere', center=[40, 40, 40], radius=25, n=0.234 + 3.07j)  # Core
#sm.add_shape('sphere', center=[2000, 2000,2000], radius=2000, n=2.0)
#sm.add_shape('cylinder', center=[1600, 1600,1600], radius=1600, height = 1000, n=2)
#sm.add_shape('rectangle', center=[1, 1, 1], dimensions=[1, 1, 1], n=.0001+.8j)
#sm.add_shape('sphere', center=[2500, 2500, 2500], radius=400, n=2)
#sm.add_shape('rectangle', center=[2000, 2000, 3500], dimensions=[1000, 1000, 1000], n=1.5)
#sm.add_shape('rectangle', center=[1000, 1000, 1000], dimensions=[2000, 2000, 2000], n=1.5)
#sm.add_shape('cylinder', center=[500, 2000, 2000], radius=500, height=1000, n=1.1)
#sm.add_shape('cylinder', center=[3500, 2000, 2000], radius=500, height=1000, n=1.1)
#sm.add_shape('cylinder', center=[2000, 500, 2000], radius=500, height=1000, n=1.1)
#sm.add_shape('cylinder', center=[2000, 3500, 2000], radius=500, height=500, n=1.1)
#sm.add_shape('sphere', center=[2600, 2600, 2600], radius=500, n=2)
# The center should be at least 1591 units in x and y, and 79.55 units in z
prismR = 3000
height = 300
#sm.add_shape('prism', center=[prismR, prismR, height/2], radius=prismR, sides=6, height=height, n=1.8, axis='z')
'''
# Large sphere
large_radius = 1500
large_volume = (4/3) * np.pi * large_radius**3

# Calculate small sphere radius for 10% total volume
target_small_total_volume = 0.1 * large_volume
single_small_volume = target_small_total_volume / 2000
small_radius = (3 * single_small_volume / (4 * np.pi))**(1/3)

print(f"Small sphere radius: {small_radius:.2f}")

# Add large sphere
#sm.add_shape('sphere', center=[1500, 1500, 1500], radius=large_radius, n=1.001)

# Add random small spheres inside large sphere
random.seed(42)
for _ in range(2000):
    # Random coordinates within safe bounds
    x = random.uniform(small_radius, 3000-small_radius)
    y = random.uniform(small_radius, 3000-small_radius)
    z = random.uniform(small_radius, 3000-small_radius)
    
    # Check if point is within large sphere (with buffer)
    distance = ((x-1500)**2 + (y-1500)**2 + (z-1500)**2)**0.5
    if distance <= large_radius - small_radius:  # Ensure small sphere doesn't intersect edge
        sm.add_shape('sphere', center=[x, y, z], radius=small_radius, n=.3+4j)

'''

'''
sm.add_lattice(
    shape_type='sphere',
    shape_params={
        'radius': 50,
        'n': 2.5
    },
    lattice_type='square',
    spacing=110,  # 50 (radius) + 50 (radius) + 10 (gap) = 110
    size=(30, 30, 1),  # 20x20 grid, only one layer in z-direction
    x_offset=0,
    y_offset=0,
    z_offset=150  # 200 (slab height) + 50 (sphere radius)
)

# Calculate the dimensions of the slab
slab_width = 30 * 110  # 20 spheres * 110 spacing
slab_height = 100

# Add the slab underneath
sm.add_shape(
    'rectangle',
    center=[slab_width/2, slab_width/2, slab_height/2],  # Center of the slab
    dimensions=[slab_width, slab_width, slab_height],
    n=1.1
)
'''
# Setup simulation parameters

wavelength = 1000
#wavelength = 2 * np.pi*1e20
k = 2 * np.pi / wavelength
E_direction = np.array([1, 0, 0])
E_polarization = np.array([0, 1, 0])
k_vec = (2 * np.pi / wavelength) * E_direction

max_refractive_index = max(abs(shape.n) for shape in sm.shapes.values())
#lattice_spacing = wavelength / (10* np.abs(max_refractive_index))
lattice_spacing = sphereR / 22
#lattice_spacing = 1/(64)
print('lattice_spacing')
print(lattice_spacing)
print('dipoles per wavelength')
print(wavelength / lattice_spacing)
fft_factor =0
refract_mult = 1
# Process shapes and create arrays
sm.process_shapes(k, E_direction, E_polarization, lattice_spacing)
#for shape_id, shape in sm.shapes.items():
#    if isinstance(shape, Sphere):
#        print(f"Processed sphere (ID: {shape_id}) center: {shape.center}")
alpha_array, grid_size = create_alpha_array(sm, lattice_spacing)
print(f"Total size: {alpha_array.size}, Non-zero elements: {np.count_nonzero(alpha_array)}")
print('grid size')
print(grid_size)
'''
import plotly.graph_objects as go
import numpy as np

try:
    import cupy as cp
    if isinstance(alpha_array, cp.ndarray):
        alpha_array = cp.asnumpy(alpha_array)
except ImportError:
    pass

nonzero_indices = np.nonzero(alpha_array)
x, y, z = nonzero_indices[0], nonzero_indices[1], nonzero_indices[2]
values = np.abs(alpha_array[nonzero_indices])

# Find the maximum range across all dimensions
max_range = max(x.max(), y.max(), z.max())

arrow_scale = max_range * 0.2  # 20% of the plot size
arrow_origin = [0, 0, 0]  # Starting point for the arrows

# Create the main scatter plot
fig = go.Figure(data=[
    # Original scatter points
    go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(size=5, color=values, colorscale='Viridis'),
        name='Dipoles'
    ),
    # Arrow for E-field direction (red)
    go.Scatter3d(
        x=[arrow_origin[0], arrow_origin[0] + E_direction[0] * arrow_scale],
        y=[arrow_origin[1], arrow_origin[1] + E_direction[1] * arrow_scale],
        z=[arrow_origin[2], arrow_origin[2] + E_direction[2] * arrow_scale],
        mode='lines+text',
        line=dict(color='red', width=10),
        text=['', 'k'],
        name='k direction'
    ),
    # Arrow for E-field polarization (blue)
    go.Scatter3d(
        x=[arrow_origin[0], arrow_origin[0] + E_polarization[0] * arrow_scale],
        y=[arrow_origin[1], arrow_origin[1] + E_polarization[1] * arrow_scale],
        z=[arrow_origin[2], arrow_origin[2] + E_polarization[2] * arrow_scale],
        mode='lines+text',
        line=dict(color='blue', width=10),
        text=['', 'E'],
        name='E polarization'
    )
])

fig.update_layout(
    scene=dict(
        xaxis=dict(range=[0, max_range]),
        yaxis=dict(range=[0, max_range]),
        zaxis=dict(range=[0, max_range]),
        aspectmode='cube'  # Forces a perfect cube
    ),
    width=1000,  # Width in pixels
    height=1000,  # Height in pixels
    showlegend=True
)
fig.show()
'''

# Generate matrices
'''
inv_alpha, mask, interaction_matrix, preconditioner = prepare_interaction_matrices(
    grid_size, 
    k, 
    refract_mult,  # refract_mult 
    fft_factor,
    alpha_array,
    lattice_spacing,
    tensor_type="greens",
    use_cpu=False
)

if np.isinf(preconditioner).any() or np.isnan(preconditioner).any():
    print("Preconditioner contains inf or nan values")
else:
    print("Preconditioner does not contain inf or nan values")
    # Calculate incident E field
'''
x, y, z = [np.arange(0, d) * lattice_spacing for d in grid_size]
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

phase = np.exp(1j * (k_vec[0]*X + k_vec[1]*Y + k_vec[2]*Z))

E_inc = np.zeros((*grid_size, 3), dtype=np.complex128)
for i in range(3):
    E_inc[..., i] = E_polarization[i] * phase




In [None]:
import importlib
import Gpbicgstab
import Gpbicgstabhybrid
import Gpbicgstabdouble
importlib.reload(Gpbicgstab)
importlib.reload(Gpbicgstabhybrid)
importlib.reload(Gpbicgstabdouble)
from Gpbicgstab import solve_gpbicgstab
from Gpbicgstabhybrid import solve_gpbicgstabhybrid
from Gpbicgstabdouble import solve_gpbicgstab_double
from Gpbicgstabmulti import solve_gpbicgstabmulti

torch.cuda.empty_cache()
cp.get_default_memory_pool().free_all_blocks()

x_expansion = 44*3
y_expansion = 44*3
z_expansion =  44*3
#refract_mult = (0.742928349616824+0.8431778272716605j)
#refract_mult = .75+.75j
refract_mult = 2+.2j
#refract_mult = (2.5898431362220435+13.140622967068339j)
#refract_mult = (-9.077168464910171+0.6399585176031903j)
print(f" Refract mult:{refract_mult}")

# Generate matrices
is_2d = False
cutoff = False
precon =True

inv_alpha, mask, interaction_matrix, preconditioner = prepare_interaction_matrices(
    grid_size, 
    k, 
    refract_mult,
    x_expansion,
    y_expansion,
    z_expansion,
    alpha_array,
    lattice_spacing,
    reduced=True,
    is_2d=is_2d,
    cutoff=cutoff,
    double_precision=False # Using single precision
)
interaction_matrix = interaction_matrix.astype(cp.complex64)

torch.cuda.empty_cache()
cp.get_default_memory_pool().free_all_blocks()

ratio =100
max_iter = 5000

cpu_start = time.perf_counter()
start_event = cp.cuda.Event()
end_event = cp.cuda.Event()
start_event.record()

result = solve_gpbicgstab(grid_size, inv_alpha, interaction_matrix, 
                         preconditioner, E_inc, mask, ratio, 
                         max_iter, is_2d=is_2d, precon=precon)
print(f"\nInitial Run:")
print(f"Iterations: {result[0]}, Lowest norm: {result[2]:.2e}, Mean of p_current: {cp.mean((result[1])):.2e}")
'''
previous_result = result[1].copy()

# Loop for multiple runs of just the multi version
for i in range(20):
    print(f"\nRun {i+1}:")
    
    # Check if previous result is same as input
    if i>0 and cp.allclose(previous_result, result[1]):
        print("No change in results detected - scaling input by .999")
        result_input = result[1] * np.random.uniform(0.999, 1.001)
    else:
        result_input = result[1]
        
    temp_result = solve_gpbicgstabmulti(grid_size, inv_alpha, interaction_matrix, 
                                       preconditioner, E_inc, mask, ratio, 
                                       max_iter, result_input)
    
    # Store current result before updating
    previous_result = result[1].copy()
    result = temp_result
    
    print(f"Iterations: {result[0]}, Lowest norm: {result[2]:.2e}, Mean of p_current: {cp.mean((result[1])):.2e}")
    
    # Check break condition
    if 0 < result[0] < max_iter:
        print(f"\nBreaking loop du e to convergence at run {i+1}")
        break
'''
end_event.record()
end_event.synchronize()
gpu_time = cp.cuda.get_elapsed_time(start_event, end_event)
cpu_time = time.perf_counter() - cpu_start

print(f"\nGPU time: {gpu_time:.2f} ms")
print(f"CPU time: {cpu_time:.2f} s")
end_event.record()
end_event.synchronize()
cpu_end = time.perf_counter()

cpu_time = cpu_end - cpu_start
cuda_time = cp.cuda.get_elapsed_time(start_event, end_event) / 1000
torch.cuda.empty_cache()
cp.get_default_memory_pool().free_all_blocks()
print(f"CPU time: {cpu_time:.4f} seconds")
print(f"CUDA event time: {cuda_time:.4f} seconds")


In [None]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import interact, RadioButtons, IntSlider, VBox, HBox, Checkbox
import ipywidgets as widgets
import cupy as cp
from importlib import reload
import ExcitedField
ExcitedField = reload(ExcitedField)
from ExcitedField import E_field_solver

# Calculate E-field
E_field = E_field_solver(result[1], interaction_matrix, grid_size)  # Returns cupy complex64
E_magnitude = cp.sqrt(cp.sum(cp.abs(E_field)**2, axis=-1))

# Convert E_inc to GPU if it's not already
if isinstance(E_inc, np.ndarray):
    E_inc_gpu = cp.asarray(E_inc)
else:
    E_inc_gpu = E_inc

# Calculate total field option
E_total = E_field + E_inc_gpu
E_total_magnitude = cp.sqrt(cp.sum(cp.abs(E_total)**2, axis=-1))

# Transfer to CPU and convert to numpy float32
E_magnitude_cpu = cp.asnumpy(E_magnitude).astype(np.float32)
E_total_magnitude_cpu = cp.asnumpy(E_total_magnitude).astype(np.float32)

nx, ny, nz = E_magnitude_cpu.shape[:3]

# Create widgets
plane_selector = RadioButtons(
    options=['XY', 'YZ', 'XZ'],
    description='Plane:',
    disabled=False
)

slice_slider = IntSlider(
    value=0,
    min=0,
    max=nz - 1,  # Initial max for XY plane
    step=1,
    description='Slice:',
    continuous_update=True
)

show_total = Checkbox(
    value=False,
    description='Show Total Field',
    disabled=False
)

colorscale_selector = RadioButtons(
    options=['inferno', 'magma', 'viridis', 'plasma'],
    value='inferno',
    description='Colormap:',
    disabled=False
)

fig = go.FigureWidget()

# Add bounding box
fig.add_trace(go.Scatter3d(
    x=[0, nx, nx, 0, 0, 0, nx, nx, 0, 0, nx, nx, 0, 0, nx, nx],
    y=[0, 0, ny, ny, 0, 0, 0, ny, ny, 0, 0, ny, ny, 0, 0, ny],
    z=[0, 0, 0, 0, 0, nz, nz, nz, nz, 0, 0, 0, 0, nz, nz, nz],
    mode='lines',
    line=dict(color='gray', width=1),
    showlegend=False,
    hoverinfo='skip'
))

# Add initial surface plot
fig.add_trace(go.Surface(
    x=np.array([[]]),
    y=np.array([[]]),
    z=np.array([[]]),
    colorscale='inferno',
    showscale=True,
    colorbar=dict(title='|E|')
))

def update_slider_max(plane):
    if plane == 'XY':
        slice_slider.max = nz - 1
    elif plane == 'YZ':
        slice_slider.max = nx - 1
    else:  # XZ
        slice_slider.max = ny - 1
    
    # Adjust current value if it exceeds new maximum
    if slice_slider.value > slice_slider.max:
        slice_slider.value = slice_slider.max

def on_plane_change(change):
    update_slider_max(change.new)
    update_plot()

def update_plot(*args):
    plane = plane_selector.value
    slice_idx = slice_slider.value
    data = E_total_magnitude_cpu if show_total.value else E_magnitude_cpu
    
    if plane == 'XY':
        x, y = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
        z = np.full_like(x, slice_idx)
        surfacecolor = data[:, :, slice_idx]
    elif plane == 'YZ':
        y, z = np.meshgrid(np.arange(ny), np.arange(nz), indexing='ij')
        x = np.full_like(y, slice_idx)
        surfacecolor = data[slice_idx, :, :]
    else:  # XZ
        x, z = np.meshgrid(np.arange(nx), np.arange(nz), indexing='ij')
        y = np.full_like(x, slice_idx)
        surfacecolor = data[:, slice_idx, :]
    
    fig.data[1].update(
        x=x,
        y=y,
        z=z,
        surfacecolor=surfacecolor,
        colorscale=colorscale_selector.value,
        colorbar=dict(title='|E|' if not show_total.value else '|E| Total')
    )

# Set initial layout
fig.update_layout(
    scene=dict(
        aspectmode='data',
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.5, y=1.5, z=1.5)
        )
    ),
    width=800,
    height=800
)

# Connect widgets to update function
plane_selector.observe(on_plane_change, 'value')
slice_slider.observe(update_plot, 'value')
show_total.observe(update_plot, 'value')
colorscale_selector.observe(update_plot, 'value')

# Initial plot
update_plot()

# Display widgets and figure
controls = VBox([plane_selector, slice_slider, show_total, colorscale_selector])
display(controls, fig)


In [None]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import interact, RadioButtons, IntSlider, VBox, HBox, Checkbox
import ipywidgets as widgets
import cupy as cp
from importlib import reload
import ExcitedField
ExcitedField = reload(ExcitedField)
from ExcitedField import E_field_solver

# Calculate E-field
E_field = E_field_solver(result[1], interaction_matrix, grid_size)  # Returns cupy complex64
E_magnitude = cp.sqrt(cp.sum(cp.abs(E_field)**2, axis=-1))

# Convert E_inc to GPU if it's not already
if isinstance(E_inc, np.ndarray):
    E_inc_gpu = cp.asarray(E_inc)
else:
    E_inc_gpu = E_inc

# Calculate total field option
E_total = E_field + E_inc_gpu
E_total_magnitude = cp.sqrt(cp.sum(cp.abs(E_total)**2, axis=-1))

# Transfer to CPU and convert to numpy float32
E_magnitude_cpu = cp.asnumpy(E_magnitude).astype(np.float32)
E_total_magnitude_cpu = cp.asnumpy(E_total_magnitude).astype(np.float32)

nx, ny, nz = E_magnitude_cpu.shape[:3]

# Create widgets
plane_selector = RadioButtons(
    options=['XY', 'YZ', 'XZ'],
    description='Plane:',
    disabled=False
)

slice_slider = IntSlider(
    value=0,
    min=0,
    max=nz - 1,  # Initial max for XY plane
    step=1,
    description='Slice:',
    continuous_update=True
)

show_total = Checkbox(
    value=False,
    description='Show Total Field',
    disabled=False
)

colorscale_selector = RadioButtons(
    options=['inferno', 'magma', 'viridis', 'plasma'],
    value='inferno',
    description='Colormap:',
    disabled=False
)

fig = go.FigureWidget()

# Add bounding box
fig.add_trace(go.Scatter3d(
    x=[0, nx, nx, 0, 0, 0, nx, nx, 0, 0, nx, nx, 0, 0, nx, nx],
    y=[0, 0, ny, ny, 0, 0, 0, ny, ny, 0, 0, ny, ny, 0, 0, ny],
    z=[0, 0, 0, 0, 0, nz, nz, nz, nz, 0, 0, 0, 0, nz, nz, nz],
    mode='lines',
    line=dict(color='gray', width=1),
    showlegend=False,
    hoverinfo='skip'
))

# Add initial surface plot
fig.add_trace(go.Surface(
    x=np.array([[]]),
    y=np.array([[]]),
    z=np.array([[]]),
    colorscale='inferno',
    showscale=True,
    colorbar=dict(title='|E|')
))

def update_slider_max(plane):
    if plane == 'XY':
        slice_slider.max = nz - 1
    elif plane == 'YZ':
        slice_slider.max = nx - 1
    else:  # XZ
        slice_slider.max = ny - 1
    
    # Adjust current value if it exceeds new maximum
    if slice_slider.value > slice_slider.max:
        slice_slider.value = slice_slider.max

def on_plane_change(change):
    update_slider_max(change.new)
    update_plot()

def update_plot(*args):
    plane = plane_selector.value
    slice_idx = slice_slider.value
    data = E_total_magnitude_cpu if show_total.value else E_magnitude_cpu
    
    if plane == 'XY':
        x, y = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
        z = np.full_like(x, slice_idx)
        surfacecolor = data[:, :, slice_idx]
    elif plane == 'YZ':
        y, z = np.meshgrid(np.arange(ny), np.arange(nz), indexing='ij')
        x = np.full_like(y, slice_idx)
        surfacecolor = data[slice_idx, :, :]
    else:  # XZ
        x, z = np.meshgrid(np.arange(nx), np.arange(nz), indexing='ij')
        y = np.full_like(x, slice_idx)
        surfacecolor = data[:, slice_idx, :]
    
    fig.data[1].update(
        x=x,
        y=y,
        z=z,
        surfacecolor=surfacecolor,
        colorscale=colorscale_selector.value,
        colorbar=dict(title='|E|' if not show_total.value else '|E| Total')
    )

# Set initial layout
fig.update_layout(
    scene=dict(
        aspectmode='data',
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.5, y=1.5, z=1.5)
        )
    ),
    width=800,
    height=800
)

# Connect widgets to update function
plane_selector.observe(on_plane_change, 'value')
slice_slider.observe(update_plot, 'value')
show_total.observe(update_plot, 'value')
colorscale_selector.observe(update_plot, 'value')

# Initial plot
update_plot()

# Display widgets and figure
controls = VBox([plane_selector, slice_slider, show_total, colorscale_selector])
display(controls, fig)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
import os

# XY Plane (varying z)
if not os.path.exists('temp_frames'):
    os.makedirs('temp_frames')

frames = []
for z in range(nz):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[:, :, z].T,
              cmap='inferno',
              extent=[0, nx, 0, ny],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'XY-plane, Z-slice: {z}/{nz-1}')
    plt.xlabel('X')
    plt.ylabel('Y')
    frame_path = f'temp_frames/frame_{z:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_xycube124p.gif', frames, duration=0.1)
for z in range(nz):
    os.remove(f'temp_frames/frame_{z:03d}.png')

# YZ Plane (varying x)
frames = []
for x in range(nx):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[x, :, :].T,
              cmap='inferno',
              extent=[0, ny, 0, nz],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'YZ-plane, X-slice: {x}/{nx-1}')
    plt.xlabel('Y')
    plt.ylabel('Z')
    frame_path = f'temp_frames/frame_{x:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_yzcube124p.gif', frames, duration=0.1)
for x in range(nx):
    os.remove(f'temp_frames/frame_{x:03d}.png')

# XZ Plane (varying y)
frames = []
for y in range(ny):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[:, y, :].T,
              cmap='inferno',
              extent=[0, nx, 0, nz],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'XZ-plane, Y-slice: {y}/{ny-1}')
    plt.xlabel('X')
    plt.ylabel('Z')
    frame_path = f'temp_frames/frame_{y:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_xzcube124p.gif', frames, duration=0.1)
for y in range(ny):
    os.remove(f'temp_frames/frame_{y:03d}.png')

os.rmdir('temp_frames')

print("GIFs have been saved as 'e_field_animation_xycube.gif', 'e_field_animation_yzcube.gif', and 'e_field_animation_xzcube.gif'")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
import os

# Ensure the temporary frame directory exists
temp_dir = 'temp_frames'
if not os.path.exists(temp_dir):
    os.makedirs(temp_dir)

dpi_value = 300       # Increase DPI for higher-resolution images
figsize_value = (8,8) # You can adjust if you want larger figures
interp_style = 'bicubic'  # Alternatives: 'bicubic', etc.

# --- Generate GIF for XY Plane (varying z) ---
frames = []
for z in range(nz):
    plt.figure(figsize=figsize_value)
    plt.imshow(E_total_magnitude_cpu[:, :, z].T,
               cmap='inferno',
               extent=[0, nx, 0, ny],
               origin='lower',
               interpolation=interp_style)
    plt.colorbar(label='|E| Total')
    plt.title(f'XY-plane, Z-slice: {z}/{nz-1}')
    plt.xlabel('X')
    plt.ylabel('Y')
    
    # Save the figure with increased DPI
    frame_path = os.path.join(temp_dir, f'frame_{z:03d}.png')
    plt.savefig(frame_path, dpi=dpi_value)
    plt.close()
    
    # Read the saved frame
    frames.append(imageio.imread(frame_path))

# Save the GIF for XY plane
imageio.mimsave('e_field_animation_xycube124.gif', frames, duration=0.1)

# Remove temporary files for XY plane
for z in range(nz):
    os.remove(os.path.join(temp_dir, f'frame_{z:03d}.png'))

# --- Generate GIF for YZ Plane (varying x) ---
frames = []
for x in range(nx):
    plt.figure(figsize=figsize_value)
    plt.imshow(E_total_magnitude_cpu[x, :, :].T,
               cmap='inferno',
               extent=[0, ny, 0, nz],
               origin='lower',
               interpolation=interp_style)
    plt.colorbar(label='|E| Total')
    plt.title(f'YZ-plane, X-slice: {x}/{nx-1}')
    plt.xlabel('Y')
    plt.ylabel('Z')
    
    frame_path = os.path.join(temp_dir, f'frame_{x:03d}.png')
    plt.savefig(frame_path, dpi=dpi_value)
    plt.close()
    
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_yzcube114.gif', frames, duration=0.1)

for x in range(nx):
    os.remove(os.path.join(temp_dir, f'frame_{x:03d}.png'))

# --- Generate GIF for XZ Plane (varying y) ---
frames = []
for y in range(ny):
    plt.figure(figsize=figsize_value)
    plt.imshow(E_total_magnitude_cpu[:, y, :].T,
               cmap='inferno',
               extent=[0, nx, 0, nz],
               origin='lower',
               interpolation=interp_style)
    plt.colorbar(label='|E| Total')
    plt.title(f'XZ-plane, Y-slice: {y}/{ny-1}')
    plt.xlabel('X')
    plt.ylabel('Z')
    
    frame_path = os.path.join(temp_dir, f'frame_{y:03d}.png')
    plt.savefig(frame_path, dpi=dpi_value)
    plt.close()
    
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_xzcube114.gif', frames, duration=0.1)

for y in range(ny):
    os.remove(os.path.join(temp_dir, f'frame_{y:03d}.png'))

# Clean up by removing the temporary directory
os.rmdir(temp_dir)

print("GIFs have been saved as:")
print("   - e_field_animation_xycube180D.gif")
print("   - e_field_animation_yzcube180D.gif")
print("   - e_field_animation_xzcube180D.gif")

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create figure with subplots
fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=('XY Plane', 'YZ Plane', 'XZ Plane'),
    specs=[[{'type': 'heatmap'}, {'type': 'heatmap'}, {'type': 'heatmap'}]]
)

# Create frames for each plane
frames = []
max_slices = max(nz, nx, ny)

for i in range(max_slices):
    frame_data = []
    
    # XY plane (z varies)
    z_slice = min(i, nz-1)
    frame_data.append(
        go.Heatmap(
            z=E_total_magnitude_cpu[:, :, z_slice].T,
            colorscale='inferno',
            showscale=True,
            colorbar=dict(x=0.3),
        )
    )
    
    # YZ plane (x varies)
    x_slice = min(i, nx-1)
    frame_data.append(
        go.Heatmap(
            z=E_total_magnitude_cpu[x_slice, :, :].T,
            colorscale='inferno',
            showscale=True,
            colorbar=dict(x=0.65),
        )
    )
    
    # XZ plane (y varies)
    y_slice = min(i, ny-1)
    frame_data.append(
        go.Heatmap(
            z=E_total_magnitude_cpu[:, y_slice, :].T,
            colorscale='inferno',
            showscale=True,
            colorbar=dict(x=1.0),
        )
    )
    
    frames.append(dict(
        data=frame_data,
        name=str(i),
        traces=[0, 1, 2]
    ))

# Add initial data
fig.add_trace(
    go.Heatmap(
        z=E_total_magnitude_cpu[:, :, 0].T,
        colorscale='inferno',
        showscale=True,
        colorbar=dict(x=0.3),
    ),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(
        z=E_total_magnitude_cpu[0, :, :].T,
        colorscale='inferno',
        showscale=True,
        colorbar=dict(x=0.65),
    ),
    row=1, col=2
)

fig.add_trace(
    go.Heatmap(
        z=E_total_magnitude_cpu[:, 0, :].T,
        colorscale='inferno',
        showscale=True,
        colorbar=dict(x=1.0),
    ),
    row=1, col=3
)

# Update layout
fig.update_layout(
    title_text="E-field Magnitude in Different Planes",
    showlegend=False,
    width=1500,
    height=500,
    updatemenus=[{
        "buttons": [
            {
                "args": [None, {
                    "frame": {"duration": 50},
                    "fromcurrent": True,
                    "mode": "immediate",
                }],
                "label": "Play",
                "method": "animate"
            },
            {
                "args": [[None], {
                    "frame": {"duration": 0},
                    "mode": "immediate",
                    "transition": {"duration": 0}
                }],
                "label": "Pause",
                "method": "animate"
            }
        ],
        "direction": "left",
        "pad": {"r": 10, "t": 87},
        "showactive": True,
        "type": "buttons",
        "x": 0.1,
        "xanchor": "right",
        "y": 0,
        "yanchor": "top"
    }],
    sliders=[{
        "active": 0,
        "yanchor": "top",
        "xanchor": "left",
        "currentvalue": {
            "font": {"size": 20},
            "prefix": "Slice: ",
            "visible": True,
            "xanchor": "right"
        },
        "pad": {"b": 10, "t": 50},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        "steps": [
            {
                "args": [[str(i)], {
                    "frame": {"duration": 0},
                    "mode": "immediate",
                    "transition": {"duration": 0}
                }],
                "label": str(i),
                "method": "animate"
            } for i in range(max_slices)
        ]
    }]
)

# Update axes labels
fig.update_xaxes(title_text="X", row=1, col=1)
fig.update_yaxes(title_text="Y", row=1, col=1)
fig.update_xaxes(title_text="Y", row=1, col=2)
fig.update_yaxes(title_text="Z", row=1, col=2)
fig.update_xaxes(title_text="X", row=1, col=3)
fig.update_yaxes(title_text="Z", row=1, col=3)

# Add frames to the figure
fig.frames = frames

# Save as HTML
fig.write_html("E_field_animations124.html")

# Display in notebook
#fig.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Create figure with 3 subplots (one for each expansion)
fig = plt.figure(figsize=(20, 6))

# Extract data for each expansion size
expansions = [2, 4, 8]
for idx, exp in enumerate(expansions):
    exp_results = [r for r in results if r['expansion'] == exp]
    
    # Create arrays for plotting
    real_vals = [r['real'] for r in exp_results]
    imag_vals = [r['imag'] for r in exp_results]
    norm_vals = [r['norm'] for r in exp_results]
    
    # Create subplot
    ax = fig.add_subplot(1, 3, idx+1, projection='3d')
    
    # Create scatter plot
    scatter = ax.scatter(real_vals, imag_vals, norm_vals, 
                        c=norm_vals, cmap='viridis',
                        s=100, alpha=0.6)
    
    # Add colorbar
    plt.colorbar(scatter, label='Norm')
    
    # Labels and title
    ax.set_xlabel('Real Part')
    ax.set_ylabel('Imaginary Part')
    ax.set_zlabel('Norm')
    ax.set_title(f'Expansion {exp}x64')
    
    # Add text for best result
    best_result = min(exp_results, key=lambda x: x['norm'])
    ax.text2D(0.05, 0.95, 
              f'Best Result:\nReal: {best_result["real"]}\nImag: {best_result["imag"]}\nNorm: {best_result["norm"]:.2e}', 
              transform=ax.transAxes)

plt.tight_layout()
plt.show()

# Create 2D heatmaps
fig, axs = plt.subplots(1, 3, figsize=(20, 6))

for idx, exp in enumerate(expansions):
    exp_results = [r for r in results if r['expansion'] == exp]
    
    # Create 2D grid
    real_vals = np.unique([r['real'] for r in exp_results])
    imag_vals = np.unique([r['imag'] for r in exp_results])
    norm_grid = np.zeros((len(real_vals), len(imag_vals)))
    
    # Fill grid with norm values
    for r in exp_results:
        i = np.where(real_vals == r['real'])[0][0]
        j = np.where(imag_vals == r['imag'])[0][0]
        norm_grid[i, j] = r['norm']
    
    # Create heatmap
    im = axs[idx].imshow(norm_grid, origin='lower', 
                        extent=[min(imag_vals)-0.5, max(imag_vals)+0.5, 
                               min(real_vals)-0.5, max(real_vals)+0.5],
                        aspect='auto', cmap='viridis')
    
    # Add colorbar
    plt.colorbar(im, ax=axs[idx])
    
    # Labels and title
    axs[idx].set_xlabel('Imaginary Part')
    axs[idx].set_ylabel('Real Part')
    axs[idx].set_title(f'Expansion {exp}x64')
    
    # Mark best result
    best_result = min(exp_results, key=lambda x: x['norm'])
    axs[idx].plot(best_result['imag'], best_result['real'], 'r*', 
                  markersize=15, label=f'Best (norm={best_result["norm"]:.2e})')
    axs[idx].legend()

plt.tight_layout()
plt.show()

# Create convergence plot
fig, ax = plt.subplots(figsize=(12, 6))

for exp in expansions:
    exp_results = [r for r in results if r['expansion'] == exp]
    sorted_norms = sorted([r['norm'] for r in exp_results])
    ax.plot(sorted_norms, label=f'Expansion {exp}x64')

ax.set_yscale('log')
ax.set_xlabel('Result Index')
ax.set_ylabel('Norm (log scale)')
ax.set_title('Convergence Comparison')
ax.grid(True)
ax.legend()

plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Organize data for plotting
test_types = ['Unpreconditioned', 'Base 2D', 'Modified Base 2D', 'Expanded 2D']
colors = ['blue', 'red', 'green', 'purple']
markers = ['o', 's', '^', 'D']

# First plot - Linear scale
plt.figure(figsize=(10, 6))
for test_type, color, marker in zip(test_types, colors, markers):
    # Extract data for this test type
    data = [(r, i) for t, r, i in results if t == test_type]
    radii, iterations = zip(*data)
    
    # Convert negative iterations (NANs) to None for plotting
    iterations_masked = []
    radii_masked = []
    
    for r, i in zip(radii, iterations):
        if i >= 0:
            iterations_masked.append(i)
            radii_masked.append(r)
    
    plt.plot(radii_masked, iterations_masked, color=color, marker=marker, 
             label=test_type, linestyle='-', markersize=8)
    
    # Plot NANs if any
    nan_radii = [r for r, i in zip(radii, iterations) if i < 0]
    if nan_radii:
        plt.scatter(nan_radii, [plt.ylim()[1]]*len(nan_radii), 
                   color=color, marker='x', s=100, alpha=0.5)

plt.xlabel('Radius (nm)')
plt.ylabel('Iterations')
plt.title('Convergence Iterations vs Radius')
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()

if any(i < 0 for _, _, i in results):
    plt.figtext(0.5, 0.01, 'Note: × markers indicate non-convergence (NAN)', 
                ha='center', va='center', fontsize=10)

plt.tight_layout()
plt.show()

# Second plot - Log scale
plt.figure(figsize=(10, 6))
for test_type, color, marker in zip(test_types, colors, markers):
    # Extract data for this test type
    data = [(r, i) for t, r, i in results if t == test_type]
    radii, iterations = zip(*data)
    
    # Convert negative iterations (NANs) to None for plotting
    iterations_masked = []
    radii_masked = []
    
    for r, i in zip(radii, iterations):
        if i >= 0:
            iterations_masked.append(i)
            radii_masked.append(r)
    
    plt.plot(radii_masked, iterations_masked, color=color, marker=marker, 
             label=test_type, linestyle='-', markersize=8)
    
    # Plot NANs if any
    nan_radii = [r for r, i in zip(radii, iterations) if i < 0]
    if nan_radii:
        plt.scatter(nan_radii, [plt.ylim()[1]]*len(nan_radii), 
                   color=color, marker='x', s=100, alpha=0.5)

plt.xlabel('Radius (nm)')
plt.ylabel('Iterations (log scale)')
plt.title('Convergence Iterations vs Radius (Log Scale)')
plt.yscale('log')
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()

if any(i < 0 for _, _, i in results):
    plt.figtext(0.5, 0.01, 'Note: × markers indicate non-convergence (NAN)', 
                ha='center', va='center', fontsize=10)

plt.tight_layout()
plt.show()

# Print summary of non-convergence cases
print("\nNon-convergence cases (NANs):")
print("-" * 50)
for test_type, radius, iterations in results:
    if iterations < 0:
        print(f"{test_type}: Radius = {radius}nm")

In [None]:
import time
import numpy as np
import cupy as cp
import plotly.graph_objects as go
import ipywidgets as widgets
from ipywidgets import VBox, HBox
from ExcitedField import E_field_solver  # (Assuming this is your custom import)

# ---------------------------------------------------------
# 1. Find surface points (sphere or erode)
# ---------------------------------------------------------
def find_surface_points(inv_alpha, depth=0, use_sphere=True):
    if isinstance(inv_alpha, np.ndarray):
        inv_alpha = cp.asarray(inv_alpha)
        
    if use_sphere:
        # Spherical method
        points = cp.where(inv_alpha != 0)
        center = cp.array([cp.mean(points[0]), cp.mean(points[1]), cp.mean(points[2])])
        
        x, y, z = cp.where(inv_alpha != 0)
        distances = cp.sqrt((x - center[0])**2 + (y - center[1])**2 + (z - center[2])**2)
        max_radius = cp.max(distances)
        
        shell_mask = cp.zeros_like(inv_alpha, dtype=bool)
        current_radius = max_radius - depth
        xx, yy, zz = cp.ogrid[:inv_alpha.shape[0], :inv_alpha.shape[1], :inv_alpha.shape[2]]
        distances = cp.sqrt((xx - center[0])**2 + (yy - center[1])**2 + (zz - center[2])**2)
        
        tolerance = 1.0
        shell_mask = (cp.abs(distances - current_radius) < tolerance) & (inv_alpha != 0)
        return cp.where(shell_mask)
    
    else:
        # Back to original 6-neighbor erosion
        eroded = cp.copy(inv_alpha)
        nx, ny, nz = eroded.shape
        
        # Erode by depth
        for _ in range(depth):
            surface_mask = cp.zeros_like(eroded, dtype=bool)
            
            for dx, dy, dz in [(0,0,1), (0,0,-1), (0,1,0), (0,-1,0), (1,0,0), (-1,0,0)]:
                padded = cp.pad(eroded, ((1,1), (1,1), (1,1)), mode='constant', constant_values=0)
                shifted = padded[1+dx:nx+1+dx, 1+dy:ny+1+dy, 1+dz:nz+1+dz]
                surface_mask |= (eroded != 0) & (shifted == 0)
            
            # Remove surface points
            eroded[surface_mask] = 0
        
        # Find surface of eroded shape
        surface_mask = cp.zeros_like(eroded, dtype=bool)
        
        for dx, dy, dz in [(0,0,1), (0,0,-1), (0,1,0), (0,-1,0), (1,0,0), (-1,0,0)]:
            padded = cp.pad(eroded, ((1,1), (1,1), (1,1)), mode='constant', constant_values=0)
            shifted = padded[1+dx:nx+1+dx, 1+dy:ny+1+dy, 1+dz:nz+1+dz]
            surface_mask |= (eroded != 0) & (shifted == 0)
        
        return cp.where(surface_mask)


# ---------------------------------------------------------
# 2. Precompute all layers (for a given method and field)
# ---------------------------------------------------------
def precalculate_layers(inv_alpha, use_sphere, max_depth, field_values):
    """
    Precompute all layers for a given method (sphere or erode)
    and current field (excited or total).
    """
    print(f"Precalculating layers for {'sphere' if use_sphere else 'erode'} method...")
    start_time = time.time()

    layers = []
    for depth in range(max_depth + 1):
        # Get surface points for the chosen depth
        surface_points = find_surface_points(inv_alpha, depth, use_sphere=use_sphere)
        # Convert them back to CPU arrays
        surface_x, surface_y, surface_z = [cp.asnumpy(coord) for coord in surface_points]

        # Grab the field values for those points (still on CPU side in field_values)
        # If 'field_values' is a NumPy array, we can index it directly:
        surface_values = field_values[surface_x, surface_y, surface_z]

        layers.append({
            'x': surface_x,
            'y': surface_y,
            'z': surface_z,
            'values': surface_values
        })

    print(f"Precalculation took {time.time() - start_time:.2f} seconds")
    return layers

# ---------------------------------------------------------
# 3. Example E-field calculations (replace with your own code)
# ---------------------------------------------------------
# Suppose 'result', 'interaction_matrix', 'grid_size', etc. exist in your context.
# We mimic your approach, retrieving E_field and combining with E_inc.

# result[1] might be something like inv_alpha, so let's assume:


E_field = E_field_solver(result[1], interaction_matrix, grid_size)
E_magnitude = cp.sqrt(cp.sum(cp.abs(E_field)**2, axis=-1))

# E_inc might be your incident field
if isinstance(E_inc, np.ndarray):
    E_inc_gpu = cp.asarray(E_inc)
else:
    E_inc_gpu = E_inc

# Calculate total field
E_total = E_field + E_inc_gpu
E_total_magnitude = cp.sqrt(cp.sum(cp.abs(E_total)**2, axis=-1))

# Transfer magnitudes to CPU for plotting
E_magnitude_cpu = cp.asnumpy(E_magnitude).astype(np.float32)
E_total_magnitude_cpu = cp.asnumpy(E_total_magnitude).astype(np.float32)

# ---------------------------------------------------------
# 4. Prepare interactivity
# ---------------------------------------------------------
max_depth = min(inv_alpha.shape) // 2  # Some guess at max_depth

depth_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=max_depth,
    step=1,
    description='Depth:',
    continuous_update=False
)

method_toggle = widgets.ToggleButtons(
    options=['Sphere', 'Erode'],
    description='Method:',
    disabled=False
)

field_toggle = widgets.ToggleButtons(
    options=['Excited', 'Total'],
    description='Field:',
    disabled=False
)

# Create a figure widget
fig = go.FigureWidget()

# Cache to store precomputed surfaces
cache = {
    'layers': None,
    'method': None,
    'field': None
}

# ---------------------------------------------------------
# 5. The update function: read from cache or recompute, then update the Plotly trace
# ---------------------------------------------------------
def update_plot(change):
    global cache

    current_method = method_toggle.value
    current_field_option = field_toggle.value

    # Decide which field data to use
    if current_field_option == 'Total':
        current_field = E_total_magnitude_cpu
    else:
        current_field = E_magnitude_cpu

    # Recompute only if method or field has changed
    if cache['method'] != current_method or cache['field'] != current_field_option:
        print("Method or field changed, recalculating layers...")

        use_sphere = (current_method == 'Sphere')
        cache['layers'] = precalculate_layers(
            inv_alpha, use_sphere, max_depth, current_field
        )
        cache['method'] = current_method
        cache['field'] = current_field_option

        print("Recalculation complete")

    depth = depth_slider.value
    layer_data = cache['layers'][depth]

    # Directly update the trace with the new layer points
    fig.data[0].update(
        x=layer_data['x'],
        y=layer_data['y'],
        z=layer_data['z'],
        marker=dict(
            color=layer_data['values'],
            # Force 1.0 opacity so front points don't vanish behind the back
            opacity=1.0,
            colorscale='inferno',
            showscale=True,
            colorbar=dict(
                title='|E| (Total)' if current_field_option == 'Total' else '|E|'
            ),
            size=3,
            symbol='circle'
        )
    )

# ---------------------------------------------------------
# 6. Initial figure setup
# ---------------------------------------------------------
print("Performing initial calculation...")

# Precompute layers by default with 'Sphere' and 'Excited'
cache['layers'] = precalculate_layers(inv_alpha, True, max_depth, E_magnitude_cpu)
cache['method'] = 'Sphere'
cache['field'] = 'Excited'

print("Initial calculation complete")

# Get the default data for the first layer
layer0 = cache['layers'][0]

# Add a scatter3d trace
fig.add_trace(go.Scatter3d(
    x=layer0['x'],
    y=layer0['y'],
    z=layer0['z'],
    mode='markers',
    marker=dict(
        color=layer0['values'],
        opacity=1.0,  # No partial transparency
        colorscale='inferno',
        showscale=True,
        colorbar=dict(title='|E|'),
        size=3,
        symbol='circle'
    ),
    name='Surface'
))

# Configure some layout defaults
fig.update_layout(
    scene=dict(
        aspectmode='data',
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.5, y=1.5, z=1.5)
        ),
        dragmode='orbit',
        # perspective or orthographic can be tested:
        camera_projection_type='perspective'
    ),
    width=800,
    height=800,
    title='Surface Points with E-field Magnitude'
)

# ---------------------------------------------------------
# 7. Hook up widget callbacks
# ---------------------------------------------------------
depth_slider.observe(update_plot, names='value')
method_toggle.observe(update_plot, names='value')
field_toggle.observe(update_plot, names='value')

# Display everything in a vertical box
controls = VBox([
    HBox([method_toggle, field_toggle]),
    depth_slider
])

# Show the final UI (controls + figure)
display(VBox([controls, fig]))


In [None]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import interact, RadioButtons, IntSlider, VBox, HBox, Checkbox
import ipywidgets as widgets
import cupy as cp
from importlib import reload
import ExcitedField
ExcitedField = reload(ExcitedField)
from ExcitedField import E_field_solver

# Calculate E-field
E_field = E_field_solver(result[1], interaction_matrix, grid_size)  # Returns cupy complex64
E_magnitude = cp.sqrt(cp.sum(cp.abs(E_field)**2, axis=-1))

# Convert E_inc to GPU if it's not already
if isinstance(E_inc, np.ndarray):
    E_inc_gpu = cp.asarray(E_inc)
else:
    E_inc_gpu = E_inc

# Calculate total field option
E_total = E_field + E_inc_gpu
E_total_magnitude = cp.sqrt(cp.sum(cp.abs(E_total)**2, axis=-1))

# Transfer to CPU and convert to numpy float32
E_magnitude_cpu = cp.asnumpy(E_magnitude).astype(np.float32)
E_total_magnitude_cpu = cp.asnumpy(E_total_magnitude).astype(np.float32)

nx, ny, nz = E_magnitude_cpu.shape[:3]

# Create widgets
plane_selector = RadioButtons(
    options=['XY', 'YZ', 'XZ'],
    description='Plane:',
    disabled=False
)

slice_slider = IntSlider(
    value=0,
    min=0,
    max=nz - 1,  # Initial max for XY plane
    step=1,
    description='Slice:',
    continuous_update=True
)

show_total = Checkbox(
    value=False,
    description='Show Total Field',
    disabled=False
)

colorscale_selector = RadioButtons(
    options=['inferno', 'magma', 'viridis', 'plasma'],
    value='inferno',
    description='Colormap:',
    disabled=False
)

fig = go.FigureWidget()

# Add bounding box
fig.add_trace(go.Scatter3d(
    x=[0, nx, nx, 0, 0, 0, nx, nx, 0, 0, nx, nx, 0, 0, nx, nx],
    y=[0, 0, ny, ny, 0, 0, 0, ny, ny, 0, 0, ny, ny, 0, 0, ny],
    z=[0, 0, 0, 0, 0, nz, nz, nz, nz, 0, 0, 0, 0, nz, nz, nz],
    mode='lines',
    line=dict(color='gray', width=1),
    showlegend=False,
    hoverinfo='skip'
))

# Add initial surface plot
fig.add_trace(go.Surface(
    x=np.array([[]]),
    y=np.array([[]]),
    z=np.array([[]]),
    colorscale='inferno',
    showscale=True,
    colorbar=dict(title='|E|')
))

def update_slider_max(plane):
    if plane == 'XY':
        slice_slider.max = nz - 1
    elif plane == 'YZ':
        slice_slider.max = nx - 1
    else:  # XZ
        slice_slider.max = ny - 1
    
    # Adjust current value if it exceeds new maximum
    if slice_slider.value > slice_slider.max:
        slice_slider.value = slice_slider.max

def on_plane_change(change):
    update_slider_max(change.new)
    update_plot()

def update_plot(*args):
    plane = plane_selector.value
    slice_idx = slice_slider.value
    data = E_total_magnitude_cpu if show_total.value else E_magnitude_cpu
    
    if plane == 'XY':
        x, y = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
        z = np.full_like(x, slice_idx)
        surfacecolor = data[:, :, slice_idx]
    elif plane == 'YZ':
        y, z = np.meshgrid(np.arange(ny), np.arange(nz), indexing='ij')
        x = np.full_like(y, slice_idx)
        surfacecolor = data[slice_idx, :, :]
    else:  # XZ
        x, z = np.meshgrid(np.arange(nx), np.arange(nz), indexing='ij')
        y = np.full_like(x, slice_idx)
        surfacecolor = data[:, slice_idx, :]
    
    fig.data[1].update(
        x=x,
        y=y,
        z=z,
        surfacecolor=surfacecolor,
        colorscale=colorscale_selector.value,
        colorbar=dict(title='|E|' if not show_total.value else '|E| Total')
    )

# Set initial layout
fig.update_layout(
    scene=dict(
        aspectmode='data',
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.5, y=1.5, z=1.5)
        )
    ),
    width=800,
    height=800
)

# Connect widgets to update function
plane_selector.observe(on_plane_change, 'value')
slice_slider.observe(update_plot, 'value')
show_total.observe(update_plot, 'value')
colorscale_selector.observe(update_plot, 'value')

# Initial plot
update_plot()

# Display widgets and figure
controls = VBox([plane_selector, slice_slider, show_total, colorscale_selector])
display(controls, fig)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
import os

# XY Plane (varying z)
if not os.path.exists('temp_frames'):
    os.makedirs('temp_frames')

frames = []
for z in range(nz):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[:, :, z].T,
              cmap='inferno',
              extent=[0, nx, 0, ny],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'XY-plane, Z-slice: {z}/{nz-1}')
    plt.xlabel('X')
    plt.ylabel('Y')
    frame_path = f'temp_frames/frame_{z:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_xy.gif', frames, duration=0.1)
for z in range(nz):
    os.remove(f'temp_frames/frame_{z:03d}.png')

# YZ Plane (varying x)
frames = []
for x in range(nx):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[x, :, :].T,
              cmap='inferno',
              extent=[0, ny, 0, nz],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'YZ-plane, X-slice: {x}/{nx-1}')
    plt.xlabel('Y')
    plt.ylabel('Z')
    frame_path = f'temp_frames/frame_{x:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_yz.gif', frames, duration=0.1)
for x in range(nx):
    os.remove(f'temp_frames/frame_{x:03d}.png')

# XZ Plane (varying y)
frames = []
for y in range(ny):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[:, y, :].T,
              cmap='inferno',
              extent=[0, nx, 0, nz],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'XZ-plane, Y-slice: {y}/{ny-1}')
    plt.xlabel('X')
    plt.ylabel('Z')
    frame_path = f'temp_frames/frame_{y:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_xz.gif', frames, duration=0.1)
for y in range(ny):
    os.remove(f'temp_frames/frame_{y:03d}.png')

os.rmdir('temp_frames')

print("GIFs have been saved as 'e_field_animation_xycube.gif', 'e_field_animation_yzcube.gif', and 'e_field_animation_xzcube.gif'")


In [None]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import interact, RadioButtons, IntSlider, VBox, HBox, Checkbox
import ipywidgets as widgets
import cupy as cp
from importlib import reload
import ExcitedField
ExcitedField = reload(ExcitedField)
from ExcitedField import E_field_solver

# Calculate E-field
E_field = E_field_solver(result[1], interaction_matrix, grid_size)  # Returns cupy complex64
E_magnitude = cp.sqrt(cp.sum(cp.abs(E_field)**2, axis=-1))

# Convert E_inc to GPU if it's not already
if isinstance(E_inc, np.ndarray):
    E_inc_gpu = cp.asarray(E_inc)
else:
    E_inc_gpu = E_inc

# Calculate total field option
E_total = E_field + E_inc_gpu
E_total_magnitude = cp.sqrt(cp.sum(cp.abs(E_total)**2, axis=-1))

# Transfer to CPU and convert to numpy float32
E_magnitude_cpu = cp.asnumpy(E_magnitude).astype(np.float32)
E_total_magnitude_cpu = cp.asnumpy(E_total_magnitude).astype(np.float32)

nx, ny, nz = E_magnitude_cpu.shape[:3]

# Create widgets
plane_selector = RadioButtons(
    options=['XY', 'YZ', 'XZ'],
    description='Plane:',
    disabled=False
)

slice_slider = IntSlider(
    value=0,
    min=0,
    max=max(nx, ny, nz) - 1,
    step=1,
    description='Slice:',
    continuous_update=True
)

show_total = Checkbox(
    value=False,
    description='Show Total Field',
    disabled=False
)

colorscale_selector = RadioButtons(
    options=['inferno', 'magma', 'viridis', 'plasma'],
    value='inferno',
    description='Colormap:',
    disabled=False
)

fig = go.FigureWidget()

# Add bounding box
fig.add_trace(go.Scatter3d(
    x=[0, nx, nx, 0, 0, 0, nx, nx, 0, 0, nx, nx, 0, 0, nx, nx],
    y=[0, 0, ny, ny, 0, 0, 0, ny, ny, 0, 0, ny, ny, 0, 0, ny],
    z=[0, 0, 0, 0, 0, nz, nz, nz, nz, 0, 0, 0, 0, nz, nz, nz],
    mode='lines',
    line=dict(color='gray', width=1),
    showlegend=False,
    hoverinfo='skip'
))

# Add initial surface plot
fig.add_trace(go.Surface(
    x=np.array([[]]),
    y=np.array([[]]),
    z=np.array([[]]),
    colorscale='inferno',
    showscale=True,
    colorbar=dict(title='|E|')
))

def update_plot(*args):
    plane = plane_selector.value
    slice_idx = slice_slider.value
    data = E_total_magnitude_cpu if show_total.value else E_magnitude_cpu
    
    if plane == 'XY':
        x, y = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
        z = np.full_like(x, slice_idx)
        surfacecolor = data[:, :, slice_idx]
    elif plane == 'YZ':
        y, z = np.meshgrid(np.arange(ny), np.arange(nz), indexing='ij')
        x = np.full_like(y, slice_idx)
        surfacecolor = data[slice_idx, :, :]
    else:  # XZ
        x, z = np.meshgrid(np.arange(nx), np.arange(nz), indexing='ij')
        y = np.full_like(x, slice_idx)
        surfacecolor = data[:, slice_idx, :]
    
    fig.data[1].update(
        x=x,
        y=y,
        z=z,
        surfacecolor=surfacecolor,
        colorscale=colorscale_selector.value,
        colorbar=dict(title='|E|' if not show_total.value else '|E| Total')
    )

# Set initial layout
fig.update_layout(
    scene=dict(
        aspectmode='data',
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        camera=dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.5, y=1.5, z=1.5)
        )
    ),
    width=800,
    height=800
)

# Connect widgets to update function
plane_selector.observe(update_plot, 'value')
slice_slider.observe(update_plot, 'value')
show_total.observe(update_plot, 'value')
colorscale_selector.observe(update_plot, 'value')

# Initial plot
update_plot()

# Display widgets and figure
controls = VBox([plane_selector, slice_slider, show_total, colorscale_selector])
display(controls, fig)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
import os

# XY Plane (varying z)
if not os.path.exists('temp_frames'):
    os.makedirs('temp_frames')

frames = []
for z in range(nz):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[:, :, z].T,
              cmap='inferno',
              extent=[0, nx, 0, ny],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'XY-plane, Z-slice: {z}/{nz-1}')
    plt.xlabel('X')
    plt.ylabel('Y')
    frame_path = f'temp_frames/frame_{z:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_xy.gif', frames, duration=0.1)
for z in range(nz):
    os.remove(f'temp_frames/frame_{z:03d}.png')

# YZ Plane (varying x)
frames = []
for x in range(nx):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[x, :, :].T,
              cmap='inferno',
              extent=[0, ny, 0, nz],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'YZ-plane, X-slice: {x}/{nx-1}')
    plt.xlabel('Y')
    plt.ylabel('Z')
    frame_path = f'temp_frames/frame_{x:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_yz.gif', frames, duration=0.1)
for x in range(nx):
    os.remove(f'temp_frames/frame_{x:03d}.png')

# XZ Plane (varying y)
frames = []
for y in range(ny):
    plt.figure(figsize=(8, 8))
    plt.imshow(E_total_magnitude_cpu[:, y, :].T,
              cmap='inferno',
              extent=[0, nx, 0, nz],
              origin='lower')
    plt.colorbar(label='|E| Total')
    plt.title(f'XZ-plane, Y-slice: {y}/{ny-1}')
    plt.xlabel('X')
    plt.ylabel('Z')
    frame_path = f'temp_frames/frame_{y:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    frames.append(imageio.imread(frame_path))

imageio.mimsave('e_field_animation_xz.gif', frames, duration=0.1)
for y in range(ny):
    os.remove(f'temp_frames/frame_{y:03d}.png')

os.rmdir('temp_frames')

print("GIFs have been saved as 'e_field_animation_xy.gif', 'e_field_animation_yz.gif', and 'e_field_animation_xz.gif'")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
import os

# Assuming E_magnitude_cpu and E_total_magnitude_cpu are already calculated as in your code
nx, ny, nz = E_magnitude_cpu.shape[:3]

# Create temporary directory for frames if it doesn't exist
if not os.path.exists('temp_frames'):
    os.makedirs('temp_frames')

# Create frames
frames = []
for z in range(nz):
    # Create figure for this frame
    plt.figure(figsize=(8, 8))
    
    # Plot heatmap
    plt.imshow(E_total_magnitude_cpu[:, :, z].T,
              cmap='inferno',
              extent=[0, nx, 0, ny],
              origin='lower')
    
    plt.colorbar(label='|E| Total')
    plt.title(f'Z-slice: {z}/{nz-1}')
    plt.xlabel('X')
    plt.ylabel('Y')
    
    # Save frame
    frame_path = f'temp_frames/frame_{z:03d}.png'
    plt.savefig(frame_path)
    plt.close()
    
    frames.append(imageio.imread(frame_path))

# Create GIF
imageio.mimsave('e_field_animation2.gif', frames, duration=0.1)  # 0.1 seconds per frame

# Clean up temporary files
for z in range(nz):
    os.remove(f'temp_frames/frame_{z:03d}.png')
os.rmdir('temp_frames')

print("GIF has been saved as 'e_field_animation2.gif'")
