<a href="https://colab.research.google.com/github/Joykw1/NLP_RAG_project/blob/main/Code/Multipassage_retrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers bitsandbytes accelerate outlines datasets

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

checkpoint = "Qwen/Qwen2.5-1.5B-Instruct"


# Configure 8-bit quantization. We use this to save VRAM, as we don't have a lot available.
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True  # Enables 8-bit quantization
)

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    quantization_config=bnb_config,  # Apply BitsAndBytesConfig
    device_map="cuda"   # Assign to GPU
)

In [None]:
from datasets import load_dataset
ru_dataset = load_dataset("xquad", "xquad.ru")
en_dataset = load_dataset("xquad", "xquad.en")
de_dataset = load_dataset("xquad", "xquad.de")

In [None]:
dataset_dict = {
    'ru': ru_dataset,
    'en': en_dataset,
    'de': de_dataset
}

In [None]:
en_dataset

In [None]:
import pandas as pd
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import json
from tqdm.notebook import tqdm
import os

# Define the configurations
CONFIGS = {
    # Original setup: Russian question, all contexts
    "ru_all": {
        "question_lang": "ru",
        "context_langs": ["ru", "en", "de"],
        "answer_lang": "ru",
        "base_dataset": "ru"
    },
    # Setup 1: Russian question, only English and German contexts
    "ru_en_de": {
        "question_lang": "ru",
        "context_langs": ["en", "de"],
        "answer_lang": "ru",
        "base_dataset": "ru"
    },
    # Setup 2: English question, Russian and German contexts
    "en_ru_de": {
        "question_lang": "en",
        "context_langs": ["ru", "de"],
        "answer_lang": "en",
        "base_dataset": "en"
    },
    # Setup 3: German question, all contexts
    "de_all": {
        "question_lang": "de",
        "context_langs": ["ru", "de", "en"],
        "answer_lang": "de",
        "base_dataset": "de"
    },
    # Setup 4: German question, English and Russian contexts
    "de_ru_en": {
        "question_lang": "de",
        "context_langs": ["ru", "en"],
        "answer_lang": "de",
        "base_dataset": "de"
    },
    # Setup 5: English question, all contexts
    "en_all": {
        "question_lang": "en",
        "context_langs": ["de", "ru", "en"],
        "answer_lang": "en",
        "base_dataset": "en"
    }

}

# Choose which configuration to run
# Options: 'ru_all', 'ru_en_de', 'en_ru_de',
config_name = 'de_ru_en'  # Change this to run different configurations
config = CONFIGS[config_name]

print(f"Running with configuration: {config_name}")
print(f"Question language: {config['question_lang']}")
print(f"Context languages: {config['context_langs']}")
print(f"Answer language: {config['answer_lang']}")

# Parameters
batch_size = 10        # How often to save to CSV
max_samples = None     # Set to a number to limit processing (e.g., 10 for testing)


# Create output file paths
output_path = f"multilingual_qa_results_{config_name}.jsonl"
output_csv = f"multilingual_qa_results_{config_name}.csv"

# Check if the output file already exists to determine where to start
existing_ids = set()
if os.path.exists(output_path):
    with open(output_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line.strip())
                existing_ids.add(data.get('id'))
            except:
                continue
    print(f"Found {len(existing_ids)} existing processed examples")

# Function to create a chat template based on the configuration
def create_chat_input(question, contexts, config):
    # Get configuration parameters
    question_lang = config["question_lang"]
    context_langs = config["context_langs"]

    # Map language codes to full names
    lang_names = {"ru": "Russian", "en": "English", "de": "German"}

    # Build the contexts section of the prompt
    contexts_text = ""
    for lang in context_langs:
        lang_name = lang_names[lang]
        context = contexts[lang]
        contexts_text += f"{lang_name} Context: {context}\n\n"

    # Get the name of the question language
    question_lang_name = lang_names[question_lang]

    # Create the prompt with appropriate language instruction
    prompt = f"""
I need help with a question answering task using multiple languages.

{contexts_text}Question ({question_lang_name}): {question}

Your answer must be in {question_lang_name}.
Your answer must contain only words from the contexts.
Your answer must be a single noun phrase.
If the question is 'how many' your answer must be a single numeral.
If the question is 'who' your answer should only contain names or nouns.
"""
    return prompt

# Function to generate an answer
def generate_answer(input_text):
    # Apply the model's chat template
    if tokenizer.chat_template:
        messages = [{"role": "user", "content": input_text}]
        formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)
    else:
        # Simple fallback if no chat template
        formatted_input = f"<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n"

    # Tokenize and generate
    inputs = tokenizer(formatted_input, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )

    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    response = response.strip()

    # Clean up role prefixes that might appear in the output
    # Remove "user" or "system" prefixes that might appear at the beginning
    response = re.sub(r'^(user|system)\s+', '', response)

    # Clean up any remaining instances within the text
    response = re.sub(r'\n(user|system)\s+', '\n', response)

    return response

# Select the base dataset based on the configuration
base_dataset = dataset_dict[config['base_dataset']]
base_split = 'validation'  # We always use the validation split

# Process each example and save results incrementally
results = []
total_examples = len(base_dataset[base_split])
if max_samples is not None:
    total_examples = min(total_examples, max_samples)

print(f"Processing {total_examples} examples...")

for idx, example in tqdm(enumerate(base_dataset[base_split]), total=total_examples):
    if max_samples is not None and idx >= max_samples:
        break

    example_id = example['id']

    # Skip if already processed
    if example_id in existing_ids:
        print(f"Skipping already processed example ID: {example_id}")
        continue

    try:
        # Process the example based on the configuration
        question = example['question']

        # Build contexts dictionary with aligned examples
        contexts = {}
        for lang in config["context_langs"]:
            if lang == config['base_dataset']:
                # Use the context from the base dataset
                contexts[lang] = example['context']
            else:
                # Find the same example in the other dataset by ID
                matching_examples = [ex for ex in dataset_dict[lang][base_split] if ex['id'] == example_id]
                if matching_examples:
                    contexts[lang] = matching_examples[0]['context']
                else:
                    # Fallback if ID matching fails: use the same index
                    contexts[lang] = dataset_dict[lang][base_split][idx]['context']

        # Get the original answer based on the configuration
        if config['base_dataset'] == config['question_lang']:
            # Original answer is from the same dataset as the question
            original_answer = example['answers']
        else:
            # For en_ru_de: we need the answer from the English dataset
            matching_examples = [ex for ex in dataset_dict[config['answer_lang']][base_split] if ex['id'] == example_id]
            if matching_examples:
                original_answer = matching_examples[0]['answers']
            else:
                original_answer = dataset_dict[config['answer_lang']][base_split][idx]['answers']

        # Create chat input with the appropriate configuration
        input_text = create_chat_input(question, contexts, config)

        # Generate model answer
        model_answer = generate_answer(input_text)

        # Create result row
        result = {
            'id': example_id,
            f'question_{config["question_lang"]}': question,
            'original_answer': original_answer,
            'model_answer': model_answer
        }

        # Save to list
        results.append(result)

        # Save incrementally to file
        with open(output_path, 'a', encoding='utf-8') as f:
            f.write(json.dumps(result, ensure_ascii=False) + '\n')

        # Create/update CSV file after each batch
        if len(results) % batch_size == 0 or idx == total_examples - 1:
            pd.DataFrame(results).to_csv(output_csv, index=False)
            print(f"Saved {len(results)} results to CSV")

    except Exception as e:
        print(f"Error processing example {example_id}: {str(e)}")
        # Save what we have so far
        if results:
            pd.DataFrame(results).to_csv(output_csv, index=False)

# Final DataFrame
final_df = pd.DataFrame(results)
print(f"Total processed examples: {len(final_df)}")

# Display the first few rows
if len(final_df) > 0:
    print("\nSample results:")
    display(final_df.head())