In [4]:
import os
import pickle
import pandas as pd
import numpy as np
import random
import time
import csv
import re
from tqdm import tqdm
import openai
from datasets import load_dataset

In [5]:
OPENAI_API_KEY = "<your_api_key>" 
checkpoint_file = "gsm8k_checkpoint.pkl"

# Load checkpoint if available
if os.path.exists(checkpoint_file):
    with open(checkpoint_file, "rb") as f:
        checkpoint = pickle.load(f)
        start_index = checkpoint.get("current_index", 0)
        perturbed_data = checkpoint.get("perturbed_data", [])
        print(f"Resuming from checkpoint at index {start_index}")
else:
    start_index = 0
    perturbed_data = []

Resuming from checkpoint at index 7480


In [43]:
# OpenAI client setup
from openai import OpenAI

def setup_openai_client():
    return OpenAI(api_key=OPENAI_API_KEY)


In [44]:
# Load dataset
def load_gsm8k_data():
    print("Loading GSM8K dataset from Hugging Face...")
    dataset = load_dataset("openai/gsm8k", "main")
    train_data = dataset["train"]
    print(f"Loaded {len(train_data)} training examples")
    return train_data

In [45]:
def create_perturbation(client, input_sentence):
    prompt = f"""You are given an input answer: \"{input_sentence}\"
Your task is to generate a single, semantically equivalent but perturbed version of this answer that satisfies the following objectives:

Semantic Meaning — The overall intent, implications, and propositional content must remain unchanged.
Reasoning Consistency — If the answer encodes or implies reasoning steps, logical structure, or causal/temporal sequences, those must be preserved exactly, even if phrased differently.
Domain Fusion: Identify the answer's core syntactic elements (subject, verb, object, and key modifiers) and blend them with imagery, terminology, or metaphors from a semantically unrelated domain to enrich its meaning or style.
Lexical Perturbation: Introduce minor visual or lexical noise by replacing approximately 30% of the answer's tokens with alternatives (e.g., lookalike Unicode characters, obscure synonyms, or inflected variants).
Temporal Shift: Recast the answer to reflect a different time period by adjusting temporal references, using appropriate verb tenses, and selecting vocabulary that capture the stylistic nuances of the chosen era.

Steps to follow:
Step 1: Parse the input to extract its syntactic components.
Step 2: Select a distinct, unrelated target domain and identify 2-3 relevant concepts or expressions.
Step 3: Apply lexical perturbation to roughly 30% of the tokens.
Step 4: Shift the temporal style to suit the chosen period.
Step 5: Integrate all modifications into a fluent final version that preserves the original meaning of the input answer.

Output: Return only the final, fully perturbed answer."""

    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.8,
            #max_tokens=300
        )
        return response.choices[0].message.content.strip()

    except Exception as e:
        print(f"Error generating perturbation: {e}")
        time.sleep(2)
        return None

In [46]:
# Retry decorator
def retry_with_exponential_backoff(
    func,
    initial_delay: float = 1,
    exponential_base: float = 2,
    jitter: bool = True,
    max_retries: int = 10,
    errors: tuple = (Exception,),
):
    def wrapper(*args, **kwargs):
        num_retries = 0
        delay = initial_delay

        while True:
            try:
                return func(*args, **kwargs)
            except errors as e:
                num_retries += 1
                if num_retries > max_retries:
                    raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
                delay *= exponential_base * (1 + jitter * random.random())
                print(f"API error occurred: {str(e)}. Retrying in {delay:.2f} seconds...")
                time.sleep(delay)
    return wrapper

In [47]:
def save_to_csv(results, output_path):
    df = pd.DataFrame(results)
    df.to_csv(output_path, index=False)
    print(f"Results saved to {output_path}")
    return df

In [48]:
# Main Pipeline
client = setup_openai_client()
data = load_gsm8k_data()
dataset = [{"question": item["question"], "answer": item["answer"]} for item in data]

batch_size = 10  # You can adjust the batch size here

for i in range(start_index, len(dataset), batch_size):
    batch = dataset[i:i + batch_size]

    try:
        perturbed_batch = []
        for example in batch:
            perturbed_answer = create_perturbation(client, example['answer'])
            perturbed_batch.append({
                "original_question": example['question'],
                "original_answer": example['answer'],
                "perturbed_answer": perturbed_answer if perturbed_answer else ""
            })

        perturbed_data.extend(perturbed_batch)

        # Save checkpoint after each batch
        checkpoint = {
            "current_index": i + batch_size,
            "perturbed_data": perturbed_data
        }
        with open(checkpoint_file, "wb") as f:
            pickle.dump(checkpoint, f)

        print(f"Checkpoint saved at index {i + batch_size}")

        # Progress update after every 10 batches
        if ((i - start_index) // batch_size + 1) % 10 == 0:
            print(f"Progress update: Completed {(i - start_index) // batch_size + 1} batches")

    except Exception as e:
        print(f"Error at batch starting index {i}: {e}")
        continue



Loading GSM8K dataset from Hugging Face...
Loaded 7473 training examples
Checkpoint saved at index 580
Checkpoint saved at index 590
Checkpoint saved at index 600
Checkpoint saved at index 610
Checkpoint saved at index 620
Checkpoint saved at index 630
Checkpoint saved at index 640
Checkpoint saved at index 650
Checkpoint saved at index 660
Checkpoint saved at index 670
Progress update: Completed 10 batches
Checkpoint saved at index 680
Checkpoint saved at index 690
Checkpoint saved at index 700
Checkpoint saved at index 710
Checkpoint saved at index 720
Checkpoint saved at index 730
Checkpoint saved at index 740
Checkpoint saved at index 750
Checkpoint saved at index 760
Checkpoint saved at index 770
Progress update: Completed 20 batches
Checkpoint saved at index 780
Checkpoint saved at index 790
Checkpoint saved at index 800
Checkpoint saved at index 810
Checkpoint saved at index 820
Checkpoint saved at index 830
Checkpoint saved at index 840
Checkpoint saved at index 850
Checkpoint 

In [50]:
# Final save
save_to_csv(perturbed_data, "gsm8k_perturbed_answers.csv")
print("Finished all batches.")


Results saved to gsm8k_perturbed_answers.csv
Finished all batches.
