In [1]:
#| default_exp model_embedding.semantic_deduplication

# Remove semantically similar images
> Remove semantically similar images from training dataset

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [3]:
#| export
from cv_tools.core import *
from cv_tools.imports import *


In [4]:
#| export
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import List
from PIL import Image
import torchvision.transforms as transforms


In [11]:
from private_front_easy_pin_detection.pytorch_model_development import UnetManualMaxPoolOnly

In [5]:
import faiss

> installing faiss was a huge pain, but it works now
- conda install -c conda-forge faiss-gpu libblas=*=*mkl

In [6]:
#| export
class ImageFeatureDataset(Dataset):
    def __init__(self, image_paths: List[str], transform=None):
        self.image_paths = image_paths
        self.transform = transform or transforms.Compose([
            transforms.Resize((1152, 1632)),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485], std=[0.229])
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Convert to grayscale
        image = Image.fromarray(image)
        image = image/255.0
        if self.transform:
            image = self.transform(image)
        return image, img_path

In [7]:
# Load the UNet and modify it to only use the encoder
class EmbeddingModel(torch.nn.Module):
    def __init__(self, unet_model):
        super(EmbeddingModel, self).__init__()
        self.encoder = unet_model.encoder  # Use only the encoder

    def forward(self, x):
        c1, c2, c3, c4, c5 = self.encoder(x)  # Extract the deepest features
		# Global average pooling on the last feature map
        pooled = nn.AdaptiveAvgPool2d((1, 1))(c5)
        return pooled.view(pooled.size(0), -1)

In [8]:
#| export
def extract_features(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
    feature_dim: int = 256
) -> Tuple[np.ndarray, List[str]]:
    """Extract features from images using the encoder."""
    features = np.zeros((len(dataloader.dataset), feature_dim), dtype=np.float32)
    paths = []
    
    model.eval()
    with torch.no_grad():
        for idx, (batch, batch_paths) in enumerate(tqdm(dataloader, desc="Extracting features")):
            batch = batch.to(device)
            batch_features = model(batch).cpu().numpy()
            
            start_idx = idx * dataloader.batch_size
            end_idx = start_idx + len(batch)
            features[start_idx:end_idx] = batch_features
            paths.extend(batch_paths)
    
    return features, paths

In [9]:
def find_duplicates(
    features: np.ndarray,
    paths: List[str],
    similarity_threshold: float = 0.95
) -> Dict[str, List[str]]:
    """Find duplicate images using FAISS indexing."""
    # Normalize features for cosine similarity
    features = features.astype(np.float32)
    faiss.normalize_L2(features)
    
    # Create FAISS index
    index = faiss.IndexFlatIP(features.shape[1])  # Inner product for cosine similarity
    index.add(features)
    
    # Search for similar images
    similarities, indices = index.search(features, k=50)  # Get top 50 similar images
    
    # Group duplicates
    duplicates = {}
    processed = set()
    
    for i in range(len(features)):
        if i in processed:
            continue
            
        current_duplicates = []
        for j, sim in zip(indices[i], similarities[i]):
            if sim > similarity_threshold and i != j and j not in processed:
                current_duplicates.append(paths[j])
                processed.add(j)
                
        if current_duplicates:
            duplicates[paths[i]] = current_duplicates
            processed.add(i)
    
    return duplicates

In [12]:
UnetManualMaxPoolOnly?

[0;31mInit signature:[0m [0mUnetManualMaxPoolOnly[0m[0;34m([0m[0min_channels[0m[0;34m,[0m [0mn_classes[0m[0;34m,[0m [0mboth_pool[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.

.. note::
    As per the example above, an ``__ini

In [10]:
def process_images(
    image_dir: str,
    output_dir: str,
    batch_size: int = 32,
    similarity_threshold: float = 0.95
) -> Tuple[Dict[str, List[str]], int]:
    """Main function to process images and find duplicates."""
    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_paths = [str(p) for p in Path(image_dir).glob("**/*") if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}]
    
    # Initialize model
    encoder = UnetManualMaxPoolOnly(in_channels=1, both_pool=False)
    feature_extractor = EmbeddingModel(encoder).to(device)
    
    # Create dataloader
    dataset = ImageFeatureDataset(image_paths)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
    
    # Extract features
    features, paths = extract_features(feature_extractor, dataloader, device)

	# Find duplicates
    duplicates = find_duplicates(features, paths, similarity_threshold)
    
    # Create output directory structure
    os.makedirs(output_dir, exist_ok=True)
    unique_dir = os.path.join(output_dir, "unique")
    duplicate_dir = os.path.join(output_dir, "duplicates")
    os.makedirs(unique_dir, exist_ok=True)
    os.makedirs(duplicate_dir, exist_ok=True)
    
    # Move files
    processed_count = 0
    for original, duplicate_list in duplicates.items():
        # Copy original to unique folder
        shutil.copy2(original, unique_dir)
        
        # Copy duplicates to duplicate folder
        for dup in duplicate_list:
            shutil.copy2(dup, duplicate_dir)
            processed_count += 1
    
    return duplicates, processed_count

In [4]:
#| hide
import nbdev; nbdev.nbdev_export('19_model_embedding.semantic_deduplication.ipynb')