### Libraries and device

In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import torch
import tempfile
import shutil
from transformers import AutoImageProcessor, ResNetModel
from tqdm import tqdm
import pandas as pd

from util import preprocess_image

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device}")

### Model initialization and settings

In [None]:
# ResNet-50 model
model_name = "microsoft/resnet-50"
processor = AutoImageProcessor.from_pretrained(model_name)
model = ResNetModel.from_pretrained(model_name).to(device)
print("\n[INFO] ResNet model loaded.\n")


# Paths
images_folder = os.path.join("..", "..", "miniImageNet", "images")  # Path to MiniImageNet images folder
save_path = os.path.join("D:", "CV_Project", "resnet_emb_intermediate")  # Path where to save the embeddings

### Extraction of Training and Validation images embeddings

In [None]:
IMG_SIZE = 384

# Create a dictionary to hold intermediate outputs
intermediate_outputs = {}

# Function to create a hook
def get_hook(idx):
    # Hook function to capture the output of each stage
    def hook(module, input, output):
        intermediate_outputs[f"stage_{idx}"] = output
    return hook

# Register hooks to each stage of the ResNet model
for idx, stage in enumerate(model.encoder.stages):
    stage.register_forward_hook(get_hook(idx))


# Extract image names from CSV files (train and validation)
csv1 = os.path.join("..", "..", "miniImageNet", "train_stratified.csv")
csv2 = os.path.join("..", "..", "miniImageNet", "validation_stratified.csv")
df = pd.concat([pd.read_csv(csv1), pd.read_csv(csv2)], ignore_index=True)
image_names = df['filename'].tolist()


# Loop through images and extract embeddings
for file_name in tqdm(image_names, desc="Processing images"):
    # extract and preprocess image
    image_path = os.path.join(images_folder, file_name)
    image = preprocess_image(image_path, IMG_SIZE)

    # Prepare inputs
    inputs = processor(image, return_tensors="pt", do_resize=False, do_center_crop=False).to(device)

    # Clear previous intermediate outputs
    intermediate_outputs.clear()

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # Convert tensors to CPU and float16
    embeddings_dict = {name: tensor.cpu().to(torch.float16) for name, tensor in intermediate_outputs.items()}

    # Save embeddings using a temporary file for atomicity
    filename = os.path.join(save_path, file_name.split(".")[0] + ".pt")
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        torch.save(embeddings_dict, tmp.name)
        tmp.flush()
        os.fsync(tmp.fileno())
    shutil.move(tmp.name, filename)