## real.json 정보 데이터 로드

In [1]:
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
import logging
import time
from torchvision import transforms
from oml.datasets.base import DatasetWithLabels
from oml.models import ViTExtractor

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("EmbeddingsLogger")

# Custom dataset class with index information and image transforms
class CustomImageLabeledDataset(DatasetWithLabels):
    def __init__(self, df, dataset_root, transform=None):
        super().__init__(df=df, dataset_root=dataset_root)
        self.transform = transform

    def __getitem__(self, idx: int) -> dict:
        item = super().__getitem__(idx)
        item['indices'] = idx
        item['id'] = self.df.iloc[idx]['id']
        item['gid'] = self.df.iloc[idx]['gid']

        # Apply the transform to the original image
        if self.transform:
            item["input_tensors"] = self.transform(item["input_tensors"].cpu().permute(1, 2, 0).numpy())

        return item

# Define the image transformations (resize to 224x224)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load dataset and initialize model
dataset_root = "static"
df_train = pd.read_json(f"{dataset_root}/real.json")
logger.info(f"Dataset loaded: {df_train.shape[0]} samples")

# Check if a GPU is available, and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Load pretrained model and move to GPU
try:
    model = ViTExtractor.from_pretrained("vits16_dino").to(device)
    model.eval()
except Exception as e:
    logger.error(f"Error loading model: {e}")
    raise

# Initialize dataset and data loader
try:
    dataset_root = "/mnt/e/data/image"
    train_dataset = CustomImageLabeledDataset(df_train, dataset_root=dataset_root, transform=transform)
    full_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
except Exception as e:
    logger.error(f"Error initializing dataset or dataloader: {e}")
    raise

# Initialize lists and variables
embeddings_list = []
labels_list = []
id_list = []
gid_list = []
image_paths_list = []

# Initialize timer for checkpointing
checkpoint_interval = 30 * 60  # 30 minutes in seconds
last_checkpoint_time = time.time()
checkpoint_path = f"{dataset_root}/checkpoint_embeddings.npz"

# Load checkpoint if available
def load_checkpoint():
    try:
        checkpoint_data = np.load(checkpoint_path)
        embeddings_list.append(checkpoint_data['embeddings'])
        labels_list.append(checkpoint_data['labels'])
        id_list.append(checkpoint_data['ids'])
        gid_list.append(checkpoint_data['gids'])
        image_paths_list.extend(checkpoint_data['image_paths'])
        return len(image_paths_list)
    except Exception as e:
        logger.warning(f"No valid checkpoint found: {e}")
        return 0

# Function to save checkpoint
def save_checkpoint():
    logger.info("Saving checkpoint...")
    if embeddings_list:
        all_embeddings = np.concatenate(embeddings_list, axis=0)
        all_labels = np.concatenate(labels_list, axis=0)
        all_ids = np.concatenate(id_list, axis=0)
        all_gids = np.concatenate(gid_list, axis=0)
        np.savez(checkpoint_path, embeddings=all_embeddings, labels=all_labels, ids=all_ids, gids=all_gids, image_paths=image_paths_list)
        logger.info(f"Checkpoint saved to {checkpoint_path}")

# Resume from the last checkpoint
processed_images = load_checkpoint()
start_batch = processed_images // 32

# Inference Loop
with torch.no_grad():
    for batch_num, batch in enumerate(tqdm(full_loader, total=len(full_loader), desc="Inference")):
        if batch_num < start_batch:
            continue

        try:
            inputs = batch["input_tensors"].to(device)
            labels = batch["labels"]
            indices = batch["indices"]
            ids = batch["id"]
            gids = batch["gid"]

            embeddings = model(inputs)

            # Transfer tensors to CPU
            embeddings_cpu = embeddings.cpu().numpy()
            labels_cpu = labels.cpu().numpy()
            ids_cpu = np.array(ids)
            gids_cpu = np.array(gids)
            batch_image_paths = df_train.iloc[indices]['path'].values

            embeddings_list.append(embeddings_cpu)
            labels_list.append(labels_cpu)
            id_list.append(ids_cpu)
            gid_list.append(gids_cpu)
            image_paths_list.extend(batch_image_paths)

            # Check if it's time to save a checkpoint
            if time.time() - last_checkpoint_time >= checkpoint_interval:
                save_checkpoint()
                last_checkpoint_time = time.time()
        except Exception as e:
            logger.error(f"Error processing batch {batch_num}: {e}")

# Final save after processing all batches
logger.info("Saving final embeddings...")
try:
    all_embeddings = np.concatenate(embeddings_list, axis=0)
    all_labels = np.concatenate(labels_list, axis=0)
    all_ids = np.concatenate(id_list, axis=0)
    all_gids = np.concatenate(gid_list, axis=0)
    embeddings_path = f"{dataset_root}/embeddings11.npz"
    np.savez(embeddings_path, embeddings=all_embeddings, labels=all_labels, ids=all_ids, gids=all_gids, image_paths=image_paths_list)
    logger.info(f"Final embeddings saved to {embeddings_path}")
except Exception as e:
    logger.error(f"Error saving final embeddings: {e}")

2024-06-11 03:52:15,979 - INFO - Dataset loaded: 79918 samples
2024-06-11 03:52:16,000 - INFO - Using device: cuda


https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth
Checkpoint is already here.


Inference:   8%|▊         | 205/2498 [29:58<6:25:08, 10.08s/it]2024-06-11 04:22:24,463 - INFO - Saving checkpoint...
2024-06-11 04:22:24,738 - INFO - Checkpoint saved to /mnt/e/data/image/checkpoint_embeddings.npz
Inference:  16%|█▌        | 393/2498 [1:00:05<6:35:33, 11.27s/it]2024-06-11 04:52:31,318 - INFO - Saving checkpoint...
2024-06-11 04:52:31,718 - INFO - Checkpoint saved to /mnt/e/data/image/checkpoint_embeddings.npz
Inference:  23%|██▎       | 568/2498 [1:30:13<5:46:03, 10.76s/it]2024-06-11 05:22:40,178 - INFO - Saving checkpoint...
2024-06-11 05:22:40,581 - INFO - Checkpoint saved to /mnt/e/data/image/checkpoint_embeddings.npz
Inference:  29%|██▉       | 736/2498 [2:00:17<6:17:12, 12.84s/it]2024-06-11 05:52:46,425 - INFO - Saving checkpoint...
2024-06-11 05:52:46,961 - INFO - Checkpoint saved to /mnt/e/data/image/checkpoint_embeddings.npz
Inference:  36%|███▌      | 894/2498 [2:30:27<5:20:54, 12.00s/it]2024-06-11 06:22:55,500 - INFO - Saving checkpoint...
2024-06-11 06:22:56