In [None]:
import torch
from transformers import AutoImageProcessor, AutoModel
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from glob import glob
from PIL import Image
from tqdm import tqdm

In [None]:
# %%
import os
import platform
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [None]:
def set_device():
    try:
        if torch.cuda.is_available():
            device = 'cuda'
        elif torch.backends.mps.is_available():
            device = 'mps'
        else:
            device = 'cpu'
    except:
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'
    print(f"Using {device} as device")
    return device

device = set_device()

### Dino Custom Processor and Model

In [None]:
# Setup DinoV2 Custom Processor to ensure gradient flow in later training
transform_pipeline = transforms.Compose([
    #transforms.Resize(256),  # Resize so the shortest side is 256
    #transforms.CenterCrop((224, 224)),  # Center crop to 224x224
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])
def dino_processor(input):
    if isinstance(input, str):
        img = Image.open(input).convert('RGB')
        img = transforms.ToTensor()(img.resize([512,512]))
        img = img.unsqueeze(0)

        processed_img = transform_pipeline(img)
    elif isinstance(input, torch.Tensor):
        processed_img = transform_pipeline(input)
    else:
        raise ValueError("Input must be either a string or a torch.Tensor")
    return processed_img


In [None]:
def setup_dino_model(device):
    model_name = "facebook/dinov2-base"
    model = AutoModel.from_pretrained(model_name)
    model = model.to(device)
    return model

dino_model = setup_dino_model(device)

### Function to extract embeddings

In [None]:
def calculate_embeddings(input_images_path, save_path):
    if os.path.exists(save_path):
        print(f"Embeddings for {input_images_path} already calculated")
        return
    
    input_images = glob(f"{input_images_path}*.jpg")
    skus = [elem.split('/')[-1].split('.')[0] for elem in input_images]

    embeddings = {elem:None for elem in skus}

    for image_path in tqdm(input_images):
        sku = image_path.split('/')[-1].split('.')[0]
        # Load Image and preprocess
        input = dino_processor(image_path)
        input = input.to(device)
        # Perform forward pass
        with torch.no_grad():
            output = dino_model(input)
            embedding = output['pooler_output']
        # Assign embedding to embeddings
        embeddings[sku] = embedding.detach().cpu()
    
    # Save embeddings
    torch.save(embeddings, save_path)


### Calculate all embeddings:

In [None]:
# Real Images
input_images_path = f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/e4e_images/all/"
save_path = f"{DATA_PATH}/Models/Assessor/DinoV2/Embeddings/real_images_embedings.pt"
calculate_embeddings(input_images_path, save_path)

# e4e 00003
input_images_path = f"{DATA_PATH}/Generated_Images/e4e/00003_snapshot_920/"
save_path = f"{DATA_PATH}/Models/Assessor/DinoV2/Embeddings/e4e_00003_snapshot_920.pt"
calculate_embeddings(input_images_path, save_path)

# e4e 00005
input_images_path = f"{DATA_PATH}/Generated_Images/e4e/00005_snapshot_1200/"
save_path = f"{DATA_PATH}/Models/Assessor/DinoV2/Embeddings/e4e_00005_snapshot_1200.pt"
calculate_embeddings(input_images_path, save_path)

# PTI
input_images_path = f"{DATA_PATH}/Generated_Images/PTI/"
save_path = f"{DATA_PATH}/Models/Assessor/DinoV2/Embeddings/PTI.pt"
calculate_embeddings(input_images_path, save_path)


# Restyle
input_images_path = f"{DATA_PATH}/Generated_Images/restyle/inference_results/4/"
save_path = f"{DATA_PATH}/Models/Assessor/DinoV2/Embeddings/restyle.pt"
calculate_embeddings(input_images_path, save_path)