In [92]:
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
import numpy as np
import os
from sklearn.metrics.pairwise import cosine_similarity

In [93]:
# 환경 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# YOLOv5 모델 로드
yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
yolo_model.to(device).eval()

# ResNet 모델 로드 및 특징 추출 레이어 설정
resnet_model = resnet50(pretrained=True)
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1])  # 마지막 FC layer 제거
resnet_model.to(device).eval()

# 변환 함수 정의
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def extract_features(image, model):
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model(image).cpu().numpy().flatten()
    return features

Using cache found in C:\Users\mkmy7/.cache\torch\hub\ultralytics_yolov5_master
YOLOv5  2024-5-27 Python-3.11.9 torch-2.3.0 CUDA:0 (NVIDIA GeForce GTX 1050, 2048MiB)

Fusing layers... 
YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs
Adding AutoShape... 


In [94]:
# 이미지 로드 함수 추가
def load_image(image_path):
    try:
        image = Image.open(image_path).convert('RGB')
        return image
    except IOError:
        return None

# 이미지에서 사람 탐지 및 상의/하의 추출
def detect_and_extract(image_path):
    image = load_image(image_path)
    if image is None:
        print(f"이미지 로드 실패: {image_path}")
        return []

    image_np = np.array(image)
    results = yolo_model(image_np)
    
    clothes_images = []
    
    for detection in results.xyxy[0]:
        x1, y1, x2, y2, conf, cls = detection
        if cls == 0:  # 사람 클래스
            person_image = image.crop((int(x1), int(y1), int(x2), int(y2)))
            height = int(y2) - int(y1)
            top_image = person_image.crop((0, 0, person_image.width, height // 2))
            bottom_image = person_image.crop((0, height // 2, person_image.width, height))
            clothes_images.append((top_image, 'top'))
            clothes_images.append((bottom_image, 'bottom'))
    
    return clothes_images

In [95]:
# 디렉토리에서 옷 이미지 불러오기 및 특징 추출
def load_clothes_images(directory):
    clothes_features = []
    clothes_paths = []
    
    files = os.listdir(directory)
    print(f"디렉토리 내 파일 목록: {files}")
    
    for filename in files:
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(directory, filename)
            image = load_image(image_path)
            if image is None:
                print(f"이미지 로드 실패: {image_path}")
                continue
            features = extract_features(image, resnet_model)
            if features is None:
                print(f"특징 추출 실패: {image_path}")
                continue
            clothes_features.append(features)
            clothes_paths.append(image_path)
    
    return clothes_features, clothes_paths

In [96]:
# 유사도 계산 및 top 3 유사한 옷 찾기
def find_similar_clothes(clothes_images, clothes_features, clothes_paths):
    similar_clothes = []

    for clothes_image, clothes_type in clothes_images:
        query_features = extract_features(clothes_image, resnet_model)
        
        if query_features is None:
            print(f"특징 추출 실패: {clothes_image}")
            continue
        
        query_features = query_features.reshape(1, -1)
        
        if len(clothes_features) == 0:
            raise ValueError("clothes_features 배열이 비어 있습니다. 옷 이미지 디렉토리를 확인하세요.")
        
        similarities = cosine_similarity(query_features, clothes_features)[0]
        top_indices = similarities.argsort()[-3:][::-1]

        for idx in top_indices:
            similar_clothes.append((clothes_paths[idx], similarities[idx], clothes_type))

    return similar_clothes

In [97]:
# 예제 실행 코드
image_path = 'D:/minkwan/무신사 크롤링/coordikitty-ML-DL/착용 이미지/image_1.JPG'
clothes_directory = 'D:/minkwan/무신사 크롤링/coordikitty-ML-DL/비슷한 이미지 선별 모음'

clothes_images = detect_and_extract(image_path)
clothes_features, clothes_paths = load_clothes_images(clothes_directory)

if len(clothes_features) == 0:
    print("옷 이미지 디렉토리에서 이미지를 로드하지 못했습니다. 경로를 확인하세요.")

similar_clothes = find_similar_clothes(clothes_images, clothes_features, clothes_paths)

for path, similarity, clothes_type in similar_clothes:
    print(f"유사한 {clothes_type}: {path} (유사도: {similarity})")

디렉토리 내 파일 목록: ['image_1.JPG', 'image_10.JPG', 'image_2.JPG', 'image_3.JPG', 'image_4.JPG', 'image_5.JPG', 'image_6.JPG', 'image_7.JPG', 'image_8.JPG', 'image_9.JPG']
옷 이미지 디렉토리에서 이미지를 로드하지 못했습니다. 경로를 확인하세요.


TypeError: pic should be Tensor or ndarray. Got <class 'PIL.Image.Image'>.