### 허깅페이스 SentenceTransformer 준비

In [1]:
from sentence_transformers import SentenceTransformer
from PIL import Image
import matplotlib.pyplot as plt
import os
from torch import Tensor
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = SentenceTransformer('clip-ViT-B-32')

### 크롤링한 메뉴 이미지 준비

In [3]:
menu_data_path = './datasets/차알/menu_차알'
label_list = os.listdir(menu_data_path)
for i, _ in enumerate(label_list):
    if _ == 'menu_images.csv':
        label_list.remove(_)
        continue
    label_list[i] = _.replace('.jpg','')
label_list

['트레이 짜장',
 '차알 순두부',
 '깐풍가지새우',
 '로제샹궈',
 '제네럴 쏘 치킨',
 '마라샹궈',
 '유린기 샐러드',
 '차알 볶음밥',
 '마라 짬뽕',
 '차우멘',
 '공심채 볶음',
 '마파두부',
 '차알 순두부 핫팟',
 '오렌지 치킨',
 '몽골리안 비프',
 '차알 탕수육']

In [4]:
menu_original_images = []

# 불러오는 이미지를 미리 보고 싶으시면 참조된 코드들을 해제하여
# 한번에 최대 18개까지 미리 볼 수 있다

# plt.figure(figsize=(16, 5))

for i, filename in enumerate([filename for filename in os.listdir(menu_data_path) if filename.endswith(".png") or filename.endswith(".jpg")]):
    name = os.path.splitext(filename)[0]
    image = Image.open(os.path.join(menu_data_path, filename)).convert("RGB")

    # plt.subplot(3, 6, i + 1)
    # plt.imshow(image)
    # plt.xticks([])
    # plt.yticks([])

    menu_original_images.append(image)

# CLIP으로 feature extraction
menu_encoded_images = model.encode(menu_original_images, batch_size=128, convert_to_tensor=True, show_progress_bar=True)


plt.tight_layout()

Batches: 100%|██████████| 1/1 [00:03<00:00,  3.31s/it]


<Figure size 640x480 with 0 Axes>

### 크롤링한 리뷰 이미지 준비

In [5]:
# 예시로 유린기 샐러드만 모아놓은 폴더를 지정했지만
# 자유롭게 경로 변경하여 생성 가능
test_data_path = './datasets/차알/crops/유린기 샐러드'

test_original_images = []

## 참조를 해제하면 불러오는 이미지들을 18개까지 미리 볼 수 있다

# plt.figure(figsize=(16, 5))
# start = 0
# end = 17
for i, filename in enumerate([filename for filename in os.listdir(test_data_path) if filename.endswith(".png") or filename.endswith(".jpg")]):
    # if i < start:    
    #   continue
    name = os.path.splitext(filename)[0]
    image = Image.open(os.path.join(test_data_path, filename)).convert("RGB")

    # plt.subplot(3, 6, i-start + 1) 
    # plt.imshow(image)
    # plt.xticks([])
    # plt.yticks([])

    test_original_images.append(image)

    # if i == end:
    #   break

# CLIP으로 feature extraction
test_encoded_images = model.encode(test_original_images, batch_size=128, convert_to_tensor=True, show_progress_bar=True)

# plt.tight_layout()

    

Batches: 100%|██████████| 1/1 [00:00<00:00,  4.17it/s]


### cosine 유사도 계산

In [6]:
def cos_sim(a: Tensor, b: Tensor):
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))

In [7]:
import queue

query_chunk_size = 5000
corpus_chunk_size = 100000
max_pairs = 500000
top_k = 1

pairs = queue.PriorityQueue()
min_score = -1
num_added = 0

score_function = cos_sim

for corpus_start_idx in range(0, len(menu_encoded_images), corpus_chunk_size):
    for query_start_idx in range(0, len(test_encoded_images), query_chunk_size):
        scores = score_function(test_encoded_images[query_start_idx:query_start_idx+query_chunk_size], menu_encoded_images[corpus_start_idx:corpus_start_idx+corpus_chunk_size])
        
        scores_top_k_values, scores_top_k_idx = torch.topk(scores, min(top_k, len(scores[0])), dim=1, largest=True, sorted=False)
        scores_top_k_values = scores_top_k_values.cpu().tolist()
        scores_top_k_idx = scores_top_k_idx.cpu().tolist()

        for query_itr in range(len(scores)):
            for top_k_idx, corpus_itr in enumerate(scores_top_k_idx[query_itr]):
                i = query_start_idx + query_itr
                j = corpus_start_idx + corpus_itr

                if i != j and scores_top_k_values[query_itr][top_k_idx] > min_score:
                    pairs.put((scores_top_k_values[query_itr][top_k_idx], i, j))
                    num_added += 1

                    if num_added >= max_pairs:
                        entry = pairs.get()
                        min_score = entry[0]

added_pairs = set() 
pairs_list = []
while not pairs.empty():
    score, i, j = pairs.get()
    sorted_i, sorted_j = sorted([i, j])

    if sorted_i != sorted_j and (sorted_i, sorted_j) not in added_pairs:
        added_pairs.add((sorted_i, sorted_j))
        pairs_list.append([score, i, label_list[j]])


pairs_list = sorted(pairs_list, key=lambda x: x[0], reverse=True)

scores_top_k_idx

[[6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [0],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [6],
 [0],
 [6],
 [6]]

### 준비한 리뷰 이미지 순서대로 결과 추론
출력 예시 : "메뉴이름"["확률"]

In [8]:
for i, j in zip(scores_top_k_idx, scores_top_k_values):
    print(label_list[i[0]]+str(j))

유린기 샐러드[0.8497395515441895]
유린기 샐러드[0.8475852012634277]
유린기 샐러드[0.7529000639915466]
유린기 샐러드[0.8216028213500977]
유린기 샐러드[0.8347015380859375]
유린기 샐러드[0.7983312606811523]
유린기 샐러드[0.8235374689102173]
유린기 샐러드[0.797694206237793]
트레이 짜장[0.7688065767288208]
유린기 샐러드[0.8366624116897583]
유린기 샐러드[0.8510055541992188]
유린기 샐러드[0.8327979445457458]
유린기 샐러드[0.828330934047699]
유린기 샐러드[0.8112948536872864]
유린기 샐러드[0.8266451358795166]
유린기 샐러드[0.778620719909668]
유린기 샐러드[0.8583508133888245]
유린기 샐러드[0.84061199426651]
유린기 샐러드[0.8388550281524658]
유린기 샐러드[0.8216028213500977]
유린기 샐러드[0.8248535990715027]
유린기 샐러드[0.7326924204826355]
유린기 샐러드[0.8160733580589294]
유린기 샐러드[0.7902628779411316]
유린기 샐러드[0.7529000639915466]
유린기 샐러드[0.8813169002532959]
유린기 샐러드[0.8281384110450745]
유린기 샐러드[0.8236373662948608]
유린기 샐러드[0.788112223148346]
유린기 샐러드[0.781449556350708]
트레이 짜장[0.8545013070106506]
유린기 샐러드[0.797694206237793]
유린기 샐러드[0.8236373662948608]


In [20]:
# 추가적으로 변형한 코드라 사용할 필요 없음
#####################


test_data_dir_path = './datasets/차알/crops/'
target_list = os.listdir(test_data_dir_path)
result = [[]]
score_function = cos_sim

for _, menu_name in enumerate(target_list):
    
    result.append([menu_name])

    test_data_path = test_data_dir_path + menu_name

    test_original_images = []

    for i, filename in enumerate([filename for filename in os.listdir(test_data_path) if filename.endswith(".png") or filename.endswith(".jpg")]):

        name = os.path.splitext(filename)[0]
        image = Image.open(os.path.join(test_data_path, filename)).convert("RGB")

        test_original_images.append(image)

    test_encoded_images = model.encode(test_original_images, batch_size=128, convert_to_tensor=True, show_progress_bar=True)
    result[_+1].append(len(test_encoded_images))

    query_chunk_size = 5000
    corpus_chunk_size = 100000
    max_pairs = 500000
    top_k = 1

    pairs = queue.PriorityQueue()
    min_score = -1
    num_added = 0

    for corpus_start_idx in range(0, len(menu_encoded_images), corpus_chunk_size):
        for query_start_idx in range(0, len(test_encoded_images), query_chunk_size):
            scores = score_function(test_encoded_images[query_start_idx:query_start_idx+query_chunk_size], menu_encoded_images[corpus_start_idx:corpus_start_idx+corpus_chunk_size])
            
            scores_top_k_values, scores_top_k_idx = torch.topk(scores, min(top_k, len(scores[0])), dim=1, largest=True, sorted=False)
            scores_top_k_values = scores_top_k_values.cpu().tolist()
            scores_top_k_idx = scores_top_k_idx.cpu().tolist()
    
    count = 0
    if menu_name == 'etc':
        for j in scores_top_k_values:
            if j[0] < 0.71:
                count += 1
        result[_+1].append(count)
        continue

    answer = label_list.index(menu_name)
    for j in scores_top_k_idx:
        if j[0] == answer:
            count += 1
    result[_+1].append(count)

Batches: 100%|██████████| 1/1 [00:00<00:00, 20.22it/s]


ValueError: '마라샹궈' is not in list

In [85]:
# 돌릴 필요 없음
#####################
result

[[],
 ['마라샹궈', 5, 0],
 ['마파두부', 3, 3],
 ['차알 볶음밥', 37, 32],
 ['차우멘', 20, 11],
 ['차알 순두부', 10, 4],
 ['깐풍가지새우', 28, 13],
 ['로제샹궈', 21, 19],
 ['마라 짬뽕', 52, 29],
 ['제네럴 쏘 치킨', 4, 4],
 ['차알 탕수육', 12, 6],
 ['오렌지 치킨', 4, 3],
 ['유린기 샐러드', 33, 31],
 ['차알 순두부 핫팟', 7, 3],
 ['etc', 16, 6],
 ['몽골리안 비프', 5, 5],
 ['트레이 짜장', 39, 31]]