In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
filename = 'data/india/2022/2022_india_persona_groups_cleaned_te'

In [3]:
import pandas as pd

df = pd.read_csv(f"{filename}.csv")

rows, cols = df.shape
print(f"Rows: {rows}, Columns: {cols}")

Rows: 8, Columns: 11


In [None]:
import json
with open("data/translated_questions/questions_te.json", "r") as f:
    questions = json.load(f)
with open("data/chosen_cols_updated.json", "r") as f:
    chosen_cols = json.load(f)

chosen_qsns = {}

for qsn in questions:
    if chosen_cols['chosen_cols'][qsn] == True and questions[qsn]['description'] not in chosen_cols['persona_cols']:
        chosen_qsns[qsn] = questions[qsn]

print("Number of questions in the questionnaire - ", len(chosen_qsns))

Number of questions in the questionnaire -  63


## Chosen Persona Features

In [5]:
region = 'N_REGION_ISO: Region ISO 3166-2'
urban_rural = 'H_URBRURAL: Urban-Rural'
age = 'X003R: Age recoded (6 intervals)'
gender = 'Q260: Sex'
language = 'Q272: Language at home'
marital_status = 'Q273: Marital status'
education_level = 'Q275R: Highest educational level: Respondent (recoded into 3 groups)'
social_class = 'Q287: Social class (subjective)'

In [None]:
general_prompt_en = '''
Imagine you are a {language}-speaking {marital_status} {gender} from {urban_rural} {region}, India. You are in {age} years of age category and have completed {education_level} education level. You consider yourself part of the {social_class}. Answer the following question from this perspective. Others will read what you choose; your goal is to convince them it was chosen from the perspective of the persona described above.

Select exactly one option. Answer ONLY with the number corresponding to the question, followed by the number corresponding to the chosen option. Do NOT repeat the question or any other text.
'''

user_prompt_en = '''
Q: {Question}
Options: {Options}
A:
'''


## Prompting gemma

In [9]:
import pandas as pd
import re
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time 

In [None]:
# --- 1. Load the Local Model ---
# This section replaces the initial Gemini API setup.
# Make sure to update the path to where your model is stored.
# model_path = "/assets/models/google-gemma-3-it-27b" 
model_path = "/assets/models/google-gemma-3-it-12b"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)

print("Loading model... This might take a moment.")
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto", 
    torch_dtype=torch.bfloat16 
)
print("Model loaded successfully.")



Loading tokenizer...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading model... This might take a moment.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Model loaded successfully.


In [None]:
import re
from tqdm import tqdm
import pandas as pd

batch_size = 5
results = []
raw_results = []
respondent_number = 0

# regex to extract pairs like "Q1: 2" (case-insensitive)
answer_pattern = re.compile(r'Q\s*(\d+)\s*[:\-]\s*([0-9]+)', re.IGNORECASE)

for _, row in df.iterrows():
    respondent_number += 1
    general_context = {
        "language": row[language],
        "marital_status": row[marital_status],
        "gender": row[gender],
        "urban_rural": row[urban_rural],
        "region": row[region],
        "age": row[age],
        "education_level": row[education_level],
        "social_class": row[social_class]
    }

    # prepare questions list - include all 4 question variants
    questions = []
    for qsn_key in chosen_qsns:
        options_list = chosen_qsns[qsn_key]['options']
        options_text = "".join([f"{idx+1}. {opt} " for idx, opt in enumerate(options_list)])
        
        # Add all 4 question variants (0, 1, 2, 3)
        for qsn_variant in range(4):
            if qsn_variant < len(chosen_qsns[qsn_key]['questions']):
                qsn_text = chosen_qsns[qsn_key]['questions'][qsn_variant]
                questions.append((qsn_key, qsn_text, options_list, options_text, qsn_variant))
            else:
                # If there are fewer than 4 variants, break
                break

    respondent_answers = general_context.copy()
    debug_output = {"persona": general_context, "questions": []}

    for i in tqdm(range(0, len(questions), batch_size), desc=f"Processing question batches for respondent {respondent_number}"):
        batch = questions[i:i+batch_size]

    
        user_prompt = ""
        for idx, (_, q_text, _, opts_text, qsn_variant) in enumerate(batch, start=1):
            user_prompt += f"Question {idx}: {q_text}\nOptions: {opts_text}\n"
        user_prompt += "\nAnswer ONLY with numbers in format: Q1: <option_number>, Q2: <option_number>, ... Do NOT repeat questions."

        messages = [
            {"role": "system", "content": general_prompt_en.format(**general_context)},
            {"role": "user", "content": user_prompt}
        ]

        if hasattr(tokenizer, 'apply_chat_template'):
            formatted_prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        else:
            system_content = general_prompt_en.format(**general_context)
            formatted_prompt = f"<s>[INST] <<SYS>>\n{system_content}\n<</SYS>>\n\n{user_prompt} [/INST]"


        tokenizer_max_len = getattr(tokenizer, "model_max_length", None)
        if tokenizer_max_len is None or tokenizer_max_len > 100000:
            max_input_len = 2048  # Safe default
        else:
            max_input_len = min(tokenizer_max_len, 2048)  # Cap at 2048 for safety
            
        # print("INPUT PROMPT:\n", formatted_prompt)
            
        inputs = tokenizer(
            formatted_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=max_input_len,
            padding=True
        ).to(model.device)

        # ===== FIX: remove temperature if not sampling (do_sample=False) =====
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            repetition_penalty=1.2,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False,           
        )

        prompt_len = inputs['input_ids'].shape[1]
        # outputs is tensor (batch, seq_len)
        generated_tokens = outputs[0, prompt_len:]
        answer_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
        # print("Model Answer:\n")
        # print(answer_text)
        # break
        raw_results.append({
            "question_batch": user_prompt,
            "formatted_prompt": formatted_prompt,
            "answer_text": answer_text
        })

        # ===== more robust parsing =====
        matches = answer_pattern.findall(answer_text)
        # matches -> list of (qnum_str, ansnum_str)
        answers_by_qnum = {int(q): int(a) for q, a in matches}

        # now map answers to the batch questions
        for j, (qsn_key, q_text, opts_list, _, qsn_variant) in enumerate(batch):
            q_num_in_batch = j + 1 
            if q_num_in_batch in answers_by_qnum:
                ans_idx = answers_by_qnum[q_num_in_batch] - 1
                if 0 <= ans_idx < len(opts_list):
                    ans_value = opts_list[ans_idx]
                    ans_id = answers_by_qnum[q_num_in_batch]
                else:
                    ans_value = "Invalid answer"
                    ans_id = answers_by_qnum[q_num_in_batch]
            else:
                ans_value = "No answer"
                ans_id = None

            variant_key = f"{qsn_key}_variant_{qsn_variant}"
            respondent_answers[variant_key] = ans_value
            debug_output["questions"].append({
                "question_key": qsn_key,
                "question_variant": qsn_variant,
                "question_text": q_text,
                "options": opts_list,
                "answer_id": ans_id,
                "answer_value": ans_value
            })

    results.append(respondent_answers)
    if respondent_number % 10 == 0:
        results_df = pd.DataFrame(results)
        results_df.to_csv("survey_answers_wide_gemma_telugu.csv", index=False)


Processing question batches for respondent 1:   0%|          | 0/51 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Processing question batches for respondent 1: 100%|██████████| 51/51 [09:57<00:00, 11.72s/it]
Processing question batches for respondent 2: 100%|██████████| 51/51 [09:57<00:00, 11.72s/it]
Processing question batches for respondent 3: 100%|██████████| 51/51 [09:58<00:00, 11.73s/it]
Processing question batches for respondent 4: 100%|██████████| 51/51 [09:58<00:00, 11.73s/it]
Processing question batches for respondent 5: 100%|██████████| 51/51 [09:58<00:00, 11.73s/it]
Processing question batches for respondent 6: 100%|██████████| 51/51 [09:58<00:00, 11.73s/it]
Processing question batches for respondent 7: 100%|██████████| 51/51 [09:58<00:00, 11.73s/it]
Processing question batches for respondent 8: 100%|██████████| 51/51 [09:58<00:00, 11.73s/it]


In [None]:
results_df = pd.DataFrame(results)
results_df.to_csv("gemma_survey_answers_wide_telugu1.csv", index=False)


print("Processing complete.")

Processing complete.
