# Lab 3.1.4 Solutions: Dataset Preparation for Fine-Tuning

**Module:** 3.1 - Large Language Model Fine-Tuning  
**Difficulty:** ⭐⭐☆☆☆ (Beginner-Intermediate)  
**Exercises:** 3 (Multi-Format Converter, Chat Template Formatter, Data Quality Pipeline)

This notebook contains solutions for the dataset preparation exercises.

---

---

## Exercise 1 Solution: Multi-Format Dataset Converter

**Task:** Create a converter that handles Alpaca, ShareGPT, and custom formats.

In [None]:
import json
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass
from enum import Enum

class DatasetFormat(Enum):
    """Supported dataset formats."""
    ALPACA = "alpaca"          # instruction, input, output
    SHAREGPT = "sharegpt"      # conversations with roles
    OASST = "oasst"            # Open Assistant format
    DOLLY = "dolly"            # Databricks Dolly format
    CUSTOM = "custom"          # User-defined format


@dataclass
class ConversationTurn:
    """A single turn in a conversation."""
    role: str  # 'system', 'user', 'assistant'
    content: str


@dataclass
class Conversation:
    """A full conversation with multiple turns."""
    turns: List[ConversationTurn]
    metadata: Optional[Dict[str, Any]] = None
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "conversations": [
                {"role": t.role, "content": t.content}
                for t in self.turns
            ],
            "metadata": self.metadata or {}
        }


class UniversalDatasetConverter:
    """
    Convert between different fine-tuning dataset formats.
    
    Supports:
    - Alpaca (instruction/input/output)
    - ShareGPT (conversations)
    - Open Assistant (trees)
    - Dolly (context/instruction/response)
    - Custom formats via user-defined parsers
    """
    
    def __init__(self):
        self.custom_parsers: Dict[str, Callable] = {}
    
    def register_parser(
        self, 
        name: str, 
        parser: Callable[[Dict], Conversation]
    ):
        """
        Register a custom parser for a new format.
        
        Args:
            name: Name of the format
            parser: Function that converts a dict to Conversation
        """
        self.custom_parsers[name] = parser
    
    def parse_alpaca(self, item: Dict) -> Conversation:
        """Parse Alpaca format (instruction, input, output)."""
        turns = []
        
        # Build user message
        instruction = item.get("instruction", "")
        input_text = item.get("input", "")
        
        if input_text:
            user_content = f"{instruction}\n\nInput: {input_text}"
        else:
            user_content = instruction
        
        turns.append(ConversationTurn(role="user", content=user_content))
        turns.append(ConversationTurn(
            role="assistant", 
            content=item.get("output", "")
        ))
        
        return Conversation(turns=turns)
    
    def parse_sharegpt(self, item: Dict) -> Conversation:
        """Parse ShareGPT format (conversations list)."""
        turns = []
        
        conversations = item.get("conversations", [])
        
        role_mapping = {
            "human": "user",
            "gpt": "assistant",
            "system": "system",
            "user": "user",
            "assistant": "assistant",
        }
        
        for conv in conversations:
            role = conv.get("from", conv.get("role", "user")).lower()
            role = role_mapping.get(role, role)
            content = conv.get("value", conv.get("content", ""))
            turns.append(ConversationTurn(role=role, content=content))
        
        return Conversation(turns=turns)
    
    def parse_dolly(self, item: Dict) -> Conversation:
        """Parse Dolly format (context, instruction, response)."""
        turns = []
        
        context = item.get("context", "")
        instruction = item.get("instruction", "")
        
        # Build user message with context
        if context:
            user_content = f"Context: {context}\n\n{instruction}"
        else:
            user_content = instruction
        
        turns.append(ConversationTurn(role="user", content=user_content))
        turns.append(ConversationTurn(
            role="assistant",
            content=item.get("response", "")
        ))
        
        return Conversation(
            turns=turns,
            metadata={"category": item.get("category", "unknown")}
        )
    
    def parse_oasst(self, item: Dict) -> Conversation:
        """Parse Open Assistant format (message trees)."""
        turns = []
        
        # OASST uses a tree structure; we take the main branch
        def extract_branch(node: Dict, turns_list: List):
            role = "user" if node.get("role") == "prompter" else "assistant"
            turns_list.append(ConversationTurn(
                role=role,
                content=node.get("text", "")
            ))
            
            # Get first reply (main branch)
            replies = node.get("replies", [])
            if replies:
                extract_branch(replies[0], turns_list)
        
        if "prompt" in item:
            # Flattened format
            turns.append(ConversationTurn(role="user", content=item["prompt"]))
            turns.append(ConversationTurn(role="assistant", content=item.get("response", "")))
        else:
            # Tree format
            extract_branch(item, turns)
        
        return Conversation(turns=turns)
    
    def detect_format(self, item: Dict) -> DatasetFormat:
        """Auto-detect the format of a dataset item."""
        if "conversations" in item:
            return DatasetFormat.SHAREGPT
        elif "instruction" in item and "output" in item:
            return DatasetFormat.ALPACA
        elif "instruction" in item and "response" in item:
            return DatasetFormat.DOLLY
        elif "prompt" in item or "text" in item:
            return DatasetFormat.OASST
        else:
            return DatasetFormat.CUSTOM
    
    def convert(
        self,
        item: Dict,
        source_format: Optional[DatasetFormat] = None,
    ) -> Conversation:
        """
        Convert a single item to unified Conversation format.
        
        Args:
            item: Dataset item dictionary
            source_format: Format of the item (auto-detected if None)
        
        Returns:
            Conversation object
        """
        if source_format is None:
            source_format = self.detect_format(item)
        
        if source_format == DatasetFormat.ALPACA:
            return self.parse_alpaca(item)
        elif source_format == DatasetFormat.SHAREGPT:
            return self.parse_sharegpt(item)
        elif source_format == DatasetFormat.DOLLY:
            return self.parse_dolly(item)
        elif source_format == DatasetFormat.OASST:
            return self.parse_oasst(item)
        elif source_format == DatasetFormat.CUSTOM:
            # Try custom parsers
            for name, parser in self.custom_parsers.items():
                try:
                    return parser(item)
                except:
                    continue
            raise ValueError(f"Unknown format and no custom parser matched")
        else:
            raise ValueError(f"Unknown format: {source_format}")
    
    def convert_dataset(
        self,
        items: List[Dict],
        source_format: Optional[DatasetFormat] = None,
    ) -> List[Conversation]:
        """Convert a full dataset to Conversation format."""
        return [self.convert(item, source_format) for item in items]


# Demo
converter = UniversalDatasetConverter()

# Test with different formats
alpaca_item = {
    "instruction": "Explain what machine learning is.",
    "input": "",
    "output": "Machine learning is a branch of AI that enables computers to learn from data."
}

sharegpt_item = {
    "conversations": [
        {"from": "human", "value": "What is Python?"},
        {"from": "gpt", "value": "Python is a popular programming language."}
    ]
}

dolly_item = {
    "instruction": "Summarize the context.",
    "context": "The quick brown fox jumps over the lazy dog.",
    "response": "A fox jumps over a dog.",
    "category": "summarization"
}

print("Format Detection and Conversion:")
print("=" * 50)

for name, item in [("Alpaca", alpaca_item), ("ShareGPT", sharegpt_item), ("Dolly", dolly_item)]:
    detected = converter.detect_format(item)
    conv = converter.convert(item)
    print(f"\n{name} -> {detected.value}:")
    for turn in conv.turns:
        print(f"  [{turn.role}]: {turn.content[:50]}...")

---

## Exercise 2 Solution: Chat Template Formatter

**Task:** Create formatters for all major chat template formats.

In [None]:
from typing import List, Dict, Optional
from abc import ABC, abstractmethod

class ChatTemplateFormatter(ABC):
    """Base class for chat template formatters."""
    
    @abstractmethod
    def format(self, conversation: Conversation) -> str:
        """Format a conversation into a string."""
        pass
    
    @property
    @abstractmethod
    def name(self) -> str:
        """Name of the template."""
        pass


class ChatMLFormatter(ChatTemplateFormatter):
    """ChatML format (OpenAI style)."""
    
    @property
    def name(self) -> str:
        return "ChatML"
    
    def format(self, conversation: Conversation) -> str:
        parts = []
        for turn in conversation.turns:
            parts.append(f"<|im_start|>{turn.role}\n{turn.content}<|im_end|>")
        parts.append("<|im_start|>assistant\n")  # For generation
        return "\n".join(parts)


class Llama3Formatter(ChatTemplateFormatter):
    """Llama 3.1 chat format."""
    
    @property
    def name(self) -> str:
        return "Llama 3.1"
    
    def format(self, conversation: Conversation) -> str:
        parts = ["<|begin_of_text|>"]
        
        for turn in conversation.turns:
            role = turn.role
            parts.append(f"<|start_header_id|>{role}<|end_header_id|>\n\n{turn.content}<|eot_id|>")
        
        # Add assistant header for generation
        parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
        return "".join(parts)


class Llama2Formatter(ChatTemplateFormatter):
    """Llama 2 chat format."""
    
    def __init__(self, default_system: str = "You are a helpful assistant."):
        self.default_system = default_system
    
    @property
    def name(self) -> str:
        return "Llama 2"
    
    def format(self, conversation: Conversation) -> str:
        parts = ["<s>"]
        
        system_msg = self.default_system
        start_idx = 0
        
        # Check for system message
        if conversation.turns and conversation.turns[0].role == "system":
            system_msg = conversation.turns[0].content
            start_idx = 1
        
        # First turn includes system
        first_user = True
        for i, turn in enumerate(conversation.turns[start_idx:]):
            if turn.role == "user":
                if first_user:
                    parts.append(f"[INST] <<SYS>>\n{system_msg}\n<</SYS>>\n\n{turn.content} [/INST]")
                    first_user = False
                else:
                    parts.append(f"<s>[INST] {turn.content} [/INST]")
            elif turn.role == "assistant":
                parts.append(f" {turn.content} </s>")
        
        return "".join(parts)


class MistralFormatter(ChatTemplateFormatter):
    """Mistral/Mixtral chat format."""
    
    @property
    def name(self) -> str:
        return "Mistral"
    
    def format(self, conversation: Conversation) -> str:
        parts = ["<s>"]
        
        for turn in conversation.turns:
            if turn.role == "user":
                parts.append(f"[INST] {turn.content} [/INST]")
            elif turn.role == "assistant":
                parts.append(f"{turn.content}</s> ")
        
        return "".join(parts)


class VicunaFormatter(ChatTemplateFormatter):
    """Vicuna chat format."""
    
    @property
    def name(self) -> str:
        return "Vicuna"
    
    def format(self, conversation: Conversation) -> str:
        parts = []
        
        for turn in conversation.turns:
            if turn.role == "system":
                parts.append(f"{turn.content}\n\n")
            elif turn.role == "user":
                parts.append(f"USER: {turn.content}\n")
            elif turn.role == "assistant":
                parts.append(f"ASSISTANT: {turn.content}</s>\n")
        
        parts.append("ASSISTANT:")
        return "".join(parts)


class Phi3Formatter(ChatTemplateFormatter):
    """Phi-3 chat format."""
    
    @property
    def name(self) -> str:
        return "Phi-3"
    
    def format(self, conversation: Conversation) -> str:
        parts = []
        
        for turn in conversation.turns:
            if turn.role == "system":
                parts.append(f"<|system|>\n{turn.content}<|end|>\n")
            elif turn.role == "user":
                parts.append(f"<|user|>\n{turn.content}<|end|>\n")
            elif turn.role == "assistant":
                parts.append(f"<|assistant|>\n{turn.content}<|end|>\n")
        
        parts.append("<|assistant|>\n")
        return "".join(parts)


# Demo all formatters
formatters = [
    ChatMLFormatter(),
    Llama3Formatter(),
    Llama2Formatter(),
    MistralFormatter(),
    VicunaFormatter(),
    Phi3Formatter(),
]

# Sample conversation
sample_conv = Conversation(turns=[
    ConversationTurn(role="system", content="You are a helpful AI assistant."),
    ConversationTurn(role="user", content="What is Python?"),
    ConversationTurn(role="assistant", content="Python is a programming language."),
])

print("Chat Template Comparison:")
print("=" * 60)

for formatter in formatters:
    print(f"\n--- {formatter.name} ---")
    formatted = formatter.format(sample_conv)
    # Show first 200 chars
    print(formatted[:300] + "..." if len(formatted) > 300 else formatted)

---

## Exercise 3 Solution: Data Quality Pipeline

**Task:** Build a complete data quality pipeline with filtering, deduplication, and validation.

In [None]:
import hashlib
import re
from typing import List, Dict, Set, Tuple, Callable
from dataclasses import dataclass, field
from collections import Counter

@dataclass
class QualityReport:
    """Report of data quality processing."""
    original_count: int
    final_count: int
    removed_by_filter: Dict[str, int] = field(default_factory=dict)
    duplicates_removed: int = 0
    warnings: List[str] = field(default_factory=list)
    
    def __str__(self):
        return (
            f"Quality Report:\n"
            f"  Original: {self.original_count}\n"
            f"  Final: {self.final_count}\n"
            f"  Removed: {self.original_count - self.final_count}\n"
            f"  - By filter: {self.removed_by_filter}\n"
            f"  - Duplicates: {self.duplicates_removed}\n"
            f"  Warnings: {len(self.warnings)}"
        )


class DataQualityPipeline:
    """
    Complete data quality pipeline for fine-tuning datasets.
    
    Features:
    - Content filtering (length, language, toxicity)
    - Deduplication (exact and near-duplicate)
    - Format validation
    - Statistics and reporting
    """
    
    def __init__(self):
        self.filters: List[Tuple[str, Callable]] = []
        self.seen_hashes: Set[str] = set()
    
    def add_filter(
        self, 
        name: str, 
        filter_fn: Callable[[Conversation], bool]
    ):
        """
        Add a filter to the pipeline.
        
        Args:
            name: Name of the filter
            filter_fn: Function that returns True to KEEP the item
        """
        self.filters.append((name, filter_fn))
    
    def compute_hash(self, conversation: Conversation) -> str:
        """Compute hash for deduplication."""
        content = "".join(t.content for t in conversation.turns)
        return hashlib.md5(content.encode()).hexdigest()
    
    def compute_minhash(self, text: str, num_hashes: int = 100) -> List[int]:
        """Compute MinHash signature for near-duplicate detection."""
        # Simple word-level shingling
        words = text.lower().split()
        shingles = set()
        for i in range(len(words) - 2):
            shingles.add(" ".join(words[i:i+3]))
        
        if not shingles:
            return [0] * num_hashes
        
        # Generate hash values
        signature = []
        for seed in range(num_hashes):
            min_hash = min(
                hash((shingle, seed)) & 0xFFFFFFFF 
                for shingle in shingles
            )
            signature.append(min_hash)
        
        return signature
    
    def estimate_jaccard(self, sig1: List[int], sig2: List[int]) -> float:
        """Estimate Jaccard similarity from MinHash signatures."""
        if len(sig1) != len(sig2):
            return 0.0
        matches = sum(1 for a, b in zip(sig1, sig2) if a == b)
        return matches / len(sig1)
    
    def process(
        self,
        conversations: List[Conversation],
        deduplicate: bool = True,
        near_duplicate_threshold: float = 0.8,
    ) -> Tuple[List[Conversation], QualityReport]:
        """
        Process dataset through quality pipeline.
        
        Args:
            conversations: Input conversations
            deduplicate: Whether to remove duplicates
            near_duplicate_threshold: Jaccard threshold for near-duplicates
        
        Returns:
            Tuple of (filtered conversations, quality report)
        """
        report = QualityReport(
            original_count=len(conversations),
            final_count=0,
        )
        
        results = []
        self.seen_hashes = set()
        signatures = []  # For near-duplicate detection
        
        for conv in conversations:
            # Apply filters
            passed = True
            for filter_name, filter_fn in self.filters:
                if not filter_fn(conv):
                    report.removed_by_filter[filter_name] = (
                        report.removed_by_filter.get(filter_name, 0) + 1
                    )
                    passed = False
                    break
            
            if not passed:
                continue
            
            # Deduplication
            if deduplicate:
                # Exact dedup
                conv_hash = self.compute_hash(conv)
                if conv_hash in self.seen_hashes:
                    report.duplicates_removed += 1
                    continue
                self.seen_hashes.add(conv_hash)
                
                # Near-duplicate detection
                content = " ".join(t.content for t in conv.turns)
                sig = self.compute_minhash(content)
                
                is_near_dup = False
                for existing_sig in signatures:
                    if self.estimate_jaccard(sig, existing_sig) > near_duplicate_threshold:
                        is_near_dup = True
                        break
                
                if is_near_dup:
                    report.duplicates_removed += 1
                    continue
                
                signatures.append(sig)
            
            results.append(conv)
        
        report.final_count = len(results)
        return results, report


# Build a complete quality pipeline
pipeline = DataQualityPipeline()

# Filter: Minimum length
pipeline.add_filter(
    "min_length",
    lambda conv: sum(len(t.content) for t in conv.turns) >= 50
)

# Filter: Maximum length
pipeline.add_filter(
    "max_length",
    lambda conv: sum(len(t.content) for t in conv.turns) <= 10000
)

# Filter: Has both user and assistant turns
pipeline.add_filter(
    "has_response",
    lambda conv: any(t.role == "assistant" for t in conv.turns)
)

# Filter: No excessive repetition
def check_repetition(conv: Conversation) -> bool:
    for turn in conv.turns:
        words = turn.content.lower().split()
        if len(words) > 10:
            word_counts = Counter(words)
            most_common_ratio = word_counts.most_common(1)[0][1] / len(words)
            if most_common_ratio > 0.5:
                return False
    return True

pipeline.add_filter("no_repetition", check_repetition)

# Filter: No placeholder text
pipeline.add_filter(
    "no_placeholders",
    lambda conv: not any(
        re.search(r'\[.*?\]|\{.*?\}|<.*?>', t.content) 
        for t in conv.turns
    )
)

print("Quality Pipeline configured with filters:")
for name, _ in pipeline.filters:
    print(f"  - {name}")

In [None]:
# Test the pipeline

# Create test data with various quality issues
test_conversations = [
    # Good conversation
    Conversation(turns=[
        ConversationTurn(role="user", content="What is machine learning?"),
        ConversationTurn(role="assistant", content="Machine learning is a branch of AI that enables computers to learn from data without being explicitly programmed.")
    ]),
    # Too short
    Conversation(turns=[
        ConversationTurn(role="user", content="Hi"),
        ConversationTurn(role="assistant", content="Hello!")
    ]),
    # Missing response
    Conversation(turns=[
        ConversationTurn(role="user", content="This is a question without an answer.")
    ]),
    # Duplicate of first
    Conversation(turns=[
        ConversationTurn(role="user", content="What is machine learning?"),
        ConversationTurn(role="assistant", content="Machine learning is a branch of AI that enables computers to learn from data without being explicitly programmed.")
    ]),
    # Has placeholders
    Conversation(turns=[
        ConversationTurn(role="user", content="Tell me about [TOPIC]."),
        ConversationTurn(role="assistant", content="[RESPONSE HERE]")
    ]),
    # Excessive repetition
    Conversation(turns=[
        ConversationTurn(role="user", content="good good good good good good good good good good"),
        ConversationTurn(role="assistant", content="That's great!")
    ]),
    # Good conversation 2
    Conversation(turns=[
        ConversationTurn(role="user", content="Explain deep learning."),
        ConversationTurn(role="assistant", content="Deep learning is a subset of machine learning that uses neural networks with many layers to learn complex patterns from data.")
    ]),
]

# Process
filtered, report = pipeline.process(test_conversations, deduplicate=True)

print("\nQuality Pipeline Results:")
print("=" * 50)
print(report)

print(f"\n\nKept {len(filtered)} conversations:")
for i, conv in enumerate(filtered):
    user_msg = conv.turns[0].content[:50]
    print(f"  {i+1}. {user_msg}...")

---

## Summary

These solutions demonstrate:

1. **Universal Converter**: Handle any dataset format with auto-detection

2. **Chat Templates**: Format conversations for any model (Llama, Mistral, etc.)

3. **Quality Pipeline**: Clean data with filtering and deduplication

### Key Takeaways

- **Always validate** your data before fine-tuning
- **Deduplication** prevents overfitting on repeated examples
- **Chat templates** must match the model's training format
- **Quality > Quantity** for fine-tuning datasets