In [11]:
%%writefile environment.yml
name: TorchEnv
channels:
  - pytorch
dependencies:
  - numpy
  - pandas
  - pillow
  - boto3
  - transformers
  - pytorch
  - torchvision
  - cpuonly 
  - pip:
    - matplotlib
    - pinecone-client
    

Overwriting environment.yml


In [1]:
%%writefile Welcome.py
import streamlit as st
st.write('Welcome :)')

Writing Welcome.py


In [8]:
%%writefile pages/CLIP-visual-search.py
import streamlit as st
import os
import glob
import random
import shutil
import numpy as np
from PIL import Image
#import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
import boto3
import io
from concurrent.futures import ThreadPoolExecutor
from transformers import CLIPProcessor, CLIPVisionModel

st.write('## CLIP Visual Search')
device = 'cpu'
@st.cache_data
def read_file():
    with open('src/image_paths.txt', 'r') as f:
        selected_files = [line.strip() for line in f.readlines()]
    image_encodings = torch.load('src/image_features.pt', map_location=torch.device('cpu'))
    return selected_files, image_encodings

@st.cache_resource
def initialize():
    model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32')
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return model, processor

@st.cache_resource
def initialize_s3():
    os.environ['AWS_ACCESS_KEY_ID'] = st.secrets['key_id']
    os.environ['AWS_SECRET_ACCESS_KEY'] = st.secrets['key_secret']
    os.environ['AWS_DEFAULT_REGION'] = st.secrets['region']
    s3 = boto3.client('s3')
    return s3

def CLIP_search(image, top_k):
    image_input = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_encoding = model(image_input['pixel_values'])
    similarity = F.cosine_similarity(image_encoding['pooler_output'], image_encodings, dim=1)
    topk_values, topk_indices = torch.topk(similarity, k=top_k)
    return topk_indices.numpy(), topk_values.numpy()



def read_s3(file_paths):
    bucket_name = 'visualsearch7374'
    s3 = initialize_s3()
    
    def read_image(file_path):
        s3_response = s3.get_object(Bucket=bucket_name, Key=file_path)
        return s3_response['Body'].read()
    
    # use a thread pool to read the images in parallel
    with ThreadPoolExecutor() as executor:
        image_contents = list(executor.map(read_image, file_paths))
        
    # create image objects from the binary content 
    image_outputs = [Image.open(io.BytesIO(image_content)) for image_content in image_contents]
    return image_outputs

selected_files, image_encodings = read_file()
model, processor = initialize()

# GUI
_,col1,_ = st.columns([1,8,1])
_,col2,_ = st.columns([1,8,1])
with col1: 
    upload_method = st.radio("Select a way", ("From examples", "From local"))
    form = st.form(key='image-form')
    if upload_method == "From examples":
        image_input = form.selectbox('Select the image here:',
                                ['src/test1.jpg', 'src/test2.jpg', 'src/test3.jpg']
                               )
    else:
        image_input = form.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

    topK = form.number_input('result number',
                                min_value=0,
                                max_value=20,
                             value =10,
                                help = 'Number of images in searching results')
    submit = form.form_submit_button('Submit')
with col2:
    if image_input:
        image = Image.open(image_input) 
        st.image(image, width=400, caption='Uploaded image', use_column_width=False)    
    if submit:
        topk_indices, topk_values = CLIP_search(image, topK)
        search_outputs = [selected_files[x] for x in topk_indices]
        #st.write(search_outputs) # not show image right now
        output_images = read_s3(search_outputs)
        n = 5
        img = 0
        cols = st.columns(n)
        while img <len(output_images):
            with cols[img%n]:
                st.image(output_images[img],caption = f'Similarity : {topk_values[img]:.2%}')
            img += 1
        

Overwriting pages/CLIP-visual-search.py


In [9]:
%%writefile pages/squeezenet-search.py
import streamlit as st
import torch
import numpy as np
from PIL import Image
import pinecone
from torchvision.transforms import (
    Compose, 
    Resize, 
    CenterCrop, 
    ToTensor, 
    Normalize
)

#Environment set up
pinekey = st.secrets['pinekey']
preprocess = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
INDEX_NAME = 'image-search-clothes'

st.write('## squeezenet Visual Search')
st.write('# Upload Image')

#PineCone Initialize
# authenticate with Pinecone API, keys available at your project at https://app.pinecone.io
pinecone.init(
#     h.pinecone_api_key,
    pinekey,
    environment="us-west4-gcp"  # find next to API key in console
)

if INDEX_NAME not in pinecone.list_indexes():
    pinecone.create_index(name=INDEX_NAME, dimension=INDEX_DIMENSION)
index = pinecone.Index(INDEX_NAME)




#Upload image
uploaded_file = st.file_uploader("Choose an image file", type=['jpg', 'png', 'jpeg'])

_,col1,_ = st.columns([1,8,1])
with col1: 
    upload_method = st.radio("Select a way", ("From examples", "From local"))
    form = st.form(key='image-form')
    if upload_method == "From examples":
        image_input = form.selectbox('Select the image here:',
                                ['src/test1.jpg', 'src/test2.jpg', 'src/test3.jpg']
                               )
    else:
        image_input = form.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

    topK = form.number_input('result number',
                                min_value=0,
                                max_value=20,
                             value =10,
                                help = 'Number of images in searching results')
    submit = form.form_submit_button('Submit')
    
    #Pinepone response & process
    query_embedding = model(preprocess(image).unsqueeze(0)).tolist()
    # response = index.query(query_embedding, top_k=4, include_metadata=True)
    response = index.query(query_embedding, top_k=4, include_metadata=True)
    
    #Process the image id and connecting to S3
    top_similar_imageId = []
    for i in response['matches']:
        top_similar_imageId.append(i['id'].split('.')[1])
    print(top_similar_imageId)
    


Overwriting pages/squeezenet-search.py


In [40]:
# 5k images
with open('src/image_paths.txt', 'r') as f:
    selected_files = [line.strip() for line in f.readlines()]
selected_files

['images/MEN-Shirts_Polos-id_00002750-01_3_back.jpg',
 'images/WOMEN-Blouses_Shirts-id_00000890-03_4_full.jpg',
 'images/MEN-Jackets_Vests-id_00005597-01_7_additional.jpg',
 'images/MEN-Tees_Tanks-id_00007301-04_4_full.jpg',
 'images/WOMEN-Tees_Tanks-id_00002512-03_1_front.jpg',
 'images/WOMEN-Jackets_Coats-id_00003012-04_2_side.jpg',
 'images/WOMEN-Blouses_Shirts-id_00007218-01_2_side.jpg',
 'images/WOMEN-Sweaters-id_00002437-02_1_front.jpg',
 'images/WOMEN-Skirts-id_00006463-05_4_full.jpg',
 'images/WOMEN-Graphic_Tees-id_00003821-02_1_front.jpg',
 'images/WOMEN-Tees_Tanks-id_00006259-02_2_side.jpg',
 'images/WOMEN-Dresses-id_00003582-05_1_front.jpg',
 'images/MEN-Sweaters-id_00003555-01_1_front.jpg',
 'images/WOMEN-Tees_Tanks-id_00003523-14_4_full.jpg',
 'images/MEN-Tees_Tanks-id_00005529-01_1_front.jpg',
 'images/WOMEN-Dresses-id_00002177-03_4_full.jpg',
 'images/WOMEN-Sweatshirts_Hoodies-id_00003587-01_7_additional.jpg',
 'images/WOMEN-Blouses_Shirts-id_00004604-04_2_side.jpg',
 'i

In [31]:
# 3 test images
import glob
imagePath = 'src'
files = glob.glob(imagePath+'\*.jpg')
files

['src\\test1.jpg', 'src\\test2.jpg', 'src\\test3.jpg']

In [36]:
from PIL import Image
image = Image.open('src/test1.jpg') 