In [0]:
from openai import AzureOpenAI
from PIL import Image
import base64, json, re, os
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer, util
from io import BytesIO

# ✅ AzureOpenAI client 설정
client = AzureOpenAI(
    api_key="", 
    api_version="2024-05-01-preview",
    azure_endpoint=""
)

# ✅ 이미지 검증 및 처리
def validate_image_data(image_data):
    try:
        img = Image.open(BytesIO(image_data))
        if img.mode != "RGB":
            img = img.convert("RGB")
        output = BytesIO()
        img.save(output, format="JPEG", quality=90)
        return output.getvalue()
    except Exception as e:
        print(f"[검증 실패] 이미지 오류: {e}")
        return None

# ✅ 이미지 기반 태그 추출
def extract_tags_from_image(b64_image):
    prompt = """
    이 이미지를 보고 음악 추천에 필요한 아래 태그들을 각각 하나씩 정확히 추출해줘.
    {"상황":"산책","감성":"기분전환","시간대":"아침","스타일":"시원한","날씨":"맑은날","계절":"봄"}
    """

    try:
        decoded = base64.b64decode(b64_image)
        response = client.chat.completions.create(
            model="team2-gpt",
            messages=[
                {"role": "system", "content": "너는 이미지를 보고 음악 분위기를 분류해주는 AI야."},
                {"role": "user", "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"}},
                ]},
            ],
            temperature=0.2,
            max_tokens=500,
        )

        raw = response.choices[0].message.content
        json_str = re.sub(r"^```json|```$", "", raw.strip(), flags=re.MULTILINE).strip()
        return json.loads(json_str)

    except Exception as e:
        print(f"[태그 추출 실패] 기본값 반환: {e}")
        return {
            "상황": "휴식", "감성": "편안한", "시간대": "오후",
            "스타일": "잔잔한", "날씨": "맑은날", "계절": "봄"
        }

# ✅ 모델 및 데이터 로딩
try:
    model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", 
                                 device="cuda" if torch.cuda.is_available() else "cpu")
    df = pd.read_pickle("songs_with_embeddings.pkl")
    print(f"✅ 데이터셋 로드 완료 ({len(df)}개 곡)")
except Exception as e:
    raise RuntimeError(f"모델 또는 데이터 로딩 실패: {e}")

# ✅ 음악 추천
def final_recommend(user_tags_dict, selected_genres, total_recommend=5):
    if not selected_genres:
        raise ValueError("최소 1개의 장르를 선택해야 합니다.")
    if len(selected_genres) > 3:
        raise ValueError("장르는 최대 3개까지 선택 가능합니다.")

    tag_text = " ".join([v for v in user_tags_dict.values() if v])
    if not tag_text:
        raise ValueError("입력된 태그가 없습니다.")

    user_emb = model.encode(tag_text, convert_to_tensor=True)
    top_songs, seen_titles = [], []

    for genre in selected_genres:
        subset = df[df["장르"] == genre].copy()
        if subset.empty:
            continue
        subset["similarity"] = subset["embedding"].apply(lambda e: util.cos_sim(user_emb, e)[0][0].item())
        top = subset.sort_values(by="similarity", ascending=False).head(1)
        top_songs.append(top)
        seen_titles.extend(top["곡명"].tolist())

    all_subset = df[df["장르"].isin(selected_genres)].copy()
    all_subset["similarity"] = all_subset["embedding"].apply(lambda e: util.cos_sim(user_emb, e)[0][0].item())
    remain = all_subset[~all_subset["곡명"].isin(seen_titles)]
    extra = remain.sort_values(by="similarity", ascending=False).head(max(0, total_recommend - len(top_songs)))

    final_df = pd.concat(top_songs + [extra], ignore_index=True)
    return final_df[["곡명", "가수", "similarity", "장르", "상황태그", "감성태그", "시간대태그", "스타일태그", "날씨태그", "계절태그"]]

# ✅ 메인 실행
try:
    if 'dbutils' not in globals():
        raise EnvironmentError("Databricks 환경에서 실행해주세요.")

    img_path = dbutils.widgets.get("image_path")
    genres_str = dbutils.widgets.get("genres")

    genres = json.loads(genres_str)
    local_path = f"/dbfs{img_path}"
    if not os.path.exists(local_path):
        raise FileNotFoundError(f"파일이 존재하지 않습니다: {local_path}")

    with open(local_path, "rb") as f:
        raw_image = f.read()

    validated = validate_image_data(raw_image)
    if not validated:
        raise ValueError("이미지 데이터가 유효하지 않습니다.")

    b64_image = base64.b64encode(validated).decode("utf-8")
    tags = extract_tags_from_image(b64_image)
    recs = final_recommend(tags, genres)

    result = {
        "tags": tags,
        "recommendations": recs.to_dict(orient="records")
    }

    dbutils.notebook.exit(json.dumps(result, ensure_ascii=False))

except Exception as e:
    import traceback
    err = {
        "error": str(e),
        "error_details": traceback.format_exc(),
        "tags": {},
        "recommendations": []
    }
    dbutils.notebook.exit(json.dumps(err, ensure_ascii=False))