# Inference Notebook Template

Country Level 3 (L3) means-

* With Continent as cue
* With Explicit Country list

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"]="expandable_segments:True"

🔨 TO BE MODIFIED 🔨

In [None]:
# TO BE MODIFIED

!pip install transformers==4.48.3 tokenizers==0.21.0

# 📚 Helper: Save any results dict to JSON

In [None]:
import json, os

def save_results(data: dict,
                 model_name: str,
                 variant: str,
                 task: str,
                 task_level: str,
                 prompt_level: str,
                 run_count: str,
                 output_dir: str = "/kaggle/working/results"):
    # Ensure nested directories are created
    model_dir = os.path.join(output_dir, model_name)
    os.makedirs(model_dir, exist_ok=True)

    fname    = f"RQ1_{model_name}_{variant}_{task}_{task_level}_{prompt_level}_{run_count}.json"  # Fixed name
    out_path = os.path.join(model_dir, fname)

    with open(out_path, "w") as f:
        json.dump(data, f, indent=4)
    print(f"✅ Saved results to {out_path}")


🔨 TO BE MODIFIED 🔨

In [None]:
# TO BE MODIFIED


# paramters for output json
MODEL_NAME = 'Centurio Qwen'
VARIANT = '8B'
PROMPT_LEVEL = 'prompt1'
RUN_COUNT = 'r1'

folder_path = "/kaggle/input/streetartdata/StreetArtProject/RQ1"


# FIXED (NO MODIFICATION NEEDED)
TASK = 'country'
TASK_LEVEL = 'L3'

In [None]:
import os

def get_all_file_paths(root_dir):
    file_paths = []
    stack = [root_dir]

    while stack:
        current_dir = stack.pop()
        with os.scandir(current_dir) as it:
            for entry in it:
                if entry.is_file():
                    file_paths.append(entry.path)
                elif entry.is_dir():
                    stack.append(entry.path)

    return file_paths

image_paths = get_all_file_paths(folder_path)

print(f"Found {len(image_paths)} files.")

In [None]:
# image_paths

# Variant: Centurio Qwen

https://huggingface.co/WueNLP/centurio_qwen

## 1️⃣ Prompt

🔨 TO BE MODIFIED 🔨

In [None]:
# TO BE MODIFIED

prompt_template = """
This art is from continent **{given_continent}**. Which country is it most likely from? *(Only choose one from [{country_list}])*

Classify it with one word: art, vandalism, protest, decoration, advertisement, or other.

Finally, If this particular mural were banned, would it improve or harm urban life in *this location*, would urban life improve or harm? Briefly justify.
"""

## 2️⃣ Load Processor and Model

In [None]:
# from huggingface_hub import notebook_login
# notebook_login()

🔨 TO BE MODIFIED 🔨

In [None]:
# TO BE MODIFIED

from transformers import AutoModelForCausalLM, AutoProcessor
from glob import glob
from PIL import Image
import torch
import os

repo_id    = f"WueNLP/centurio_qwen"

model = AutoModelForCausalLM.from_pretrained(
    repo_id,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map='auto'
)

processor = AutoProcessor.from_pretrained(
    repo_id,
    trust_remote_code=True
)


## 3️⃣ Inference

In [None]:
# import json
#
# # 1. Load your JSON file
# with open('/content/drive/MyDrive/StreetArtProject/results/Qwen2.5-VL/3B-Instruct_RQ1_continent_prompt1.json', 'r') as f:
#     data = json.load(f)
#
# # 1. Build a set of annotated paths from your JSON
# json_paths = { entry['image_path'] for entry in data }
#
# # 2. Filter your existing list
# paths_not_in_json = [p for p in image_paths if p not in json_paths]
#
# # Now `paths_not_in_json` contains only those files missing from your JSON annotations.
# print(f"{len(paths_not_in_json)} paths aren’t in the JSON.")

In [None]:
# image_paths = paths_not_in_json
# print(len(image_paths))

In [None]:
import tempfile

🔨 TO BE MODIFIED 🔨

In [None]:
# TO BE MODIFIED

def infer_img(image_path: str, prompt: str):
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")
    else:
        image = Image.open(image_path).convert("RGB")

    if "<image_placeholder>" not in prompt:
        prompt = "<image_placeholder>\n" + prompt

    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    inputs = processor(
        text=[text],
        images=[image],
        return_tensors="pt"
    ).to("cuda")

    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to(model.device)
            if inputs[k].is_floating_point():
                inputs[k] = inputs[k].to(torch.bfloat16)

    with torch.inference_mode():
        generated_ids = model.generate(**inputs, max_new_tokens=1024, temperature=0.3)
        trimmed = [o[len(i):] for i, o in zip(inputs.input_ids, generated_ids)]
        output = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    return output.strip()


## 4️⃣ Package & Save to JSON

In [None]:
import os, re
from copy import deepcopy

def extract_location_from_path(path):
    """
    From any path where the filename encodes location, e.g.:
      .../dir/Continent_Country_filename.ext
      .../dir/Continent_Country_City_filename.ext
    Returns (continent, country, city_or_None, continent_dir)
    """
    if not path:
        raise ValueError("Empty path provided")

    basename = os.path.basename(path)
    stem, _ext = os.path.splitext(basename)

    tokens = re.split(r'[_\-]+', stem)

    if len(tokens) < 2:
        raise ValueError(f"Filename does not contain continent and country tokens: {basename!r}")

    continent = tokens[0]
    country = tokens[1]
    city = None

    tier_re = re.compile(r'^(?:tier|t)?\d+$', re.IGNORECASE)
    filename_marker_re = re.compile(r'^(?:img|image|photo|scan|picture|pic)\d*$', re.IGNORECASE)

    if len(tokens) >= 4:
        cand = tokens[2]
        if (re.search(r'[A-Za-z]', cand)
            and not re.search(r'\d', cand)
            and not tier_re.match(cand)
            and not filename_marker_re.match(cand)):
            city = cand
        else:
            city = None
    elif len(tokens) == 3:
        cand = tokens[2]
        if (re.search(r'[A-Za-z]', cand)
            and not re.search(r'\d', cand)
            and not tier_re.match(cand)
            and not filename_marker_re.match(cand)):
            city = cand
        else:
            city = None

    continent_dir = None
    abs_parent = os.path.abspath(os.path.dirname(path))
    cur = abs_parent
    while True:
        if os.path.basename(cur) == continent:
            continent_dir = cur
            break
        parent = os.path.dirname(cur)
        if parent == cur:
            break
        cur = parent

    if continent_dir is None:
        fallback = os.path.join(abs_parent, continent)
        continent_dir = "/" + fallback.lstrip(os.sep)

    return continent, country, city, continent_dir



# --- Data: 9 countries per continent (TOP(3), MIDDLE(3), LAST(3)) ---
CONTINENT_COUNTRIES = {
    "Africa": [
        "Nigeria", "South Africa", "Egypt",    # TOP (3)
        "Seychelles", "Comoros", "Cape Verde", # MIDDLE (3)
        "Mozambique", "Rwanda", "Botswana"     # LAST (3)
    ],
    "Asia": [
        "China", "India", "Indonesia",
        "United Arab Emirates", "Singapore", "Israel",
        "Timor-Leste", "Maldives", "Brunei"
    ],
    "Europe": [
        "Russia", "Germany", "United Kingdom",
        "Denmark", "Lithuania", "Slovakia",
        "Monaco", "Gibraltar", "San Marino"
    ],
    "North America": [
        "United States", "Mexico", "Canada",
        "Saint Lucia", "Guyana", "Bahamas",
        "British Virgin Islands", "Anguilla", "Montserrat"
    ],
    "South America": [
        "Brazil", "Argentina", "Colombia",
        "Venezuela", "Chile", "Ecuador",
        "Paraguay", "Uruguay", "Falkland Islands"
    ],
}

# --- Normalization helpers ---
def _norm(s: str) -> str:
    """
    Minimal normalizer used for matching:
      - lowercases and removes non-alphanumeric characters.
    Use this to compare filename tokens to canonical country names.
    """
    return re.sub(r'[^a-z0-9]', '', (s or '').lower())

def get_countries(continent: str, image_path: str = None):
    """
    Return the continent's country list (copy) and optionally apply filename-based replacement.

    Behavior:
      - continent: case-insensitive continent name (e.g. "Africa").
      - image_path: optional filename like "Continent_Country_..._[TOP|MIDDLE|LAST].ext".
        * The country token is taken from the SECOND underscore-separated token (parts[1]).
        * If that country is NOT already in the continent list, the function replaces the
          country at the group index with the parsed country:
            TOP    -> index 2
            MIDDLE -> index 5
            LAST   -> index 8
          (If tag is missing the function uses index 5 by current logic.)
    Returns:
      (countries_list_copy, info_dict)
        - countries_list_copy: a shallow copy of the continent list with any replacement applied.
        - info_dict: { used_continent, filename_country, tag, matched, added, removed, index }
    Raises:
      ValueError: if the provided continent is not recognized.
    """
    # find canonical continent key
    cont_key = None
    for k in CONTINENT_COUNTRIES:
        if _norm(k) == _norm(continent):
            cont_key = k
            break
    if not cont_key:
        raise ValueError(f"Unknown continent: {continent!r}")

    arr = deepcopy(CONTINENT_COUNTRIES[cont_key])
    info = {"used_continent": cont_key, "filename_country": None, "tag": None,
            "matched": None, "added": None, "removed": None, "index": None}

    if not image_path:
        return arr, info

    base = os.path.basename(image_path)
    name, _ = os.path.splitext(base)
    parts = name.split('_')
    if len(parts) < 2:
        return arr, info  # can't parse a country token

    # country is the second token (parts[1]) per user's filename format
    country_token = parts[1]
    # detect tag if last token is TOP/MIDDLE/LAST
    tag = parts[-1].upper() if parts[-1].upper() in ("TOP", "MIDDLE", "LAST") else None

    # build display country from single token: replace dashes with space and title-case
    display_country = country_token.replace('-', ' ').title()
    info["filename_country"] = display_country
    info["tag"] = tag

    # check if already present (using normalization)
    for c in arr:
        if _norm(c) == _norm(display_country):
            info["matched"] = c
            return arr, info  # already exists — no change

    # Not present: replace at group index
    tag_to_idx = {"TOP": 2, "MIDDLE": 5, "LAST": 8}
    remove_idx = tag_to_idx.get(tag, 5)
    remove_idx = max(0, min(remove_idx, len(arr) - 1))

    removed = arr.pop(remove_idx)
    arr.insert(remove_idx, display_country)

    info["added"] = display_country
    info["removed"] = removed
    info["index"] = remove_idx
    return arr, info


In [None]:
import re
import json
import os
import torch

def strip_code_fence(s: str) -> str:
    """
    Remove leading/trailing triple-backtick fences (and any 'json' marker)
    and trim whitespace.
    """
    if not isinstance(s, str):
        return s
    s = s.strip()
    s = re.sub(r"^```(?:\s*json)?\s*", "", s, flags=re.I)
    s = re.sub(r"```$", "", s)
    return s.strip()

def try_parse_json_from_string(s: str):
    """
    Try to parse JSON from a string. Returns (parsed_obj, error_message).
    If parsing fails, parsed_obj is None and error_message contains info.
    """
    cleaned = strip_code_fence(s)
    try:
        return json.loads(cleaned), None
    except json.JSONDecodeError:
        # fallback: try extracting first {...} or [...] substring
        m = re.search(r"(\{(?:.|\s)*\}|\[(?:.|\s)*\])", cleaned)
        if m:
            try:
                return json.loads(m.group(1)), None
            except json.JSONDecodeError as e:
                return None, f"JSON decode failed for extracted substring: {e}"
        return None, "no JSON found or JSON invalid"

def normalize(parsed):
    """
    Normalize parsed JSON:
      - if list of one dict -> return that dict
      - if list of many -> return {"json_list": parsed}
      - otherwise return parsed as-is
    """
    if isinstance(parsed, list):
        if len(parsed) == 1 and isinstance(parsed[0], dict):
            return parsed[0]
        return {"json_list": parsed}
    return parsed

results = []

for i, image_path in enumerate(image_paths):
    print(f"\nProcessing {i+1}/{len(image_paths)}: {image_path}\n")

    # extract location components (skip file if extraction fails)
    try:
        continent, country, city, continent_dir = extract_location_from_path(image_path)
    except Exception as e:
        print(f" ⚠️ Skipping {image_path!r}: {e}")
        continue

    # list countries from the continent
    try:
        countries, info = get_countries(continent, image_path)
        country_list = ", ".join(countries)
    except Exception as e:
        print(f" ⚠️ Could not list countries in {continent_dir!r}: {e}")
        country_list = ""

    # render the prompt, only inserting country_list if non-empty
    prompt = prompt_template.format(
        given_continent=continent,
        country_list=country_list
    )
    print(f" Prompt: {prompt}\n")

    # --- inference (try full-res, then resize retry on OOM) ----------------
    try:
        raw_output = infer_img(image_path, prompt)
        print(" → infer_img() succeeded (full-res).")
    except RuntimeError as e:
        msg = str(e).lower()
        if "cuda out of memory" in msg:
            torch.cuda.empty_cache()
            print(f" ⚠️ OOM on full-res for {image_path}. Resizing to 448x448 and retrying…")
            try:
                img = Image.open(image_path).convert("RGB")
                img = img.resize((448, 448))
                with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
                    tmp_path = tmp.name
                    img.save(tmp_path, format="JPEG")
                raw_output = infer_img(tmp_path, prompt)
                print(" → infer_img() succeeded (resized).")
            except Exception as e2:
                print(f" ❌ Failed even after resize: {e2}")
                raw_output = None
        else:
            print(f" ❌ RuntimeError on {image_path}: {e}")
            raw_output = None
    except Exception as e:
        print(f" ❌ Error processing {image_path}: {e}")
        raw_output = None

    # Produce a single canonical 'output' value:
    # - If model returned a dict/list already -> normalize and use that
    # - If model returned a string -> try parse JSON; if success use parsed normalized,
    #   else keep original string
    output_value = None
    if raw_output is None:
        output_value = None
        print(" → No output from model.")
    else:
        if isinstance(raw_output, (dict, list)):
            output_value = normalize(raw_output)
            print(f" → Model returned {type(raw_output).__name__}; stored structured output.")
        elif isinstance(raw_output, str):
            parsed, perr = try_parse_json_from_string(raw_output)
            if parsed is not None:
                output_value = normalize(parsed)
                print(" → Parsed JSON from model string; stored structured output.")
            else:
                # keep the original string (no duplication)
                output_value = raw_output
                print(f" → Could not parse JSON from model string: {perr!s}. Keeping raw string as output.")
        else:
            # for bytes, numbers, etc., keep as-is
            output_value = raw_output
            print(f" → Model returned type {type(raw_output).__name__}; keeping as output.")


    result = {
        "image_path":   image_path,
        "model":        MODEL_NAME,
        "variant":      VARIANT,
        "task":         TASK,
        "task_level":   TASK_LEVEL,
        "prompt_level": PROMPT_LEVEL,
        "run_count":    RUN_COUNT,
        "prompt":       prompt,
        "output":       output_value
    }
    results.append(result)
    print(f"\nOutput:\n{output_value}\n")

    # Save progress
    try:
        save_results(results, MODEL_NAME, VARIANT, TASK, TASK_LEVEL, PROMPT_LEVEL, RUN_COUNT)

    except Exception as e:
        print(f" ❌ Failed to save results: {e}")

    # always clear GPU cache
    torch.cuda.empty_cache()
    print(f"Total files done: {i+1}")
    print("\n######################################################\n")
