# Dataset Enhancement with DeepSeek-R1

**Goal:** Generate MORE high-quality questions using existing curated JSONs

**Model:** DeepSeek-R1-Distill-Llama-70B via vLLM

**Process:**
1. Load existing curated questions one by one
2. Use each as example to generate variations
3. Validate generated questions
4. Save enhanced dataset

In [None]:
# Configuration
CURATED_DIR = Path("/Users/777bhavyagoyal/Developer/UNSLOTHxAMDxHACk/MAIN_CURATED_JSON")
OUTPUT_DIR = Path("/Users/777bhavyagoyal/Developer/UNSLOTHxAMDxHACk/ENHANCED_CURATED_JSON")
OUTPUT_DIR.mkdir(exist_ok=True)

# vLLM settings from your config
API_BASE = "http://localhost:8001/v1"
MODEL = "unsloth/DeepSeek-R1-Distill-Llama-70B"

# Generation settings
DATASET_MULTIPLIER = 2  # Generate 2x the original dataset size
TEMPERATURE = 0.4
TOP_P = 0.95
MAX_TOKENS = 4096
SLEEP_TIME = 0.1  # From your config

print(f"✅ Configuration:")
print(f"  API: {API_BASE}")
print(f"  Model: {MODEL}")
print(f"  Dataset Multiplier: {DATASET_MULTIPLIER}x (will generate 2x original size)")
print(f"  Temperature: {TEMPERATURE}")
print(f"  Input: {CURATED_DIR}")
print(f"  Output: {OUTPUT_DIR}")

In [None]:
# Configuration
CURATED_DIR = Path("/Users/777bhavyagoyal/Developer/UNSLOTHxAMDxHACk/MAIN_CURATED_JSON")
OUTPUT_DIR = Path("/Users/777bhavyagoyal/Developer/UNSLOTHxAMDxHACk/ENHANCED_CURATED_JSON")
OUTPUT_DIR.mkdir(exist_ok=True)

# vLLM settings from your config
API_BASE = "http://localhost:8001/v1"
MODEL = "unsloth/DeepSeek-R1-Distill-Llama-70B"

# Generation settings
NUM_VARIATIONS_PER_QUESTION = 2  # Generate 2 variations per original question
TEMPERATURE = 0.4
TOP_P = 0.95
MAX_TOKENS = 4096
SLEEP_TIME = 0.1  # From your config

print(f"✅ Configuration:")
print(f"  API: {API_BASE}")
print(f"  Model: {MODEL}")
print(f"  Variations per question: {NUM_VARIATIONS_PER_QUESTION}")
print(f"  Temperature: {TEMPERATURE}")
print(f"  Input: {CURATED_DIR}")
print(f"  Output: {OUTPUT_DIR}")

## Generation Prompt

In [None]:
ENHANCEMENT_PROMPT = """Generate {num_new} NEW variations of this question for AMD AI Dev Day Hackathon.

CRITICAL REQUIREMENTS:
✓ Keep same topic ({topic})
✓ Keep same difficulty level
✓ Change entity names (use different letters/names)
✓ Change numbers/positions/relationships
✓ Keep same question structure
✓ SELF-CONTAINED: Include ALL info needed to solve
✓ NO coded relations (no symbols like * × +)
✓ Exactly 4 choices starting with "A) ", "B) ", "C) ", "D) "
✓ Answer is single capital letter: "A", "B", "C", or "D"
✓ Reasoning as SINGLE STRING with 5 steps
✓ 3-6 named entities
✓ At least 3 unique constraints

EXAMPLE QUESTION TO BASE ON:
{example}

Generate {num_new} NEW similar questions with:
- Different entity names
- Different numbers/positions
- Different specific constraints
- Same topic and difficulty

Each question MUST have these exact 7 fields:
{{
  "topic": "{topic}",
  "question": "Complete self-contained question text",
  "choices": ["A) option1", "B) option2", "C) option3", "D) option4"],
  "answer": "A" or "B" or "C" or "D",
  "explanation": "Brief explanation under 100 words",
  "reasoning": "Step 1: ... Step 2: ... Step 3: ... Step 4: ... Step 5: ...",
  "difficulty": "easy" or "medium" or "hard"
}}

CRITICAL OUTPUT RULES:
- Return ONLY JSON array: [{{}}, {{}}]
- NO markdown code blocks
- NO extra text before or after
- After </think> tag output [
- Reasoning must be SINGLE STRING not array
- Exactly 4 choices per question

Generate {num_new} new questions now:
"""

print("✅ Enhancement prompt ready")

## vLLM API Call Function

In [None]:
def call_vllm(prompt):
    """Call vLLM API and return generated questions"""
    try:
        response = requests.post(
            f"{API_BASE}/completions",
            json={
                "model": MODEL,
                "prompt": prompt,
                "temperature": TEMPERATURE,
                "top_p": TOP_P,
                "max_tokens": MAX_TOKENS
            },
            timeout=120
        )
        response.raise_for_status()
        
        output = response.json()['choices'][0]['text']
        
        # Clean output - remove thinking tags if present
        if '</think>' in output:
            output = output.split('</think>')[-1].strip()
        
        # Remove markdown code blocks
        output = output.replace('```json', '').replace('```', '').strip()
        
        # Extract JSON array
        match = re.search(r'\[.*\]', output, re.DOTALL)
        if match:
            questions = json.loads(match.group(0))
            return questions if isinstance(questions, list) else []
        else:
            print(f"    ⚠️  No JSON array found in output")
            return []
            
    except requests.exceptions.RequestException as e:
        print(f"    ❌ API Error: {e}")
        return []
    except json.JSONDecodeError as e:
        print(f"    ❌ JSON Parse Error: {e}")
        return []
    except Exception as e:
        print(f"    ❌ Unexpected Error: {e}")
        return []

print("✅ API call function ready")

In [None]:
# Generate enhanced questions
seed(42)

# STEP 1: Count original questions
total_original = len(all_questions)
target_enhanced = total_original * DATASET_MULTIPLIER  # 2x original

print(f"\n{'='*60}")
print(f"📊 DATASET PLAN")
print(f"{'='*60}")
print(f"Original questions: {total_original}")
print(f"Target multiplier: {DATASET_MULTIPLIER}x")
print(f"Target enhanced questions: {target_enhanced}")
print(f"{'='*60}\n")

# STEP 2: Calculate how many variations per question
variations_per_question = DATASET_MULTIPLIER  # 2 variations per question = 2x dataset

enhanced_questions = []
total_generated = 0
total_valid = 0
total_attempted = 0

print(f"🔄 Starting enhancement...\n")
print(f"Strategy: Generate {variations_per_question} variations per question")
print(f"Processing {len(all_questions)} questions")
print(f"Expected output: ~{target_enhanced} new questions\n")
print("="*60)

for idx, original_q in enumerate(all_questions, 1):
    topic = original_q.get('topic')
    
    print(f"\n[{idx}/{len(all_questions)}] Processing {topic} question...")
    print(f"  Target: {variations_per_question} variations | Total so far: {total_valid}/{target_enhanced}")
    
    # Create prompt
    prompt = ENHANCEMENT_PROMPT.format(
        num_new=variations_per_question,
        topic=topic,
        example=json.dumps(original_q, indent=2, ensure_ascii=False)
    )
    
    # Call API
    total_attempted += 1
    new_questions = call_vllm(prompt)
    
    if new_questions:
        total_generated += len(new_questions)
        print(f"  📥 Generated {len(new_questions)} questions")
        
        # Validate each generated question
        for q in new_questions:
            is_valid, reason = validate_question(q, topic)
            if is_valid:
                enhanced_questions.append(q)
                total_valid += 1
                print(f"  ✅ Valid question added (Total: {total_valid}/{target_enhanced})")
            else:
                print(f"  ❌ Invalid: {reason}")
    else:
        print(f"  ❌ No questions generated")
    
    # Progress report every 50 questions
    if idx % 50 == 0:
        print(f"\n{'='*60}")
        print(f"PROGRESS CHECKPOINT")
        print(f"{'='*60}")
        print(f"Processed: {idx}/{len(all_questions)} ({idx/len(all_questions)*100:.1f}%)")
        print(f"Valid questions: {total_valid}/{target_enhanced} ({total_valid/target_enhanced*100:.1f}%)")
        print(f"Success rate: {(total_valid/max(total_generated,1))*100:.1f}%")
        print(f"{'='*60}")
    
    # Sleep to avoid overwhelming API
    sleep(SLEEP_TIME)
    
    # Stop if we've reached target (optional - remove if you want exact multiplier)
    # if total_valid >= target_enhanced:
    #     print(f"\n✅ Target reached! Stopping at {total_valid} questions")
    #     break

print(f"\n{'='*60}")
print(f"🎉 ENHANCEMENT COMPLETE!")
print(f"{'='*60}")
print(f"Original questions: {total_original}")
print(f"Target enhanced: {target_enhanced}")
print(f"API calls made: {total_attempted}")
print(f"Questions generated: {total_generated}")
print(f"Questions validated: {total_valid}")
print(f"Success rate: {(total_valid/max(total_generated,1))*100:.1f}%")
print(f"Target achievement: {(total_valid/target_enhanced)*100:.1f}%")
print(f"{'='*60}")

In [None]:
def validate_question(q, expected_topic):
    """Validate generated question meets all requirements"""
    
    # Check all required fields
    required_fields = ['topic', 'question', 'choices', 'answer', 'explanation', 'reasoning', 'difficulty']
    if not all(field in q for field in required_fields):
        return False, "Missing required fields"
    
    # Check topic matches
    if q.get('topic') != expected_topic:
        return False, f"Topic mismatch: expected {expected_topic}, got {q.get('topic')}"
    
    # Check exactly 4 choices
    if not isinstance(q.get('choices'), list) or len(q.get('choices')) != 4:
        return False, "Must have exactly 4 choices"
    
    # Check choices format
    choice_prefixes = ['A)', 'B)', 'C)', 'D)']
    for i, choice in enumerate(q.get('choices', [])):
        if not choice.startswith(choice_prefixes[i]):
            return False, f"Choice {i+1} must start with {choice_prefixes[i]}"
    
    # Check answer format
    if q.get('answer') not in ['A', 'B', 'C', 'D']:
        return False, "Answer must be A, B, C, or D"
    
    # Check reasoning is string not array
    if not isinstance(q.get('reasoning'), str):
        return False, "Reasoning must be a single string, not array"
    
    # Check reasoning has 5 steps
    reasoning = q.get('reasoning', '')
    step_count = sum(1 for i in range(1, 6) if f'Step {i}:' in reasoning)
    if step_count < 5:
        return False, f"Reasoning must have 5 steps, found {step_count}"
    
    # Check difficulty
    if q.get('difficulty') not in ['easy', 'medium', 'hard']:
        return False, "Difficulty must be easy, medium, or hard"
    
    # Check question is not too short (self-contained check)
    if len(q.get('question', '')) < 50:
        return False, "Question too short, likely not self-contained"
    
    return True, "Valid"

print("✅ Validation function ready")

## Load Existing Questions

In [None]:
# Load all curated questions from JSON files
all_questions = []
json_files = sorted(CURATED_DIR.glob("*.json"))

print(f"📂 Found {len(json_files)} JSON files in {CURATED_DIR}\n")

for json_file in json_files:
    try:
        with open(json_file, 'r') as f:
            questions = json.load(f)
            
        # Filter valid questions
        valid = []
        for q in questions:
            if (len(q.get('choices', [])) == 4 and 
                q.get('answer', '') in ['A', 'B', 'C', 'D'] and
                q.get('topic') in ['blood_relations', 'seating_arrangement']):
                valid.append(q)
        
        all_questions.extend(valid)
        print(f"  ✅ {json_file.name}: Loaded {len(valid)} valid questions")
        
    except Exception as e:
        print(f"  ❌ {json_file.name}: Error - {e}")

print(f"\n📊 Total curated questions loaded: {len(all_questions)}")

# Group by topic
by_topic = {}
for q in all_questions:
    topic = q.get('topic')
    if topic not in by_topic:
        by_topic[topic] = []
    by_topic[topic].append(q)

print(f"\n📋 Questions by topic:")
for topic, questions in by_topic.items():
    print(f"  {topic}: {len(questions)} questions")

## Generate Enhanced Questions

In [None]:
# Generate enhanced questions
seed(42)

enhanced_questions = []
total_generated = 0
total_valid = 0
total_attempted = 0

print(f"\n🔄 Starting enhancement...\n")
print(f"Processing {len(all_questions)} questions")
print(f"Generating {NUM_VARIATIONS_PER_QUESTION} variations per question")
print(f"Expected total: {len(all_questions) * NUM_VARIATIONS_PER_QUESTION} new questions\n")
print("="*60)

for idx, original_q in enumerate(all_questions, 1):
    topic = original_q.get('topic')
    
    print(f"\n[{idx}/{len(all_questions)}] Processing {topic} question...")
    
    # Create prompt
    prompt = ENHANCEMENT_PROMPT.format(
        num_new=NUM_VARIATIONS_PER_QUESTION,
        topic=topic,
        example=json.dumps(original_q, indent=2, ensure_ascii=False)
    )
    
    # Call API
    total_attempted += 1
    new_questions = call_vllm(prompt)
    
    if new_questions:
        total_generated += len(new_questions)
        print(f"  📥 Generated {len(new_questions)} questions")
        
        # Validate each generated question
        for q in new_questions:
            is_valid, reason = validate_question(q, topic)
            if is_valid:
                enhanced_questions.append(q)
                total_valid += 1
                print(f"  ✅ Valid question added")
            else:
                print(f"  ❌ Invalid: {reason}")
    else:
        print(f"  ❌ No questions generated")
    
    # Progress report every 10 questions
    if idx % 10 == 0:
        print(f"\n{'='*60}")
        print(f"PROGRESS: {idx}/{len(all_questions)} processed")
        print(f"Generated: {total_generated} | Valid: {total_valid} | Success rate: {(total_valid/max(total_generated,1))*100:.1f}%")
        print(f"{'='*60}")
    
    # Sleep to avoid overwhelming API
    sleep(SLEEP_TIME)

print(f"\n{'='*60}")
print(f"🎉 ENHANCEMENT COMPLETE!")
print(f"{'='*60}")
print(f"Original questions: {len(all_questions)}")
print(f"API calls made: {total_attempted}")
print(f"Questions generated: {total_generated}")
print(f"Questions validated: {total_valid}")
print(f"Success rate: {(total_valid/max(total_generated,1))*100:.1f}%")
print(f"{'='*60}")

## Save Enhanced Dataset

In [None]:
# Save all enhanced questions
output_file = OUTPUT_DIR / "enhanced_questions.json"
with open(output_file, 'w') as f:
    json.dump(enhanced_questions, f, indent=2, ensure_ascii=False)

print(f"✅ Saved {len(enhanced_questions)} enhanced questions")
print(f"   File: {output_file}")

# Save by topic
by_topic_enhanced = {}
for q in enhanced_questions:
    topic = q.get('topic')
    if topic not in by_topic_enhanced:
        by_topic_enhanced[topic] = []
    by_topic_enhanced[topic].append(q)

print(f"\n📁 Saving by topic:")
for topic, questions in by_topic_enhanced.items():
    topic_file = OUTPUT_DIR / f"enhanced_{topic}.json"
    with open(topic_file, 'w') as f:
        json.dump(questions, f, indent=2, ensure_ascii=False)
    print(f"  ✅ {topic}: {len(questions)} questions → {topic_file.name}")

# Combine original + enhanced
combined = all_questions + enhanced_questions
combined_file = OUTPUT_DIR / "combined_curated_enhanced.json"
with open(combined_file, 'w') as f:
    json.dump(combined, f, indent=2, ensure_ascii=False)

print(f"\n✅ Combined dataset saved")
print(f"   File: {combined_file}")
print(f"   Total: {len(combined)} questions")

## Summary Report

In [None]:
print(f"\n{'='*60}")
print(f"📊 FINAL SUMMARY")
print(f"{'='*60}")
print(f"\n📈 Dataset Growth:")
print(f"  Original curated: {len(all_questions)}")
print(f"  Enhanced generated: {len(enhanced_questions)}")
print(f"  Total combined: {len(combined)}")
print(f"  Growth multiplier: {len(combined)/len(all_questions):.2f}x")

print(f"\n📋 By Topic (Original → Enhanced):")
for topic in by_topic.keys():
    orig_count = len(by_topic.get(topic, []))
    enh_count = len(by_topic_enhanced.get(topic, []))
    total = orig_count + enh_count
    print(f"  {topic}:")
    print(f"    Original: {orig_count}")
    print(f"    Enhanced: {enh_count}")
    print(f"    Total: {total} ({total/orig_count:.2f}x growth)")

print(f"\n💾 Output Files:")
print(f"  {OUTPUT_DIR}/")
print(f"    ├── enhanced_questions.json ({len(enhanced_questions)} questions)")
print(f"    ├── combined_curated_enhanced.json ({len(combined)} questions)")
for topic in by_topic_enhanced.keys():
    print(f"    └── enhanced_{topic}.json ({len(by_topic_enhanced[topic])} questions)")

print(f"\n{'='*60}")
print(f"🎊 Enhancement Complete!")
print(f"{'='*60}")

## Sample Check

In [None]:
# Show sample enhanced questions
if enhanced_questions:
    print("\n📝 Sample Enhanced Questions:\n")
    
    samples = sample(enhanced_questions, min(3, len(enhanced_questions)))
    
    for i, q in enumerate(samples, 1):
        print(f"{'='*60}")
        print(f"Sample {i}: {q.get('topic')} ({q.get('difficulty')})")
        print('='*60)
        print(json.dumps(q, indent=2, ensure_ascii=False))
        print()
        
        # Quick validation check
        is_valid, reason = validate_question(q, q.get('topic'))
        if is_valid:
            print("✅ Validation: PASS")
        else:
            print(f"❌ Validation: FAIL - {reason}")
        print()
else:
    print("\n⚠️  No enhanced questions generated")