In [1]:
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import transforms
import torch
from torch.utils.data import DataLoader

from main import ImagesManager
from data_manager import CustomDataset

  from .autonotebook import tqdm as notebook_tqdm


## Model and data loader

In [2]:
def load_model():
    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    model.eval()

    class Identity(torch.nn.Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return x

    model.fc = Identity()
    return model

In [3]:
def create_data_loader(path, model_input_shape):
    transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Resize((model_input_shape,model_input_size)),
                                #transforms.CenterCrop(model_input_size),
                                ])
    dataset = CustomDataset(path, transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    return dataloader

## Having fun

In [4]:
model = load_model()
model_input_size = 224
source_dir = "uncleaned_data"
target_dir = "cleaned_data"

In [6]:
dataloader = create_data_loader(source_dir, model_input_size)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
img_manager = ImagesManager(dataloader, model, DEVICE)

In [7]:
img_manager.images_to_vec_data()
img_manager.vecs_to_tensor()
img_manager.count_cross_similarity()
img_manager.seperate_non_uniq(target_dir, treshold=0.7)

100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.02s/it]
100%|█████████████████████████████████████████████████████████████████████████████| 111/111 [00:00<00:00, 18499.14it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 112/112 [00:00<00:00, 266.03it/s]
112it [00:00, 333.33it/s]
