In [None]:
import os
import json
from pathlib import Path
from typing import List, Dict, Tuple
import requests
from tqdm import tqdm
import numpy as np
import pandas as pd
from time import sleep

from PIL import Image

import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import spacy
from sentence_transformers import SentenceTransformer

IMAGE_FOLDER = "/content/drive/MyDrive/test image"
OUTPUT_DIR = "conceptnet_encodings"
TOP_N_CONCEPTNET = 10
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
CAPTION_MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
CONCEPTNET_CACHE_FILE = "conceptnet_cache.json"
REQUESTS_SLEEP = 0.1

print("Loading captioning model...")
caption_model = VisionEncoderDecoderModel.from_pretrained(CAPTION_MODEL_NAME).to(DEVICE)
caption_processor = ViTImageProcessor.from_pretrained(CAPTION_MODEL_NAME)
caption_tokenizer = AutoTokenizer.from_pretrained(CAPTION_MODEL_NAME)

if caption_tokenizer.pad_token is None:
    caption_tokenizer.pad_token = caption_tokenizer.eos_token
    caption_tokenizer.pad_token_id = caption_tokenizer.eos_token_id

try:
    caption_model.config.pad_token_id = caption_tokenizer.pad_token_id
    caption_model.config.decoder_start_token_id = getattr(caption_tokenizer, "cls_token_id", None) or caption_tokenizer.eos_token_id
except Exception:
    pass

GEN_KWARGS = {
    "max_length": 30,
    "num_beams": 4,
    "early_stopping": True,
    "no_repeat_ngram_size": 2,
}

def generate_caption(image: Image.Image) -> str:
    """Generate a caption for a single PIL Image."""
    if image.mode != "RGB":
        image = image.convert("RGB")
    pixel_values = caption_processor(images=image, return_tensors="pt").pixel_values.to(DEVICE)
    with torch.no_grad():
        output_ids = caption_model.generate(pixel_values, **GEN_KWARGS)
    caption = caption_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
    return caption

print("Loading spaCy model...")
nlp = spacy.load("en_core_web_sm")

def extract_concept_words(caption: str) -> List[str]:
    """Extract nouns & noun-chunks from caption text."""
    doc = nlp(caption)
    concepts = []
    for nc in doc.noun_chunks:
        t = nc.text.strip().lower()
        if len(t) > 1:
            concepts.append(t)
    for tok in doc:
        if tok.pos_ in {"NOUN", "PROPN"}:
            t = tok.lemma_.lower().strip()
            if len(t) > 1 and t not in {"image", "photo", "picture"}:
                concepts.append(t)
    seen = set()
    ordered = []
    for c in concepts:
        if c not in seen:
            seen.add(c)
            ordered.append(c)
    return ordered

CONCEPTNET_BASE = "https://api.conceptnet.io"

def load_conceptnet_cache(path: str) -> dict:
    if os.path.exists(path):
        try:
            with open(path, "r", encoding="utf-8") as f:
                return json.load(f)
        except Exception:
            return {}
    return {}

def save_conceptnet_cache(path: str, cache: dict):
    tmp = path + ".tmp"
    with open(tmp, "w", encoding="utf-8") as f:
        json.dump(cache, f, ensure_ascii=False, indent=2)
    os.replace(tmp, path)

cn_cache = load_conceptnet_cache(CONCEPTNET_CACHE_FILE)

def conceptnet_related(concept: str, topn: int = 10) -> List[Tuple[str, float]]:
    """
    Query ConceptNet for related concepts, with caching.
    Returns list of (surface_text, weight) sorted by weight desc.
    """
    key = concept.strip().lower()
    if key in cn_cache:
        return cn_cache[key][:topn]

    concept_clean = key.replace(" ", "_")
    url = f"{CONCEPTNET_BASE}/related/c/en/{concept_clean}"
    params = {"filter": "/c/en"}
    try:
        resp = requests.get(url, params=params, timeout=8)
        resp.raise_for_status()
        j = resp.json()
        related = []
        for item in j.get("related", [])[:topn]:
            surface = item.get("surfaceText") or item.get("@id", "").split("/")[-1]
            if isinstance(surface, str):
                surface = surface.replace("_", " ").lower()
                related.append((surface, float(item.get("weight", 0.0))))
        related.sort(key=lambda x: -x[1])
    except Exception:
        related = []
    cn_cache[key] = related
    sleep(REQUESTS_SLEEP)
    return related[:topn]

print("Loading embedding model...")
embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE)

def build_concept_bag(concepts: List[str], cn_expand_topn: int = 8) -> str:
    """
    Given list of concept words, expand each with ConceptNet and return a single combined text.
    Example output: "dog; pet; canine; animal; fur; leash"
    """
    bag = []
    for c in concepts:
        bag.append(c)
        related = conceptnet_related(c, topn=cn_expand_topn)
        for (term, weight) in related:
            bag.append(term)
    seen = set()
    ordered = []
    for t in bag:
        if t not in seen and isinstance(t, str) and len(t) > 0:
            seen.add(t)
            ordered.append(t)
    return "; ".join(ordered)

def embed_texts(texts: List[str]) -> np.ndarray:
    """Batch encode texts using sentence-transformers; returns numpy array."""
    embeddings = embed_model.encode(texts, batch_size=BATCH_SIZE, convert_to_numpy=True, show_progress_bar=False)
    return embeddings

def find_images(folder: str, exts={".jpg", ".jpeg", ".png", ".bmp", ".tiff"}) -> List[Path]:
    p = Path(folder)
    files = [f for f in p.rglob("*") if f.suffix.lower() in exts]
    files.sort()
    return files

def process_images(image_folder: str, output_dir: str):
    outp = Path(output_dir)
    outp.mkdir(parents=True, exist_ok=True)
    images = find_images(image_folder)
    print(f"Found {len(images)} images in {image_folder}")

    records = []
    concept_bag_texts = []
    skipped = 0

    for img_path in tqdm(images, desc="Processing images"):
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Skipping {img_path}: cannot open ({e})")
            skipped += 1
            continue

        try:
            caption = generate_caption(img)
        except Exception as e:
            print(f"Warning: failed to caption {img_path}: {e}")
            caption = ""

        concepts = extract_concept_words(caption) if caption else []
        concept_bag = build_concept_bag(concepts, cn_expand_topn=TOP_N_CONCEPTNET) if concepts else ""

        records.append({
            "filename": str(img_path),
            "caption": caption,
            "extracted_concepts": "|".join(concepts),
            "concept_bag": concept_bag
        })
        concept_bag_texts.append(concept_bag if concept_bag else "")

    print("Embedding concept bags...")
    if len(concept_bag_texts) == 0:
        print("No concept bags to embed.")
        return

    embeddings = embed_texts(concept_bag_texts)
    np.save(outp / "image_concept_embeddings.npy", embeddings)
    df = pd.DataFrame(records)
    df.to_csv(outp / "image_concept_metadata.csv", index=False)

    save_conceptnet_cache(CONCEPTNET_CACHE_FILE, cn_cache)

    print(f"Saved embeddings ({embeddings.shape}) and metadata to {outp}. Skipped {skipped} files.")

if __name__ == "__main__":
    process_images(IMAGE_FOLDER, OUTPUT_DIR)