In [1]:
import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop
import time
import os
import glob

### Functions for Working with Image Data

### Putting the Images into the Vector DB

#### Looping through all the Images

In [2]:
image_paths = []
for image in glob.glob("./photos/**/*.jpg", recursive=True):
    image_paths.append(image)

In [3]:
from milvus import default_server
from pymilvus import utility, connections

In [4]:
default_server.start()



    __  _________ _   ____  ______
   /  |/  /  _/ /| | / / / / / __/
  / /|_/ // // /_| |/ / /_/ /\ \
 /_/  /_/___/____/___/\____/___/ {Lite}

 Welcome to use Milvus!

 Version:   v2.2.10-lite
 Process:   50053
 Started:   2023-06-22 17:38:11
 Config:    /Users/filiphaltmayer/.milvus.io/milvus-server/2.2.10/configs/milvus.yaml
 Logs:      /Users/filiphaltmayer/.milvus.io/milvus-server/2.2.10/logs

 Ctrl+C to exit ...


In [5]:
time.sleep(5)
connections.connect(host="127.0.0.1", port=default_server.listen_port)

In [6]:
DIMENSION = 2048
BATCH_SIZE = 128
COLLECTION_NAME = "fashion"

In [7]:
# run this before importing th resnet50 model if you run into an SSL certificate URLError
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [8]:
import torch
# Load the embedding model with the last layer removed
embeddings_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
embeddings_model = torch.nn.Sequential(*(list(embeddings_model.children())[:-1]))
embeddings_model.eval()
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
segmentation_model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")


Using cache found in /Users/filiphaltmayer/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub
2023-06-22 17:38:20.336051: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
from pymilvus import FieldSchema, CollectionSchema, Collection, DataType

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="name", dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="seg_id", dtype=DataType.INT64),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=True)
collection = Collection(name=COLLECTION_NAME, schema=schema)

In [10]:
index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

In [11]:
from PIL import Image
import sys

In [12]:
def get_segmentation(image):
    inputs = extractor(images=image, return_tensors="pt")

    outputs = segmentation_model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]
    return pred_seg 

# returns two lists masks (tensor) and obj_ids (int)
# "mattmdjaga/segformer_b2_clothes" from hugging face
def get_masks(segmentation):
    obj_ids = torch.unique(segmentation)
    obj_ids = obj_ids[1:]
    masks = segmentation == obj_ids[:, None, None]
    return masks, obj_ids

def crop_images(masks, obj_ids, img):
    boxes = masks_to_boxes(masks)
    crop_boxes = []
    for box in boxes:
        crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
        crop_boxes.append(crop_box)
    
    preprocess = T.Compose([
        T.Resize(size=(256, 256)),
        T.ToTensor()
    ])
    preprocess_bounded = T.Compose([T.ToTensor()])
    
    cropped_images = []
    seg_ids = []
    for i in range(len(crop_boxes)):
        crop_box = crop_boxes[i]
        cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
        cropped_images.append(preprocess(cropped))
        seg_ids.append(obj_ids[i].item())
    with torch.no_grad():
        embeddings = embeddings_model(torch.stack(cropped_images)).squeeze().tolist()
    return embeddings, boxes.tolist(), seg_ids


In [13]:
for path in image_paths[:3]:
    image = Image.open(path)
    path_split = path.split("/")
    name = " ".join(path_split[2].split("_"))
    segmentation = get_segmentation(image)
    masks, ids = get_masks(segmentation)
    embeddings, crop_corners, seg_ids = crop_images(masks, ids, image)
    inserts = [{"embedding": embeddings[x], 
                "seg_id": seg_ids[x],
                "name": name,
                "filepath": path,
                "crop_corner": crop_corners[x]} for x in range(len(embeddings))]
    collection.insert(inserts)
                
collection.flush()

### Querying the Vector DB

#### transform the input image

In [26]:
from pprint import pprint
from PIL import ImageDraw
from collections import Counter
import matplotlib.patches as patches

%matplotlib auto

def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)
# cmap = get_cmap(len(data))
# for i, (X, Y) in enumerate(data):
#    scatter(X, Y, c=cmap(i))

search_paths = ["./photos/Taylor_Swift/Taylor_Swift_3.jpg"]

for path in search_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = " ".join(path_split[2].split("_"))
    segmentation = get_segmentation(image)
    masks, ids = get_masks(segmentation)
    embeddings, crop_corners, _ = crop_images(masks, ids, image)
    start = time.time()
    print(len(embeddings))
    res = collection.search(embeddings, 
       anns_field='embedding', 
       param={"metric_type": "L2",
              "params": {"nprobe": 10}, "offset": 1}, 
       limit=5, 
       output_fields=['filepath', 'crop_corner'])
    finish = time.time()

    print(finish - start)

    filepaths = []
    for hits in res:
        seen = set()
        for i, hit in enumerate(hits):
            if hit.entity.get("filepath") not in seen:
                seen.add(hit.entity.get("filepath"))
                filepaths.extend([hit.entity.get("filepath") for _ in range(len(hits) - i)])
    
    counts = Counter(filepaths)
    most_common = [path for path, _ in counts.most_common(2)]
    
    matches = {}
    for i, hits in enumerate(res):
        matches[i] = {"search": crop_corners[i], "res": {}}
        tracker = set(most_common)
        for hit in hits:
            if hit.entity.get("filepath") in tracker:
                matches[i]["res"][hit.entity.get("filepath")] = hit.entity.get("crop_corner")
                tracker.remove( hit.entity.get("filepath"))
                if len(tracker) == 0:
                    continue
    
    # pprint(matches)
    cmap = get_cmap(len(crop_corners))
    # plt.axis('off')
    #subplot(r,c) provide the no. of rows and columns
    f, axarr = plt.subplots(1,3) 
    axarr[0].imshow(image)
    axarr[0].set_title("lol")
    axarr[0].axis('off')
    for i, (x0, y0, x1, y1) in enumerate(crop_corners):
        rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(i), facecolor='none')
        axarr[0].add_patch(rect)
    for i, x in enumerate(most_common):
        # use the created array to output your multiple images. In this case I have stacked 4 images vertically
        image = Image.open(x)
        axarr[i+1].imshow(image)
        axarr[i+1].set_title("lol")
        axarr[i+1].axis('off')
        for key, value in matches.items():
            if x in value["res"]:
                x0, y0, x1, y1 = value["res"][x]
                rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(key), facecolor='none')
                axarr[i+1].add_patch(rect)
        # for key, val in matches.items():
            
        


Using matplotlib backend: MacOSX
11
0.032022953033447266


  return plt.cm.get_cmap(name, n)


In [15]:
for index, result in enumerate(res):
    print(index)
    print(result)

0
["id: 442362938519978091, distance: 388.6242370605469, entity: {'filepath': './photos/Kendall_Jenner/Kendall_Jenner_7.jpg', 'crop_corner': [478.0, 118.0, 643.0, 302.0]}", "id: 442362938519978100, distance: 453.6371765136719, entity: {'filepath': './photos/Kendall_Jenner/Kendall_Jenner_6.jpg', 'crop_corner': [156.0, 222.0, 474.0, 531.0]}", "id: 442362938519978115, distance: 456.8822021484375, entity: {'filepath': './photos/Kendall_Jenner/Kendall_Jenner_4.jpg', 'crop_corner': [706.0, 183.0, 851.0, 499.0]}", "id: 442362938519978101, distance: 469.0942687988281, entity: {'filepath': './photos/Kendall_Jenner/Kendall_Jenner_6.jpg', 'crop_corner': [213.0, 438.0, 456.0, 993.0]}", "id: 442362938519978104, distance: 509.27178955078125, entity: {'filepath': './photos/Kendall_Jenner/Kendall_Jenner_6.jpg', 'crop_corner': [295.0, 113.0, 408.0, 461.0]}"]
1
["id: 442362938519978089, distance: 474.35711669921875, entity: {'filepath': './photos/Kendall_Jenner/Kendall_Jenner_7.jpg', 'crop_corner': [592

In [16]:
plt.imshow(Image.open(data_batch[2][0]))

NameError: name 'data_batch' is not defined

In [None]:
plt.imshow(Image.open(res[0][0].entity.filepath))

In [None]:
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

In [None]:
default_server.stop()
default_server.cleanup()