In [None]:
import os
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import openai
import faiss
import pandas as pd
import os
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
import numpy as np
from groq import Groq


In [None]:
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
index_context = faiss.read_index('/####')
index_title = faiss.read_index('/####')
df = pd.read_csv('/####')

client = Groq(
    api_key=os.environ.get("GROQ_API_KEY")
)

In [None]:
################################################################################
#
#                             Story Reader
#
################################################################################

def extract_keywords(question, top_n=5):
    """
    질문에서 불필요한 단어를 제거하고 핵심 키워드를 추출.
    """
    # 사용자 정의 불용어 리스트
    custom_stop_words = [
        "들려줘", "이야기", "이야기를", "장군", "알려줘", "주세요", "관련된", "어떤", "동화", "동화를", "해줘", "줘", "들려", "읽어", "읽어줘", "읽어주세요"
    ]

    # 특수문자 제거 및 공백 정리
    question_cleaned = ''.join(char if char.isalnum() else ' ' for char in question).strip()

    # CountVectorizer로 단어 빈도 계산
    vectorizer = CountVectorizer(
        max_features=top_n,
        stop_words=custom_stop_words,  # stop_words를 리스트로 전달
        token_pattern=r'\b\w+\b'
    )
    question_vector = vectorizer.fit_transform([question_cleaned])
    keywords = vectorizer.get_feature_names_out()
    return list(keywords)

def retrieve_and_debug_with_filter(
    question, index_title, index_context, df, model, top_k=5, title_weight=0.7, keyword=None
):
    question_embedding = model.encode([question])

    _, title_indices = index_title.search(question_embedding, top_k)
    _, context_indices = index_context.search(question_embedding, top_k)

    combined_scores = {}
    for i in range(top_k):
        t_idx = title_indices[0][i]
        c_idx = context_indices[0][i]

        combined_scores[t_idx] = combined_scores.get(t_idx, 0) + title_weight
        combined_scores[c_idx] = combined_scores.get(c_idx, 0) + (1 - title_weight)

    sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

    # 키워드 필터링
    filtered_results = []
    for idx, score in sorted_results[:top_k]:
        title = df.iloc[idx]['Title']
        context = df.iloc[idx]['Context']

        if keyword is not None and len(keyword) > 0:  # 키워드가 존재하는 경우에만 처리
            if isinstance(keyword, list):
                # 키워드가 배열인 경우
                keyword_relevance = any(
                    keyword_word in title or keyword_word in context
                    for keyword_word in keyword
                    if isinstance(keyword_word, str)
                )
            elif isinstance(keyword, str):
                # 키워드가 문자열인 경우
                keyword_relevance = keyword in title or keyword in context
            else:
                keyword_relevance = False

            if not keyword_relevance:
                continue

        filtered_results.append((idx, score, title, context))

    # 검색 결과가 없으면 점수가 높은 결과를 반환
    if not filtered_results:
        for idx, score in sorted_results[:top_k]:
            title = df.iloc[idx]['Title']
            context = df.iloc[idx]['Context']
            filtered_results.append((idx, score, title, context))

    return filtered_results



def generate_answer_with_full_context(results):
    """
    검색된 결과를 기반으로 LLM을 통해 동화 전체 내용을 구어체로 읽어주는 답변 생성.
    """
    prompt_template = """
        You are taking on the role of a kindergarten teacher who tells stories to children in an engaging and entertaining way.
        Ensure the stories are narrated in simple, conversational language that children can easily understand and enjoy.
        Always use correct grammar, and employ imaginative expressions to captivate the children’s attention.
        If you find that the requested story and the provided data are unrelated, create a new story that matches the request and narrate it accordingly.
        When discussing historical events or biographies, ensure the timeline is accurate and avoid including incorrect information.
        End each story with a valuable lesson for children.
        Ensure the narrative flows naturally and avoids awkward transitions.
        Beware of repetitive responses.
        You must always provide responses in Korean.
    """

    if not results:
        return "관련된 동화를 찾을 수 없습니다. 다른 질문을 해보세요!"

    for idx, _, title, context in results:
        messages = [
            {"role": "system", "content": prompt_template + "You are a kindergarten teacher. Narrate the following story in a friendly and engaging way for children."},
            {"role": "user", "content": prompt_template + f"Title: {title}\n\nContext:\n{context}\n\nPlease narrate this story in a simple and conversational style for children."}
        ]

        response = client.chat.completions.create(
            model="llama-3.3-70b-versatile",
            messages=messages,
            max_tokens=1500,
            temperature=0.2,
        )
        # 첫 번째 적합한 결과 반환
        return response.choices[0].message.content.strip()

def handle_user_query(question):
    """
    사용자의 질문에 따라 답변을 생성하는 메인 함수.
    """
    keywords = extract_keywords(question)
    #print(f"추출된 키워드: {keywords}")
    results = retrieve_and_debug_with_filter(
        question=question,
        index_title=index_title,
        index_context=index_context,
        df=df,
        model=model,
        keyword=keywords
    )
    return generate_answer_with_full_context(results)

In [None]:
################################################################################
##################      음악 + 삽화 생성      ##################################
################################################################################

import os
import warnings
import torch
import torchaudio
import uvicorn
import threading
import operator
import numpy as np
import scipy.io.wavfile as wav
import soundfile as sf
from typing import TypedDict, List, Optional, Annotated, Any

# FastAPI 및 관련 모듈
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel

# LangChain 및 LangGraph 관련 모듈
from langgraph.graph import StateGraph
from langgraph.constants import START, END
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain_core.runnables import RunnableConfig
from langchain_groq import ChatGroq

# Diffusers 관련 모듈 (이미지 생성)
from diffusers import StableDiffusion3Pipeline

# SciKit-Learn (유사도 분석)
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

# ngrok 및 코랩 환경을 위한 모듈
import nest_asyncio
from pyngrok import ngrok

# 경고 메시지 필터링
warnings.filterwarnings("ignore", category=UserWarning)

# 그래프 상태 정의
class GraphState(TypedDict):
    user_input: Annotated[str, operator.add]
    story_content: Annotated[str, lambda a, b: b]
    atmosphere: Annotated[str, operator.add]
    music_path: Annotated[str, operator.add]  # 생성된 음악 파일 경로
    assist_response: Annotated[str, operator.add]  # 어시스턴트 응답
    classification: Annotated[str, lambda a, b: b]
    illustration_path: Annotated[str, operator.add]  # 생성된 삽화 파일 경로
    follow_flag: Annotated[bool, lambda a, b: b]
    stop_flag: Annotated[bool, lambda a, b: b]
    matched_mood: Annotated[str, lambda a, b: b]

memory = ConversationBufferMemory(memory_key="chat_history")

# Story Reader 기능
def story_reader_func(state: GraphState) -> GraphState:
    story_content = handle_user_query(state["user_input"])
    state["story_content"] = story_content
    return state

def assistant_service(state: GraphState) -> GraphState:
    # 사용자 입력 가져오기
    user_input = state["user_input"]

    # 대화 히스토리 가져오기
    conversation_history = memory.load_memory_variables({})["chat_history"]

    # 프롬프트 템플릿 구성
    prompt_template = """
        You are an AI assistant who provides helpful and accurate answers to user questions.
        You must always provide responses in Korean.
        Don't include emojis in your responses.
        If the conversation involves children, make sure to use simple and easy-to-understand words. Avoid using difficult or complex vocabulary.
        Here is the conversation history:
        {chat_history}

        User Input: {user_query}

        Response:
    """
    prompt = PromptTemplate(
        input_variables=["chat_history", "user_query"],
        template=prompt_template
    )

    # LLM 호출
    llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0.2)
    response = llm.predict(
        prompt.format(
            chat_history=conversation_history,
            user_query=user_input
        )
    ).strip()

    # 대화 메모리 업데이트
    memory.save_context({"user_input": user_input}, {"response": response})
    state["assist_response"] = response

    return state


def analyze_atmosphere_func(state: GraphState) -> GraphState:
    llm = ChatGroq(model='gemma2-9b-it', temperature=0.7)
    prompt = f"""
    You are an illustration expert specializing in creating captivating and visually enchanting illustrations for children's fairy tales.
    Your task is to craft a detailed and imaginative prompt for an AI drawing model to generate a scene based on the following story context:

    Story Context:
    {state['story_content']}

    The illustration should bring the story to life by focusing on these aspects:
    - Highlight a key moment or turning point in the story that is emotionally impactful or visually stunning.
    - Provide a detailed description of the protagonist, including clothing style, facial expressions, posture, and distinguishing features that reflect their personality or role in the story.
    - Enrich the environment by describing the setting, lighting, colors, and background elements that enhance the scene's mood and atmosphere.
    - Incorporate magical or fantastical elements seamlessly into the scene, such as glowing objects, sparkling effects, or whimsical creatures, to emphasize the fairy-tale theme.
    - Ensure the scene evokes the intended emotion or mood, whether it’s whimsical, adventurous, mysterious, or heartwarming.
    - Maintain a balanced composition between the protagonist, environment, and supporting details to create a harmonious and visually appealing illustration.

    Example Format:
    "[Scene Description: Include protagonist details, setting, lighting, colors, magical elements, and overall mood.]"
    """

    response = llm.predict(prompt).strip()
    state["atmosphere"] = response
    return state

def analyze_music_mood_func(state: GraphState) -> GraphState:
    llm = ChatGroq(model='llama-3.3-70b-versatile', temperature=0.7)

    # 15가지 미리 정의된 분위기
    predefined_moods = [
        "Magical and Whimsical", "Peaceful and Serene", "Mysterious and Suspenseful",
        "Joyful and Playful", "Adventurous and Exciting", "Melancholic and Reflective",
        "Magical Darkness", "Royal and Majestic", "Romantic and Dreamy",
        "Tense and Dangerous", "Mystical and Ethereal", "Cheerful and Festive",
        "Dark and Foreboding", "Bright and Hopeful", "Action-Packed and Heroic"
    ]

    # 프롬프트 구성
    prompt = f"""
    You are tasked with determining the most appropriate mood for a background music track based on the following story context.

    Here are 15 predefined moods:
    {', '.join(predefined_moods)}

    Story Context:
    {state['story_content']}

    Please respond with the most fitting mood from the list above. Provide only the mood name.
    """

    # LLM 호출
    response = llm.predict(prompt).strip()

    # LLM의 선택이 유효한지 확인
    if response in predefined_moods:
        state["matched_mood"] = response
    else:
        state["matched_mood"] = "custom"  # 유효하지 않으면 custom으로 처리

    print(f"분석된 음악 분위기: {state['matched_mood']}")
    return state
import os

# 미리 생성된 음악 파일 경로
predefined_music_files = {
    "Magical and Whimsical": "magical_and_whimsical.wav",
    "Peaceful and Serene": "peaceful_and_serene.wav",
    "Mysterious and Suspenseful": "mysterious_and_suspenseful.wav",
    "Joyful and Playful": "joyful_and_playful.wav",
    "Adventurous and Exciting": "adventurous_and_exciting.wav",
    "Melancholic and Reflective":"melancholic_and_reflective.wav",
    "Magical Darkness": "magical_darkness.wav",
    "Royal and Majestic": "royal_and_majestic.wav",
    "Romantic and Dreamy": "romantic_and_dreamy.wav",
    "Tense and Dangerous": "tense_and_dangerous.wav",
    "Mystical and Ethereal ":"mystical_and_ethereal.wav",
    "Cheerful and Festive": "cheerful_and_festive.wav",
    "Dark and Forebodingl": "dark_and_foreboding.wav",
    "Bright and Hopeful": "bright_and_hopeful.wav",
    "Action-Packed and Heroic": "action-packed_and_heroic.wav"
    # 나머지 분위기 파일 경로를 추가
}

def generate_music_func(state: GraphState) -> GraphState:
    matched_mood = state.get("matched_mood", "custom")

    if matched_mood != "custom":
        # 미리 생성된 음악 반환
        music_path = predefined_music_files.get(matched_mood, "default_music.wav")
        if os.path.exists(music_path):
            state["music_path"] = music_path
            print(f"미리 생성된 음악 선택: {music_path}")
            return state

    # 새로운 음악 생성
    model = MusicGen.get_pretrained("small")
    model.set_generation_params(duration=30)

    prompt = state["atmosphere"]
    print(f"새로운 음악 생성 중... 분위기: {prompt}")
    wav = model.generate(prompt)
    music_path = "background_music_custom.wav"
    torchaudio.save(music_path, wav[0].cpu(), sample_rate=model.sample_rate)

    state["music_path"] = music_path
    print(f"새로운 배경음악 생성 완료: {music_path}")
    return state


# 삽화 생성 노드
def generate_illustration_func(state: GraphState) -> GraphState:
    model_id = "stabilityai/stable-diffusion-3.5-large"
    pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe = pipe.to("cuda")

    prompt = f"""
      Create a captivating and visually enchanting illustration for a children's fairy tale.
      The scene should reflect the mood: {state['atmosphere']}, while bringing the story to life.

      Guidelines for the illustration:
      - Art Style: Use a whimsical and vibrant art style with soft pastel colors, smooth, hand-drawn lines, and a magical atmosphere.
      - Characters: Depict expressive and engaging characters with exaggerated features that clearly convey emotions (e.g., joy, wonder, excitement). Focus on the protagonist’s unique traits, such as their outfit, posture, and facial expressions.
      - Environment: Design a richly detailed and immersive setting that enhances the mood. Include elements like glowing objects, fantastical landscapes, or playful surroundings that match the fairy-tale theme.
      - Composition: Ensure a visually striking and balanced composition, with the protagonist as the focal point, harmonizing with the environment and supporting elements.
      - Magical Details: Integrate magical or fantastical elements seamlessly, such as sparkling trails, glowing effects, whimsical creatures, or otherworldly features.
      - Emotional Impact: Capture the essence of a pivotal or heartwarming moment from the story, making the illustration emotionally impactful and memorable.
    """
    print(f"삽화 생성 중... 프롬프트: {prompt}")

    image = pipe(
        prompt,
        num_inference_steps=40,
        guidance_scale=4.5,
    ).images[0]

    illustration_path = "story_illustration.png"
    image.save(illustration_path)
    print(f"삽화 생성 완료: {illustration_path}")

    state["illustration_path"] = illustration_path
    return state

def classify_input(user_input: str) -> str:
    prompt = f"""
    You are an evaluator who classifies the type of user question.
    The user question is as follows: {user_input}

    Classification criteria:
    - 'story': If the input involves reading a storybook or storytelling.
    - 'normal': For general conversation or assistance.
    - 'follow': If the input involves following, mimicking, or synchronizing with the user.
    - 'stop' : If the input includes phrases like "멈춰", "그만", "정지", "쫓아오지마", or any similar expression indicating a command to stop or cease an action.

    Respond **only** with one of the following words (without explanation):
    - 'story'
    - 'normal'
    - 'follow'
    - 'stop'

    Your response must be one of these exact words.
    """

    llm = ChatGroq(model='gemma2-9b-it', temperature=0.2)
    response = llm.predict(prompt.format(user_input=user_input)).strip()

    if response not in ["story", "normal", "follow", "stop"]:
      print(f"[ERROR] Invalid classification response: {response}")
      raise ValueError(f"Unexpected classification value from LLM: {response}")

    print(f"[DEBUG] classify_input 결과: {response}")
    return response

def decision_node(state: GraphState) -> GraphState:
    classification = classify_input(state["user_input"])
    if classification not in ["story", "normal", 'follow', 'stop']:
        raise ValueError(f"Unexpected classification value: {classification}")

    # 기존 값을 덮어쓰기
    state["classification"] = classification.strip()
    print(f"[DEBUG] 분류 결과: {state['classification']}")
    return state

def follow_func(state: GraphState) -> GraphState:
    state["follow_flag"] = True
    state["stop_flag"] = False
    print("follow_on")
    return state

def stop_func(state: GraphState) -> GraphState:
    state["stop_flag"] = True
    state["follow_flag"] = False
    print("Stop!")
    return state

# LangGraph 초기화
lang_graph = StateGraph(state_schema=GraphState)

# 노드 추가
lang_graph.add_node("decision", decision_node)
lang_graph.add_node("follow_func", follow_func)
lang_graph.add_node("stop_func", stop_func)
lang_graph.add_node("story_reader", story_reader_func)
lang_graph.add_node("analyze_atmosphere", analyze_atmosphere_func)
lang_graph.add_node("generate_illustration", generate_illustration_func)
lang_graph.add_node("assistant_service", assistant_service)
lang_graph.add_node("analyze_music_mood", analyze_music_mood_func)
lang_graph.add_node("generate_music", generate_music_func)
# 엣지 추가
lang_graph.add_edge(START, "decision")
lang_graph.add_conditional_edges(
    source="decision",
    path=lambda state: state["classification"],
    path_map={
        "story": "story_reader",
        "normal": "assistant_service",
        "follow": "follow_func",
        "stop" : "stop_func",
    }
)
lang_graph.add_edge("story_reader", "analyze_atmosphere")
lang_graph.add_edge("story_reader", "analyze_music_mood")
lang_graph.add_edge("analyze_music_mood", "generate_music")
lang_graph.add_edge("analyze_atmosphere", "generate_illustration")
lang_graph.add_edge("generate_music", END)
lang_graph.add_edge("generate_illustration", END)
lang_graph.add_edge("assistant_service", END)
lang_graph.add_edge("follow_func", END)
lang_graph.add_edge("stop_func", END)


# 그래프 컴파일
compiled_graph = lang_graph.compile()

# FastAPI 애플리케이션 생성
app = FastAPI()

# 요청 모델 정의
class UserRequest(BaseModel):
    user_input: str

# 응답 모델 정의
class ServerResponse(BaseModel):
    classification: Optional[str] = None
    assist_response: Optional[str] = None
    story_output: Optional[str] = None
    music_path: Optional[str] = None
    illustration_path: Optional[str] = None
    follow_flag: Optional[bool] = False
    stop_flag: Optional[bool] = False

# 엔드포인트 추가
@app.post("/process", response_model=ServerResponse)
def process_request(request: UserRequest):
    try:
        # 초기 상태 설정
        initial_state: GraphState = {
          "user_input": request.user_input,  # 리스트로 초기화
          "story_content": "",
          "atmosphere": "",
          "music_path": "",
          "assist_response": "",
          "classification": "",
          "illustration_path": "",
          "follow_flag": False,
          "stop_flag": False
        }

        # 컴파일된 그래프 실행
        final_state = compiled_graph.invoke(initial_state)

        classification = final_state.get("classification", "").strip()
        if classification not in ["story", "normal", "follow", "stop"]:
            raise ValueError(f"Unexpected classification value: {classification}")

        # 분류 및 결과 처리
        classification = final_state.get("classification", "unknown")
        if classification == "story":
            return ServerResponse(
                story_output=final_state.get("story_content", "No story content available."),
                music_path=final_state.get("music_path", "No music available."),
                illustration_path=final_state.get("illustration_path", "No illustration available.")
            )
        elif classification == "normal":
            return ServerResponse(
                assist_response=final_state.get("assist_response", "No assist response available.")
            )
        elif classification == "follow":
            return ServerResponse(
                follow_flag=final_state.get("follow_flag", False)
            )
        elif classification == "stop":
            return ServerResponse(
                stop_flag=final_state.get("stop_flag", False)
            )
        else:
            raise ValueError(f"LastSession Unexpected classification value: {classification}")
    except Exception as e:
        print(f"Error during processing: {e}")
        raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")


@app.get("/music/{filename}")
def get_music_file(filename: str):
    file_path = os.path.join("./", filename)  # 파일이 저장된 경로
    if not os.path.exists(file_path):
        raise HTTPException(status_code=404, detail="File not found")
    return FileResponse(path=file_path, media_type="audio/wav", filename=filename)

@app.get("/illustration/{filename}")
def get_illustration_file(filename: str):
    file_path = os.path.join("./", filename)  # 파일이 저장된 경로
    if not os.path.exists(file_path):
        raise HTTPException(status_code=404, detail="File not found")
    return FileResponse(path=file_path, media_type="image/png", filename=filename)

# 코랩 환경에서 서버 실행을 위한 메인 블록
if __name__ == "__main__":
    import nest_asyncio
    from pyngrok import ngrok

    ngrok.set_auth_token("YOUR_TOKEN")
    # 코랩 환경에서 ngrok 터널 설정
    nest_asyncio.apply()
    public_url = ngrok.connect(8000)
    print("Public URL:", public_url.public_url)

    # FastAPI 서버 실행
    uvicorn.run(app, host="0.0.0.0", port=8000)
