# smart_connections_img_resnet

Calcula embeddings de imagen para Smart Connections con ResNet50 (ImageNet) y recalcula los scores combinando texto+imagen.

Orden sugerido:
1. Ejecuta `smart_connections_embeddings.ipynb` (texto) para generar `matchings_text.npy` y actualizar `matchings_pairs.jsonl` con score_text.
2. Ejecuta este notebook para generar `matchings_img_resnet.npy`, combinar scores (0.7 texto / 0.3 imagen) y sobrescribir `matchings_pairs.jsonl`.


In [1]:
# smart_connections_img_resnet
from __future__ import annotations
import json, math, io
from pathlib import Path
from typing import Dict, List, Any
import requests
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
from torchvision import models

ALPHA = 0.7  # peso texto; 1-ALPHA imagen

try:
    PROJECT_ROOT = Path(__file__).resolve().parents[1]
except NameError:
    PROJECT_ROOT = Path.cwd().resolve().parents[0]
    if not (PROJECT_ROOT / "data").exists() and len(PROJECT_ROOT.parents) > 0:
        PROJECT_ROOT = PROJECT_ROOT.parent

DATA_DIR = PROJECT_ROOT / "data"
PROD_PATH = DATA_DIR / "matchings_products.jsonl"
PAIR_PATH = DATA_DIR / "matchings_pairs.jsonl"
TEXT_EMB = DATA_DIR / "matchings_text.npy"
IMG_EMB = DATA_DIR / "matchings_img_resnet.npy"
INDEX_PATH = DATA_DIR / "matchings_index.json"

assert PROD_PATH.exists(), "Falta matchings_products.jsonl"
assert PAIR_PATH.exists(), "Falta matchings_pairs.jsonl"
assert TEXT_EMB.exists(), "Ejecuta primero 4_smart_connections_text.ipynb"
assert INDEX_PATH.exists(), "Falta index de ids"

# Carga datos
products = [json.loads(line) for line in PROD_PATH.read_text().splitlines() if line.strip()]
pairs = [json.loads(line) for line in PAIR_PATH.read_text().splitlines() if line.strip()]
ids = json.loads(INDEX_PATH.read_text())["ids"]
id_to_idx = {pid: i for i, pid in enumerate(ids)}
text_emb = np.load(TEXT_EMB)
assert text_emb.shape[0] == len(ids)

# Modelo ResNet50
device = "cuda" if torch.cuda.is_available() else "cpu"
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2).to(device)
model.fc = torch.nn.Identity()
model.eval()

preprocess = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img_embs = np.zeros((len(ids), 2048), dtype=np.float32)
session = requests.Session()

def fetch_image(url: str):
    try:
        r = session.get(url, timeout=5)
        r.raise_for_status()
        img = Image.open(io.BytesIO(r.content)).convert("RGB")
        return img
    except Exception:
        return None

@torch.no_grad()
def embed_image(img: Image.Image) -> np.ndarray:
    x = preprocess(img).unsqueeze(0).to(device)
    vec = model(x).cpu().numpy()[0]
    norm = np.linalg.norm(vec) or 1.0
    return vec / norm

missing = 0
for idx, pid in enumerate(ids):
    prod = products[idx]
    url = prod.get("image_url") or prod.get("image")
    if not url:
        missing += 1
        continue
    img = fetch_image(url)
    if img is None:
        missing += 1
        continue
    img_embs[idx] = embed_image(img)

np.save(IMG_EMB, img_embs)
print(f"Embeddings imagen guardados en {IMG_EMB}; faltantes: {missing}")

scored = []
for pair in pairs:
    ci = id_to_idx.get(pair["client_id"])
    cj = id_to_idx.get(pair["competitor_id"])
    score_t = float(np.dot(text_emb[ci], text_emb[cj])) if ci is not None and cj is not None else 0.0
    score_i = float(np.dot(img_embs[ci], img_embs[cj])) if ci is not None and cj is not None else 0.0
    score = ALPHA * score_t + (1-ALPHA) * score_i
    pair["score_text"] = score_t
    pair["score_image"] = score_i
    pair["score"] = score
    pair["similarity"] = score * 100
    if "label" not in pair:
        pair["label"] = 1
    if "is_distractor" not in pair:
        pair["is_distractor"] = pair.get("label",1)==0
    scored.append(pair)

pos = [p["score"] for p in scored if p.get("label") == 1]
neg = [p["score"] for p in scored if p.get("label") == 0]
print(f"Mean score label=1: {sum(pos)/max(len(pos),1):.4f} | label=0: {sum(neg)/max(len(neg),1):.4f}")

cl = next((p for p in products if "garmin quatix 7x solar watch".lower() in p.get("title","").lower()), None)
if cl:
    print("Caso Garmin:")
    for pair in scored:
        if pair["client_id"] != cl["id"]:
            continue
        print(f"pair {pair['pair_id']} label={pair['label']} score={pair['score']:.4f} (t={pair['score_text']:.4f}, i={pair['score_image']:.4f}) comp={pair['competitor_id']}")

PAIR_PATH.write_text("\n".join(json.dumps(r, ensure_ascii=False) for r in scored))
print(f"Pares actualizados en {PAIR_PATH}. Recarga la API con ?nocache=1")




Embeddings imagen guardados en /Users/marc/Documents/Projectes/tfm-product-matching/data/matchings_img_resnet.npy; faltantes: 781
Mean score label=1: 0.5799 | label=0: 0.1884
Caso Garmin:
pair 0 label=1 score=0.7837 (t=0.7750, i=0.8038) comp=831121
pair 2000 label=0 score=0.1312 (t=0.1524, i=0.0817) comp=178374
pair 1 label=1 score=0.7551 (t=0.7145, i=0.8497) comp=1143212
pair 2001 label=0 score=0.1971 (t=0.1943, i=0.2036) comp=574333
Pares actualizados en /Users/marc/Documents/Projectes/tfm-product-matching/data/matchings_pairs.jsonl. Recarga la API con ?nocache=1
