# Lab 3.1.6: Dataset Preparation - Solutions

Complete solutions for dataset preparation exercises.

## Exercise 1: Multi-Format Dataset Converter

In [None]:
import json
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field

@dataclass
class UniversalConverter:
    """
    Universal dataset format converter supporting multiple formats.
    """
    
    @staticmethod
    def detect_format(data: List[Dict]) -> str:
        """Auto-detect dataset format."""
        sample = data[0]
        
        if "conversations" in sample:
            return "sharegpt"
        elif "messages" in sample:
            return "openai"
        elif "instruction" in sample:
            return "alpaca"
        elif "prompt" in sample and "completion" in sample:
            return "completion"
        elif "chosen" in sample and "rejected" in sample:
            return "preference"
        else:
            return "unknown"
    
    @staticmethod
    def to_sharegpt(data: List[Dict], source_format: str) -> List[Dict]:
        """Convert any format to ShareGPT."""
        converted = []
        
        for item in data:
            if source_format == "alpaca":
                conversations = [
                    {"from": "human", "value": item.get("instruction", "") + 
                     (f"\n\nInput: {item['input']}" if item.get("input") else "")},
                    {"from": "gpt", "value": item.get("output", "")}
                ]
            elif source_format == "openai":
                conversations = []
                for msg in item.get("messages", []):
                    role = "human" if msg["role"] == "user" else "gpt"
                    conversations.append({"from": role, "value": msg["content"]})
            elif source_format == "completion":
                conversations = [
                    {"from": "human", "value": item.get("prompt", "")},
                    {"from": "gpt", "value": item.get("completion", "")}
                ]
            else:
                conversations = item.get("conversations", [])
            
            converted.append({"conversations": conversations})
        
        return converted
    
    @staticmethod
    def to_alpaca(data: List[Dict], source_format: str) -> List[Dict]:
        """Convert any format to Alpaca."""
        converted = []
        
        for item in data:
            if source_format == "sharegpt":
                convs = item.get("conversations", [])
                human_msgs = [c["value"] for c in convs if c["from"] == "human"]
                gpt_msgs = [c["value"] for c in convs if c["from"] == "gpt"]
                
                converted.append({
                    "instruction": human_msgs[0] if human_msgs else "",
                    "input": "",
                    "output": gpt_msgs[0] if gpt_msgs else ""
                })
            elif source_format == "openai":
                msgs = item.get("messages", [])
                user_msgs = [m["content"] for m in msgs if m["role"] == "user"]
                asst_msgs = [m["content"] for m in msgs if m["role"] == "assistant"]
                
                converted.append({
                    "instruction": user_msgs[0] if user_msgs else "",
                    "input": "",
                    "output": asst_msgs[0] if asst_msgs else ""
                })
            elif source_format == "completion":
                converted.append({
                    "instruction": item.get("prompt", ""),
                    "input": "",
                    "output": item.get("completion", "")
                })
            else:
                converted.append(item)
        
        return converted
    
    @staticmethod
    def to_chatml(data: List[Dict], source_format: str, system_prompt: str = "") -> List[Dict]:
        """Convert any format to ChatML."""
        converted = []
        
        for item in data:
            messages = []
            
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            
            if source_format == "alpaca":
                user_content = item.get("instruction", "")
                if item.get("input"):
                    user_content += f"\n\nInput: {item['input']}"
                messages.append({"role": "user", "content": user_content})
                messages.append({"role": "assistant", "content": item.get("output", "")})
            
            elif source_format == "sharegpt":
                for conv in item.get("conversations", []):
                    role = "user" if conv["from"] == "human" else "assistant"
                    messages.append({"role": role, "content": conv["value"]})
            
            elif source_format == "completion":
                messages.append({"role": "user", "content": item.get("prompt", "")})
                messages.append({"role": "assistant", "content": item.get("completion", "")})
            
            converted.append({"messages": messages})
        
        return converted

# Test
converter = UniversalConverter()

# Sample Alpaca data
alpaca_data = [
    {"instruction": "Explain gravity", "input": "", "output": "Gravity is a force..."},
    {"instruction": "Translate", "input": "Hello world", "output": "Hola mundo"}
]

print("Alpaca → ShareGPT:")
print(json.dumps(converter.to_sharegpt(alpaca_data, "alpaca")[0], indent=2))

print("\nAlpaca → ChatML:")
print(json.dumps(converter.to_chatml(alpaca_data, "alpaca", "You are helpful.")[0], indent=2))

## Exercise 2: Quality Filter Pipeline

In [None]:
import re
from collections import Counter

class DataQualityPipeline:
    """
    Complete data quality filtering pipeline.
    """
    
    def __init__(self):
        self.stats = {
            "total": 0,
            "passed": 0,
            "filtered": Counter()
        }
    
    def filter_length(
        self,
        text: str,
        min_chars: int = 50,
        max_chars: int = 10000,
        min_words: int = 10,
        max_words: int = 2000
    ) -> tuple:
        """Filter by length."""
        if len(text) < min_chars:
            return False, "too_short_chars"
        if len(text) > max_chars:
            return False, "too_long_chars"
        
        words = text.split()
        if len(words) < min_words:
            return False, "too_few_words"
        if len(words) > max_words:
            return False, "too_many_words"
        
        return True, None
    
    def filter_quality(self, text: str) -> tuple:
        """Filter by content quality."""
        # Check for repetition
        words = text.lower().split()
        if len(words) > 10:
            unique_ratio = len(set(words)) / len(words)
            if unique_ratio < 0.3:
                return False, "too_repetitive"
        
        # Check for special character ratio
        alpha_chars = sum(c.isalpha() for c in text)
        if len(text) > 0 and alpha_chars / len(text) < 0.5:
            return False, "too_many_special_chars"
        
        # Check for excessive capitalization
        upper_chars = sum(c.isupper() for c in text)
        if len(text) > 0 and upper_chars / len(text) > 0.5:
            return False, "excessive_caps"
        
        return True, None
    
    def filter_content(self, text: str) -> tuple:
        """Filter inappropriate content."""
        # Placeholder patterns (expand as needed)
        placeholder_patterns = [
            r"\[.*?\]",  # [placeholder]
            r"\{.*?\}",  # {placeholder}
            r"<.*?>",    # <placeholder>
            r"___+",     # ____
        ]
        
        for pattern in placeholder_patterns:
            if len(re.findall(pattern, text)) > 3:
                return False, "too_many_placeholders"
        
        # Check for "I cannot" / "I'm sorry" patterns
        refusal_patterns = [
            r"i cannot",
            r"i'm sorry",
            r"i am unable",
            r"as an ai",
        ]
        
        text_lower = text.lower()
        for pattern in refusal_patterns:
            if pattern in text_lower:
                return False, "contains_refusal"
        
        return True, None
    
    def filter_language(
        self,
        text: str,
        min_english_ratio: float = 0.8
    ) -> tuple:
        """Basic language filter (English)."""
        # Simple heuristic: check ASCII ratio
        ascii_chars = sum(ord(c) < 128 for c in text)
        if len(text) > 0:
            ascii_ratio = ascii_chars / len(text)
            if ascii_ratio < min_english_ratio:
                return False, "non_english"
        
        return True, None
    
    def process(
        self,
        data: List[Dict],
        text_field: str = "text"
    ) -> List[Dict]:
        """Process dataset through all filters."""
        filtered_data = []
        self.stats = {"total": len(data), "passed": 0, "filtered": Counter()}
        
        for item in data:
            text = item.get(text_field, "")
            if isinstance(text, list):
                # Handle conversation format
                text = " ".join(str(t) for t in text)
            
            # Run all filters
            filters = [
                self.filter_length,
                self.filter_quality,
                self.filter_content,
                self.filter_language,
            ]
            
            passed = True
            for filter_func in filters:
                result, reason = filter_func(text)
                if not result:
                    self.stats["filtered"][reason] += 1
                    passed = False
                    break
            
            if passed:
                filtered_data.append(item)
                self.stats["passed"] += 1
        
        return filtered_data
    
    def report(self):
        """Print filtering report."""
        print("\nData Quality Report")
        print("=" * 40)
        print(f"Total samples: {self.stats['total']}")
        print(f"Passed: {self.stats['passed']} ({100*self.stats['passed']/self.stats['total']:.1f}%)")
        print(f"Filtered: {self.stats['total'] - self.stats['passed']}")
        print("\nFilter breakdown:")
        for reason, count in self.stats["filtered"].most_common():
            print(f"  {reason}: {count}")

# Test
test_data = [
    {"text": "This is a good quality response about machine learning and its applications."},
    {"text": "Hi"},  # Too short
    {"text": "word " * 500},  # Repetitive
    {"text": "I cannot help with that request. As an AI..."},  # Refusal
    {"text": "!!!! @@@@ #### $$$$ %%%%"},  # Special chars
    {"text": "This is another valid response with sufficient length and quality content."},
]

pipeline = DataQualityPipeline()
filtered = pipeline.process(test_data)
pipeline.report()

## Exercise 3: Preference Data Generator

In [None]:
class PreferenceDataGenerator:
    """
    Generate preference pairs for DPO/ORPO training.
    """
    
    @staticmethod
    def from_ratings(
        data: List[Dict],
        prompt_field: str = "prompt",
        response_field: str = "response",
        rating_field: str = "rating"
    ) -> List[Dict]:
        """
        Create pairs from rated responses.
        
        Higher rating = chosen, lower rating = rejected.
        """
        from itertools import groupby
        
        # Group by prompt
        sorted_data = sorted(data, key=lambda x: x[prompt_field])
        pairs = []
        
        for prompt, group in groupby(sorted_data, key=lambda x: x[prompt_field]):
            responses = list(group)
            if len(responses) < 2:
                continue
            
            # Sort by rating
            responses.sort(key=lambda x: x[rating_field], reverse=True)
            
            # Create pairs: best vs each worse
            best = responses[0]
            for worse in responses[1:]:
                if best[rating_field] > worse[rating_field]:
                    pairs.append({
                        "prompt": prompt,
                        "chosen": best[response_field],
                        "rejected": worse[response_field]
                    })
        
        return pairs
    
    @staticmethod
    def from_comparison(
        data: List[Dict],
        prompt_field: str = "prompt",
        response_a_field: str = "response_a",
        response_b_field: str = "response_b",
        preference_field: str = "preference"  # "a" or "b"
    ) -> List[Dict]:
        """
        Create pairs from A/B comparisons.
        """
        pairs = []
        
        for item in data:
            if item[preference_field].lower() == "a":
                chosen = item[response_a_field]
                rejected = item[response_b_field]
            else:
                chosen = item[response_b_field]
                rejected = item[response_a_field]
            
            pairs.append({
                "prompt": item[prompt_field],
                "chosen": chosen,
                "rejected": rejected
            })
        
        return pairs
    
    @staticmethod
    def from_binary_feedback(
        data: List[Dict],
        prompt_field: str = "prompt",
        response_field: str = "response",
        feedback_field: str = "feedback"  # "good" or "bad"
    ) -> Dict[str, List[Dict]]:
        """
        Organize data for KTO training (binary feedback).
        """
        desirable = []
        undesirable = []
        
        for item in data:
            entry = {
                "prompt": item[prompt_field],
                "completion": item[response_field]
            }
            
            if item[feedback_field].lower() in ["good", "positive", "1", "true"]:
                entry["label"] = True
                desirable.append(entry)
            else:
                entry["label"] = False
                undesirable.append(entry)
        
        return {
            "desirable": desirable,
            "undesirable": undesirable
        }

# Test
generator = PreferenceDataGenerator()

# From ratings
rated_data = [
    {"prompt": "What is AI?", "response": "AI is artificial intelligence.", "rating": 3},
    {"prompt": "What is AI?", "response": "AI refers to machine learning systems.", "rating": 4},
    {"prompt": "What is AI?", "response": "Comprehensive explanation...", "rating": 5},
]

pairs = generator.from_ratings(rated_data)
print("From Ratings:")
print(json.dumps(pairs, indent=2))

## Exercise 4: Deduplication

In [None]:
import hashlib
from difflib import SequenceMatcher

class DataDeduplicator:
    """
    Remove duplicate and near-duplicate samples.
    """
    
    @staticmethod
    def exact_dedup(data: List[Dict], text_field: str = "text") -> List[Dict]:
        """Remove exact duplicates using hash."""
        seen = set()
        unique = []
        
        for item in data:
            text = str(item.get(text_field, ""))
            text_hash = hashlib.md5(text.encode()).hexdigest()
            
            if text_hash not in seen:
                seen.add(text_hash)
                unique.append(item)
        
        print(f"Exact dedup: {len(data)} → {len(unique)} ({len(data)-len(unique)} removed)")
        return unique
    
    @staticmethod
    def fuzzy_dedup(
        data: List[Dict],
        text_field: str = "text",
        threshold: float = 0.9
    ) -> List[Dict]:
        """
        Remove near-duplicates using similarity threshold.
        
        Note: O(n²) complexity - use MinHash for large datasets.
        """
        unique = []
        
        for item in data:
            text = str(item.get(text_field, ""))
            is_duplicate = False
            
            for existing in unique:
                existing_text = str(existing.get(text_field, ""))
                similarity = SequenceMatcher(None, text, existing_text).ratio()
                
                if similarity >= threshold:
                    is_duplicate = True
                    break
            
            if not is_duplicate:
                unique.append(item)
        
        print(f"Fuzzy dedup: {len(data)} → {len(unique)} ({len(data)-len(unique)} removed)")
        return unique
    
    @staticmethod
    def minhash_dedup(
        data: List[Dict],
        text_field: str = "text",
        threshold: float = 0.8,
        num_perm: int = 128
    ) -> List[Dict]:
        """
        Efficient near-duplicate detection using MinHash LSH.
        
        Requires: pip install datasketch
        """
        try:
            from datasketch import MinHash, MinHashLSH
        except ImportError:
            print("Install datasketch: pip install datasketch")
            return data
        
        lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
        unique = []
        
        for i, item in enumerate(data):
            text = str(item.get(text_field, ""))
            
            # Create MinHash
            mh = MinHash(num_perm=num_perm)
            for word in text.split():
                mh.update(word.encode('utf-8'))
            
            # Check for duplicates
            if not lsh.query(mh):
                lsh.insert(f"doc_{i}", mh)
                unique.append(item)
        
        print(f"MinHash dedup: {len(data)} → {len(unique)} ({len(data)-len(unique)} removed)")
        return unique

# Test
deduplicator = DataDeduplicator()

test_data = [
    {"text": "The quick brown fox jumps over the lazy dog."},
    {"text": "The quick brown fox jumps over the lazy dog."},  # Exact dup
    {"text": "The quick brown fox jumped over the lazy dog."},  # Near dup
    {"text": "A completely different sentence about something else."},
]

deduped = deduplicator.exact_dedup(test_data)
deduped = deduplicator.fuzzy_dedup(deduped, threshold=0.85)

## Key Takeaways

1. **Format Conversion**: Support multiple formats (Alpaca, ShareGPT, ChatML)
2. **Quality Filtering**: Length, quality, content, and language checks
3. **Preference Pairs**: Generate from ratings, comparisons, or binary feedback
4. **Deduplication**: Exact hash + fuzzy matching + MinHash for scale