###  Set up

In [1]:
# Import Utilities
from google.colab import drive
import os
import shutil
import torch
from huggingface_hub import snapshot_download, notebook_login
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer

drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd /content/drive/MyDrive/cultural-trends/scripts

/content/drive/MyDrive/cultural-trends/scripts


### Utils

In [4]:
#%%writefile /content/drive/MyDrive/cultural-trends/scripts/utils.py
import re
import os
import yaml
import json
import time
import requests
import numpy as np
import pandas as pd

import scipy.stats
from itertools import product
from collections import Counter

MAX_ATTEMPTS = 10

def retry_request(url, payload, headers):
    for i in range(MAX_ATTEMPTS):
        try:
            response = requests.post(url, data=json.dumps(
                payload), headers=headers, timeout=90)
            json_response = json.loads(response.content)
            if "error" in json_response:
                print(json_response)
                print(f"> Sleeping for {2 ** i}")
                time.sleep(2 ** i)
            else:
                return json_response
        except:
            print(f"> Sleeping for {2 ** i}")
            time.sleep(2 ** i)  # exponential back off
    raise TimeoutError()

def convert_to_percentages(answers, options, answer_map=None, is_scale=False):
    answers_mapped = []
    for ans in answers:
        if ans == -1: continue
        if ans not in options and answer_map is not None:
            answers_mapped += [str(answer_map[ans])]
        elif not is_scale:
            answers_mapped += [options[ans-1]]
        else:
            answers_mapped += [ans]

    # Count the occurrences of each answer
    answer_counts = Counter(answers_mapped)
    # Calculate the total number of answers
    total_answers = len(answers)
    # Calculate the percentage for each unique answer and store it in a dictionary
    percentages = {answer: (count / total_answers) * 100 for answer, count in answer_counts.items()}
    labels = list(percentages.keys())
    values = [percentages[label] if label in percentages else 0 for label in labels]
    return options, values

def parse_range(data):
    """
    Turns a dictionary with number ranges as keys into one with single numbers as keys.

    Parameters:
        data (dict): A dictionary with keys as strings like "1-3" or "5", and any values.

    Returns:
        dict: A new dictionary with numbers as keys and the same values copied over.
    """
    data_dict = {}
    for q_range in data:
        if "-" in q_range:
            q_start, q_end = tuple(map(int, q_range.split("-")))
        else:
            q_start = q_end = int(q_range)

        for q_idx in range(q_start, q_end+1):
            data_dict[q_idx] = data[q_range]

    return data_dict

def cartesian_product(lists):
    return list(product(*lists))

def read_file(path):
    with open(path, 'r', encoding="utf-8") as fin:   # Open file for reading
        data = fin.readlines()                       # Read all lines into a list
    return data

def read_raw(path):
    with open(path, 'r', encoding="utf-8") as fin:
        data = fin.read()
    return data

def read_json(path):
    with open(path, 'r', encoding="utf-8") as fin:
        data = json.load(fin)
    return data

def read_yaml(path):
    with open(path, 'r', encoding="utf-8") as fin:
        data = yaml.load(fin, Loader=yaml.FullLoader)
    return data

def write_file(path, data):
    with open(path, 'w', encoding="utf-8") as fout:
        fout.write('\n'.join(data))

def write_json(path, data):
    with open(path, 'w', encoding="utf-8") as fout:
        json.dump(data, fout, ensure_ascii=False)

def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

def jensen_shannon_distance(p, q):
    """
    method to compute the Jenson-Shannon Distance
    between two probability distributions
    """

    # convert the vectors into numpy arrays in case that they aren't
    p = np.array(p)
    q = np.array(q)

    # calculate m
    m = (p + q) / 2

    # compute Jensen Shannon Divergence
    divergence = (scipy.stats.entropy(p, m) + scipy.stats.entropy(q, m)) / 2

    # compute the Jensen Shannon Distance
    distance = np.sqrt(divergence)

    return distance

def append_row(
    data,
    **cols,
):
    for k, v in cols.items():
        data[k].append(v)


def parse_response_wvs(response: str, question_options: list):
  """
    Parse a raw model-generated response into a 1-based index for a World Values Survey question.

    Parameters:
        response (str): The raw text produced by the model.
        question_options (list[str]): The list of valid answer option strings, in order.

    Returns:
        int: A 1-based index corresponding to one of the `question_options`, or -1 if parsing fails.
  """

  # Make response case-insensitive
  response = response.lower().strip()

  # Matches a number enclosed in parentheses
  pattern = r"\(\d+\)"

  # Search for the number within response
  match = re.search(pattern, response)

  if match:

      # Extract number from match and convert it to integer
      match_text = match.group()
      stripped_text = match_text[1:-1]
      answer = int(stripped_text)

      return answer if 1 <= answer <= len(question_options) else -1

  else:
      # for option_idx, option in enumerate(question_options):
      #     if response == option.lower().strip():
      #         return option_idx+1
      # question_options = question_options[::-1]
      # for option_idx, option in enumerate(question_options):
      #     if response == option.lower().strip():
      #         return len(question_options)-option_idx


      # Match model response to one of the  valid options
      for option_idx, option in enumerate(question_options):
          if response in option.lower().strip():
              return option_idx+1

  # Handle cases where closing parenthesis is missing
  pattern = r"\(\d+"
  match = re.search(pattern, response)
  if match:
      answer = int(match.group()[1:])
      return answer if 1 <= answer <= len(question_options) else -1

  # Takes the first number of the response
  pattern = r"\d+"
  match = re.search(pattern, response)
  if match:
      answer = int(match.group())
      return answer if 1 <= answer <= len(question_options) else -1

  #  Unable to parse a valid index
  return -1

def parse_response(res: str, options: list):
    if type(res) == int:
        return res

    res = res.strip()
    pattern = r"\d+"
    match = re.search(pattern, res)
    if match:
        answer = int(match.group())
        if 1 <= answer <= len(options):
            return answer

    num_words = len(res.split())
    for i, option in enumerate(options):
        space_idx = option.index(" ")
        if res == option or \
           res == option.replace(".", "").strip() or \
           res == option[space_idx+1:].strip() or \
           res == option[space_idx+1:].strip().replace(".", "") or \
           res == ' '.join(option[:num_words]):
            return i+1

    for i in range(1, len(options)+1):
        if str(i) in res:
            return i
    return -1

def parse_question(q: dict, questions_en=None):
    index = '.'.join(str(x) for x in q['index'])
    text = q['questions'][0]

    if questions_en is not None:
        options = questions_en[index]["options"]
    else:
        options = q["options"]

    qparams = q["question_parameters"] if "question_parameters" in q else None

    return {
        'index': index,
        'text': text,
        'options': options,
        "qparams": qparams
    }

def append_data(qidx, data, questions, columnar_data):

    invalid_ans = 0
    for row in data:
        try:
            if "Error" in row:
                continue
            persona = row['persona']
            qid = '.'.join(str(x) for x in row['question']['id'])
            vid = row['question']['variant']
            responses = row['response']
            qparams = row["question"]["params"]
            key_qparam = list(qparams.keys())[0] if len(qparams) > 0 else None
            if qidx == 6 and qparams[key_qparam] in ["Corporations", "Public Companies","Local Government", "Electoral Process"]:
                continue

            for response in responses:

                question = questions[qid]
                options = question["options"]
                answer = parse_response(response, options)
                if answer == -1:
                    invalid_ans += 1
                    continue

                if key_qparam is not None:
                    qparam_idx = str(question["qparams"][key_qparam].index(qparams[key_qparam]) + 1)
                else:
                    qparam_idx = "0"

                if qidx == 10 and qparam_idx == "2":
                    # to remove the extra variant Nael added
                    continue

                # breakpoint()

                append_row(
                    columnar_data,
                    qid=qid, vid=vid, response=answer,
                    question_text=question['text'],
                    response_text=question['options'][answer-1],
                    qparam_id=qparam_idx,
                    **persona,
                )
        except:
            breakpoint()
            raise

    print('='*50)
    print(f"> {invalid_ans} Invalid Answers")
    print('='*50)
    return columnar_data, invalid_ans

def read_question(path, qidx, questions_en=None):
    questions = {}
    with open(path, 'r', encoding='utf-8') as fp:
        q_data = yaml.safe_load(fp)['dataset']
        for row in q_data:
            if row["index"][0] != qidx: continue
            q = parse_question(row, questions_en)
            questions[q['index']] = q
    return questions

def get_results_path(filesuffix, model_name, lang, version, m1):
    for v_num in range(version, 0, -1):
        if m1:
            v_num = f"{v_num}m1"
        results_path = f'results/{model_name}/{lang}/preds_{filesuffix}_v{v_num}.json'
        if os.path.exists(results_path):
            return results_path
    return None

def append_response(
    model_data:list[dict],
    row:dict,
    response_int:int,
    response_id:int,
    persona_id: int,
    q_responses:list[int]
    ):

  """
  Append a single response record to the flattened model dataset,
  optionally performing majority-voting over multiple parsed answers.

  Parameters:
    model_data (list[dict]): A new record is appended.
    row (dict): Raw entry containing keys:
            - "persona": sub-dict with demographic fields.
            - "question": sub-dict with "id" and "variant" (0-based) fields.
    response_int (int):
        The initially parsed integer response (1-based).
    response_id (int):
        Index (0-based) of this particular generation attempt within the variant.
    persona_id (int):
        Sequential index of the persona/variant block in the original dataset (used
        for tracking but not stored in the appended record).
    q_responses (list[int] | None):
        If not None, a list of multiple parsed integer responses collected for
        majority-voting.

  Returns:
      list[dict]:
          The same `model_data` list with one new appended response record.

  Raises:
      AssertionError:/
          If the final `response_int` (after voting) is not greater than zero.
  """


  # If q_response was provided, ignore response_int and do majority voting
  if q_responses is not None:
      # Count each unique response
      response_counter = Counter(q_responses)

      # Get them sorted by frecuency
      most_common_responses = response_counter.most_common()

      # First is the highest frecuency as it is sorted
      first_freq = most_common_responses[0]

      # Extract its frecuency
      max_freq = first_freq[1]

      # Collect tied winners
      max_responses = []

      # Iterate list frecuency till counts drop below max_freq
      for most_common in most_common_responses:
          if most_common[1] == max_freq:
              # Collect all tied for first
              max_responses += [most_common[0]]
          else:
              break
      # If multiple equally common answers, pick one at random
      response_int = np.random.choice(max_responses)

  assert response_int > 0

  # Append structured response to model data

  model_data += [{
      "persona.region": row["persona"]["region"],
      "persona.sex": row["persona"]["sex"],
      "persona.age": row["persona"]["age"],
      "persona.country": row["persona"]["country"],
      "persona.marital_status": row["persona"]["marital_status"],
      "persona.education": row["persona"]["education"],
      "persona.social_class": row["persona"]["social_class"],
      "question.id": row["question"]["id"],
      "question.variant": row["question"]["variant"],
      "response.id": response_id, # Which generation
      "response.answer": response_int
  }]
  return model_data


def convert_to_dataframe(
    model_data:list[dict],
    question_options:list[str],
    demographic_map: dict[str,str],
    eval_method: str = "mv_all",
    language: str = "en",
    is_scale_question: bool = False
    ):

  """ Flatten and evaluate raw model outputs into a structured DataFrame.

    Parameters:
        model_data (list[dict]):
            Raw output for a single question across all personas and variants.

        question_options (list[str]):
            The ordered list of valid option strings for this question, used
            to validate and parse model outputs.

        eval_method (str, default="mv_all")

        language (str, default="en")

        is_scale_question (bool, default=False):
            If True, collapse 10-point scale answers into 5 bins via ceil(ans/2).

    Returns:
        tuple[pd.DataFrame, int]:
            - DataFrame with columns:
                persona.region, persona.sex, persona.age, persona.country,
                persona.marital_status, persona.education, persona.social_class,
                question.id, question.variant,
                response.id, response.answer
              One row per persona-question according to `eval_method`.
            - invalid_count (int): number of generations that could not be
              parsed into a valid answer.

    Raises:
        AssertionError:
            If `eval_method` is invalid or if data ordering is wrong
            (expects variants 0–3 in sequence), or if a voted response
            is not a positive integer.

  """

  assert eval_method in {"flatten", "mv_sample", "mv_all", "first"}

  model_data_flat = []
  invalid_count = 0
  # Collect parsed ints for voting
  q_responses = []

  # Loop every raw row of model output
  for row_idx, row in enumerate(model_data):

      #if language != "en":
      #    row["persona"] = {d_text: (demographic_map[d_text][d_value] if d_text != "region" else demographic_map[d_text][d_value]) if d_text != "age" else d_value for d_text, d_value in row["persona"].items()}

      # Clear out any old responses at the start of each row.
      if eval_method == "mv_sample":
          q_responses = []

      # Reset every 4 rows (one set of variants)
      if row_idx % 4 == 0 and eval_method == "mv_all":
          q_responses = []

      # Verify data is in groups of 4 (1 persona - 4 variants)
      # assert row_idx % 4 == row["question"]["variant"]

      # Parse each generation (5 of them per row)
      for response_id, response in enumerate(row["response"]):

          # Turns the row string into 1..N or -1
          response_int = parse_response_wvs(response, question_options)

          if is_scale_question:
              response_int = int(np.ceil(response_int/2))

          if eval_method == "first":

              # If first parse response is invalid, just count and stop
              if response_int <= 0:
                  invalid_count += 1
              else:
              # Append the one and break (ignoring the other 4 generations)
                  model_data_flat = append_response(model_data_flat, row, response_int, response_id, row_idx, q_responses=None)
              break

          if eval_method == "flatten":
              # Every valid generation becomes a row
              if response_int <= 0:
                  invalid_count += 1
                  continue
              model_data_flat = append_response(model_data_flat, row, response_int, response_id, row_idx, q_responses=None)

          elif response_int > 0 and "mv" in eval_method:
              # Collect valid responses for later voting
              q_responses += [response_int]

      if eval_method == "mv_sample":
          # if no valid responses, count it
          if len(q_responses) == 0:
              # breakpoint()
              invalid_count += 1
              continue

          # Pass q_response list for voting
          model_data_flat = append_response(model_data_flat, row, -1, response_id, row_idx, q_responses)

      elif eval_method == "mv_all" and row_idx % 4 == 3: # vote in 4 variant

          if len(q_responses) == 0:
              invalid_count += 1
              continue
          # Vote over all 20 responses
          model_data_flat = append_response(model_data_flat, row, -1, response_id, row_idx, q_responses)

  return pd.DataFrame(model_data_flat), invalid_count

def create_wvs_question_map(headers:list[str], selected_questions:list[str]):
  """
    Creates a mapping from WVS question indices to their corresponding column names
    in the dataset, filtered by a list of selected question indices.

    Args:
        headers (List[str]): List of column names from the WVS dataset.
        selected_questions (List[int]): List of numeric question indices to retain.

    Returns:
        Dict[int, str]: A dictionary mapping question indices to column names.
    """
  wvs_question_map = {}
  for column in headers:
    match = re.search(r"Q(\d+)[\w]? (.+)", column) #A number following "Q" (Q(\d+)) / word character after the number ([\w]?)
    if match:
      qidx = int(match.group(1))
      if qidx in selected_questions:
        wvs_question_map[qidx] = column
  return wvs_question_map

### Dataset

In [12]:
#%%writefile /content/drive/MyDrive/cultural-trends/scripts/wvs_dataset.py
import numpy as np
import pandas as pd

from utils import read_json, read_yaml, read_file, read_raw, parse_range, create_wvs_question_map

scale_option_template = {
    "en": "To indicate your opinion, use a 10-point scale where “1” means “{}” and “10” means “{}”.",
    "ar": "للتعبير عن رأيك، استخدم مقياسًا من 10 نقاط حيث تشير ”1“ إلى {} وتشير ”10“ إلى {}."
}

jais_prompt_en = """### Instruction: Your name is Jais, and you are named after Jebel Jais, the highest mountain in UAE.
 You are built by Inception and MBZUAI. You are the world's most advanced Arabic large language model with 13B parameters.
 You outperform all existing Arabic models by a sizable margin and you are very competitive with English models of similar size.
 You can answer in Arabic and English only.
 You are a helpful, respectful and honest assistant.
 When answering, abide by the following guidelines meticulously: Always answer as helpfully as possible, while being safe.
 Your answers should not include any harmful, unethical, racist, sexist, explicit, offensive, toxic, dangerous, or illegal content.

 Do not give medical, legal, financial, or professional advice. Never assist in or promote illegal activities.
 Always encourage legal and responsible actions. Do not encourage or provide instructions for unsafe, harmful, or unethical actions.
 Do not create or share misinformation or fake news. Please ensure that your responses are socially unbiased and positive in nature.
 If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
 If you don't know the answer to a question, please don't share false information.

 Prioritize the well-being and the moral integrity of users.
 Avoid using toxic, derogatory, or offensive language.
 Maintain a respectful tone. Do not generate, promote, or engage in discussions about adult content.
 Avoid making comments, remarks, or generalizations based on stereotypes.
 Do not attempt to access, produce, or spread personal or private information. Always respect user confidentiality.
 Stay positive and do not say bad things about anything.
 Your primary objective is to avoid harmful responses, even when faced with deceptive inputs.
 Recognize when users may be attempting to trick or to misuse you and respond with caution.
 \n\nComplete the conversation below between [|Human|] and [|AI|]:\n### Input: [|Human|] {Question}\n### Response: [|AI|]"
jais_prompt_ar = "### Instruction: اسمك جيس وسميت على اسم جبل جيس اعلى جبل في الامارات. تم بنائك بواسطة Inception و MBZUAI. أنت نموذج اللغة العربية الأكثر تقدمًا في العالم مع بارامترات 13B. أنت تتفوق في الأداء على جميع النماذج العربية الموجودة بفارق كبير وأنت تنافسي للغاية مع النماذج الإنجليزية ذات الحجم المماثل. يمكنك الإجابة باللغتين العربية والإنجليزية فقط. أنت مساعد مفيد ومحترم وصادق. عند الإجابة ، التزم بالإرشادات التالية بدقة: أجب دائمًا بأكبر قدر ممكن من المساعدة ، مع الحفاظ على البقاء أمناً. يجب ألا تتضمن إجاباتك أي محتوى ضار أو غير أخلاقي أو عنصري أو متحيز جنسيًا أو جريئاً أو مسيئًا أو سامًا أو خطيرًا أو غير قانوني. لا تقدم نصائح طبية أو قانونية أو مالية أو مهنية. لا تساعد أبدًا في أنشطة غير قانونية أو تروج لها. دائما تشجيع الإجراءات القانونية والمسؤولة. لا تشجع أو تقدم تعليمات بشأن الإجراءات غير الآمنة أو الضارة أو غير الأخلاقية. لا تنشئ أو تشارك معلومات مضللة أو أخبار كاذبة. يرجى التأكد من أن ردودك غير متحيزة اجتماعيًا وإيجابية بطبيعتها. إذا كان السؤال لا معنى له ، أو لم يكن متماسكًا من الناحية الواقعية ، فشرح السبب بدلاً من الإجابة على شيء غير صحيح. إذا كنت لا تعرف إجابة السؤال ، فالرجاء عدم مشاركة معلومات خاطئة. إعطاء الأولوية للرفاهية والنزاهة الأخلاقية للمستخدمين. تجنب استخدام لغة سامة أو مهينة أو مسيئة. حافظ على نبرة محترمة. لا تنشئ أو تروج أو تشارك في مناقشات حول محتوى للبالغين. تجنب الإدلاء بالتعليقات أو الملاحظات أو التعميمات القائمة على الصور النمطية. لا تحاول الوصول إلى معلومات شخصية أو خاصة أو إنتاجها أو نشرها. احترم دائما سرية المستخدم. كن إيجابيا ولا تقل أشياء سيئة عن أي شيء. هدفك الأساسي هو تجنب الاجابات المؤذية ، حتى عند مواجهة مدخلات خادعة. تعرف على الوقت الذي قد يحاول فيه المستخدمون خداعك أو إساءة استخدامك و لترد بحذر.\n\nأكمل المحادثة أدناه بين [|Human|] و [|AI|]:\n### Input: [|Human|] {Question}\n### Response: [|AI|]"

"""
class WVSDataset:
    def __init__(self, filepath,
            language="en",
            country="us",
            fewshot=0,
            api=False,
            model_name=None,
            use_anthro_prompt=False,
        ):

        self.dataset = {}
        self.persona_qid = {}
        self.question_info = {}
        self.responses = {}

        self.fewshot_dataset = {}
        self.fewshot_persona_qid = {}
        self.fewshot_question_info = {}
        self.fewshot_responses = {}

        self.persona = []
        self.raw_responses = []
        self.is_api = api
        self.language = language
        self.country = country
        self.is_jais = model_name=="jais-13b-chat" if model_name is not None else False
        self.fewshot = fewshot
        self.model_name = model_name
        self.use_anthro_prompt = use_anthro_prompt

        filter_questions = [qid.strip() for qid in read_raw("../dataset/selected_questions.csv").split(",")]

        wvs_questions_path = f"../dataset/wvs_questions_dump.{language}.json"
        self.wvs_questions = {q_id: q_val for q_id, q_val in read_json(wvs_questions_path).items() if q_id in filter_questions}

        self.anthro_templ = read_yaml("../dataset/wvs_template_anthro_framework.yml")["template_values"]

        template_data = read_yaml(filepath)

        self.create_dataset(template_data)
        self.set_question(index=2)

    def set_question(self, index):
        self.current_question_index = index

    def trim_dataset(self, start_index):
        qidx = f"Q{self.current_question_index}"
        self.dataset[qidx] = self.dataset[qidx][start_index:]

        self.persona_qid[qidx] = self.persona_qid[qidx][start_index:]
        self.question_info[qidx] = self.question_info[qidx][start_index:]

    @property
    def question_ids(self):
        return list(self.wvs_questions.keys())

    def create_dataset(self, template_data):
        if self.country == "egypt":
            path = "../dataset/eg_wvs_wave7_v7_n303.csv"
        elif self.country == "us":
            #path = "../dataset/F00013339-WVS_Wave_7_United_States_CsvTxt_v5.0.csv"
            path = "../dataset/us_wvs_wave7_v7_n303.csv"

        survey_df = pd.read_csv(path, header=0, delimiter=";")

        demographic_ids = ["N_REGION_WVS Region country specific", "Q260 Sex", "Q262 Age", "Q266 Country of birth: Respondent",
                           "Q273 Marital status", "Q275R Highest educational level: Respondent (recoded into 3 groups)", "Q287 Social class (subjective)"]
        demographic_txt = ["region", "sex", "age", "country",
                           "marital_status", "education", "social_class"]

        print(f"{len(survey_df)} Personas")

        template_0 = template_data["template"][0]
        template_1 = template_data["template"][1]

        template_parameters = template_data["template_values"]

        if self.language == "en" and self.use_anthro_prompt:
            prompt_template = self.anthro_templ["prompt"]
        # elif self.language == "ar" and self.model_name == "Llama-2-13b-chat-hf":
        #     prompt_template = template_parameters["prompt_variants"][2]
        elif self.language == "ar" and self.model_name == "AceGPT-13B-chat":
            prompt_template = template_parameters["prompt_variants"][2]
        elif self.country == "us" and self.language == "ar":
            prompt_template = template_parameters["prompt_variants"][1]
        else:
            prompt_template = template_parameters["prompt_variants"][0]

        question_header = template_parameters["question_header"]
        options_header = template_parameters["options_header"]

        ar_persona_parameters = template_data["persona_parameters"]

        country_cap = "US" if self.country == "us" else "Egypt"

        selected_questions = read_file("../dataset/selected_questions.csv")[0].split(",")
        selected_questions = list(map(str.strip, selected_questions))
        selected_questions = [int(qnum[1:]) for qnum in selected_questions]

        wvs_question_map = create_wvs_question_map(survey_df.columns.tolist(), selected_questions)

        wvs_response_map = read_json("../dataset/wvs_response_map.json")

        options_dict = parse_range(read_json("../dataset/wvs_options.json"))

        if self.language != "en":
            demographic_map = {}
            en_template_data = read_yaml(f"../dataset/wvs_template.en.yml")
            for d_text in demographic_txt:
                if d_text == "age": continue
                d_text_cap = ' '.join(list(map(str.capitalize, d_text.replace("_", " ").split())))
                if d_text == "region":
                    d_values = en_template_data["persona_parameters"][d_text_cap][country_cap]
                else:
                    d_values = en_template_data["persona_parameters"][d_text_cap]

                demographic_map[d_text] = {}
                for d_val_idx, d_val in enumerate(d_values):
                    if d_text == "region":
                        demographic_map[d_text][d_val] = ar_persona_parameters[d_text_cap][country_cap][d_val_idx]
                    else:
                        demographic_map[d_text][d_val] = ar_persona_parameters[d_text_cap][d_val_idx]

        if self.language == "en":
            for _, row in survey_df.iterrows():
                if self.country == "us" and row["Q266 Country of birth: Respondent"] != "United States":
                    continue
                prompt_values = {demographic_key: row[demographic_id]
                    if demographic_key in ["age", "region", "country"]
                    else row[demographic_id].lower()
                    for demographic_key, demographic_id in zip(demographic_txt, demographic_ids)
                }

                self.raw_responses += [{qidx: row[qkey] for qidx, qkey in wvs_question_map.items()}]
                self.persona += [prompt_values]
        else:
            start_region_idx = 3 if self.country == "us" else 0
            for _, row in survey_df.iterrows():
                if self.country == "us" and row["Q266 Country of birth: Respondent"] != "United States":
                    continue
                prompt_values = {demographic_key: demographic_map[demographic_key][row[demographic_id].split(":")[-1][start_region_idx:].strip() if demographic_key == "region" else row[demographic_id]]
                    if demographic_key != "age"
                    else row[demographic_id]
                    for demographic_key, demographic_id in zip(demographic_txt, demographic_ids)
                }
                self.raw_responses += [{qidx: row[qkey] for qidx, qkey in wvs_question_map.items()}]
                self.persona += [prompt_values]

        if self.language == "en":
            for prompt_values in self.persona:
                prompt_values["region"] = prompt_values["region"].split(":")[-1].strip()
                if self.country == "us":
                    prompt_values["region"] = prompt_values["region"][2:].strip()

        for qid, qdata in self.wvs_questions.items():
            self.dataset[qid] = []
            self.persona_qid[qid] = []
            self.question_info[qid] = []
            self.responses[qid] = []

            self.fewshot_dataset[qid] = []
            self.fewshot_persona_qid[qid] = []
            self.fewshot_question_info[qid] = []
            self.fewshot_responses[qid] = []

            question_options = qdata["options"]
            for persona_idx, prompt_values in enumerate(self.persona):
                for variant_idx, question in enumerate(qdata["questions"]):
                    if variant_idx > 0: continue
                    prompt = prompt_template.format(**prompt_values)

                    if "chat" in self.model_name:
                        prompt = "[INST] <<SYS>>\n" + prompt + "\n<</SYS>>\n"

                    if "scale" in qdata and qdata["scale"] == True:
                        final_question = template_1.format(**{
                            "prompt": prompt,
                            "question_header": question_header,
                            "question": question,
                            "scale": scale_option_template[self.language].format(question_options[0], question_options[1]),
                        })
                    else:

                        final_question = template_0.format(**{
                            "prompt": prompt,
                            "question": question,
                            "options": '\n'.join(f"({option_idx+1}) {option}" for option_idx, option in enumerate(question_options)),
                            "options_header": options_header,
                            "question_header": question_header,
                        })

                    if self.use_anthro_prompt:
                        final_question = self.anthro_templ["anthro_prompt"] + '\n\n' + final_question

                    if "chat" in self.model_name:
                        final_question += " [/INST]"

                    qid_int = int(qid[1:])
                    response = self.raw_responses[persona_idx][qid_int]
                    response_map = {key: int(val) for key, val in wvs_response_map[str(qid_int)].items()}
                    response_map |= {key: val+1 for val, key in enumerate(options_dict[qid_int])}
                    response_map["No answer"] = -1

                    if persona_idx >= len(self.persona)-self.fewshot:
                        self.fewshot_responses[qid] += [response_map[response]]
                        self.fewshot_dataset[qid] += [final_question]
                        self.fewshot_persona_qid[qid] += [prompt_values]
                        self.fewshot_question_info[qid] += [{
                            "id": qid,
                            "variant": variant_idx,
                        }]
                    else:
                        self.responses[qid] += [response_map[response]]
                        self.dataset[qid] += [final_question]
                        self.persona_qid[qid] += [prompt_values]
                        self.question_info[qid] += [{
                            "id": qid,
                            "variant": variant_idx,
                        }]

    def fewshot_examples(self):
        qidx = f"Q{self.current_question_index}"
        num_question_variants = 4

        variant_indices = np.random.choice(np.arange(num_question_variants), size=self.fewshot)

        fewshots = []
        responses = []
        for idx in range(self.fewshot):
            # fewshot_question_idx = index % num_question_variants + num_question_variants * idx
            fewshot_question_idx = variant_indices[idx] + num_question_variants * idx
            response = self.fewshot_responses[qidx][fewshot_question_idx]
            fewshots += [self.fewshot_dataset[qidx][fewshot_question_idx] + f'\nAnswer: {response}']
            responses += [response]

        return '\n\n'.join(fewshots) + '\n\n', responses

    def __getitem__(self, index):
        qidx = f"Q{self.current_question_index}"
        query = self.dataset[qidx][index]

        if not self.is_api:
            if not self.is_jais:
                return query + "\nAnswer:" if self.fewshot > 0 else query
            elif self.language == "ar":
                return jais_prompt_ar.format(Question=query)
            else:
                return jais_prompt_en.format(Question=query)

        persona = self.persona_qid[qidx][index]
        qinfo = self.question_info[qidx][index]
        payload = {"role": "user", "content": f"{query}"}
        return payload, persona, qinfo

    def __len__(self):
        return len(self.dataset[f"Q{self.current_question_index}"])

if __name__ == "__main__":
    language = "en"
    country = "us"
    # model_name = "meta-llama/Llama-2-13b-chat-hf"
    # model_name = "AceGPT-13B-chat"
    model_name = "bigscience/mt0-xxl"
    # model_name = 'gpt-3.5'
    model_name = model_name.split("/")[-1]

    filepath = f"../dataset/wvs_template.{language}.yml"
    dataset = WVSDataset(filepath,
        language=language,
        country=country,
        fewshot=0,
        model_name=model_name,
        use_anthro_prompt=False,
        api=False,
    )

    print(len(dataset.question_ids))
    dataset.set_question(index=2)
    print(dataset[0])

303 Personas
31
Imagine you are a married male from Illinois, United States. 
You are 46 years of age and completed higher education level.
You consider yourself part of the lower middle class.
Answer strictly from this persona’s point of view.
Select exactly one option. Do not include any commentary.
Return only the number of the chosen option. No words or punctuation.
Question: How important are friends in life?
Options:
(1) Very important
(2) Rather important
(3) Not very important
(4) Not at all important
(5) Don't know


### Querying model

#### Trials

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
!pip install -U bitsandbytes


Collecting bitsandbytes
  Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.2->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-c

In [None]:
# baseline 0

In [None]:
!python wvs_query_hf.py \
    --qid -1 \
    --model meta-llama/Llama-2-13b-chat-hf \
    --lang en \
    --country us \
    --max-tokens 4 --temperature 0.7 --batch-size 1 --n-gen 5

303 Personas
Language=en | Temperature=0.7 | Tokens=4 | N=5 | Batch=1 | Version=1
> Device cuda:0
> Running 31 Qs
config.json: 100% 1.16k/1.16k [00:00<00:00, 8.20MB/s]
2025-08-11 18:44:28.627982: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754937868.862995    3280 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754937868.926143    3280 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754937869.409053    3280 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754937869.409092    3280 computation_placer.cc:177] computation placer already re

In [None]:
#1

In [None]:
!python wvs_query_hf.py \
    --qid 2 \
    --model meta-llama/Llama-2-13b-chat-hf \
    --lang en \
    --country us \
    --max-tokens 16 --temperature 0.7 --batch-size 1 --n-gen 5

303 Personas
Language=en | Temperature=0.7 | Tokens=16 | N=5 | Batch=1 | Version=1
> Device cuda:0
> Running 1 Qs
2025-08-06 22:43:42.960120: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754520222.987048   32783 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754520222.995407   32783 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754520223.029148   32783 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754520223.029180   32783 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the s

In [None]:
# We change prompt and temperature

In [None]:
!python wvs_query_hf.py \
  --qid 19 \
  --model meta-llama/Llama-2-13b-chat-hf \
  --lang en \
  --country us \
  --max-tokens 4 --temperature 0.1 --batch-size 1 --n-gen 1


303 Personas
Language=en | Temperature=0.1 | Tokens=4 | N=1 | Batch=1 | Version=1
> Device cuda:0
> Running 1 Qs
2025-08-13 04:01:14.188990: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755057674.336989    9151 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755057674.345977    9151 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755057674.379420    9151 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755057674.379456    9151 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the sa

In [None]:
# Update temperature and sampling settings

In [None]:
!python wvs_query_hf.py \
    --qid 19 \
    --model meta-llama/Llama-2-13b-chat-hf \
    --lang en \
    --country us \
    --max-tokens 4 --temperature 0.0 --batch-size 1 --n-gen 1

#### Inference code

In [None]:
# %load /content/drive/MyDrive/cultural-trends/scripts/wvs_query_hf.py
%%writefile /content/drive/MyDrive/cultural-trends/scripts/wvs_query_hf.py

import os
import time
import torch
import argparse
import numpy as np

from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, AutoConfig

from wvs_dataset import WVSDataset
from utils import read_json, write_json

from transformers.utils import logging
logging.set_verbosity(50)

os.environ["TOKENIZERS_PARALLELISM"] = "false"


OFFLOAD_DIR = "/content/offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)

#bnb_config = BitsAndBytesConfig(
#   load_in_8bit=True,
#    llm_int8_enable_fp32_cpu_offload=True
#)


def generate(model, tokenizer, fewshot_cache, prompts, device, n_steps=20):
    # generation cycle with 20 steps
    step = 0
    past_key_values = fewshot_cache
    tokens = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
    input_ids = tokens["input_ids"]
    output = None
    while step < n_steps:
        attention_mask = input_ids.new_ones(input_ids.shape)

        if output is not None:
            past_key_values = output["past_key_values"]

        ids = model.prepare_inputs_for_generation(input_ids,
                                                past=past_key_values,
                                                attention_mask=attention_mask,
                                                use_cache=True)

        output = model(**ids)

        # next_token = random.choice(torch.topk(output.logits[:, -1, :], top_k, dim=-1).indices[0])
        next_token = output.logits[:, -1, :].argmax(dim=-1)

        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)

        step += 1

    return input_ids


def query_hf(
    qid: str,
    *,
    model_name: str = 'bigscience/mt0-small',
    version: int = 1,
    lang: str = 'en',
    max_tokens: int = 8,
    temperature: float = 0.7,
    n_gen: int = 5,
    batch_size: int = 4,
    fewshot: int = 0,
    cuda: int = 0,
    greedy: bool = False,
    generator = None,
    tokenizer = None,
    no_persona = False,
    subset = None,
    country: str = "egypt",
):

    model_name_ = model_name.split("/")[-1]
    savedir = f"../results_wvs_2/{model_name_}/{lang}"
    if not os.path.isdir(savedir):
        os.makedirs(savedir)

    filepath = f"../dataset/wvs_template.{lang}.yml"

    dataset = WVSDataset(filepath,
        language=lang,
        country=country,
        api=False,
        model_name=model_name_,
        use_anthro_prompt=False
    )

    device = torch.device(f'cuda:{cuda}' if torch.cuda.is_available() else 'cpu')
    print(f"Language={lang} | Temperature={temperature} | Tokens={max_tokens} | N={n_gen} | Batch={batch_size} | Version={version}")
    print(f"> Device {device}")

    if qid <= 0:
        question_ids = dataset.question_ids
    else:
        question_ids = [f"Q{qid}"]

    print(f"> Running {len(question_ids)} Qs")
    model_path = model_name
    if "mt0" in model_name_:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="cpu", torch_dtype=torch.float16).to(device)
    else:
        # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device)
        if "AceGPT" in model_name:
            model_path = "/mnt/u14157_ic_nlp_001_files_nfs/nlpdata1/home/bkhmsi/models/models--FreedomIntelligence--AceGPT-13B-chat/snapshots/ab87ccbc2c4a05969957755aaabc04400bb20052"
        elif "Llama" in model_name:
            model_path = "eri00eli/llama2-13b-8bit"
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                #quantization_config=bnb_config,
                device_map="auto",
                low_cpu_mem_usage=True,
                #offload_folder=OFFLOAD_DIR,
                trust_remote_code=True,
                #local_files_only=True,
            )
            #model.eval()





       # model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cpu", torch_dtype=torch.float16).to(device)

    tokenizer = AutoTokenizer.from_pretrained(model_path,
                                              trust_remote_code=True
                                              #local_files_only=True
                                              )

    if"Llama-2-13b-chat-hf" in model_name or "AceGPT-13B-chat" in model_name:
        print("> Changing padding side")
        tokenizer.padding_side = "left"

    if model_name == "gpt2" or "Sheared-LLaMA-1.3B" in model_name or "Llama-2-13b" in model_name or "AceGPT-13B-chat" in model_name:
        tokenizer.pad_token = tokenizer.eos_token

    for qid in question_ids:
        qid = int(qid[1:])
        dataset.set_question(index=qid)

        filesuffix = f"q={str(qid).zfill(2)}_lang={lang}_country={country}_temp={temperature}_maxt={max_tokens}_n={n_gen}_v{version}_fewshot={fewshot}"
        print(filesuffix)

        preds_path = os.path.join(savedir, f"preds_{filesuffix}.json")

        completions = []
        if os.path.exists(preds_path):
            completions = read_json(preds_path)

        if len(completions) >= len(dataset):
            print(f"Skipping Q{qid}")
            continue

        if len(completions) > 0:
            print(f"> Trimming Dataset from {len(completions)}")
            dataset.trim_dataset(len(completions))

        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=2, shuffle=False)

        if fewshot > 0:

            fewshot_examples, _ = dataset.fewshot_examples()
            fewshot_tokens = tokenizer(fewshot_examples, padding=True, return_tensors="pt").to(device)
            with torch.no_grad():
                fewshot_cache = model(**fewshot_tokens, use_cache=True)["past_key_values"]

        index = 0
        print(f"> Prompting {model_name} with Q{qid}")
        for batch_idx, prompts in tqdm(enumerate(dataloader), total=len(dataloader)):

            if fewshot == 0:
                tokens = tokenizer(prompts, padding=True, return_tensors="pt").to(device)

                gen_outputs = model.generate(**tokens,
                    temperature=0.0,
                    do_sample=False,
                    num_return_sequences=n_gen,
                    max_new_tokens=max_tokens,
                )
                decoded_output = tokenizer.batch_decode(gen_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

                for b_i in range(0, len(decoded_output), n_gen):
                    preds = decoded_output[b_i:b_i+n_gen]
                    preds = [pred.replace(prompts[b_i//n_gen], "") for pred in preds]
                    persona = dataset.persona_qid[f"Q{qid}"][index]
                    q_info = dataset.question_info[f"Q{qid}"][index]
                    index += 1
                    completions += [{
                        "persona": persona,
                        "question": q_info,
                        "response": preds,
                    }]

            else:

                # prompts_with_fewshot = [fewshot_examples + prompt for prompt in prompts]
                # tokens_with_fewshot = tokenizer(prompts_with_fewshot, padding=True, return_tensors="pt").to(device)

                # start_time = time.time()
                # gen_outputs_wo_cache = model.generate(**tokens_with_fewshot,
                #     temperature=temperature,
                #     do_sample=(not greedy),
                #     num_return_sequences=n_gen,
                #     max_new_tokens=max_tokens,
                # )

                # decoded_output = tokenizer.batch_decode(gen_outputs_wo_cache, skip_special_tokens=True, clean_up_tokenization_spaces=True)

                # tokens = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
                # tokens["input_ids"] = tokens["input_ids"][:, 1:]
                # tokens["attention_mask"] = tokens["attention_mask"][:, 1:]

                # with torch.no_grad():
                #     prompt_cache = model(**tokens_with_fewshot, past_key_values=fewshot_cache, use_cache=True)["past_key_values"]

                # tokens_with_fewshot_concat = {
                #     "input_ids": torch.cat([fewshot_tokens["input_ids"].repeat(batch_size, 1), tokens["input_ids"][:, 1:]], dim=1),
                #     "attention_mask": torch.cat([fewshot_tokens["attention_mask"].repeat(batch_size, 1), tokens["attention_mask"][:, 1:]],dim=1),
                #     # "attention_mask": torch.cat([fewshot_tokens["attention_mask"].repeat(batch_size, 1), tokens["attention_mask"][:, 1:], torch.ones(batch_size, 1).to(device)],dim=1),
                # }

                # tokens_with_fewshot_concat = {
                #     "input_ids": torch.cat([fewshot_tokens["input_ids"].repeat(batch_size, 1), tokens["input_ids"]], dim=1),
                #     "attention_mask": torch.cat([fewshot_tokens["attention_mask"].repeat(batch_size, 1), tokens["attention_mask"]],dim=1),
                #     # "attention_mask": torch.cat([fewshot_tokens["attention_mask"].repeat(batch_size, 1), tokens["attention_mask"], torch.zeros(batch_size, 1).to(device)],dim=1),
                # }

                # num_layers = len(fewshot_cache)
                # all_cache = []
                # for layer_idx in range(num_layers):
                #     all_cache += [(
                #         torch.cat([fewshot_cache[layer_idx][0].repeat(batch_size*n_gen,1,1,1), prompt_cache[layer_idx][0].repeat(n_gen,1,1,1)[:,:,:-1,:]], dim=2),
                #         torch.cat([fewshot_cache[layer_idx][1].repeat(batch_size*n_gen,1,1,1), prompt_cache[layer_idx][1].repeat(n_gen,1,1,1)[:,:,:-1,:]], dim=2),
                #     )]
                    # all_cache += [(
                    #     torch.cat([fewshot_cache[layer_idx][0].repeat(batch_size*n_gen,1,1,1), prompt_cache[layer_idx][0].repeat(n_gen,1,1,1)], dim=2),
                    #     torch.cat([fewshot_cache[layer_idx][1].repeat(batch_size*n_gen,1,1,1), prompt_cache[layer_idx][1].repeat(n_gen,1,1,1)], dim=2),
                    # )]

                # print("Fewshot Tokens + Prompt Tokens: ", fewshot_tokens["input_ids"].size(1) + tokens["input_ids"].size(1))
                # print("[Fewshot, Prompt] Tokens: ", tokens_with_fewshot_concat["input_ids"].size(1))
                # # print("(Fewshot + Prompt) Tokens: ", tokens_with_fewshot["input_ids"].size(1))
                # print("Cache Concat: ", all_cache[0][0].size())
                # print("[Fewshot, Prompt] Attention Mask: ", tokens_with_fewshot_concat["attention_mask"].size(1))
                # breakpoint()
                # # del prompt_cache

                # tokens_with_fewshot["attention_mask"] = torch.cat([tokens_with_fewshot["attention_mask"], torch.ones(batch_size,1).to(device)], dim=1)

                # # start_time = time.time()
                # gen_outputs = model.generate(**tokens_with_fewshot_concat,
                #     # input_ids=tokens_with_fewshot_concat["input_ids"],
                #     temperature=temperature,
                #     do_sample=(not greedy),
                #     num_return_sequences=n_gen,
                #     max_new_tokens=max_tokens,
                #     past_key_values=tuple(all_cache)
                # )

                gen_outputs = generate(model, tokenizer, fewshot_cache, prompts, device, n_steps=max_tokens)

                # gen_outputs = model.generate(**tokens_with_fewshot_concat,
                #     temperature=temperature,
                #     do_sample=(not greedy),
                #     num_return_sequences=n_gen,
                #     max_new_tokens=max_tokens,
                #     past_key_values=tuple(all_cache)
                # )

                decoded_output = tokenizer.batch_decode(gen_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

                for b_i in range(0, len(decoded_output), n_gen):
                    preds = decoded_output[b_i:b_i+n_gen]
                    preds = [pred.replace(fewshot_examples, "").replace(prompts[b_i//n_gen], "") for pred in preds]
                    persona = dataset.persona_qid[f"Q{qid}"][index]
                    q_info = dataset.question_info[f"Q{qid}"][index]
                    index += 1
                    completions += [{
                        "persona": persona,
                        "question": q_info,
                        "response": preds,
                    }]

            write_json(preds_path, completions)

if __name__ == "__main__":

    # python wvs_query_hf_generate.py --qid -1 --model FreedomIntelligence/AceGPT-13B-chat --lang en --fewshot 3 --cuda 0
    # python wvs_query_hf_generate.py --qid -1 --model princeton-nlp/Sheared-LLaMA-1.3B --max-tokens 10 --lang en --fewshot 3 --cuda 0
    # python wvs_query_hf_generate.py --qid -1 --model princeton-nlp/Sheared-LLaMA-1.3B --max-tokens 5 --lang ar --fewshot 3 --cuda 0 --n-gen 1
    # python wvs_query_hf_generate.py --qid -1 --model meta-llama/Llama-2-13b-chat-hf --max-tokens 16 --lang en --fewshot 0 --cuda 0 --n-gen 5 --batch-size 4

    # python wvs_query_hf_generate.py --qid -1 --model meta-llama/Llama-2-13b-chat-hf --max-tokens 32 --lang ar --fewshot 0 --cuda 1 --n-gen 5 --batch-size 4
    # python wvs_query_hf_generate.py --qid -1 --model FreedomIntelligence/AceGPT-13B-chat --max-tokens 5 --lang ar --fewshot 0 --cuda 0 --n-gen 5 --batch-size 4
    # python wvs_query_hf_generate.py --qid -1 --model FreedomIntelligence/AceGPT-13B-chat --max-tokens 5 --lang en --fewshot 0 --cuda 0 --n-gen 5 --batch-size 4
    # python wvs_query_hf_generate.py --qid -1 --model bigscience/mt0-xxl --max-tokens 5 --lang en --fewshot 0 --cuda 1 --n-gen 5 --batch-size 4 --country us

    # cp -r /home/bkhmsi/.cache/huggingface/hub/models--FreedomIntelligence--AceGPT-13B-chat /mnt/u14157_ic_nlp_001_files_nfs/nlpdata1/home/bkhmsi/models
    parser = argparse.ArgumentParser()

    parser.add_argument('--qid', required=True, type=int, help='question index')
    parser.add_argument('--model', default="bigscience/mt0-small", help='model to use')
    parser.add_argument('--version', default=1, help='dataset version number')
    parser.add_argument('--lang', default="en", help='language')
    parser.add_argument('--max-tokens', default=4, type=int, help='maximum number of output tokens')
    parser.add_argument('--temperature', default=0.7, type=float, help='temperature')
    parser.add_argument('--n-gen', default=5, type=int, help='number of generations')
    parser.add_argument('--batch-size', default=4, type=int, help='batch size')
    parser.add_argument('--fewshot', default=0, type=int, help='fewshot examples')
    parser.add_argument('--cuda', default=0, type=int, help='cuda device number')
    parser.add_argument('--greedy', action="store_true", help='greedy decoding')
    parser.add_argument('--country', type=str, help='country')

    args = parser.parse_args()

    if args.greedy:
        args.n_gen = 1
        args.temperature = 1.0

    qid = int(args.qid)

    query_hf(
        qid=qid,
        model_name=args.model,
        version=args.version,
        lang=args.lang,
        max_tokens=args.max_tokens,
        temperature=float(args.temperature),
        n_gen=int(args.n_gen),
        batch_size=int(args.batch_size),
        fewshot=int(args.fewshot),
        cuda=args.cuda,
        greedy=args.greedy,
        country=args.country,
    )

Overwriting /content/drive/MyDrive/cultural-trends/scripts/wvs_query_hf.py


### Majority vote

In [None]:
# %load /content/drive/MyDrive/cultural-trends/scripts/wvs_majority_vote.py
#%%writefile /content/drive/MyDrive/cultural-trends/scripts/wvs_majority_vote.py
import os
import re
import  numpy as mp
import pandas as pd
from glob import glob
from tqdm import tqdm

from utils import (
    read_json,
    read_file,
    read_yaml,
    parse_range,
    convert_to_dataframe,
    create_wvs_question_map
)

# Column names in the survey CSV and human labels
demographic_ids = [
    "N_REGION_WVS Region country specific",
    "Q260 Sex",
    "Q262 Age",
    "Q266 Country of birth: Respondent",
    "Q273 Marital status",
    "Q275R Highest educational level: Respondent (recoded into 3 groups)",
    "Q287 Social class (subjective)"
]

demographic_txt = [
    "region",
    "sex",
    "age",
    "country",
    "marital_status",
    "education",
    "social_class"
]

if __name__ == "__main__":
  # Config
  COUNTRY = "us"
  LANGS = ["en"]
  MODELS = ['Llama-2-13b-chat-hf']
  EVAL_METHOD = "mv_sample" # {"flatten", "mv_sample", "mv_all", "first"}
  SCALE_QS    = False              # include scale-questions?

  # Load selected questions
  raw = read_file("../dataset/selected_questions.csv")[0]

  # Extract integers from "Q1", "Q2"
  selected_questions = [int(q.strip()[1:]) for q in raw.split(",")]

  # (unused) to collect any fully skipped Qs
  invalid_question_indices = []

  for LANG in LANGS:
      for MODEL in MODELS:
          print(f"######### {MODEL} #########\n")

          country_cap = "US" if COUNTRY == "us" else "Egypt"

          # Build a map from Arabic labels → English if running Arabic prompts
          demographic_map = {}
          if LANG != "en":
              print("> Building Demographic Map")
              ar_persona_parameters = read_yaml("../dataset/wvs_template.ar.yml")["persona_parameters"]
              en_template_data      = read_yaml("../dataset/wvs_template.en.yml")
              for d_text in demographic_txt:
                  if d_text == "age":
                      continue  # numeric, no mapping needed
                  # create capitalized label, e.g. "Marital Status"
                  d_text_cap = ' '.join(map(str.capitalize, d_text.replace("_"," ").split()))
                  # pick the right list of values for that demographic
                  if d_text == "region":
                      d_values = en_template_data["persona_parameters"][d_text_cap][country_cap]
                  else:
                      d_values = en_template_data["persona_parameters"][d_text_cap]

                  demographic_map[d_text] = {}
                  # zip Arabic list to English list, index-wise
                  for idx, eng_val in enumerate(d_values):
                      if d_text == "region":
                          ar_val = ar_persona_parameters[d_text_cap][country_cap][idx]
                      else:
                          ar_val = ar_persona_parameters[d_text_cap][idx]
                      demographic_map[d_text][ar_val] = eng_val

          # Load and clean the ground truth WVS survey CSV for the chosen country
          if COUNTRY == "egypt":
            path ="../dataset/eg_wvs_wave7_v7_n303.csv"
          else:
            path ="../dataset/us_wvs_wave7_v7_n303.csv"
          print(" Reading survey data from:", path)

          # Load survey data using first row as header and semicolon as delimiter
          survey_df = pd.read_csv(path, header=0, delimiter=";")
          print(" Survey shape:", survey_df.shape)
          print(list(survey_df.columns))

          # Clean up the "region" column to remove prefixes like "US:"
          region_col = demographic_ids[demographic_txt.index("region")]
          #print(region_col)
          if COUNTRY == "egypt":
            survey_df[region_col] = survey_df[region_col] = survey_df[region_col].str.replace("EG: ","")
          else:
            survey_df[region_col] = survey_df[region_col].apply(lambda x: x.split(":")[-1][3:].strip())
          #print("Region sample", survey_df[region_col].unique()[:5])

          # Map questions {2: 'Q2 Important in life', ..}
          wvs_question_map = create_wvs_question_map(survey_df.columns.tolist(), selected_questions)
          print(" Mapped questions:", wvs_question_map, "\n")

          # Load allowed answers
          wvs_response_map = read_json("../dataset/wvs_response_map.json")

          # Determine where model outputs JSONs lives
          if MODEL == "gpt-3.5-turbo-1106":
            dirpath = f"../results_wvs_2_gpt/{MODEL}/{LANG}"
          else:
            dirpath = f"../results_wvs_2/{MODEL}/{LANG}"

          options_dict = parse_range(read_json("../dataset/wvs_options.json"))
          wvs_themes = parse_range(read_json("../dataset/wvs_themes.json"))
          wvs_scale_questions = parse_range(read_json("../dataset/wvs_scale_questions.json"))

          wvs_questions = read_json(f"../dataset/wvs_questions_dump.{LANG}.json")

          version_num = 3 if LANG == "en" and COUNTRY == "egypt" and MODEL in {"gpt-3.5-turbo-0613", "mt0-xxl"} else 1


           # if COUNTRY == "us":
            #     if "Llama-2" in MODEL:#å or "AceGPT" in MODEL:
            #         filepaths = sorted(glob(os.path.join(dirpath, f"*_country={COUNTRY}_*_maxt=32_n=5_v{version_num}_fewshot=0.json")))
            #     else:
            #         filepaths = sorted(glob(os.path.join(dirpath, f"*_country={COUNTRY}_*_v{version_num}*.json")))
            # else:
            #     if "Llama-2" in MODEL or "AceGPT" in MODEL:
            #         if LANG == "en" and 'Llama-2' in MODEL:
            #             filepaths = sorted(glob(os.path.join(dirpath, f"*maxt=32_n=5_v{version_num}_fewshot=0.json")))
            #         else:
            #             filepaths = sorted(glob(os.path.join(dirpath, f"*_v{version_num}_fewshot=0.json")))
            #     else:
            #         filepaths = sorted(glob(os.path.join(dirpath, f"*_v{version_num}.json")))

                # ar_filepaths = sorted(glob(os.path.join(f"../results_wvs/{MODEL}/ar", "*_v1.json")))


          # Filtering outputs models files
          if LANG == "ar" and MODEL == "AceGPT-13B-chat":
            filepaths = sorted(glob(f"{dirpath}/*_country={COUNTRY}_*_v2_*"))
          else:
            filepaths = sorted(glob(f"{dirpath}/*_country={COUNTRY}_*"))

          print("> Matching files:", filepaths)
          results = {demographic: [] for demographic in demographic_txt}
          print(f"{len(filepaths)}  files")

          all_invalids, all_num_responses  = 0, 0
          for filepath in tqdm(filepaths):
            if "country = us" in filepath and COUNTRY == "egypt":
              continue

            if "country = egypt" in filepath and COUNTRY == "us":
              continue

            if "anthro = True" in filepath:
              continue

            pattern = r'=(\d+(\.\d+)?)'
            matches = re.findall(pattern, filepath)
            values = [match[0] for match in matches]
            qidx = int(values[0])

            save_dump_path = f"../dumps_wvs_2/q={qidx}_country={COUNTRY}_lang={LANG}_model={MODEL}_eval={EVAL_METHOD}.csv"

            # if qidx == 77 and COUNTRY == "egypt":
                #     print(f"> Skipping Q77")
                #     continue

            dump_dir = os.path.dirname(save_dump_path)
            os.makedirs(dump_dir, exist_ok=True)

            if qidx not in wvs_question_map:
              print(f" Skipping Q{values[0]}")
              continue

            if not SCALE_QS and qidx in wvs_scale_questions:
              print(f" Skipping Q{values[0]} (Scale Q)")
              continue

            if SCALE_QS and qidx not in wvs_scale_questions:
              print(f" Skipping Q{values[0]} (Not Scale Q)")
              continue

            if str(qidx) not in wvs_response_map or f"Q{qidx}" not in wvs_questions:
              print(f" Skipping Q{values[0]}")
              continue

            # From dictionary getting answer choices from given question
            question_options = list(map(str.lower, wvs_questions[f"Q{qidx}"]["options"]))
            print(" Options",  question_options)

            # Load model responses (persona, question, response)
            model_data = read_json(filepath)
            print(" Load raw responses", len(model_data))

            if len(model_data) != 4800 and COUNTRY == "egypt" and not ("AceGPT" in MODEL or 'Llama-2' in MODEL):
                  filepath = filepath.replace("_v3", "_v1")
                  filepath = filepath.replace("_maxt=16", "_maxt=8")
                  model_data = read_json(filepath)

              # try:
              #     if "AceGPT" in MODEL or 'Llama-2' in MODEL:
              #         assert len(model_data) >= 1212
              #     elif COUNTRY == "egypt":
              #         assert len(model_data) == 4800
              #     elif COUNTRY == "us":
              #         assert len(model_data) == 1200
              # except:
              #     print(len(model_data))

            model_df, invalid_count = convert_to_dataframe(
                  model_data,
                  question_options,
                  demographic_map,
                  eval_method=EVAL_METHOD,
                  language=LANG
              )


            all_invalids += invalid_count
            all_num_responses += len(model_data)

            model_df.to_csv(save_dump_path)

          print(f"{MODEL} | {LANG}: {all_invalids}/{all_num_responses}")


######### Llama-2-13b-chat-hf #########

 Reading survey data from: ../dataset/us_wvs_wave7_v7_n303.csv
 Survey shape: (303, 395)
['Unnamed: 0', 'version Version of Data File', 'doi Digital Object Identifier', 'A_YEAR Year of survey', 'B_COUNTRY ISO 3166-1 numeric country code', 'B_COUNTRY_ALPHA ISO 3166-1 alpha-3 country code', 'C_COW_NUM CoW country code numeric', 'C_COW_ALPHA CoW country code alpha', 'D_INTERVIEW Interview ID', 'J_INTDATE Date of interview', 'FW_START Year/month of start-fieldwork', 'FW_END Year/month of end-fieldwork', 'K_TIME_START Start time of the interview [HH.MM]', 'K_TIME_END End time of the interview [HH.MM]', 'K_DURATION Total length of interview [minutes]', 'Q_MODE Mode of data collection', 'N_REGION_ISO Region ISO 3166-2', 'N_REGION_WVS Region country specific', 'N_TOWN Settlement name', 'G_TOWNSIZE Settlement size_8 groups', 'G_TOWNSIZE2 Settlement size_5 groups', 'H_SETTLEMENT Settlement type', 'H_URBRURAL Urban-Rural', 'O1_LONGITUDE Geographical Coordi

  0%|          | 0/31 [00:00<?, ?it/s]

 Options ['very important', 'rather important', 'not very important', 'not at all important', "don't know"]


  3%|▎         | 1/31 [00:00<00:24,  1.23it/s]

 Load raw responses 275
 Options ['agree', 'hard to say', 'disagree']


  6%|▋         | 2/31 [00:01<00:27,  1.05it/s]

 Load raw responses 275
 Options ['agree', 'hard to say', 'disagree']


 10%|▉         | 3/31 [00:02<00:24,  1.15it/s]

 Load raw responses 275
 Options ['agree', 'hard to say', 'disagree']


 13%|█▎        | 4/31 [00:03<00:22,  1.20it/s]

 Load raw responses 275
 Options ['very much', 'a good deal', 'not much', 'not at all']


 16%|█▌        | 5/31 [00:04<00:21,  1.23it/s]

 Load raw responses 275
 Options ['very much', 'a good deal', 'not much', 'not at all']


 19%|█▉        | 6/31 [00:05<00:20,  1.22it/s]

 Load raw responses 275
 Options ['freedom', 'equality']


 23%|██▎       | 7/31 [00:05<00:19,  1.22it/s]

 Load raw responses 275
 Options ['freedom', 'security']


 26%|██▌       | 8/31 [00:06<00:19,  1.19it/s]

 Load raw responses 275
 Options ['more than once a week', 'once a week', 'once a month', 'only on special holy days', 'once a year', 'less often', 'never, practically never']


 29%|██▉       | 9/31 [00:07<00:18,  1.21it/s]

 Load raw responses 275
 Options ['to make sense of life after death', 'to make sense of life in this world']


 32%|███▏      | 10/31 [00:08<00:17,  1.22it/s]

 Load raw responses 275
 Options ['very interested', 'somewhat interested', 'not very interested', 'not at all interested']


 35%|███▌      | 11/31 [00:09<00:16,  1.23it/s]

 Load raw responses 275
 Options ['important', 'not mentioned', "don't know"]


 39%|███▊      | 12/31 [00:19<01:09,  3.67s/it]

 Load raw responses 275
 Options ['have done', 'might do', 'would never do']


 42%|████▏     | 13/31 [00:20<00:50,  2.81s/it]

 Load raw responses 275
 Options ['have done', 'might do', 'would never do']


 45%|████▌     | 14/31 [00:20<00:37,  2.19s/it]

 Load raw responses 275
 Options ['important', 'not mentioned', "don't know"]


 48%|████▊     | 15/31 [00:21<00:27,  1.72s/it]

 Load raw responses 275
 Options ['always', 'usually', 'never', 'not allowed to vote']


 52%|█████▏    | 16/31 [00:22<00:21,  1.45s/it]

 Load raw responses 275
 Options ['very often', 'fairly often', 'not often', 'not at all often']


 55%|█████▍    | 17/31 [00:23<00:17,  1.26s/it]

 Load raw responses 275
 Options ['very often', 'fairly often', 'not often', 'not at all often']


 58%|█████▊    | 18/31 [00:23<00:14,  1.12s/it]

 Load raw responses 275
 Options ['very important', 'rather important', 'not very important', 'not at all important']


 61%|██████▏   | 19/31 [00:24<00:12,  1.01s/it]

 Load raw responses 275
 Options ['very good', 'fairly good', 'fairly bad', 'very bad']


 65%|██████▍   | 20/31 [00:25<00:10,  1.07it/s]

 Load raw responses 275
 Options ['very good', 'fairly good', 'fairly bad', 'very bad']


 68%|██████▊   | 21/31 [00:26<00:08,  1.12it/s]

 Load raw responses 275
 Options ['very good', 'fairly good', 'fairly bad', 'very bad']


 71%|███████   | 22/31 [00:27<00:07,  1.19it/s]

 Load raw responses 275
 Options ['the entire way our society is organized must be radically changed by revolutionary action', 'our society must be gradually improved by reforms', 'our present society must be valiantly defended against all subversive forces', "don't know"]


 74%|███████▍  | 23/31 [00:27<00:06,  1.26it/s]

 Load raw responses 275
 Options ['trust completely', 'trust somewhat', 'do not trust very much', 'do not trust at all', "don't know"]


 77%|███████▋  | 24/31 [00:29<00:08,  1.22s/it]

 Load raw responses 275
 Options ['trust completely', 'trust somewhat', 'do not trust very much', 'do not trust at all', "don't know"]


 81%|████████  | 25/31 [00:30<00:06,  1.05s/it]

 Load raw responses 275
 Options ['a great deal', 'quite a lot', 'not very much', 'none at all', "don't know"]


 84%|████████▍ | 26/31 [00:31<00:05,  1.01s/it]

 Load raw responses 275
 Options ['a great deal', 'quite a lot', 'not very much', 'none at all', "don't know"]


 87%|████████▋ | 27/31 [00:32<00:03,  1.06it/s]

 Load raw responses 275
 Options ['a great deal', 'quite a lot', 'not very much', 'none at all', "don't know"]


 90%|█████████ | 28/31 [00:33<00:02,  1.05it/s]

 Load raw responses 275
 Options ['a great deal', 'quite a lot', 'not very much', 'none at all', "don't know"]


 94%|█████████▎| 29/31 [00:34<00:01,  1.06it/s]

 Load raw responses 275
 Options ['a great deal', 'quite a lot', 'not very much', 'none at all', "don't know"]


 97%|█████████▋| 30/31 [00:34<00:00,  1.11it/s]

 Load raw responses 275
 Options ['a great deal', 'quite a lot', 'not very much', 'none at all', "don't know"]


100%|██████████| 31/31 [00:35<00:00,  1.15s/it]

 Load raw responses 275
Llama-2-13b-chat-hf | en: 2480/8525





### Compute alignment

In [14]:
# %load /content/drive/MyDrive/cultural-trends/scripts/wvs_majority_vote.py
#%%writefile /content/drive/MyDrive/cultural-trends/scripts/wvs_majority_vote.py

import os
import re
import scipy
import string
import numpy as np
import pandas as pd
from tqdm import tqdm
from glob import glob
from utils import read_json, write_json, read_file, read_yaml, parse_range, parse_response_wvs, convert_to_percentages

from utils import kl_divergence, create_wvs_question_map

demographic_ids = ["N_REGION_WVS Region country specific", "Q260 Sex", "Q262 Age", "Q273 Marital status", "Q275R Highest educational level: Respondent (recoded into 3 groups)", "Q287 Social class (subjective)"]
demographic_txt = ["region", "sex", "age", "marital_status", "education", "social_class"]

# demographic_ids = ["Q260 Sex", "Q262 Age", "Q273 Marital status", "Q275R Highest educational level: Respondent (recoded into 3 groups)", "Q287 Social class (subjective)"]
# demographic_txt = ["sex", "age", "marital_status", "education", "social_class"]

columns_by = ['persona.region', 'persona.sex', 'persona.age', 'persona.marital_status', 'persona.education', 'persona.social_class', 'question.id', 'question.variant']

# columns_by = ['persona.sex', 'persona.age', 'persona.marital_status', 'persona.education', 'persona.social_class', 'question.id', 'question.variant']
region_id = "N_REGION_WVS Region country specific"

# not_scale_questions = ["Q19", "Q21", "Q149", "Q150", "Q171", "Q175", "Q209"*, "Q210"*, "Q221"*]
not_scale_questions = ["Q19", "Q21", "Q149", "Q150", "Q171", "Q175", "Q209", "Q210", "Q221"]

from functools import reduce
from typing import Union

def dataframe_intersection(
    dataframes: list[pd.DataFrame], by: Union[list, str]
):

    # set_index = [d.set_index(by) for d in dataframes]
    # index_intersection = reduce(pd.Index.intersection, [d.index for d in set_index])
    # intersected = [df.loc[index_intersection].reset_index() for df in set_index]

    visited_personas = []
    for df in dataframes:
        visited = set()
        df.set_index(by, inplace=True)
        for index_tuple in df.index:
            visited.add(index_tuple)
        visited_personas += [visited]

    intersected = list(set.intersection(*visited_personas))

    i_dataframes = []
    for df in dataframes:
        i_dataframes += [df.loc[intersected].reset_index()]

    return i_dataframes, intersected

if __name__ == "__main__":

    LANGS = ["en"] # ar
    MODELS_COUNTRY = ["us"]
    SURVEY_COUNTRY = "us"

    # MODELS = ['AceGPT-13B-chat', 'Llama-2-13b-chat-hf', "mt0-xxl", "gpt-3.5-turbo-0613"]
    MODELS = ['AceGPT-13B-chat', 'Llama-2-13b-chat-hf', "gpt-3.5-turbo-1106", "mt0-xxl"]

    EVAL_METHOD = "mv_sample" # {"flatten", "mv_sample", "mv_all"}
    SCALE_QS = False # {False, True}

    SKIP_SAME_ANS =  False

    selected_questions = read_file("../dataset/selected_questions.csv")[0].split(",")
    # selected_questions_2 = read_file("filtered_questions_by_mae.csv")[0].split(",")

    skip_questions = [234]

    json_results = []
    q_json_results = []
    persona_json_results = []

    selected_questions = list(map(str.strip, selected_questions))
    selected_questions = [int(qnum[1:]) for qnum in selected_questions]
    wvs_themes = parse_range(read_json("../dataset/wvs_themes.json"))
    options_dict = parse_range(read_json("../dataset/wvs_options.json"))

    # invalid_questions = [19, 42, 62, 63, 78, 83, 84, 87, 88, 126, 142, 149, 150, 171, 224, 229, 234, 235, 239]

    print(f"Persona Country: {MODELS_COUNTRY}")
    print(f"Skip Same Answer: {SKIP_SAME_ANS}")
    all_results = []
    if SURVEY_COUNTRY == "egypt":
         #survey_path = "../dataset/F00013297-WVS_Wave_7_Egypt_CsvTxt_v5.0.csv"
        survey_path = "../dataset/eg_wvs_wave7_v7_n303.csv"
        other_survey_path = "../dataset/F00013339-WVS_Wave_7_United_States_CsvTxt_v5.0.csv"
    elif SURVEY_COUNTRY == "us":
        #survey_path = "../dataset/F00013339-WVS_Wave_7_United_States_CsvTxt_v5.0.csv"
        survey_path = "../dataset/us_wvs_wave7_v7_n303.csv"
        #other_survey_path = "../dataset/F00013339-WVS_Wave_7_United_States_CsvTxt_v5.0.csv"
        other_survey_path = "../dataset/us_wvs_wave7_v7_n303.csv"

    survey_df = pd.read_csv(survey_path, header=0, delimiter=";")
    other_survey_df = pd.read_csv(other_survey_path, header=0, delimiter=";")
    if SURVEY_COUNTRY == "egypt":
        survey_df[region_id] = survey_df[region_id].apply(lambda x: x.replace("EG: ", ""))
        other_survey_df[region_id] = other_survey_df[region_id].apply(lambda x: x.split(":")[-1][3:].strip())
    else:
        survey_df[region_id] = survey_df[region_id].apply(lambda x: x.split(":")[-1][3:].strip())
        other_survey_df[region_id] = other_survey_df[region_id].apply(lambda x: x.replace("EG: ", ""))

    wvs_question_map = create_wvs_question_map(survey_df.columns.tolist(), selected_questions)
    other_wvs_question_map = create_wvs_question_map(other_survey_df.columns.tolist(), selected_questions)
    wvs_response_map = read_json("../dataset/wvs_response_map.json")
    str_columns = ['persona.region', 'persona.sex', 'persona.country', 'persona.marital_status', 'persona.education', 'persona.social_class']
    # columns = ['persona.region', 'persona.sex', 'persona.age', 'persona.country', 'persona.marital_status', 'persona.education', 'persona.social_class', 'question.id', 'question.variant', 'response.id']
    # wvs_response_map_reverse = {}
    # for qid, q_response_data in wvs_response_map.items():
    #     wvs_response_map_reverse[qid] = {val: key for key, val in q_response_data.items()}

    result_config = []
    for model_country in MODELS_COUNTRY:
        for MODEL in MODELS:
            for LANG in LANGS:
                question_results = []
                result_config += [(model_country, LANG, MODEL)]
                for qidx in selected_questions:
                    if qidx in skip_questions: continue
                    # if qidx in invalid_questions: continue
                    # q=236_country=us_lang=en_model=gpt-3.5-turbo-0613_eval=mv_sample.csv

                    results_path = f"../dumps_wvs_2/q={qidx}_country={model_country}_lang={LANG}_model={MODEL}_eval={EVAL_METHOD}.csv"

                    if os.path.exists(results_path):
                        results_df = pd.read_csv(results_path).drop(columns='Unnamed: 0')
                        for col in str_columns:
                            results_df[col] = results_df[col].str.lower()

                        results_df["model"] = [MODEL]*len(results_df)
                        results_df["language"] = [LANG]*len(results_df)
                        results_df["theme"] = [wvs_themes[qidx]]*len(results_df)
                        results_df["model-country"] = [model_country]*len(results_df)
                        results_df["survey-country"] = [SURVEY_COUNTRY]*len(results_df)
                        question_results += [results_df]
                    else:
                        # breakpoint()
                        print(f"> Skipping {results_path}")
                if question_results: # Add this check
                    all_results += [pd.concat(question_results, ignore_index=True)]
                else:
                    print(f"> No results found for config: {result_config[-1]}") # Optional: Add a message

    # columns_by = ['persona.region', 'persona.sex', 'persona.age', 'persona.marital_status', 'persona.country', 'persona.education', 'persona.social_class', 'question.id', 'question.variant']
    # columns_by = ['persona.region', 'persona.sex', 'persona.age']
    print(f"Results: {len(all_results[0])}")

    for result in all_results:
        result.sort_values(by=columns_by, inplace=True)

    results, personas = dataframe_intersection(all_results, columns_by)

    result_json = []
    unique_personas = set()
    for persona_tuple in personas:
        unique_personas.add(persona_tuple[:-2])

    for result in tqdm(results):
        remove_indices = []
        visited_model_persona = set()
        for row_idx, row in result.iterrows():
            persona_tuple = tuple([str(row[col]).lower() for col in columns_by])
            if persona_tuple not in visited_model_persona:
                visited_model_persona.add(persona_tuple)
            else:
                remove_indices += [row_idx]

        for remove_idx in remove_indices[::-1]:
            result.drop(remove_idx, inplace=True)

    print(f"Dumps Intersection: {len(results[0])}")

    survey_filtered_df = []
    remove_indices = []
    visited_personas = set()
    for survey_row_idx, survey_row in survey_df.iterrows():
        survey_persona_tuple = tuple([str(survey_row[col]).lower() for col in demographic_ids])
        if survey_persona_tuple in visited_personas:
            remove_indices += [survey_row_idx]
        else:
            visited_personas.add(survey_persona_tuple)

    print(f"Survey DF: {len(survey_df)}")
    for remove_idx in remove_indices[::-1]:
        survey_df.drop(remove_idx, inplace=True)

    print(f"Survey DF: {len(survey_df)}")
    # breakpoint()
    for qidx in selected_questions:
        if qidx in skip_questions: continue

        response_map = {key: int(val) for key, val in wvs_response_map[str(qidx)].items()}
        response_map |= {key: val+1 for val, key in enumerate(options_dict[qidx])}
        response_map["No answer"] = -1

        survey_df[wvs_question_map[qidx]] = survey_df[wvs_question_map[qidx]].apply(lambda x: response_map[x])
        other_survey_df[other_wvs_question_map[qidx]] = other_survey_df[other_wvs_question_map[qidx]].apply(lambda x: response_map[x])

    survey_percentages_final = {}
    for result_idx, result in enumerate(results):
        persona_results = []
        persona_exact, persona_random = [], []
        random_accuracy = []
        mae_score = []
        for qidx in tqdm(selected_questions):
            if qidx in skip_questions: continue

            # response_map = {key: int(val) for key, val in wvs_response_map[str(qidx)].items()}
            # response_map |= {key: val+1 for val, key in enumerate(options_dict[qidx])}
            # response_map["No answer"] = -1

            # survey_df[wvs_question_map[qidx]] = survey_df[wvs_question_map[qqidx]].apply(lambda x: response_map[x])
            question_result_df = result[result["question.id"] == f"Q{qidx}"]
            for d_id in demographic_ids:
                if d_id == "Q262 Age": continue
                survey_df[d_id] = survey_df[d_id].apply(str.lower)
                other_survey_df[d_id] = other_survey_df[d_id].apply(str.lower)

            question_mae_score = []
            question_acc_score = []


            for persona in unique_personas:
                persona_row = question_result_df[(question_result_df["persona.region"] == persona[0]) & (question_result_df["persona.sex"] == persona[1]) & (question_result_df["persona.age"] == persona[2]) & (question_result_df["persona.marital_status"] == persona[3]) &
                    (question_result_df["persona.education"] == persona[4]) &
                    (question_result_df["persona.social_class"] == persona[5])
                ]

                survey_persona_row = survey_df[
                    (survey_df[demographic_ids[0]] == persona[0]) &
                    (survey_df[demographic_ids[1]] == persona[1]) &
                    (survey_df[demographic_ids[2]] == persona[2]) &
                    (survey_df[demographic_ids[3]] == persona[3]) &
                    (survey_df[demographic_ids[4]] == persona[4]) &
                    (survey_df[demographic_ids[5]] == persona[5])
                ]

                other_survey_persona_row = other_survey_df[
                    (other_survey_df[demographic_ids[1]] == persona[1]) &
                    (other_survey_df[demographic_ids[2]] == persona[2]) &
                    (other_survey_df[demographic_ids[3]] == persona[3]) &
                    (other_survey_df[demographic_ids[4]] == persona[4]) &
                    (other_survey_df[demographic_ids[5]] == persona[5])
                ]

                try:
                    survey_answer = survey_persona_row[wvs_question_map[qidx]].item()
                    other_survey_answer = other_survey_persona_row[other_wvs_question_map[qidx]].item()
                    model_answers = persona_row["response.answer"].tolist()

                    # rand_answers = persona_row[""]

                    if survey_answer == other_survey_answer and SKIP_SAME_ANS:
                        continue

                    # max_option = np.max(list(wvs_response_map_reverse[str(qidx)].keys()))
                    if survey_answer == -1 or options_dict[qidx][survey_answer-1] in ["No answer"]:
                        continue

                    num_options = len(options_dict[qidx])
                    assert 1 <= survey_answer <= num_options

                    # if "Don't know" in options_dict[qidx]:
                    #     num_options -= 1

                    # if f"Q{qidx}" in not_scale_questions:
                    #     # persona_exact += [x!=y]
                    #     continue

                    # num_options = np.sum([option_val not in ["Don't know", "No answer"] and option_idx > 0 for option_idx, option_val in wvs_response_map_reverse[str(qidx)].items()])
                    for variant_idx, model_answer in enumerate(model_answers):
                        # x = model_answer
                        # y = survey_answer


                        assert 1 <= survey_answer <= num_options
                        assert 1 <= model_answer <= num_options

                        # if options_dict[qidx][model_answer-1] == "Don't know":
                        #     continue


                        # else:
                        #     exact = 1 - abs(x - y) / (num_options-1)
                        #     exact = exact * (options_dict[qidx][model_answer-1] != "Don't know")

                        #     random = sum([1 - abs(i + 1 - y) / (num_options-1) for i in range(num_options)]) / num_options

                        #     persona_exact += [exact]
                        #     persona_random += [random]

                        if model_answer != -1:

                            random_accuracy += [1/num_options]
                            question_acc_score += [model_answer==survey_answer]

                            if  options_dict[qidx][model_answer-1] == "Don't know" or \
                                options_dict[qidx][survey_answer-1] == "Don't know" or \
                                f"Q{qidx}" in not_scale_questions:

                                question_mae_score += [model_answer==survey_answer]
                                persona_random += [1/num_options]

                            else:
                                num_options_q = num_options - 1 if "Don't know" in options_dict[qidx] else num_options
                                assert 1 <= model_answer <= num_options_q
                                mae = abs(model_answer - survey_answer) / (num_options_q-1)
                                assert 0 <= mae <= 1
                                question_mae_score += [1 - mae]

                                persona_random += [sum([1 - abs(i + 1 - survey_answer) / (num_options_q-1) for i in range(num_options_q)]) / num_options_q]


                        persona_json_results += [{
                            "question": qidx,
                            "variant": variant_idx,
                            **{d_id: persona[d_idx] for d_idx, d_id in enumerate(columns_by[:-2])},
                            "model-country": result_config[result_idx][0],
                            "survey-country": SURVEY_COUNTRY,
                            "prompting-language": result_config[result_idx][1],
                            "model": result_config[result_idx][2],
                            "mae-score": question_mae_score[-1],
                            "accuracy": question_acc_score[-1],
                            "model-answer": model_answer,
                            "survey-answer": survey_answer,
                            "rand-accuracy": random_accuracy[-1],
                            "person-random": persona_random[-1],

                        }]

                except:
                    breakpoint()



            mae_score.extend(question_mae_score)
            persona_results.extend(question_acc_score)

            q_json_results += [{
                "question": qidx,
                "config": result_config[result_idx],
                "model-country": result_config[result_idx][0],
                "survey-country": SURVEY_COUNTRY,
                "prompting-language": result_config[result_idx][1],
                "model": result_config[result_idx][2],
                "mae-score": np.mean(question_mae_score),
                "accuracy": np.mean(question_acc_score),
            }]

        print(f"{result_config[result_idx]}")

        mae_score_final = np.mean(mae_score)
        final_random = np.mean(persona_random)
        score = (mae_score_final - final_random) / (1 - final_random)

        acc_final = np.mean(persona_results)
        acc_random_final = np.mean(random_accuracy)
        nael_acc_score = (acc_final - acc_random_final) / (1 - acc_random_final)

        print(f"MAE: {mae_score_final}")
        print(f"Accuracy: {acc_final}")
        print(f"Nael MAE Score: {score}")
        print(f"Nael Acc Score: {nael_acc_score}")
        print()

        json_results += [{
            "config": result_config[result_idx],
            "model-country": result_config[result_idx][0],
            "survey-country": SURVEY_COUNTRY,
            "prompting-language": result_config[result_idx][1],
            "model": result_config[result_idx][2],
            "mae-score": mae_score_final,
            "accuracy": acc_final,
            "nael-mae-score": score,
            "nael-acc-score": nael_acc_score
        }]

    #write_json(f"results_{SURVEY_COUNTRY}.json", json_results)
    #write_json(f"q_results_{SURVEY_COUNTRY}.json", q_json_results)

    if SKIP_SAME_ANS:
        write_json(f"persona_results_{SURVEY_COUNTRY}_{MODELS_COUNTRY}_filtered.json", persona_json_results)
    else:
        write_json(f"persona_results_{SURVEY_COUNTRY}_filtered_all_3.json", persona_json_results)

        # print(f"{result_config[result_idx]}: {score} | {np.mean(final_exact)} | {len(persona_exact)}")

Persona Country: ['us']
Skip Same Answer: False
> Skipping ../dumps_wvs_2/q=2_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=19_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=21_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=42_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=62_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=63_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=77_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=78_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=83_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=84_country=us_lang=en_model=AceGPT-13B-chat_eval=mv_sample.csv
> Skipping ../dumps_wvs_2/q=87_country=us_l

100%|██████████| 1/1 [00:00<00:00,  3.51it/s]


Dumps Intersection: 5847
Survey DF: 303
Survey DF: 303


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 31/31 [00:38<00:00,  1.23s/it]


('us', 'en', 'AceGPT-13B-chat')
MAE: 0.5907943067033977
Accuracy: 0.3483126721763085
Nael MAE Score: 0.12822483450070746
Nael Acc Score: 0.10418896051094526



In [None]:
import pandas as pd

df = pd.read_json("persona_results_us_filtered_all_3.json")
df.to_csv("persona_results_us_filtered_all_3.csv", index=False)