In [None]:
import os

import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.auto import tqdm

from datasets.utils import PreprocessingDataset
from models.utils import get_model_by_name
from utils.environment import modified_environ


# Create image embeddings

In [None]:
# Parameters
BATCH_SIZE, NUM_WORKERS = 8, 4
IMAGES_EXT = ["*.gif", "*.jpg", "*.jpeg", "*.png", "*.webp"]
USE_GPU = True

# Dataset
DATASET = "UGallery"
assert DATASET in ["UGallery", "Wikimedia", "Pinterest"]

# Model
MODEL = "resnet50"
assert MODEL in ["resnet50"]
MODEL_VERSION = "imagenet"  # "imagenet" or "places365"
assert MODEL_VERSION in ["imagenet", "places365"]

# Images path
IMAGES_DIR = None
if DATASET == "UGallery":
    # ~35s, v2 27.6s
    # IMAGES_DIR = os.path.join("/", "mnt", "workspace", "Ugallery", "images")
    IMAGES_DIR = os.path.join("/", "mnt", "workspace", "Ugallery", "mini-images-224-224-v2")
elif DATASET == "Wikimedia":
    # ~1h 10m, v2 1m 36s
    # IMAGES_DIR = os.path.join("/", "mnt", "data2", "wikimedia", "images", "img")
    IMAGES_DIR = os.path.join("/", "mnt", "data2", "wikimedia", "mini-images-224-224-v2")
elif DATASET == "Pinterest":
    # ~2h, v2 1h 7m
    # IMAGES_DIR = os.path.join("/", "mnt", "data2", "pinterest_iccv", "images")
    IMAGES_DIR = os.path.join("/", "mnt", "data2", "pinterest_iccv", "mini-images_filtered-224-224-v2")


In [None]:
# Paths (output)
OUTPUT_EMBEDDING_PATH = os.path.join("data", DATASET, f"{DATASET.lower()}_embedding-{MODEL}_{MODEL_VERSION}.npy")


In [None]:
import PIL
from PIL import ImageFile


# Needed for some images in Pinterest and Wikimedia dataset
PIL.Image.MAX_IMAGE_PIXELS = 3_000_000_000
# Some images are "broken" in Wikimedia dataset
ImageFile.LOAD_TRUNCATED_IMAGES = True


In [None]:
%%time
# Setting up torch device (useful if GPU available)
print("\nCreating device...")
device = torch.device("cuda:0" if torch.cuda.is_available() and USE_GPU else "cpu")
if torch.cuda.is_available() != USE_GPU:
    print((f"\nNotice: Not using GPU - "
           f"Cuda available ({torch.cuda.is_available()}) "
           f"does not match USE_GPU ({USE_GPU})"
    ))

# Downloading models for feature extraction
print("\nDownloading model...")
with modified_environ(TORCH_HOME="."):
    print(f"Model: {MODEL} (pretrained on {MODEL_VERSION})")
    if MODEL_VERSION == "imagenet":
        model = get_model_by_name(MODEL).eval().to(device)
    else:
        # Places365
        model = torchvision.models.__dict__["resnet50"](num_classes=365)
        checkpoint = torch.load(os.path.join("/", "mnt", "data2", "netdissect2", "lite", "zoo", "resnet50_places365.pth.tar"))
        model.load_state_dict({
            str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items()
        })
        del checkpoint
        # Drop last layer (need embedding, not classification)
        model = torch.nn.Sequential(*list(model.children()))[:-1]
        for param in model.parameters():
            model.requires_grad = False
        model = model.eval().to(device)

# Setting up transforms and dataset
print("\nSetting up transforms and dataset...")
images_transforms = transforms.Compose([
    # transforms.Resize(256),  # Already done in mini v2
    # transforms.CenterCrop(224),  # Already done in mini v2
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_dataset = PreprocessingDataset(
    IMAGES_DIR,
    extensions=IMAGES_EXT,
    transform=images_transforms,
)
image_dataloader = DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
print(f">> Images dataset: {len(image_dataset)}")

# Calculate embedding dimension size
dummy_input = torch.ones(1, *image_dataset[0]["image"].size()).to(device)
dummy_output = model(dummy_input)
emb_dim = dummy_output.size(1)
print(f">> Embedding dimension size: {emb_dim}")

# Feature extraction phase
print(f"\nFeature extraction...")
output_ids = np.empty(len(image_dataset), dtype=object)
output_embedding = torch.zeros((len(image_dataset), emb_dim), dtype=torch.float32, device=device)
with torch.no_grad():
    for batch_i, sample in enumerate(tqdm(image_dataloader, desc="Feature extraction")):
        item_image = sample["image"].to(device)
        item_idx = sample["idx"]
        output_ids[[*item_idx]] = sample["id"]
        output_embedding[item_idx] = model(item_image).squeeze(-1).squeeze(-1)
output_embedding = output_embedding.cpu().numpy()

# Fill output embedding
embedding = np.ndarray(
    shape=(len(image_dataset), 2),
    dtype=object,
)
for i in range(len(image_dataset)):
    embedding[i] = np.asarray([output_ids[i], output_embedding[i]])
print(f">> Embedding shape: {embedding.shape}")

# Save embedding to file
print(f"\nSaving embedding to file... ({OUTPUT_EMBEDDING_PATH})")
np.save(OUTPUT_EMBEDDING_PATH, embedding, allow_pickle=True)

# Free some memory
if USE_GPU:
    print(f"\nCleaning GPU cache...")
    model = model.to(torch.device("cpu"))
    torch.cuda.empty_cache()

# Finished
print("\nDone")
