In [None]:
import numpy as np
from PIL import Image
import random
import matplotlib.pyplot as plt

class ImageSecretSharing:
    def __init__(self, prime=251):  # Prime should be larger than max pixel value (usually 255)
        self.prime = prime
        
    def _evaluate_polynomial(self, coefficients, x):
        """Evaluate a polynomial at point x"""
        result = 0
        for coefficient in reversed(coefficients):
            result = (result * x + coefficient) % self.prime
        return result
    
    def _generate_polynomial(self, secret, threshold):
        """Generate a random polynomial with secret as the constant term"""
        coefficients = [secret]
        for _ in range(threshold - 1):
            coefficients.append(random.randint(1, self.prime - 1))
        return coefficients
    
    def _lagrange_interpolation(self, x_values, y_values, x):
        """Reconstruct the secret using Lagrange interpolation"""
        k = len(x_values)
        result = 0
        
        for i in range(k):
            numerator = 1
            denominator = 1
            for j in range(k):
                if i != j:
                    numerator = (numerator * (x - x_values[j])) % self.prime
                    denominator = (denominator * (x_values[i] - x_values[j])) % self.prime
            
            # Calculate modular inverse of denominator
            inverse_denominator = pow(denominator, self.prime - 2, self.prime)  # Fermat's little theorem
            
            term = (y_values[i] * numerator * inverse_denominator) % self.prime
            result = (result + term) % self.prime
            
        return result
    
    def share_image(self, image_path, n_shares, threshold):
        """
        Split an image into n shares where any threshold number of shares can reconstruct
        the original image, but fewer than threshold shares reveal nothing.
        
        Args:
            image_path: Path to the input image file
            n_shares: Number of shares to generate
            threshold: Minimum number of shares required for reconstruction
            
        Returns:
            List of image shares
        """
        # Read the image
        original_image = Image.open(image_path)
        width, height = original_image.size
        image_array = np.array(original_image)
        
        # Determine if image is grayscale or RGB
        is_grayscale = len(image_array.shape) == 2
        
        # Create shares container with int32 dtype to prevent overflow
        if is_grayscale:
            shares = [np.zeros((height, width), dtype=np.int32) for _ in range(n_shares)]
        else:
            channels = image_array.shape[2]
            shares = [np.zeros((height, width, channels), dtype=np.int32) for _ in range(n_shares)]
        
        # For each pixel
        if is_grayscale:
            for y in range(height):
                for x in range(width):
                    pixel_value = int(image_array[y, x])  # Ensure integer type
                    # Generate a polynomial for this pixel
                    poly = self._generate_polynomial(pixel_value, threshold)
                    
                    # Generate shares for this pixel
                    for i in range(n_shares):
                        x_coord = i + 1  # Use share index + 1 as x-coordinate
                        shares[i][y, x] = self._evaluate_polynomial(poly, x_coord)
        else:
            for y in range(height):
                if y % 100 == 0:  # Progress indicator for large images
                    print(f"Processing row {y}/{height}")
                for x in range(width):
                    # Process each color channel separately
                    for c in range(image_array.shape[2]):
                        pixel_value = int(image_array[y, x, c])  # Ensure integer type
                        # Generate a polynomial for this pixel's channel
                        poly = self._generate_polynomial(pixel_value, threshold)
                        
                        # Generate shares for this pixel's channel
                        for i in range(n_shares):
                            x_coord = i + 1  # Use share index + 1 as x-coordinate
                            shares[i][y, x, c] = self._evaluate_polynomial(poly, x_coord)
        
        # Convert shares to PIL images
        share_images = []
        for i in range(n_shares):
            # Clip values to valid uint8 range before converting
            share_array = np.clip(shares[i], 0, 255).astype(np.uint8)
            share_image = Image.fromarray(share_array)
            share_images.append(share_image)
            
        return share_images
    
    def reconstruct_image(self, shares, share_indices=None):
        """
        Reconstruct the original image from a set of shares.
        
        Args:
            shares: List of share images
            share_indices: Indices of the shares if they're not in order 1,2,...
            
        Returns:
            Reconstructed image
        """
        if share_indices is None:
            share_indices = [i+1 for i in range(len(shares))]
        
        # Convert share images to numpy arrays with int32 dtype
        share_arrays = [np.array(share, dtype=np.int32) for share in shares]
        
        # Check if shares have the same dimensions
        height, width = share_arrays[0].shape[:2]
        is_grayscale = len(share_arrays[0].shape) == 2
        
        # Create container for the reconstructed image with int32 dtype
        if is_grayscale:
            reconstructed = np.zeros((height, width), dtype=np.int32)
        else:
            channels = share_arrays[0].shape[2]
            reconstructed = np.zeros((height, width, channels), dtype=np.int32)
        
        # For each pixel
        if is_grayscale:
            for y in range(height):
                for x in range(width):
                    # Get the shares of this pixel
                    x_values = share_indices
                    y_values = [int(share_arrays[i][y, x]) for i in range(len(shares))]
                    
                    # Reconstruct the pixel using Lagrange interpolation
                    reconstructed[y, x] = self._lagrange_interpolation(x_values, y_values, 0)
        else:
            for y in range(height):
                if y % 100 == 0:  # Progress indicator for large images
                    print(f"Reconstructing row {y}/{height}")
                for x in range(width):
                    for c in range(channels):
                        # Get the shares of this pixel's channel
                        x_values = share_indices
                        y_values = [int(share_arrays[i][y, x, c]) for i in range(len(shares))]
                        
                        # Reconstruct the pixel channel using Lagrange interpolation
                        reconstructed[y, x, c] = self._lagrange_interpolation(x_values, y_values, 0)
        
        # Clip values to valid uint8 range before converting to image
        reconstructed = np.clip(reconstructed, 0, 255).astype(np.uint8)
        return Image.fromarray(reconstructed)

# Example usage
if __name__ == "__main__":
    # Initialize the image secret sharing system
    secret_sharing = ImageSecretSharing()
    
    # Define parameters
    image_path = r"C:\ASSignments\6th SEMESTER\MINI_PROJECT\P2\img1.webp"  # Replace with your image path
    n_shares = 3  # Total number of shares
    threshold = 2  # Minimum shares needed for reconstruction
    
    # Generate shares
    print(f"Generating {n_shares} shares with threshold {threshold}...")
    shares = secret_sharing.share_image(image_path, n_shares, threshold)
    
    # Save shares
    for i, share in enumerate(shares):
        share.save(f"share_{i+1}.png")
        plt.figure(figsize=(5, 5))
        plt.imshow(share)
        plt.title(f"Share {i+1}")
        plt.axis('off')
        plt.show()
    
    # Demonstrate reconstruction with different combinations
    print("Reconstructing with different share combinations:")
    
    # Reconstruct with the minimum number of shares (using first and second shares)
    min_shares = [shares[0], shares[1]]
    min_indices = [1, 2]
    reconstructed_min = secret_sharing.reconstruct_image(min_shares, min_indices)
    reconstructed_min.save("reconstructed_min.png")
    
    plt.figure(figsize=(5, 5))
    plt.imshow(reconstructed_min)
    plt.title("Reconstructed with minimum shares (1, 2)")
    plt.axis('off')
    plt.show()
    
    # Reconstruct with all shares
    reconstructed_all = secret_sharing.reconstruct_image(shares)
    reconstructed_all.save("reconstructed_all.png")
    
    plt.figure(figsize=(5, 5))
    plt.imshow(reconstructed_all)
    plt.title("Reconstructed with all shares")
    plt.axis('off')
    plt.show()
    
    # Show original for comparison
    original = Image.open(image_path)
    plt.figure(figsize=(5, 5))
    plt.imshow(original)
    plt.title("Original Image")
    plt.axis('off')
    plt.show()
    
    print("Process completed!")