# Setup

Gonna just use the Taylor Swift Images in this one

In [10]:
import os
from dotenv import load_dotenv

load_dotenv()

ZILLIZ_URI = os.getenv("ZILLIZ_URI")
ZILLIZ_TOKEN = os.getenv("ZILLIZ_TOKEN")

In [11]:
from pymilvus import utility, connections

In [12]:
connections.connect(uri=ZILLIZ_URI, token=ZILLIZ_TOKEN)

## Getting the Model

In [13]:
import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop
import time
import glob

In [14]:
# run this before importing the resnet50 model if you run into an SSL certificate URLError
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# Load the embedding model with the last layer removed
embeddings_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
embeddings_model = torch.nn.Sequential(*(list(embeddings_model.children())[:-1]))
embeddings_model.eval()
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
segmentation_model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")


Using cache found in /Users/yujiantang/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub


# Segment

In [16]:
image_paths = []
for image in glob.glob("./photos/Taylor_Swift/*.jpg", recursive=True):
    image_paths.append(image)

In [18]:
DIMENSION = 2048
BATCH_SIZE = 128
COLLECTION_NAME = "TSwizzleFashionComparison"

In [19]:
from pymilvus import FieldSchema, CollectionSchema, Collection, DataType

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="name", dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name="seg_id", dtype=DataType.INT64),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=True)
collection = Collection(name=COLLECTION_NAME, schema=schema)

In [20]:
index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 5},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

In [21]:
from PIL import Image
import sys

## Functions

In [22]:
id2label= {
    0: "Background",
    1: "Hat",
    2: "Hair",
    3: "Sunglasses",
    4: "Upper-clothes",
    5: "Skirt",
    6: "Pants",
    7: "Dress",
    8: "Belt",
    9: "Left-shoe",
    10: "Right-shoe",
    11: "Face",
    12: "Left-leg",
    13: "Right-leg",
    14: "Left-arm",
    15: "Right-arm",
    16: "Bag",
    17: "Scarf"
  }
# want to extract keys: 1, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17

In [23]:
wanted = [1, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17]

In [24]:
def get_segmentation(image):
    inputs = extractor(images=image, return_tensors="pt")

    outputs = segmentation_model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )

    pred_seg = upsampled_logits.argmax(dim=1)[0]
    return pred_seg 

# returns two lists masks (tensor) and obj_ids (int)
# "mattmdjaga/segformer_b2_clothes" from hugging face
def get_masks(segmentation):
    obj_ids = torch.unique(segmentation)
    obj_ids = obj_ids[1:]
    wanted_ids = [x.item() for x in obj_ids if x in wanted]
    # print(obj_ids)
    # print(wanted_ids)
    wanted_ids = torch.Tensor(wanted_ids)
    # print(wanted_ids)
    masks = segmentation == wanted_ids[:, None, None]
    return masks, obj_ids

def crop_images(masks, obj_ids, img):
    boxes = masks_to_boxes(masks)
    crop_boxes = []
    for box in boxes:
        crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
        crop_boxes.append(crop_box)
    
    preprocess = T.Compose([
        T.Resize(size=(256, 256)),
        T.ToTensor()
    ])
    
    cropped_images = []
    seg_ids = []
    for i in range(len(crop_boxes)):
        crop_box = crop_boxes[i]
        cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
        cropped_images.append(preprocess(cropped))
        seg_ids.append(obj_ids[i].item())
    with torch.no_grad():
        embeddings = embeddings_model(torch.stack(cropped_images)).squeeze().tolist()
    return embeddings, boxes.tolist(), seg_ids


# Upload

In [25]:
for path in image_paths:
    image = Image.open(path)
    path_split = path.split("/")
    name = " ".join(path_split[2].split("_"))
    segmentation = get_segmentation(image)
    masks, ids = get_masks(segmentation)
    embeddings, crop_corners, seg_ids = crop_images(masks, ids, image)
    inserts = [{"embedding": embeddings[x],
                "seg_id": seg_ids[x],
                "name": name,
                "filepath": path,
                "crop_corner": crop_corners[x]} for x in range(len(embeddings))]
    collection.insert(inserts)
                
collection.flush()

IndexError: list index out of range

In [17]:
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)