In [None]:
# Attempt to use SKLearn's One Class SVM to rank ARID crops.

In [None]:
import torch
from torch import nn
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset
from torchvision import transforms, models

from tqdm import tqdm
from pathlib import Path
from PIL import Image

model = models.mobilenet_v2(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])
model.eval()

img_size = 256

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

webly_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

In [None]:
class WeblyNormDataset(Dataset):

    def __init__(self, webly_root, search_term, transform, dloader=True):
        self.webly_root = Path(webly_root)
        self.transform = transform
        self.items = []
        self.dloader = dloader
        
        for engine in self.webly_root.iterdir():  
            if engine.stem != '.floyddata':
                for search in engine.iterdir():
                    if search.stem == search_term:
                        for img in Path(search / 'images').iterdir():
                            self.items.append(img)

    
    def __len__(self):
        return len(self.items)

    
    def __getitem__(self, idx):
        path = self.items[idx]
        img = Image.open(path)
        img = img.convert('RGB')
        sample = self.transform(img)
        if self.dloader:
            return sample
        else:
            return sample, path


In [None]:
# Remove Duplicates.
dist = WeblyNormDataset('/home/justin/Desktop/webly-dataset/', 'yellow_lemon_fruit', webly_transform, dloader=False)
results = {}
zero = torch.zeros([1, 1280, 8, 8])
with torch.no_grad():
    for i, pth in tqdm(dist):
        rel_path = Path((*pth.parts[pth.parts.index('webly-dataset')+1:]))
        o = model(i.unsqueeze(0))
        b = torch.dist(zero, o)
        results[rel_path] = b.item()
        
print(len(results))    
_results = {v:k for k,v in results.items()}
_results = {v:k for k,v in _results.items()}
print(len(_results))
f = list(results)

In [None]:
from sklearn.svm import OneClassSVM

In [None]:
clf = OneClassSVM()

In [None]:
preds = clf.fit_predict(f)

In [None]:
from arid import arid
from PIL import ImageDraw, Image
from IPython.display import display
wps = arid.get_wps("/home/justin/Desktop/arid-dataset")

import json
from pathlib import Path
import skimage.io

wp = wps[134]

quality = 'single'
title = wp.get_title()
with open(f'/home/justin/Desktop/thesis/ss/ss-{quality}-{title}.json') as json_file:
    ss_data = json.load(json_file)

img_paths = wp.rgb_image_paths()
for img_path in img_paths:
    image = skimage.io.imread(img_path)
    img = Image.open(img_path)
    new_img_path = arid.annotation_path(img_path, 'selective-search')
    img_key = Path(img_path).stem
    
    annotations = []
    top_score = 0
    for box in ss_data[img_key]['boxes']:
        x1 = box[0]
        y1 = box[1]
        x2 = box[2]
        y2 = box[3]
        
        crop = img.crop((x1, y1, x2, y2))
        t_crop = test_transform(crop)
        features = model(t_crop.unsqueeze(0))
        features = features.flatten().detach().numpy().tolist()
        score = clf.score_samples([features])
        annotations.append({
            'id': 'test',
            'coords': [(x1,y1), (x2, y1), (x2, y2), (x1,y2)],
            'score': score,
            'colormap': 'YlGn'
        })


#     gt_annotations_raw = wp.get_annotations(img_path.stem)['annotations']
#     gt_annotations = []
#     for gt_annotation_raw in gt_annotations_raw:
#         if gt_annotation_raw['id'] is not None:
#             x = gt_annotation_raw['x']
#             y = gt_annotation_raw['y']
#             w = gt_annotation_raw['width']
#             h = gt_annotation_raw['height']

    break


In [None]:
b = []
for dz in annotations:
    if dz['score'][0] > 394:
        g = {}
        g['id'] = dz['id']
        g['coords'] = dz['coords']
        g['colormap'] = dz['colormap']
        g['score'] = dz['score'][0]
        b.append(g)

img2 = img.copy()
arid.annotate_img(img2, new_img_path, b, save=False)
display(img2)
