In [None]:
!pip install -q -U transformers accelerate
!pip install -q sentence-transformers faiss-cpu pandas tqdm
import pandas as pd
import torch
import re
import faiss
import numpy as np
from tqdm.auto import tqdm
import huggingface_hub
from typing import Optional, List, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer

class QuantitativeStockAnalyzer:
    def analyze(self, context: str, tweets_str: str) -> (str, str):
        try:
            lines = context.strip().split('\n')
            if 'date,open,high,low,close' in lines[0].lower(): df = pd.read_csv(pd.io.common.StringIO('\n'.join(lines)))
            else:
                headers = ['date','open','high','low','close','adj-close','inc-5','inc-10','inc-15','inc-20','inc-25','inc-30']
                df = pd.read_csv(pd.io.common.StringIO('\n'.join(lines)), header=None, names=headers[:len(lines[0].split(','))])
            df['close'] = pd.to_numeric(df['close'], errors='coerce').dropna()
            if len(df['close']) >= 10:
                ma5 = df['close'].rolling(window=5).mean().iloc[-1]
                ma10 = df['close'].rolling(window=10).mean().iloc[-1]
                tech_summary = "Bullish short-term trend (5-day MA > 10-day MA)." if ma5 > ma10 else "Bearish short-term trend (5-day MA < 10-day MA)."
            else: tech_summary = "Insufficient data for moving average."
        except: tech_summary = "Technical data parsing failed."
        positive_keywords = ['bullish', 'gains', 'upgraded', 'highs', 'beat', 'positive', 'good', 'soaring', 'growth', 'rebound', 'buying', 'breakout', 'support', 'boosted', 'reiterates']
        negative_keywords = ['bearish', 'loss', 'downgraded', 'lows', 'miss', 'negative', 'flaw', 'risk', 'sells', 'cuts', 'weakness', 'decline', 'drop', 'warning', 'pressure']
        pos_count = sum(1 for word in positive_keywords if word in tweets_str.lower())
        neg_count = sum(1 for word in negative_keywords if word in tweets_str.lower())
        net_sentiment = pos_count - neg_count
        if net_sentiment > 1: sentiment_summary = "Strongly Positive"
        elif net_sentiment > 0: sentiment_summary = "Slightly Positive"
        elif net_sentiment < -1: sentiment_summary = "Strongly Negative"
        elif net_sentiment < 0: sentiment_summary = "Slightly Negative"
        else: sentiment_summary = "Neutral"
        return tech_summary, f"Net sentiment from tweets is {sentiment_summary}."

class ComplianceAgentRAG:
    def __init__(self):
        self.principles = ["Principle: Fiduciary Duty. Action: An advisor must prioritize the client's interests. Breach: Recommending a high-commission fund unsuitable for a client's risk profile.","Principle: Suitability. Action: Recommendations must match the client's risk tolerance, goals, and situation.","Principle: Full Disclosure. Action: All risks, fees, and conflicts of interest must be clearly explained.","Principle: AML/CFT. Action: Report suspicious transactions (e.g., from high-risk jurisdictions, unclear source of funds) to AMLO and refuse onboarding if CDD fails.","Principle: Data Privacy (PDPA). Action: Must obtain explicit, informed consent before sharing client data with third parties.","Principle: Market Integrity. Action: Prohibit creating artificial volume or price movement.","Principle: Cybersecurity. Action: Critical vulnerabilities must be patched immediately. Incidents must be reported to the BOT."]
        self.sbert_model = SentenceTransformer('all-MiniLM-L6-v2', device='cuda' if torch.cuda.is_available() else 'cpu')
        self.index = self._build_index()
    def _build_index(self):
        embeddings = self.sbert_model.encode(self.principles, convert_to_tensor=True)
        index = faiss.IndexFlatL2(embeddings.shape[1]); index.add(embeddings.cpu().detach().numpy())
        return index
    def retrieve_relevant_principles(self, query: str, k: int = 3) -> List[str]:
        query_embedding = self.sbert_model.encode([query]); _, indices = self.index.search(query_embedding, k)
        return [self.principles[i] for i in indices[0]]

def classify_query(query: str) -> str:
    q_lower = query.lower()
    if 'project whether the closing price' in q_lower or 'ราคาปิดของ' in q_lower: return "stock_prediction"
    if 'most appropriate action' in q_lower or 'แนวทางที่เหมาะสมที่สุด' in q_lower or 'best aligns with the ethical' in q_lower: return "ethical_scenario"
    return "general_mcq"

def get_surgical_prompt(prompt_type: str, **kwargs) -> Dict[str, str]:
    system_prompt = "You are Typhoon, an elite financial analyst with a CFA charter, renowned for your accuracy and ethical judgment. You will reason step-by-step before providing a final, concise answer."
    # --- REINFORCED PROMPT FOR STOCK PREDICTION ---
    if prompt_type == "stock_prediction":
        user_prompt = f"""**Task:** Predict the stock price movement.\n\n**Chain of Thought:**\n1. **Technical Analysis:** The trend is: {kwargs['tech_analysis']}.\n2. **Sentiment Analysis:** The market mood is: {kwargs['sentiment_analysis']}.\n3. **Synthesis & Conclusion:** Weigh the technicals against the sentiment to make a definitive prediction.\n\n**Full Context:**\n{kwargs['query']}\n\nAfter your reasoning, you MUST conclude with a single line containing ONLY the final answer: "Final Answer: [Rise/Fall]"."""
    elif prompt_type == "ethical_scenario":
        principles_str = "\n- ".join(kwargs['principles'])
        user_prompt = f"""**Task:** Identify the MOST appropriate action based on Thai SEC/BOT regulations.\n\n**Chain of Thought:**\n1. **Core Issue:** The central conflict here is about {kwargs.get('issue', 'regulatory compliance')}.\n2. **Guiding Principles:** The most relevant principles are:\n- {principles_str}\n3. **Option Evaluation:** Assess each option (A, B, C, D) against these principles.\n4. **Conclusion:** Select the single best option.\n\n**Scenario:**\n{kwargs['query']}\n\nAfter your reasoning, conclude with the line "Final Answer: [letter]"."""
    else: # general_mcq
        user_prompt = f"""**Task:** Answer the following multiple-choice question correctly.\n\n**Chain of Thought:**\n1. **Question Analysis:** The question tests the concept of {kwargs.get('concept', 'a financial principle')}.\n2. **Option Evaluation:** Assess options A, B, C, D, and E.\n3. **Conclusion:** State the correct option.\n\n**Question:**\n{kwargs['query']}\n\nAfter your reasoning, conclude with the line "Final Answer: [letter]"."""
    return {"system": system_prompt, "user": user_prompt}

MODEL_ID = "KBTG-Labs/THaLLE-0.1-7B-fa"

def inference(messages: List[Dict[str, str]], model, tokenizer) -> str:
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=350, do_sample=False)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

def extract_answer(response_text: str, query_type: str) -> Optional[str]:
    """
    Extracts the final answer from the model's response with prioritization.
    """
    # Priority 1: Look for the explicit "Final Answer:" pattern from our CoT prompt.
    final_answer_match = re.search(r"Final Answer:\s*(Rise|Fall|[A-E])", response_text, re.IGNORECASE)
    if final_answer_match:
        ans = final_answer_match.group(1).capitalize()
        # Capitalize() makes 'rise' -> 'Rise', 'a' -> 'A'
        if ans in ["Rise", "Fall"]:
            return ans
        return ans.upper()

    # Priority 2 (for stock predictions): If CoT fails, search the whole text for the keywords.
    # This is a strong fallback for Rise/Fall questions.
    if query_type == "stock_prediction":
        if 'rise' in response_text.lower() or 'ขึ้น' in response_text: return 'Rise'
        if 'fall' in response_text.lower() or 'ลง' in response_text: return 'Fall'

    # Priority 3 (for MCQs): Find the last mentioned capital letter A-E. This often is the model's conclusion.
    mcq_matches = re.findall(r'\b([A-E])\b', response_text)
    if mcq_matches:
        return mcq_matches[-1]

    return None # Return None if no valid answer can be parsed

def run_typhoon_agent(input_path: str, output_path: str):
    df = pd.read_csv(input_path)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto")

    stock_agent = QuantitativeStockAnalyzer()
    compliance_agent = ComplianceAgentRAG()

    predictions = []
    for i, row in tqdm(df.iterrows(), total=len(df), desc="Typhoon Agent Processing"):
        query = row["query"]
        prediction = None # Start with a clean prediction

        try:
            query_type = classify_query(query)
            prompt_data = {}

            if query_type == "stock_prediction":
                context_match = re.search(r'[Cc]ontext:(.*?)(?:\n\n\d{4}|\Z)', query, re.DOTALL)
                tweets_match = re.search(r'\d{4}-\d{2}-\d{2}:.*', query, re.DOTALL)
                context = context_match.group(1).strip() if context_match else ""
                tweets = tweets_match.group(0).strip() if tweets_match else ""
                tech_summary, sentiment_summary = stock_agent.analyze(context, tweets)
                prompt_data = get_surgical_prompt('stock_prediction', query=query, tech_analysis=tech_summary, sentiment_analysis=sentiment_summary)
            elif query_type == "ethical_scenario":
                relevant_principles = compliance_agent.retrieve_relevant_principles(query)
                prompt_data = get_surgical_prompt('ethical_scenario', query=query, principles=relevant_principles)
            else: # general_mcq
                prompt_data = get_surgical_prompt('general_mcq', query=query)

            messages = [{"role": "system", "content": prompt_data["system"]}, {"role": "user", "content": prompt_data["user"]}]

            response = inference(messages, model, tokenizer)
            prediction = extract_answer(response, query_type)

            if prediction is None:
                prediction = "A" # Fallback to default if parsing fails
                print(f"Warning: Could not extract answer for row {i+1}. Defaulting to 'A'. Raw: '{response[:100]}'")

        except Exception as e:
            print(f"ERROR on row {i+1}: {e}")
            prediction = "A" # Fallback on critical error

        predictions.append(prediction)

    df["answer"] = predictions
    df[["id", "answer"]].to_csv(output_path, index=False)
    print(f"\n✅ Predictions saved to: {output_path}")

if __name__ == "__main__":
    input_csv = '/content/combine_testsub.csv'
    output_csv = "submission.csv"

    run_typhoon_agent(input_csv, output_csv)