This notebook tackles class imbalance in the Quora duplicate question detection task by augmenting the training data with high-quality, semantically similar duplicate pairs generated by large language models (LLMs). The augmentation was performed using OpenAI's ChatGPT (gpt-3.5-turbo) and T5 (Text-to-Text Transfer Transformer) to produce paraphrased duplicate question pairs. All synthetic examples were filtered using BERT-based sentence embeddings and cosine similarity to ensure semantic alignment and label quality. The goal was to enrich the dataset with more positive (duplicate) examples, addressing its original class imbalance. About 2,700 high-confidence synthetic duplicates were added. However, the total number was limited due to the high computational and memory demands of generation, embedding, and filtering.

# Imports

In [1]:
import json
import os
import time
import pandas as pd
import glob
import re
import ftfy
import sys
sys.path.append('..')

from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer, util
from src.augmentation.llm_augmentation import generate_augmented_pairs
from src.features.paraphraser_utils import load_model, batch_paraphrase, is_question
from tqdm import tqdm






# Augmentation

In [5]:
# Load OpenAI API key
with open("../creds.json") as f:
    creds = json.load(f)
os.environ["OPENAI_API_KEY"] = creds["OPENAI_API_KEY"]

# Initialize LLM and embeddings
llm = ChatOpenAI(model="gpt-4", temperature=0.4)
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

prompt = PromptTemplate(
    input_variables=["theme"],
    template=(
        "Generate an original natural-sounding user question on the topic of \"{theme}\", "
        "followed by a semantically equivalent version of the same question with different wording. Use this format:\n\n"
        "Q1: <original question>\n"
        "Q2: <paraphrased version>\n"
    ),
)

themes = [
    "travel", "finance", "health", "technology", "education",
    "job search", "taxes", "e-commerce", "cybersecurity", "food"
]

df_aug = generate_augmented_pairs(
    themes,
    n_pairs=500,               
    similarity_threshold=0.75,
    prompt=prompt,
    llm=llm,
    embed_model=embed_model,
    delay=1.5
)
df_aug.to_csv("../data/processed/augmented_duplicate_pairs.csv", index=False)
print("Saved to augmented_duplicate_pairs.csv")

✓ 1: Passed filter (0.92)
✓ 2: Passed filter (0.80)
✓ 3: Passed filter (0.94)
✓ 4: Passed filter (0.88)
✗ 5: Failed filter (0.75)
✓ 6: Passed filter (0.76)
✗ 7: Failed filter (0.71)
✗ 8: Failed filter (0.70)
✓ 9: Passed filter (0.79)
✓ 10: Passed filter (0.88)
✓ 11: Passed filter (0.92)
✗ 12: Failed filter (0.72)
✓ 13: Passed filter (0.84)
✓ 14: Passed filter (0.87)
✓ 15: Passed filter (0.88)
✗ 16: Duplicate detected, skipping.
✓ 17: Passed filter (0.90)
✓ 18: Passed filter (0.82)
✓ 19: Passed filter (0.80)
✓ 20: Passed filter (0.86)
✓ 21: Passed filter (0.86)
✗ 22: Failed filter (0.75)
✓ 23: Passed filter (0.91)
✗ 24: Duplicate detected, skipping.
✓ 25: Passed filter (0.77)
✗ 26: Failed filter (0.60)
✓ 27: Passed filter (0.92)
✓ 28: Passed filter (0.94)
✗ 29: Failed filter (0.69)
✓ 30: Passed filter (0.86)
✓ 31: Passed filter (0.75)
✓ 32: Passed filter (0.75)
✓ 33: Passed filter (0.95)
✗ 34: Duplicate detected, skipping.
✓ 35: Passed filter (0.76)
✗ 36: Duplicate detected, skipping.
✓

In [None]:
# Making another 500 duplicate pairs 
df_aug2 = generate_augmented_pairs(
    themes,
    n_pairs=1000,               
    similarity_threshold=0.75,
    prompt=prompt,
    llm=llm,
    embed_model=embed_model,
    delay=1.5
)
df_aug2.to_csv("../data/processed/2augmented_duplicate_pairs.csv", index=False)
print("Saved to augmented_duplicate_pairs.csv")

In [9]:
# Cheching for dublicates between 2 files and combine them
combined_df = pd.concat([df_aug, df_aug2], ignore_index=True)
combined_df.drop_duplicates(subset=["question1", "question2"], inplace=True)

combined_df.reset_index(drop=True, inplace=True)
combined_df.to_csv("../data/processed/augmented_duplicate_pairs_500.csv", index=False)
print("Final cleaned dataset with", len(combined_df), "pairs saved.")


Final cleaned dataset with 570 pairs saved.


In [3]:
# Adding to main training set
df_train = pd.read_csv("../data/processed/quora_train_cleaned.csv")
combined_df = pd.read_csv("../data/processed/augmented_duplicate_pairs_500.csv")

# Concatenate
df_full = pd.concat([df_train, combined_df], ignore_index=True)

# Remove accidental repeats (handle (q1, q2) and (q2, q1) as duplicates)
df_full['q1_q2'] = df_full.apply(
    lambda row: " || ".join(sorted([row['question1'], row['question2']])), axis=1
)
df_full = df_full.drop_duplicates(subset=['q1_q2'])
df_full = df_full.drop(columns=['q1_q2'])

# Shuffle 
df_full = df_full.sample(frac=1, random_state=42).reset_index(drop=True)

# T5 Paraphraser

In [6]:
# Load model 
tokenizer, model, device = load_model()

# Sample subset
n_to_augment = 20000
train_sample = df_full.sample(n=n_to_augment, random_state=42).dropna(subset=['question1', 'question2']).reset_index(drop=True)

# Paraphrase
q1_paraphrased = batch_paraphrase(train_sample['question1'].tolist(), tokenizer, model, device)
q2_paraphrased = batch_paraphrase(train_sample['question2'].tolist(), tokenizer, model, device)

assert len(train_sample) == len(q1_paraphrased) == len(q2_paraphrased), "Mismatch in lengths!"

# Create directory for checkpoints
checkpoint_dir = "paraphrase_checkpoints2"
os.makedirs(checkpoint_dir, exist_ok=True)

# Augment and save
augmented_rows = []

for i in tqdm(range(len(train_sample))):
    row = train_sample.iloc[i].copy()
    
    # Partial paraphrase
    row1 = row.copy(); row1['question1'] = q1_paraphrased[i]; augmented_rows.append(row1)
    row2 = row.copy(); row2['question2'] = q2_paraphrased[i]; augmented_rows.append(row2)

    # Only positive duplicates fully paraphrased
    if row['is_duplicate'] == 1:
        row3 = row.copy()
        row3['question1'] = q1_paraphrased[i]
        row3['question2'] = q2_paraphrased[i]
        augmented_rows.append(row3)

    # Save in chunks
    if (i + 1) % 3000 == 0 or (i + 1) == len(train_sample):
        checkpoint_df = pd.DataFrame(augmented_rows)
        checkpoint_df = checkpoint_df[
            checkpoint_df['question1'].apply(is_question) & checkpoint_df['question2'].apply(is_question)
        ].drop_duplicates(subset=['question1', 'question2', 'is_duplicate']).reset_index(drop=True)

        checkpoint_path = os.path.join(checkpoint_dir, f"augmented_checkpoint2_{i+1}.csv")
        checkpoint_df.to_csv(checkpoint_path, index=False)
        augmented_rows = []  # Clear buffer


100%|██████████| 625/625 [15:32<00:00,  1.49s/it]
100%|██████████| 625/625 [15:43<00:00,  1.51s/it]
100%|██████████| 20000/20000 [00:06<00:00, 3326.37it/s]


# Preprocessing of the augmented data and merging with the original

In [None]:
# Load and concatenate all augmented CSV files
augmented_df = pd.concat([pd.read_csv(f) for f in all_files], ignore_index=True)
print(f"Loaded augmented rows: {len(augmented_df)}")

# Load additional 500 duplicates file
combined_df = pd.read_csv("../data/processed/augmented_duplicate_pairs_500.csv")
print(f"Loaded extra file: {len(combined_df)}")

# 3. Combine all augmented data
augmented_df = pd.concat([augmented_df, combined_df], ignore_index=True)
print(f"Total combined augmented rows: {len(augmented_df)}")

augmented_df = augmented_df.dropna(subset=['question1', 'question2']).reset_index(drop=True)
print(f"After NaN removal: {len(augmented_df)}")
augmented_df = augmented_df[augmented_df['is_duplicate'] == 1].reset_index(drop=True)
print(f"Kept only positive duplicates: {len(augmented_df)}")
dup_count = augmented_df.duplicated(subset=['question1', 'question2']).sum()
print(f"Duplicate (Q1, Q2) pairs in augmented_df: {dup_count}")


# Function to clean text: fix encoding, lowercase, trim whitespace
def preprocess_augmented(text):
    text = str(text)
    text = ftfy.fix_text(text)
    text = text.strip().lower()
    text = re.sub(r'\s+', ' ', text)
    return text

# Apply text cleaning to both questions
augmented_df['question1'] = augmented_df['question1'].apply(preprocess_augmented)
augmented_df['question2'] = augmented_df['question2'].apply(preprocess_augmented)

# Keep only English rows based on ASCII check
def is_english(text):
    try:
        text.encode('utf-8').decode('ascii')
        return True
    except UnicodeDecodeError:
        return False

augmented_df = augmented_df[
    augmented_df['question1'].apply(is_english) &
    augmented_df['question2'].apply(is_english)
].reset_index(drop=True)
print(f"After English filtering: {len(augmented_df)}")

# Recalculate `is_duplicate` using semantic similarity
sim_model = SentenceTransformer('all-MiniLM-L6-v2')

def recalculate_duplicate_label(q1, q2, threshold=0.8):
    emb1 = sim_model.encode(q1, convert_to_tensor=True)
    emb2 = sim_model.encode(q2, convert_to_tensor=True)
    similarity = util.pytorch_cos_sim(emb1, emb2)
    return 1 if similarity.item() > threshold else 0

# Apply label recalculation with progress bar
augmented_df['is_duplicate'] = [
    recalculate_duplicate_label(q1, q2)
    for q1, q2 in tqdm(zip(augmented_df['question1'], augmented_df['question2']), total=len(augmented_df))
]
# Remove NaNs
augmented_df = augmented_df.dropna(subset=['question1', 'question2']).reset_index(drop=True)

# Keep only positive duplicates
augmented_df = augmented_df[augmented_df['is_duplicate'] == 1].reset_index(drop=True)

# Check for q1-q2 duplicates (ignoring label)
dup_count = augmented_df.duplicated(subset=['question1', 'question2']).sum()
print(f"Duplicated q-pairs (ignoring label) deleted: {dup_count}")

# Then drop true duplicates
augmented_df = augmented_df.drop_duplicates(subset=['question1', 'question2', 'is_duplicate'])

# Merge with original dataset and remove exact duplicates
full_df = pd.concat([df_full, augmented_df], ignore_index=True).drop_duplicates(
    subset=['question1', 'question2', 'is_duplicate']
).sample(frac=1, random_state=42).reset_index(drop=True)

# Save the cleaned and combined datasets
os.makedirs("../data/processed", exist_ok=True)
augmented_df.to_csv("../data/processed/augmented_questions.csv", index=False)
full_df.to_csv("../data/processed/full_train_augmented.csv", index=False)

print("✅ Saved augmented and full dataset.")



Loaded augmented rows: 5197
Loaded extra file: 570
Total combined augmented rows: 5767
After NaN removal: 5767
Kept only positive duplicates: 2701
Duplicate (Q1, Q2) pairs in augmented_df: 50
After English filtering: 2693


100%|██████████| 2693/2693 [00:47<00:00, 56.30it/s]


Duplicated q-pairs (ignoring label) deleted: 100
✅ Saved augmented and full dataset.


In [17]:
print(len(augmented_df), len(full_df))


1180 324649


In [18]:
for i, row in augmented_df[10:20].iterrows():
    print("Q1:", row["question1"])
    print("Q2:", row["question2"])
    print("Is Duplicate:", row["is_duplicate"])
    print("---")

Q1: have you ever seen someone die?
Q2: have you ever seen anyone die?
Is Duplicate: 1
---
Q1: what's the other thing you think about your mom?
Q2: what's that one thing that comes to your mind when you think about your mother?
Is Duplicate: 1
---
Q1: what are best books for ssc cgl?
Q2: are there any books for ssc cgl?
Is Duplicate: 1
---
Q1: why do i like girls soles?
Q2: i like girls bare soles?
Is Duplicate: 1
---
Q1: i know this is known question but how do i know if she likes me?
Q2: how would i know if she really likes me?
Is Duplicate: 1
---
Q1: does cannabis oil cure cancer?
Q2: how does cannabis affect cancer?
Is Duplicate: 1
---
Q1: how can one prepare for entrance exam of top ibdp schools?
Q2: how to prepare yourself for the ibdp school entrance exam?
Is Duplicate: 1
---
Q1: what all career options i have after ece?
Q2: career advice - i have done my be ece, what are the various career options for me now?
Is Duplicate: 1
---
Q1: my google account is disabled. how can i enab

In [21]:
# Check new class balance
print("Class distribution:")
print(full_df["is_duplicate"].value_counts(normalize=True))
print(df_train["is_duplicate"].value_counts(normalize=True))

Class distribution:
is_duplicate
0    0.626843
1    0.373157
Name: proportion, dtype: float64
is_duplicate
0    0.63024
1    0.36976
Name: proportion, dtype: float64
