In [1]:
import os
import argparse
import logging
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import json
from tqdm import tqdm

In [2]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("rlaif_generation.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def load_pubmedqa_labeled(path):
    """Load the PubMedQA dataset from a JSON file"""
    try:
        with open(path, "r") as f:
            data = json.load(f)
        logger.info(f"Successfully loaded PubMedQA dataset from {path}")
        return data
    except Exception as e:
        logger.error(f"Error loading PubMedQA dataset: {e}")
        return None

def setup_model(model_name, load_in_4bit=True):
    """Set up the tokenizer and model with specified configuration"""
    try:
        # Quantization parameters
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=load_in_4bit,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
        )

        # Tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Model with quantization
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map="auto",
            torch_dtype=torch.float16
        )

        logger.info(f"Successfully loaded model: {model_name}")
        return tokenizer, model
    except Exception as e:
        logger.error(f"Error setting up model: {e}")
        return None, None

In [12]:
def get_cot_template():
    """Return the chain-of-thought template"""
    return """Please analyze the following medical case step by step:

Context: {context}

Question: {question}

Let's think through this step by step:

1. First, identify the key information from the context:
2. Then, analyze the specific question being asked:
3. Next, consider the relevant medical concepts:
4. After that, evaluate the possible answers:
5. Finally, provide a comprehensive conclusion:

Answer:"""

def generate_answer(model, tokenizer, prompt, max_length=1024, max_new_tokens=512,
                   temperature=0.7, num_beams=4):
    """Generate an answer using the model"""
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

        # Checking if the input is too long
        input_length = inputs.input_ids.shape[1]
        if input_length > max_length:
            logger.warning(f"Input too long: {input_length} tokens (max: {max_length})")
            return None

        # Output
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            temperature=temperature,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2,
            length_penalty=1.5,
            num_beams=num_beams,
            early_stopping=True
        )

        # Decode output
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return answer
    except Exception as e:
        logger.error(f"Error generating answer: {e}")
        return None

In [13]:
def process_generated_answer(answer):
    """Process the generated answer to extract the answer and chain of thought"""
    if not answer:
        return "", ""

    if "Answer:" in answer:
        cot_part = answer.split("Answer:")[0].strip()
        answer_part = answer.split("Answer:")[1].strip()
    else:
        cot_part = ""
        answer_part = answer.strip()

    return answer_part, cot_part

In [14]:
def generate_rlaif_data(dataset_path=None, use_huggingface=True,
                        model_name="microsoft/BioGPT-Large-PubMedQA",
                        output_file="rlaif_data.json",
                        save_interval=10, max_samples=None):
    """Generate RLAIF data from PubMedQA dataset"""

    # Dataset
    if use_huggingface:
        try:
            dataset = load_dataset("pubmed_qa", "pqa_labeled")
            train_data = dataset["train"]
            logger.info(f"Loaded dataset from Hugging Face: {len(train_data)} items")
        except Exception as e:
            logger.error(f"Error loading dataset from Hugging Face: {e}")
            return None
    else:
        if not dataset_path:
            logger.error("Dataset path required when not using Hugging Face")
            return None
        train_data = load_pubmedqa_labeled(dataset_path)
        if not train_data:
            return None

    # Model and tokenizer
    tokenizer, model = setup_model(model_name)
    if not tokenizer or not model:
        return None

    #CoT
    cot_template = get_cot_template()

    # RLAIF data
    rlaif_data = []

    data_items = train_data[:max_samples] if max_samples else train_data

    for idx, item in enumerate(tqdm(data_items, desc="Generating RLAIF data")):
        try:
            if use_huggingface:
                question = item["question"]
                context_data = item["context"]
                long_answer = item["long_answer"]
                contexts = context_data["contexts"]
                context = " ".join(contexts)
            else:
                question = item.get("question", "")
                contexts = item.get("contexts", [])
                context = " ".join(contexts)
                long_answer = item.get("long_answer", "")

            # Format the prompt
            prompt = cot_template.format(context=context, question=question)

            # Generate an answer
            answer = generate_answer(model, tokenizer, prompt)
            if not answer:
                continue

            # Process the generated answer
            answer_part, cot_part = process_generated_answer(answer)

            # Create a data point
            data_point = {
                "prompt": prompt,
                "chosen": {
                    "answer": answer_part,
                    "chain_of_thought": cot_part
                },
                "rejected": {
                    "answer": long_answer,
                    "chain_of_thought": ""
                },
                "metadata": {
                    "question": question,
                    "context": context,
                    "model": model_name
                }
            }

            rlaif_data.append(data_point)

            # Save at intervals
            if (idx + 1) % save_interval == 0:
                with open(output_file, "w") as f:
                    json.dump(rlaif_data, f, indent=2)
                logger.info(f"Saved {len(rlaif_data)} items to {output_file}")

        except Exception as e:
            logger.error(f"Error processing item {idx}: {e}")
            continue

    with open(output_file, "w") as f:
        json.dump(rlaif_data, f, indent=2)
    logger.info(f"Finished generating RLAIF data. Total items: {len(rlaif_data)}")

    return rlaif_data

In [15]:
def get_parser():
    """Set up command line argument parser"""
    parser = argparse.ArgumentParser(description='Generate RLAIF data from PubMedQA dataset')
    parser.add_argument('--dataset_path', type=str, default=None,
                        help='Path to local PubMedQA dataset JSON file')
    parser.add_argument('--use_huggingface', action='store_true',
                        help='Whether to use the Hugging Face dataset')
    parser.add_argument('--model_name', type=str, default="microsoft/BioGPT-Large-PubMedQA",
                        help='Name of the model to use')
    parser.add_argument('--output_file', type=str, default="rlaif_data.json",
                        help='Path to output file')
    parser.add_argument('--save_interval', type=int, default=10,
                        help='Number of items to process before saving')
    parser.add_argument('--max_samples', type=int, default=None,
                        help='Maximum number of samples to process')
    return parser

In [16]:
if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    generate_rlaif_data(
        dataset_path=args.dataset_path,
        use_huggingface=args.use_huggingface,
        model_name=args.model_name,
        output_file=args.output_file,
        save_interval=args.save_interval,
        max_samples=args.max_samples
    )

usage: colab_kernel_launcher.py [-h] [--dataset_path DATASET_PATH]
                                [--use_huggingface] [--model_name MODEL_NAME]
                                [--output_file OUTPUT_FILE]
                                [--save_interval SAVE_INTERVAL]
                                [--max_samples MAX_SAMPLES]
colab_kernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-8020627b-57ca-4b36-808f-a3432941434f.json


SystemExit: 2