In [1]:
import numpy as np
from sklearn.feature_extraction.image import extract_patches_2d, reconstruct_from_patches_2d
from sklearn.decomposition import DictionaryLearning
from sklearn.metrics import mean_squared_error
from ksvd import KSVD
import matplotlib.pyplot as plt

ImportError: cannot import name 'KSVD' from 'ksvd' (d:\CODE\IDE\Anaconda\envs\punk\Lib\site-packages\ksvd\__init__.py)

## Step 1: Load and Prepare the Data

In [None]:
def generate_synthetic_data():
    # Create synthetic images (replace with actual image loading in practice)
    house_design_with_number = np.random.rand(64, 64)  # Example image with numbers
    house_design_without_number = np.random.rand(64, 64)  # Example image without numbers

    # Simulate adding a "number" to the image (a small patch with high intensity)
    house_design_with_number[30:35, 30:35] = 1.0

    return house_design_with_number, house_design_without_number

# Load synthetic data
img_with_number, img_without_number = generate_synthetic_data()

# Display images
plt.subplot(1, 2, 1)
plt.imshow(img_with_number, cmap='gray')
plt.title('With Number')
plt.subplot(1, 2, 2)
plt.imshow(img_without_number, cmap='gray')
plt.title('Without Number')
plt.show()


## Step 3: Dictionary Learning Using K-SVD

In [None]:
patch_size = (8, 8)
patches_with_number = extract_patches_2d(img_with_number, patch_size)
patches_without_number = extract_patches_2d(img_without_number, patch_size)

# Flatten patches for dictionary learning
patches_with_number = patches_with_number.reshape(patches_with_number.shape[0], -1)
patches_without_number = patches_without_number.reshape(patches_without_number.shape[0], -1)


## Step 4: Classification Based on Reconstruction Error

In [None]:
n_components = 100  # Number of dictionary atoms
ksvd = KSVD(n_components=n_components, max_iter=10)

# Learn dictionaries
D_with_number, _ = ksvd.fit(patches_with_number.T)
D_without_number, _ = ksvd.fit(patches_without_number.T)

# D_with_number and D_without_number are the learned dictionaries


## Step 5: Classification Based on Reconstruction Error

In [None]:
def classify_image(test_image, D_with_number, D_without_number, patch_size):
    # Extract patches from the test image
    test_patches = extract_patches_2d(test_image, patch_size)
    test_patches = test_patches.reshape(test_patches.shape[0], -1)
    
    # Compute reconstruction error for each dictionary
    error_with_number = np.sum([
        mean_squared_error(patch, D_with_number @ ksvd.transform(patch.reshape(-1, 1)).reshape(-1))
        for patch in test_patches
    ])
    
    error_without_number = np.sum([
        mean_squared_error(patch, D_without_number @ ksvd.transform(patch.reshape(-1, 1)).reshape(-1))
        for patch in test_patches
    ])
    
    return 'With Number' if error_with_number < error_without_number else 'Without Number'

# Classify a new image
new_image = generate_synthetic_data()[0]  # Simulate a new image with a number
classification_result = classify_image(new_image, D_with_number, D_without_number, patch_size)
print("Classification Result:", classification_result)


## Step 5: Remove Numbers from the Image

In [None]:
def remove_numbers_from_image(test_image, D_without_number, patch_size):
    # Extract patches from the test image
    test_patches = extract_patches_2d(test_image, patch_size)
    test_patches = test_patches.reshape(test_patches.shape[0], -1)
    
    # Reconstruct patches using the "without number" dictionary
    denoised_patches = np.array([
        D_without_number @ ksvd.transform(patch.reshape(-1, 1)).reshape(-1)
        for patch in test_patches
    ])
    
    # Reshape patches back to their original size
    denoised_patches = denoised_patches.reshape(-1, *patch_size)
    
    # Reconstruct the full image
    denoised_image = reconstruct_from_patches_2d(denoised_patches, test_image.shape)
    
    return denoised_image

# Remove numbers from the new image
denoised_image = remove_numbers_from_image(new_image, D_without_number, patch_size)

# Display the denoised image
plt.imshow(denoised_image, cmap='gray')
plt.title('Denoised Image (Number Removed)')
plt.show()