In [9]:
import os
import shutil
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
import pandas as pd
from itertools import combinations
import random
from transformers import CLIPProcessor, CLIPModel

# 설정
DATA_DIR = r"D:\Project\PJT_10\shopee-product-matching"
IMG_DIR = os.path.join(DATA_DIR, "train_images")
SAVE_DIR = "./inference_results"
MODEL_PATH = r"./saved_models\clip_pair_best_epoch2_20250716_025548.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 출력 디렉토리 생성
CORRECT_DIR = os.path.join(SAVE_DIR, "correct")
INCORRECT_DIR = os.path.join(SAVE_DIR, "incorrect")
os.makedirs(CORRECT_DIR, exist_ok=True)
os.makedirs(INCORRECT_DIR, exist_ok=True)

# 데이터 로딩
df = pd.read_csv(os.path.join(DATA_DIR, "train.csv")).reset_index(drop=True)
from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
df["label_encoded"] = label_encoder.fit_transform(df["label_group"])

# GroupShuffleSplit으로 test 데이터셋 분리
from sklearn.model_selection import GroupShuffleSplit
gss = GroupShuffleSplit(n_splits=1, test_size=0.4, random_state=42)
train_idx, temp_idx = next(gss.split(df, groups=df["label_encoded"]))
temp_df = df.iloc[temp_idx].reset_index(drop=True)
gss2 = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
val_idx, test_idx = next(gss2.split(temp_df, groups=temp_df["label_encoded"]))
test_df = temp_df.iloc[test_idx].reset_index(drop=True)

# create_pairs 함수 (기존 코드에서 재사용)
def create_pairs(df, max_neg_per_pos=2):
    pairs = []
    label_groups = df["label_encoded"].unique()
    for lg in label_groups:
        group_df = df[df["label_encoded"] == lg]
        if len(group_df) < 2:
            continue
        idxs = group_df.index.tolist()
        pos_combs = list(combinations(idxs, 2))
        for i, j in pos_combs:
            pairs.append((i, j, 1))
    
    pos_count = sum(1 for _,_,label in pairs if label == 1)
    neg_needed = pos_count * max_neg_per_pos
    all_indices = list(df.index)
    neg_pairs = set()
    while len(neg_pairs) < neg_needed:
        i, j = random.sample(all_indices, 2)
        if df.loc[i, "label_encoded"] != df.loc[j, "label_encoded"]:
            neg_pairs.add((i, j))
    for i, j in neg_pairs:
        pairs.append((i, j, 0))
    return pairs

# Dataset 클래스 (기존 코드에서 재사용)
class ShopeePairDataset(Dataset):
    def __init__(self, df, pairs, img_dir, processor):
        self.df = df
        self.pairs = pairs
        self.img_dir = img_dir
        self.processor = processor

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        i1, i2, label = self.pairs[idx]
        row1 = self.df.loc[i1]
        row2 = self.df.loc[i2]
        image1 = Image.open(os.path.join(self.img_dir, row1["image"])).convert("RGB")
        image2 = Image.open(os.path.join(self.img_dir, row2["image"])).convert("RGB")
        text1 = row1["title"]
        text2 = row2["title"]
        return {"image1": image1, "text1": text1,
                "image2": image2, "text2": text2,
                "label": label, "idx1": i1, "idx2": i2}

# Collate 함수 (기존 코드에서 재사용)
def collate_fn(batch):
    texts1 = [item["text1"] for item in batch]
    texts2 = [item["text2"] for item in batch]
    images1 = [item["image1"] for item in batch]
    images2 = [item["image2"] for item in batch]
    labels = torch.tensor([item["label"] for item in batch], dtype=torch.float)
    idx1s = [item["idx1"] for item in batch]
    idx2s = [item["idx2"] for item in batch]
    inputs1 = processor(text=texts1, images=images1, return_tensors="pt", padding=True, truncation=True)
    inputs2 = processor(text=texts2, images=images2, return_tensors="pt", padding=True, truncation=True)
    return {
        "input_ids1": inputs1["input_ids"],
        "attention_mask1": inputs1["attention_mask"],
        "pixel_values1": inputs1["pixel_values"],
        "input_ids2": inputs2["input_ids"],
        "attention_mask2": inputs2["attention_mask"],
        "pixel_values2": inputs2["pixel_values"],
        "label": labels,
        "idx1": idx1s,
        "idx2": idx2s
    }

# 모델 정의 (Siamese)
class CLIPSiameseModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(model_name)
        self.classifier = nn.Sequential(
            nn.Linear(self.clip.config.projection_dim * 4, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, input_ids1, attention_mask1, pixel_values1,
                      input_ids2, attention_mask2, pixel_values2):
        text_features1 = self.clip.get_text_features(input_ids=input_ids1, attention_mask=attention_mask1)
        image_features1 = self.clip.get_image_features(pixel_values=pixel_values1)
        feat1 = torch.cat([image_features1, text_features1], dim=1)

        text_features2 = self.clip.get_text_features(input_ids=input_ids2, attention_mask=attention_mask2)
        image_features2 = self.clip.get_image_features(pixel_values=pixel_values2)
        feat2 = torch.cat([image_features2, text_features2], dim=1)

        combined = torch.cat([feat1, feat2], dim=1)
        output = self.classifier(combined).squeeze(1)
        return output

# 모델 로드
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPSiameseModel("openai/clip-vit-base-patch32").to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# 데이터셋 및 데이터로더 준비
test_pairs = create_pairs(test_df)
test_dataset = ShopeePairDataset(test_df, test_pairs, IMG_DIR, processor)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# 추론 및 결과 저장
correct = 0
total = 0
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Inference"):
        inputs = {k: v.to(DEVICE) for k, v in batch.items() if k not in ["label", "idx1", "idx2"]}
        labels = batch["label"].to(DEVICE)
        idx1s = batch["idx1"]
        idx2s = batch["idx2"]
        
        outputs = model(**inputs)
        probs = torch.sigmoid(outputs)
        preds = (probs >= 0.5).float()
        
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        # 이미지와 title 정보 저장
        for i, (pred, label, idx1, idx2) in enumerate(zip(preds, labels, idx1s, idx2s)):
            is_correct = pred == label
            base_dir = CORRECT_DIR if is_correct else INCORRECT_DIR
            row1 = test_df.loc[idx1]
            row2 = test_df.loc[idx2]
            
            # 동일한 label_group 폴더 생성
            label_group = str(row1["label_group"]) if label == 1 else f"{row1['label_group']}_{row2['label_group']}"
            save_dir = os.path.join(base_dir, label_group)
            os.makedirs(save_dir, exist_ok=True)
            
            pair_id = f"pair_{idx1}_{idx2}"
            
            # 이미지 복사
            img1_path = os.path.join(IMG_DIR, row1["image"])
            img2_path = os.path.join(IMG_DIR, row2["image"])
            
            # 새 파일명 생성
            img1_new = os.path.join(save_dir, f"{pair_id}_img1_{row1['image']}")
            img2_new = os.path.join(save_dir, f"{pair_id}_img2_{row2['image']}")
            shutil.copy(img1_path, img1_new)
            shutil.copy(img2_path, img2_new)
            
            # Title 정보 저장
            txt_path = os.path.join(save_dir, f"{pair_id}_info.txt")
            with open(txt_path, "w", encoding="utf-8") as f:
                f.write(f"Image 1: {row1['image']}\n")
                f.write(f"Title 1: {row1['title']}\n")
                f.write(f"Image 2: {row2['image']}\n")
                f.write(f"Title 2: {row2['title']}\n")
                f.write(f"True Label: {int(label.item())}\n")
                f.write(f"Predicted Label: {int(pred.item())}\n")
                f.write(f"Probability: {probs[i].item():.4f}\n")

# 결과 출력
accuracy = correct / total
print(f"\nTest Accuracy: {accuracy:.4f}")
print(f"Results saved to {SAVE_DIR}")

Inference: 100%|████████████████████████████████████████████████████████████| 1741/1741 [22:40<00:00,  1.28it/s]


Test Accuracy: 0.9018
Results saved to ./inference_results



