In [None]:
# Notebook for removing duplicate elements from a webly dataset, resizing them, and saving them in a new space.

In [None]:
class WeblyNormDataset(Dataset):

    def __init__(self, webly_root, search_term, transform, dloader=True):
        self.webly_root = Path(webly_root)
        self.transform = transform
        self.items = []
        self.dloader = dloader
        
        for search in self.webly_root.iterdir(): 
            if search.stem == search_term:
                for img in Path(search).iterdir():
                    if img.stem != '.floyddata':
                        self.items.append(img)
                        
    
    def __len__(self):
        return len(self.items)

    
    def __getitem__(self, idx):
        path = self.items[idx]
        try:
            img = Image.open(path)
            img = img.convert('RGB')
            sample = self.transform(img)
            if self.dloader:
                return sample
            else:
                return sample, path
        except UnidentifiedImageError as e:
            print(f'Removed: {str(path)}')
            path.unlink()
            return None, None
        except OSError as e: 
            print(f'Removed: {str(path)}')
            path.unlink()
            return None, None


        
class ListDataset(Dataset):

    def __init__(self, root, items, transform, dloader=True):
        self.transform = transform
        self.root = root
        self.items = items
        self.dloader = dloader
    
    def __len__(self):
        return len(self.items)

    
    def __getitem__(self, idx):
        path = self.items[idx]
        img = Image.open(Path(self.root / path))
        img = img.convert('RGB')
        sample = self.transform(img)
        if self.dloader:
            return sample
        else:
            return sample, path

In [None]:
import torch
from torch import nn
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, models

from tqdm import tqdm
from pathlib import Path
from PIL import Image, UnidentifiedImageError

model = models.mobilenet_v2(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])
model.eval()

img_size = 128

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

webly_transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

view_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Resize((128, 128)),
])
search_phrase = 'green_white_toothbrush'

In [None]:
# Get Unique files. 
dist = WeblyNormDataset('/home/justin/Desktop/webly-dataset/', search_phrase, webly_transform, dloader=False)
results = {}
zero = torch.zeros([1, 1280, 4, 4])
with torch.no_grad():
    for i, pth in tqdm(dist):
        if i is None:
            continue
        rel_path = Path((*pth.parts[pth.parts.index('webly-dataset')+1:]))
        o = model(i.unsqueeze(0))
        b = torch.dist(zero, o)
        results[rel_path] = b.item()
        
print(len(results))    
_results = {v:k for k,v in results.items()}
_results = {v:k for k,v in _results.items()}
print(len(_results))

In [None]:
# Save resized unique files.
dset = ListDataset(Path('/home/justin/Desktop/webly-dataset'), list(_results.keys()), view_transform, dloader=False)

root = Path(Path('/home/justin/Desktop/cleaned-webly-dataset') / search_phrase)
if not root.exists():
    root.mkdir()

for d, p in tqdm(dset):
    new_path = Path(root / p.name)
    d.save(new_path)