<a href="https://colab.research.google.com/github/Maximi652/efficient-slm-architectures/blob/main/Qwen3_4B_BaseModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install --upgrade transformers accelerate

In [None]:
# Imports
import torch
import json
import re
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Pfade und Modellname
MODEL_PATH    = "/content/drive/MyDrive/Colab Notebooks/Qwen3-4B"
INPUT_JSON    = "/content/drive/MyDrive/Colab Notebooks/12B_combined_golden.json"
OUTPUT_JSON   = "/content/drive/MyDrive/Colab Notebooks/base-qwen3-4b_testresults.json"

# Tokenizer & Modell laden
model_name = MODEL_PATH

# Lade Tokenizer und Modell mit Chat-Template Support
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Use left padding to avoid decoder-only right-padding issues
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)
model.eval()

# Generation-Konfiguration (keine Gedanken)
gen_conf = GenerationConfig(
    max_new_tokens=200,
    do_sample=False
)

# Hilfsfunktionen für Prompt-Bau und Cleaning
def build_messages(qtext, snippets, qtype, mode="exact"):
    """
    Erzeugt die Chat-Message-Liste mit system directive /no_think
    und user content entsprechend modus.
    """
    # System directive zum Deaktivieren des Denkens
    system_msg = {"role":"system","content":"/no_think"}
    # Context-Snippets
    ctx = "\n".join(s["text"] for s in snippets[:2])
    if mode == "exact":
        if qtype == "yesno":
            content = f"Question: {qtext}\nContext:\n{ctx}\nAnswer only 'yes' or 'no', in English, no extras."
        elif qtype == "factoid":
            content = f"Question: {qtext}\nContext:\n{ctx}\nProvide up to 5 keywords, comma-separated, in English, no commentary."
        elif qtype == "list":
            content = f"Question: {qtext}\nContext:\n{ctx}\nProvide a comma-separated list of relevant items, in English, no filler words."
        else:
            content = f"Question: {qtext}\nContext:\n{ctx}\nProvide a brief answer in English."
    else:  # ideal
        if qtype == "yesno":
            content = f"Question: {qtext}\nContext:\n{ctx}\nProvide one-sentence ideal answer in English starting with 'Yes,' or 'No,'."
        else:
            content = f"Question: {qtext}\nContext:\n{ctx}\nProvide an ideal answer in English (one paragraph, max 200 words, full sentences)."
    # User message
    user_msg = {"role":"user","content":content}
    return [system_msg, user_msg]


def clean_exact(text, qtype):
    txt = text.strip()
    # Remove any thinking tags or filler
    txt = re.sub(r'<\/think>','', txt)
    txt = re.sub(r'\s*(Okay\.?|etc\.?|usw\.?|\.\.\.)$', '', txt, flags=re.IGNORECASE)
    if qtype == "yesno":
        return "yes" if txt.lower().startswith("yes") else "no"
    if qtype in ("factoid","list"):
        items = [i.strip() for i in txt.split(",") if i.strip()]
        return items
    return None


def clean_ideal(text, qtype):
    txt = text.strip()
    txt = re.sub(r'<\/think>','', txt)
    txt = re.sub(r'\s*(Okay\.?|etc\.?|usw\.?|\.\.\.)$', '', txt, flags=re.IGNORECASE)
    # truncate to full sentences
    sentences = re.split(r'(?<=[.!?])\s+', txt)
    if qtype == "yesno":
        return sentences[0].strip()
    # for summary/list/factoid, join until ~200 words
    total = 0
    out = []
    for sent in sentences:
        length = len(sent.split())
        if total + length <= 200:
            out.append(sent)
            total += length
        else:
            break
    return " ".join(out).strip()

# Datensatz laden
with open(INPUT_JSON, 'r', encoding='utf-8') as f:
    questions = json.load(f)['questions']

# Batch-Inferenz
batch_size = 8
submission = []

for i in tqdm(range(0, len(questions), batch_size), desc='Batches'):
    batch = questions[i:i+batch_size]
    # exact
    msgs_ex = [build_messages(q['body'], q.get('snippets',[]), q['type'], mode='exact') for q in batch]
    texts_ex = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, enable_thinking=False) for m in msgs_ex]
    inputs_ex = tokenizer(texts_ex, return_tensors='pt', padding=True, truncation=True).to(model.device)
    # Alternatively:
    # raw_ex = tokenizer(texts_ex, return_tensors='pt', padding=True, truncation=True)
    # inputs_ex = {k: v.to(model.device) for k, v in raw_ex.items()}
    with torch.no_grad():
        out_ex = model.generate(**inputs_ex, generation_config=gen_conf)
    # decode exact
    dec_ex = []
    for idx, q in enumerate(batch):
        start = inputs_ex['input_ids'].shape[1]
        ids = out_ex[idx][start:].tolist()
        text = tokenizer.decode(ids, skip_special_tokens=True)
        dec_ex.append(clean_exact(text, q['type']))
    # ideal
    msgs_id = [build_messages(q['body'], q.get('snippets',[]), q['type'], mode='ideal') for q in batch]
    texts_id = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True, enable_thinking=False) for m in msgs_id]
    inputs_id = tokenizer(texts_id, return_tensors='pt', padding=True, truncation=True).to(model.device)
    with torch.no_grad():
        out_id = model.generate(**inputs_id, generation_config=gen_conf)
    dec_id = []
    for idx, q in enumerate(batch):
        start = inputs_id['input_ids'].shape[1]
        ids = out_id[idx][start:].tolist()
        text = tokenizer.decode(ids, skip_special_tokens=True)
        dec_id.append(clean_ideal(text, q['type']))
    # merge
    for idx, q in enumerate(batch):
        submission.append({
            'id': q['id'],
            'type': q['type'],
            'exact_answer': q.get('exact_answer'),
            'ideal_answer': q.get('ideal_answer'),
            'exact_prediction': dec_ex[idx],
            'ideal_prediction': dec_id[idx]
        })

# Submission speichern
with open(OUTPUT_JSON, 'w', encoding='utf-8') as f:
    json.dump(submission, f, ensure_ascii=False, indent=2)

print('✅ Submission file created at:', OUTPUT_JSON)