In [21]:
import os
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import torch
import torchvision.models as models
import torchvision.transforms as transforms

In [22]:
# 📁 Dossiers
image_folder = Path("../data/raw/Flicker8k_Dataset")

global_out = Path("../data/processed/features_resnet_global")
spatial_out = Path("../data/processed/features_resnet_spatial")
global_out.mkdir(parents=True, exist_ok=True)
spatial_out.mkdir(parents=True, exist_ok=True)

In [23]:
# ⚙️ Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
# 📦 Modèle ResNet50 sans la dernière couche
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])  # Enlever FC layer
resnet = resnet.to(device).eval()

In [25]:
# 🔄 Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [26]:
# 🔍 Fonction d'extraction
resnet_base = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet_global = torch.nn.Sequential(*list(resnet_base.children())[:-1]).to(device).eval()
resnet_spatial = torch.nn.Sequential(*list(resnet_base.children())[:-2]).to(device).eval()

def extract_features(image_path):
    image = Image.open(image_path).convert("RGB")
    tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        # 🔹 Global features (2048,)
        global_feat = resnet_global(tensor).squeeze()  # shape (2048,)
        
        # 🔸 Spatial features (49, 2048)
        spatial_feat = resnet_spatial(tensor)  # shape (1, 2048, 7, 7)
        spatial_feat = spatial_feat.squeeze(0).permute(1, 2, 0).view(-1, 2048)  # (49, 2048)
        
    return global_feat.cpu().numpy(), spatial_feat.cpu().numpy()

In [27]:
# 📸 Sous-échantillon (ex: 10 images)
image_list = list(image_folder.glob("*.jpg"))[:10]  # ← ici tu choisis combien

In [28]:
# 🚀 Boucle de traitement
for img_path in tqdm(image_list, desc="🧠 Extraction (2 formats)"):
    image_id = img_path.stem
    global_feat, spatial_feat = extract_features(img_path)
    np.save(global_out / f"{image_id}.npy", global_feat)
    np.save(spatial_out / f"{image_id}.npy", spatial_feat)

🧠 Extraction (2 formats): 100%|██████████| 10/10 [00:00<00:00, 14.84it/s]
