In [10]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
from tqdm import tqdm
from transformers import AutoFeatureExtractor, SwinModel

In [11]:
import os
image_dir = 'images/cellphone'

# 파일 목록 확인
try:
    files = os.listdir(image_dir)
    print(f"Number of files in directory: {len(files)}")
except Exception as e:
    print(f"Error accessing directory: {e}")

Number of files in directory: 128361


In [12]:
# Swin Transformer setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(device)
model.eval()
print(device)



cuda


In [13]:
# Mixed Precision (FP16) 설정
scaler = torch.amp.GradScaler("cuda")

In [14]:
# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std)
])

In [15]:
# 이미지 데이터셋 정의
class ImageDataset(Dataset):
    def __init__(self, image_dir):
        self.image_files = [os.path.join(image_dir, file_name) for file_name in os.listdir(image_dir) if file_name.endswith(".jpg") and '_' in file_name]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image = Image.open(image_path).convert("RGB")
        image_tensor = transform(image)
        file_name = os.path.basename(image_path)

        try:
            post_id = int(file_name.split('_')[0])  # post_id 추출
        except ValueError:
            print(f"Invalid file name format: {file_name}")
            post_id = -1  # 잘못된 파일 이름의 경우 예외 처리

        return post_id, image_tensor

In [16]:
# 데이터로더 설정 (멀티프로세싱 적용)
batch_size = 32  # 배치 크기 설정
dataset = ImageDataset(image_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [17]:
for image_file in dataset.image_files[:5]:  # 첫 5개 이미지 확인
    try:
        image = Image.open(image_file).convert("RGB")
        print(f"Loaded image: {image_file}")
    except Exception as e:
        print(f"Error loading image {image_file}: {e}")

Loaded image: images/cellphone\10000_1.jpg
Loaded image: images/cellphone\10000_2.jpg
Loaded image: images/cellphone\10000_3.jpg
Loaded image: images/cellphone\10001_1.jpg
Loaded image: images/cellphone\10001_2.jpg


In [19]:
# 중간 저장 파일 경로
output_file = 'swin_image_embeddings_partial.csv'

# 중간에 저장된 결과가 있는지 확인하고 이어서 작업
if os.path.exists(output_file):
    df = pd.read_csv(output_file)
    processed_post_ids = df['post_id'].tolist()
else:
    processed_post_ids = []

# 임베딩 저장을 위한 데이터 구조
image_data = {}

# 배치 단위로 이미지 처리
for batch in tqdm(dataloader):
    post_ids, image_tensors = batch
    image_tensors = image_tensors.to(device)

    # 이미 처리된 post_id는 건너뛰기
    post_ids = [post_id for post_id in post_ids if post_id not in processed_post_ids]

    if not post_ids:
        continue  # 모든 post_id가 이미 처리된 경우 건너뛰기

    # FP16 mixed precision inference
    with torch.amp.autocast("cuda"):  # FP16 적용
        with torch.no_grad():
            outputs = model(pixel_values=image_tensors)

    # 임베딩 추출
    batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()

    # 각 이미지에 대해 post_id에 따라 임베딩 저장 및 바로 CSV에 기록
    with open(output_file, mode='a') as f:
        for idx, post_id in enumerate(post_ids):
            if post_id != -1:  # 잘못된 파일이 아닌 경우에만 처리
                post_id = int(post_id.item())
                embedding = batch_embeddings[idx]
                embedding_str = ','.join(map(str, embedding.tolist()))  # 임베딩을 문자열로 변환
                f.write(f"{post_id},{embedding_str}\n")

    # GPU 메모리 해제
    torch.cuda.empty_cache()  # 각 배치 처리 후에 메모리 해제

print("Batch inference, mixed precision 및 DataLoader를 사용한 임베딩 추출 완료.")

100%|██████████| 4012/4012 [1:27:27<00:00,  1.31s/it]

Batch inference, mixed precision 및 DataLoader를 사용한 임베딩 추출 완료.



