In [1]:
import re
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

os.environ["HF_HOME"] = "/data/gus"            # controls all HF caches
os.environ["HF_HUB_CACHE"] = "/data/gus"   # models, tokenizers, Hub files
os.environ["HF_DATASETS_CACHE"] = "/data/gus"  # datasets/Arrow files

print("HF_HUB_CACHE:", os.getenv("HF_HUB_CACHE"))
print("HF_DATASETS_CACHE:", os.getenv("HF_DATASETS_CACHE"))
print("HF_HOME:", os.getenv("HF_HOME"))

HF_HUB_CACHE: /data/gus
HF_DATASETS_CACHE: /data/gus
HF_HOME: /data/gus


In [3]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from transformers import AutoTokenizer, AutoModelForCausalLM

In [5]:
MODEL_PATH = "Qwen/Qwen2.5-7B-Instruct" 

# โหลด tokenizer พร้อม trust_remote_code=True (จำเป็นสำหรับ Qwen)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

# โหลดโมเดล
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16)

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

In [None]:

# def inference(messages: list[dict[str, str]], model, tokenizer) -> str:
#     # สำหรับ Qwen, apply_chat_template ยังใช้ได้เหมือนเดิม
#     text = tokenizer.apply_chat_template(
#         messages,
#         tokenize=False,
#         add_generation_prompt=True,
#     )
#     model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

#     generated_ids = model.generate(
#         model_inputs.input_ids,
#         max_new_tokens=768,
#         do_sample=False,
#         temperature=None,
#         top_p=None,
#         top_k=None,
#     )
#     generated_ids = [
#         output_ids[len(input_ids):]
#         for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
#     ]

#     response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

#     match = re.search(r"\b([ABCDE]|Rise|Fall)\b", response)
#     if match:
#         return match.group(0)
#     else:
#         print("Not Match: ", response)
#         return "No valid answer found in the response."

In [9]:
# Inference Function
def inference(messages: list[dict[str, str]], model, tokenizer) -> str:
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=768,
        do_sample=False,
        temperature=None,
        top_p=None,
        top_k=None,
    )
    generated_ids = [
        output_ids[len(input_ids):]
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # Extract only valid answer (A, B, D, E, Rise, Fall)
    
    match = re.search(r"\b([ABCDE]|Rise|Fall)\b", response)
    if match:
        return match.group(0)
    else:
        print("Not Match: ", response)
        return "No valid answer found in the response."
    


In [10]:
# โหลดไฟล์คำถาม
questions = pd.read_csv("/data/data/week10/test.csv")

SYSTEM_PROMPT = """
You are a highly knowledgeable and structured AI financial analyst.

Your role is to carefully analyze a wide range of financial questions, which may include:
- Choosing the most appropriate answer from multiple choices (A, B, C, D, Fall, Rise)
- Making predictions about financial trends based on provided data
- Explaining key economic terms and their impact on the financial question

You must always:
1. Read and understand the question carefully.
2. Identify any key financial terms or data.
3. Think step-by-step and logically evaluate all available information.
4. If needed, define key concepts briefly.
5. Arrive at the most appropriate answer using your reasoning.

**Respond using the exact format below (no extra text):**
Answer: <A/B/C/D/Fall/Rise>

Only respond with the final answer in that exact format.
"""


In [11]:

# Inference Loop
results = []

for row in tqdm(questions.to_dict(orient="records")):
    question = row['query']
    
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question},
    ]
    
    answer = inference(messages, model, tokenizer)
    results.append({"id": row['id'], "answer": answer})

# Results DataFrame
answer_df = pd.DataFrame(results)
print(answer_df.head())


  5%|▌         | 26/499 [00:03<01:01,  7.74it/s]

Not Match:  Down


 10%|▉         | 49/499 [00:06<00:48,  9.34it/s]

Not Match:  Down


 29%|██▉       | 147/499 [00:18<00:45,  7.66it/s]

Not Match:  Answer: 下降


 56%|█████▌    | 279/499 [00:34<00:33,  6.63it/s]

Not Match:  Answer: ลดลง


 72%|███████▏  | 357/499 [00:45<00:21,  6.75it/s]

Not Match:  Answer: ลดลง


100%|██████████| 499/499 [01:02<00:00,  7.92it/s]

                                     id answer
0  36deab86-cfd3-48b5-9bea-a36c1b0e63a8      C
1  2b5bbd26-45e8-4768-ab8a-b5dc1d153ab7      B
2  8a722080-bc16-49db-89c9-100cd61cd28a      A
3  75316e95-88f4-4fef-83b9-dde0aa52889a      B
4  bcca13bc-2675-4645-82cc-7e4c412ed294   Rise





In [12]:
answer_df

Unnamed: 0,id,answer
0,36deab86-cfd3-48b5-9bea-a36c1b0e63a8,C
1,2b5bbd26-45e8-4768-ab8a-b5dc1d153ab7,B
2,8a722080-bc16-49db-89c9-100cd61cd28a,A
3,75316e95-88f4-4fef-83b9-dde0aa52889a,B
4,bcca13bc-2675-4645-82cc-7e4c412ed294,Rise
...,...,...
494,c9dd262e-405c-4078-baae-262aa48ddcc8,A
495,73c720b5-1101-4790-af52-3366823e1d32,B
496,357db18f-d872-416e-a07f-753099853d9c,D
497,2d8b1419-1c46-4e83-892a-081fb417de38,Rise


In [13]:
answer_df.to_csv("/data/gus/week8/submission/Test_for_base_gwen2.5_7b_inst.csv", index=False)

In [14]:
answer_df

Unnamed: 0,id,answer
0,36deab86-cfd3-48b5-9bea-a36c1b0e63a8,C
1,2b5bbd26-45e8-4768-ab8a-b5dc1d153ab7,B
2,8a722080-bc16-49db-89c9-100cd61cd28a,A
3,75316e95-88f4-4fef-83b9-dde0aa52889a,B
4,bcca13bc-2675-4645-82cc-7e4c412ed294,Rise
...,...,...
494,c9dd262e-405c-4078-baae-262aa48ddcc8,A
495,73c720b5-1101-4790-af52-3366823e1d32,B
496,357db18f-d872-416e-a07f-753099853d9c,D
497,2d8b1419-1c46-4e83-892a-081fb417de38,Rise
