# Do this first

Download images from Google Drive: https://drive.google.com/file/d/1pBO02iLgToBSCOyMJ58zWHQf4ZRkP5AY/view?usp=sharing

In [1]:
import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop
import time
import glob

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
image_paths = []
for image in glob.glob("./photos/**/*.jpg", recursive=True):
    image_paths.append(image)

In [2]:
from milvus import default_server
from pymilvus import utility, connections

In [4]:
default_server.start()



    __  _________ _   ____  ______
   /  |/  /  _/ /| | / / / / / __/
  / /|_/ // // /_| |/ / /_/ /\ \
 /_/  /_/___/____/___/\____/___/ {Lite}

 Welcome to use Milvus!

 Version:   v2.2.10-lite
 Process:   82293
 Started:   2023-09-28 08:53:28
 Config:    /Users/yujiantang/.milvus.io/milvus-server/2.2.10/configs/milvus.yaml
 Logs:      /Users/yujiantang/.milvus.io/milvus-server/2.2.10/logs

 Ctrl+C to exit ...


In [5]:
connections.connect(host="127.0.0.1", port=default_server.listen_port)

In [6]:
DIMENSION = 2048
BATCH_SIZE = 128
COLLECTION_NAME = "fashion"

In [7]:
# run this before importing the resnet50 model if you run into an SSL certificate URLError
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [8]:
# Load the embedding model with the last layer removed
embeddings_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
embeddings_model = torch.nn.Sequential(*(list(embeddings_model.children())[:-1]))
embeddings_model.eval()
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
segmentation_model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")


Using cache found in /Users/yujiantang/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub


In [9]:
from pymilvus import FieldSchema, CollectionSchema, Collection, DataType

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="name", dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="seg_id", dtype=DataType.INT64),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=True)
collection = Collection(name=COLLECTION_NAME, schema=schema)

In [10]:
index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

In [11]:
from PIL import Image
import sys

In [12]:
wanted = [1, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17]

In [13]:
def get_segmentation(image):
    inputs = extractor(images=image, return_tensors="pt")

    outputs = segmentation_model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]
    return pred_seg 

# returns two lists masks (tensor) and obj_ids (int)
# "mattmdjaga/segformer_b2_clothes" from hugging face
def get_masks(segmentation):
    obj_ids = torch.unique(segmentation)
    obj_ids = obj_ids[1:]
    wanted_ids = [x.item() for x in obj_ids if x in wanted]
    print(obj_ids)
    print(wanted_ids)
    wanted_ids = torch.Tensor(wanted_ids)
    print(wanted_ids)
    masks = segmentation == wanted_ids[:, None, None]
    return masks, obj_ids

def crop_images(masks, obj_ids, img):
    boxes = masks_to_boxes(masks)
    crop_boxes = []
    for box in boxes:
        crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
        crop_boxes.append(crop_box)
    
    preprocess = T.Compose([
        T.Resize(size=(256, 256)),
        T.ToTensor()
    ])
    
    cropped_images = []
    seg_ids = []
    for i in range(len(crop_boxes)):
        crop_box = crop_boxes[i]
        cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
        cropped_images.append(preprocess(cropped))
        seg_ids.append(obj_ids[i].item())
    with torch.no_grad():
        embeddings = embeddings_model(torch.stack(cropped_images)).squeeze().tolist()
    return embeddings, boxes.tolist(), seg_ids


In [14]:
for path in image_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = " ".join(path_split[2].split("_"))
    segmentation = get_segmentation(image)
    masks, ids = get_masks(segmentation)
    embeddings, crop_corners, seg_ids = crop_images(masks, ids, image)
    inserts = [{"embedding": embeddings[x],
                "seg_id": seg_ids[x],
                "name": name,
                "filepath": path,
                "crop_corner": crop_corners[x]} for x in range(len(embeddings))]
    collection.insert(inserts)
                
collection.flush()

tensor([ 2,  3,  4,  6,  8,  9, 10, 11, 12, 13, 14, 15, 16])
[3, 4, 6, 8, 9, 10, 16]
tensor([ 3.,  4.,  6.,  8.,  9., 10., 16.])
tensor([ 2,  3,  4,  6,  9, 10, 11, 14, 15, 16])
[3, 4, 6, 9, 10, 16]
tensor([ 3.,  4.,  6.,  9., 10., 16.])
tensor([ 2,  3,  4,  6,  9, 10, 11, 15, 16])
[3, 4, 6, 9, 10, 16]
tensor([ 3.,  4.,  6.,  9., 10., 16.])
tensor([ 2,  3,  4,  6, 10, 11, 14, 15, 16])
[3, 4, 6, 10, 16]
tensor([ 3.,  4.,  6., 10., 16.])
tensor([ 2,  3,  4,  7,  9, 10, 11, 12, 13, 14, 15])
[3, 4, 7, 9, 10]
tensor([ 3.,  4.,  7.,  9., 10.])
tensor([ 2,  3,  4,  6,  8,  9, 11, 12, 13, 14, 15, 16])
[3, 4, 6, 8, 9, 16]
tensor([ 3.,  4.,  6.,  8.,  9., 16.])
tensor([ 2,  3,  4,  6,  9, 10, 11, 12, 13, 14, 15, 16])
[3, 4, 6, 9, 10, 16]
tensor([ 3.,  4.,  6.,  9., 10., 16.])
tensor([ 2,  4,  6,  9, 10, 11, 14, 15, 16])
[4, 6, 9, 10, 16]
tensor([ 4.,  6.,  9., 10., 16.])
tensor([ 2,  3,  4,  6,  9, 10, 11, 14, 15, 16, 17])
[3, 4, 6, 9, 10, 16, 17]
tensor([ 3.,  4.,  6.,  9., 10., 16., 17.])
tens

### Querying the Vector DB

#### transform the input image

In [16]:
from pprint import pprint
from PIL import ImageDraw
from collections import Counter
import matplotlib.patches as patches

# Results display option, may need changing depending on system
%matplotlib auto

LIMIT = 5  # How many closes matches per article of clothing to analyze
CLOSEST = 3 # How many closest images to display. CLOSEST <= Limit

search_paths = ["./photos/Drake/Drake_6.jpg", "./photos/Rich_Brian/Rich_Brian_8.jpg"] # Images to search for

def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.
    Sourced from https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib'''
    return plt.cm.get_cmap(name, n)

# Create the result subplots
f, axarr = plt.subplots(max(len(search_paths), 2), CLOSEST + 1) 

for search_i, path in enumerate(search_paths):

    # Generate crops and embeddings for all items found
    image = Image.open(path)
    segmentation = get_segmentation(image)
    masks, ids = get_masks(segmentation)
    embeddings, crop_corners, _ = crop_images(masks, ids, image)

    # Generate color map
    cmap = get_cmap(len(crop_corners))

    # Display the first box with image being searched for
    axarr[search_i][0].imshow(image)
    axarr[search_i][0].set_title('Search Image')
    axarr[search_i][0].axis('off')
    for i, (x0, y0, x1, y1) in enumerate(crop_corners):
        rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(i), facecolor='none')
        axarr[search_i][0].add_patch(rect)
    print(embeddings)
    print(type(embeddings))
    # Search the database for all the crops
    start = time.time()
    res = collection.search(embeddings, 
       anns_field='embedding', 
       param={"metric_type": "L2",
              "params": {"nprobe": 10}, "offset": 0}, 
       limit=LIMIT, 
       output_fields=['filepath', 'crop_corner'],
       expr="seg_id=2")
    finish = time.time()

    print("Total Search Time: ", finish - start)

    # Summarize the top unique results and weight them based on position in results
    filepaths = []
    for hits in res:
        seen = set()
        for i, hit in enumerate(hits):
            if hit.entity.get("filepath") not in seen:
                seen.add(hit.entity.get("filepath"))
                filepaths.extend([hit.entity.get("filepath") for _ in range(len(hits) - i)])
    
    # Find the most commonly ranked result image
    counts = Counter(filepaths)
    most_common = [path for path, _ in counts.most_common(CLOSEST)]
    
    # For each image, extract the corresponding item found that correlates to search images
    matches = {}
    for i, hits in enumerate(res):
        matches[i] = {}
        tracker = set(most_common)
        for hit in hits:
            if hit.entity.get("filepath") in tracker:
                matches[i][hit.entity.get("filepath")] = hit.entity.get("crop_corner")
                tracker.remove( hit.entity.get("filepath"))
   
    # Display the most common images in results
    for res_i, res_path in enumerate(most_common):
        # Display each of the images next to search image
        image = Image.open(res_path)
        axarr[search_i][res_i+1].imshow(image)
        axarr[search_i][res_i+1].set_title(" ".join(res_path.split("/")[2].split("_")))
        axarr[search_i][res_i+1].axis('off')
        # Add boudning boxes for all matched items
        for key, value in matches.items():
            if res_path in value:
                x0, y0, x1, y1 = value[res_path]
                rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(key), facecolor='none')
                axarr[search_i][res_i+1].add_patch(rect)

Using matplotlib backend: MacOSX
tensor([ 2,  4,  6,  9, 10, 11, 12, 14, 16])
[4, 6, 9, 10, 16]
tensor([ 4.,  6.,  9., 10., 16.])


  return plt.cm.get_cmap(name, n)
RPC error: [search], <MilvusException: (code=1, message=failed to create query plan: cannot parse expression: seg_id = 4, error: line 1:7 token recognition error at: '= ')>, <Time:{'RPC start': '2023-09-29 08:28:24.573058', 'RPC error': '2023-09-29 08:28:24.579046'}>


[[0.20469701290130615, 0.038344189524650574, 0.09948088973760605, 0.020724598318338394, 0.15796852111816406, 0.017043832689523697, 0.343045175075531, 0.09837844222784042, 0.0583903007209301, 0.0, 0.6000745296478271, 0.009584173560142517, 0.03654799982905388, 0.12621265649795532, 0.03129909187555313, 0.02734660729765892, 0.007125388830900192, 0.07709243893623352, 0.07596631348133087, 0.02861732989549637, 0.220659077167511, 0.09153531491756439, 1.8977293968200684, 0.019687756896018982, 0.44103488326072693, 0.029935970902442932, 0.18328814208507538, 0.0912245586514473, 0.41969725489616394, 0.011468730866909027, 0.05022737383842468, 0.4268845021724701, 0.007519081234931946, 1.6234434843063354, 0.02128111943602562, 0.14525340497493744, 0.02235114946961403, 0.024397112429142, 0.4956035017967224, 0.1521209180355072, 0.3023081421852112, 0.03825685381889343, 0.045805856585502625, 1.0749483108520508, 0.023616760969161987, 0.014190636575222015, 0.8511552810668945, 0.05003640055656433, 0.935900509

MilvusException: <MilvusException: (code=1, message=failed to create query plan: cannot parse expression: seg_id = 4, error: line 1:7 token recognition error at: '= ')>

2023-09-29 08:28:25.894 Python[82194:2835835] *** Assertion failure in +[NSEvent otherEventWithType:location:modifierFlags:timestamp:windowNumber:context:subtype:data1:data2:], NSEvent.m:647
2023-09-29 08:28:25.916 Python[82194:2835835] *** Assertion failure in +[NSEvent otherEventWithType:location:modifierFlags:timestamp:windowNumber:context:subtype:data1:data2:], NSEvent.m:647
2023-09-29 08:28:25.930 Python[82194:2835835] *** Assertion failure in +[NSEvent otherEventWithType:location:modifierFlags:timestamp:windowNumber:context:subtype:data1:data2:], NSEvent.m:647
2023-09-29 08:28:25.943 Python[82194:2835835] *** Assertion failure in +[NSEvent otherEventWithType:location:modifierFlags:timestamp:windowNumber:context:subtype:data1:data2:], NSEvent.m:647
2023-09-29 08:28:25.956 Python[82194:2835835] *** Assertion failure in +[NSEvent otherEventWithType:location:modifierFlags:timestamp:windowNumber:context:subtype:data1:data2:], NSEvent.m:647
2023-09-29 08:28:25.968 Python[82194:2835835]

: 

In [16]:
f.savefig("fashion23_1.png")

In [5]:
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

NameError: name 'COLLECTION_NAME' is not defined

In [6]:
default_server.stop()
default_server.cleanup()