In [1]:
!pip install -U transformers accelerate bitsandbytes opencv-python -q

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
import warnings
import sys
import time
warnings.filterwarnings('ignore')

from tqdm import tqdm
from pathlib import Path
from collections import Counter, defaultdict
from PIL import Image

In [4]:
PROJECT_ROOT = Path('/content/drive/Othercomputers/my_notebook/lion_final_pro_multimodal-anomaly-report-generation')

In [None]:
import os
import sys
import json
import time
import cv2
import numpy as np
from pathlib import Path
from datetime import datetime
from tqdm import tqdm

# 1. 경로 설정
PROJECT_ROOT = Path('/content/drive/Othercomputers/my_notebook/lion_final_pro_multimodal-anomaly-report-generation')
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

try:
    from scripts.eval_llm_baseline import get_llm_client
except ImportError:
    print("Error: 'scripts.eval_llm_baseline' 경로를 찾을 수 없습니다.")
    sys.exit(1)

# --- 유틸리티 함수 영역 ---

def get_product_name_from_path(img_path, data_root):
    """경로에서 제품명 추출"""
    try:
        rel_path = img_path.relative_to(data_root)
        parts = rel_path.parts
        if len(parts) > 1:
            return parts[1].replace('_', ' ').title()
    except: pass
    return "Unknown Product"

def get_similar_good_samples(target_img_path, data_root, n=1):
    """유사 정상 이미지 자동 검색 (RAG)"""
    try:
        rel_parts = target_img_path.relative_to(data_root).parts
        category, product = rel_parts[0], rel_parts[1]
        good_samples_dir = data_root / category / product / "train" / "good"

        if not good_samples_dir.exists(): return []

        target_img = cv2.imread(str(target_img_path), cv2.IMREAD_GRAYSCALE)
        target_hist = cv2.calcHist([target_img], [0], None, [256], [0, 256])
        cv2.normalize(target_hist, target_hist, 0, 1, cv2.NORM_MINMAX)

        similarities = []
        for sample_path in list(good_samples_dir.glob("*.png"))[:20]:
            sample_img = cv2.imread(str(sample_path), cv2.IMREAD_GRAYSCALE)
            if sample_img is None: continue
            sample_hist = cv2.calcHist([sample_img], [0], None, [256], [0, 256])
            cv2.normalize(sample_hist, sample_hist, 0, 1, cv2.NORM_MINMAX)
            score = cv2.compareHist(target_hist, sample_hist, cv2.HISTCMP_CORREL)
            similarities.append((str(sample_path), score))

        similarities.sort(key=lambda x: x[1], reverse=True)
        return [s[0] for s in similarities[:n]]
    except: return []

def generate_visual_files(img_full_path, data_root, bboxes):
    """시각화 파일(Heatmap, Overlay) 생성"""
    img = cv2.imread(str(img_full_path))
    if img is None: return "error", "error"

    output_root = PROJECT_ROOT / "llm_output"
    rel_path = img_full_path.relative_to(data_root)
    target_sub_path = Path(*[p for p in rel_path.parts if p not in ['test', 'train']])

    h_save_path = output_root / "anomaly_heatmap" / target_sub_path
    o_save_path = output_root / "overlay" / target_sub_path
    h_save_path.parent.mkdir(parents=True, exist_ok=True)
    o_save_path.parent.mkdir(parents=True, exist_ok=True)

    h, w = img.shape[:2]
    mask = np.zeros((h, w), dtype=np.float32)
    for box in bboxes:
        try:
            x1, y1, x2, y2 = [int(c * w / 1000) if max(box) <= 1000 else int(c) for c in box]
            cv2.rectangle(mask, (x1, y1), (x2, y2), 1.0, -1)
        except: continue

    heatmap = cv2.applyColorMap(np.uint8(255 * cv2.GaussianBlur(mask, (51, 51), 0)), cv2.COLORMAP_JET)
    cv2.imwrite(str(h_save_path), heatmap)
    overlay = img.copy()
    red_mask = np.zeros_like(img); red_mask[mask > 0] = [0, 0, 255]
    cv2.addWeighted(red_mask, 0.4, overlay, 0.6, 0, overlay)
    cv2.imwrite(str(o_save_path), overlay)

    return str(h_save_path), str(o_save_path)

# --- 메인 실행 엔진 ---

def run_batch_inference(target_dir_rel="GoodsAD", model_name="llava"):
    client = get_llm_client(model_name)
    data_root = PROJECT_ROOT / "dataset" / "MMAD"
    jsonl_path = PROJECT_ROOT / "llm_output" / "live_reports.jsonl"
    jsonl_path.parent.mkdir(parents=True, exist_ok=True)

    test_images = [img for img in list(data_root.rglob("*.png")) + list(data_root.rglob("*.jpg")) if 'test' in img.parts]
    print(f"[{model_name}] 분석 시작: 총 {len(test_images)}개")

    with open(jsonl_path, "a", encoding="utf-8") as f_jsonl:
        for img_path in tqdm(test_images, desc="Progress"):
            product_info = get_product_name_from_path(img_path, data_root)
            auto_few_shots = get_similar_good_samples(img_path, data_root, n=1)

            start_t = time.time()
            prompt = f"Analyze this {product_info} image and return strictly in JSON format with keys: is_anomaly, anomaly_type, description, danger_score, ad_score, bboxes."

            try:
                # 1. 추론 및 응답 획득
                raw_res = client.generate_answers([str(img_path)], prompt, few_shot_paths=auto_few_shots)
                answer_text = raw_res[0] if isinstance(raw_res, list) else str(raw_res)

                # 2. JSON 추출 및 이중 파싱 대응
                start_idx, end_idx = answer_text.find('{'), answer_text.rfind('}')
                if start_idx == -1: continue

                json_str = answer_text[start_idx:end_idx+1]
                data = json.loads(json_str)
                if isinstance(data, str): data = json.loads(data) # 'str' object has no attribute 'keys' 방지
                if not isinstance(data, dict): continue

                # 3. 시각화 및 리포트 작성
                h_path, o_path = generate_visual_files(img_path, data_root, data.get("bboxes", []))
                report = {
                    "original_path": str(img_path),
                    "anomaly_heatmap": h_path,
                    "overlay": o_path,
                    "inference_time": f"{time.time() - start_t:.2f}s",
                    "product_info": product_info,
                    "anomaly_info": data.get("anomaly_type", "N/A"),
                    "danger_score": data.get("danger_score", 0),
                    "ad_score": data.get("ad_score", 0.0),
                    "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                }

                # 4. 실시간 기록
                f_jsonl.write(json.dumps(report, ensure_ascii=False) + "\n")
                f_jsonl.flush()

            except Exception as e:
                print(f"\n[Error] {img_path.name}: {e}")
                continue

    return jsonl_path

if __name__ == "__main__":
    run_batch_inference()

[llava] 분석 시작: 총 6280개


Progress:   0%|          | 1/6280 [00:23<40:42:16, 23.34s/it]


[Error] 002.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 2/6280 [00:25<19:16:02, 11.05s/it]


[Error] 003.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 3/6280 [00:28<12:41:52,  7.28s/it]


[Error] 001.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 4/6280 [00:31<9:38:05,  5.53s/it] 


[Error] 000.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 5/6280 [00:34<7:50:52,  4.50s/it]


[Error] 004.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 6/6280 [00:36<6:44:05,  3.86s/it]


[Error] 005.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 7/6280 [00:39<6:03:33,  3.48s/it]


[Error] 007.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 8/6280 [00:42<5:42:16,  3.27s/it]


[Error] 006.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 9/6280 [00:46<6:01:51,  3.46s/it]


[Error] 008.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 10/6280 [00:49<6:04:25,  3.49s/it]


[Error] 009.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 11/6280 [00:52<5:35:19,  3.21s/it]


[Error] 002.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 12/6280 [00:55<5:37:16,  3.23s/it]


[Error] 003.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 13/6280 [00:58<5:33:05,  3.19s/it]


[Error] 001.png: 'str' object has no attribute 'keys'


Progress:   0%|          | 14/6280 [01:00<5:02:02,  2.89s/it]


[Error] 000.png: 'str' object has no attribute 'keys'
