In [1]:
import os
import shutil
import uuid
from collections import defaultdict
from typing import List, Tuple, Dict, NoReturn, DefaultDict, Set

import faiss
from faiss import IndexBinaryFlat, IndexFlatL2
from matplotlib import pyplot as plt
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, ResNet

In [2]:
class ImageDs(Dataset):
    def __init__(self, filenames: List[str]):
        self._filenames: List[str] = filenames

    @staticmethod
    def add_pad(img, shape):
        color_pick = img[0][0]
        padded_img = color_pick * np.ones(shape + img.shape[2:3], dtype=np.uint8)
        x_offset = int((padded_img.shape[0] - img.shape[0]) / 2)
        y_offset = int((padded_img.shape[1] - img.shape[1]) / 2)
        padded_img[x_offset:x_offset + img.shape[0], y_offset:y_offset + img.shape[1]] = img
        return padded_img

    @staticmethod
    def resize(img, shape):
        scale = min(shape[0] * 1.0 / img.shape[0], shape[1] * 1.0 / img.shape[1])
        if scale != 1:
            img = cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
        return img

In [3]:
class TestDs(ImageDs):
    def __init__(self, filenames: List[str]):
        super().__init__(filenames)

    def __len__(self):
        return len(self._filenames)

    def __getitem__(self, idx) -> np.array:
        filename: str = self._filenames[idx]
        img: np.array = cv2.imread(filename)
        img: np.array = self.resize(img, (224, 224))
        img: np.array = self.add_pad(img, (224, 224))
        img: np.array = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img: np.array = torch.tensor(img, dtype=torch.float).permute(2, 0, 1) / 255.
        return img

In [6]:
class ImageHasher:
    def __init__(self, path_to_dir: str, hash_size: int, image_distance_threshold: int):
        torch.cuda.empty_cache()
        torch.cuda.memory_summary(device=None, abbreviated=False)
        self._model: ResNet = resnet50(pretrained=True)
        self._hash_size: int = hash_size
        self._index: IndexFlatL2 = self._load_index()
        self._distance_threshold: int = image_distance_threshold
        self._path_to_dir: str = path_to_dir

    def run_job(self):
        file_names: List[str] = list(map(lambda x: os.path.join(self._path_to_dir, x), os.listdir(self._path_to_dir)))
        vectors = self._calculate_vectors(file_names)
        hashes = self._calculate_hashes(vectors)
        buckets = self._select_sets(hashes, file_names)
        self._perform_compaction(buckets)

    def _calculate_vectors(self, file_names: List[str]):
        vectors = []
        self._prepare_model()
        batch_loader: DataLoader = self._get_loader(file_names)
        for batch in batch_loader:
            image = batch
            image = image.to('cuda')
            vectors.append(self._model(image))
        return vectors

    def _calculate_hashes(self, vectors):
        hashes = []
        for v in vectors:
            for el in v.tolist():
                el = np.float32(np.array([el]))
                self._index.add(el)
                hashes.append(el)
        return hashes

    def _select_sets(self, hashes, photos_names: List[str]) -> DefaultDict[int, Set[str]]:
        used_images: Dict[int, int] = dict()
        batches: DefaultDict[int, Set[str]] = defaultdict(lambda: set())
        for img_hash in hashes:
            s: List[Tuple[int, int]] = self._check_duplicate(img_hash)
            current_image_index: int = s[0][0]
            bucket_images: List[int] = list(
                map(lambda l: l[0], filter(lambda x: x[0] not in used_images and x[1] < self._distance_threshold, s)))
            lst = map(lambda l: photos_names[l], bucket_images)
            batches[used_images.get(current_image_index, current_image_index)].update(lst)
            used_images.update([(k, current_image_index) for k in bucket_images])
        return batches

    def _perform_compaction(self, buckets: DefaultDict[int, Set[str]]):
        for batch_num, batch in buckets.items():
            if len(batch) > 1:
                new_folder = os.path.join(self._path_to_dir, uuid.uuid4().hex)
                os.mkdir(new_folder)
                for file_path in batch:
                    file_name = file_path.rsplit('/')[-1]
                    shutil.move(file_path, os.path.join(new_folder, file_name))

    def _prepare_model(self) -> NoReturn:
        for param in self._model.parameters():
            param.requires_grad = False
        # self._model.fc = torch.nn.Linear(self._model.fc.in_features, self._classes_count)
        self._model.to('cuda')
        ct: int = 0
        for child in self._model.children():
            ct += 1
            if ct < 49:
                for param in child.parameters():
                    param.requires_grad = True

    def _check_duplicate(self, img_hash: np.ndarray) -> List[Tuple[int, int]]:
        D, I = self._index.search(img_hash, len(os.listdir(self._path_to_dir)))
        return list(zip(I[0], D[0]))

    @staticmethod
    def _get_loader(file_names: List[str]):
        return DataLoader(TestDs(file_names), batch_size=25, num_workers=0)

    def _load_index(self, filename: str = 'faiss_index') -> IndexFlatL2:
        d: int = 1000
        try:
            return faiss.read_index_binary(f'{filename}_{d}')
        except RuntimeError:
            return faiss.IndexFlatL2(d)

In [7]:
hasher: ImageHasher = ImageHasher("images/NeuralTest/", 16, 2700)
v = hasher.run_job()

(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
(1, 1000)
<class 'numpy.ndarray'>
