In [None]:
print('Setup complete.')

# Lab 02: Data Preparation and Augmentation for Fine-Tuning

## Learning Objectives
- Format datasets for different fine-tuning tasks (e.g., instruction-following, chat)
- Implement data cleaning techniques to improve data quality
- Apply data augmentation strategies to increase dataset size and diversity
- Understand how to handle class imbalance in classification tasks

## Setup

In [None]:
import json
import re
import random
from typing import List, Dict, Any
from collections import Counter

## Part 1: Data Formatting

The format of your training data must match the task. Different tasks require different structures.

In [None]:
# Instruction-following format (e.g., Alpaca style)
def format_instruction(instruction: str, input_text: str, output: str) -> Dict[str, str]:
    template = (
        f"Below is an instruction that describes a task, paired with an input that provides further context. "
        f"Write a response that appropriately completes the request.\
\
"
        f"### Instruction:\
{instruction}\
\
"
        f"### Input:\
{input_text}\
\
"
        f"### Response:\
{output}"
    )
    return {"text": template.format(instruction=instruction, input_text=input_text, output=output)}
instruction_example = format_instruction(
    instruction="Translate the following English text to French.",
    input_text="Hello, how are you?",
    output="Bonjour, comment ça va?"
)
print("--- Instruction-Following Format ---")
print(instruction_example['text'])

# Chat format (e.g., OpenAI's format)
def format_chat(messages: List[Dict[str, str]]) -> Dict[str, List[Dict[str, str]]]:
    return {"messages": messages}

chat_example = format_chat([
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Who won the world series in 2020?"},
    {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}
])
print("
--- Chat Format ---")
print(json.dumps(chat_example, indent=2))

## Part 2: Data Cleaning

In [None]:
# Sample of messy data
messy_data = [
    {"prompt": "  What is 2+2?  ", "completion": "4.  "},
    {"prompt": "who is the first president of the USA?", "completion": "george washington"},
    {"prompt": "Remove this part [REMOVE] and this", "completion": "This is good"},
    {"prompt": "", "completion": "Empty prompt is bad"},
]

def clean_data(dataset: List[Dict[str, str]]) -> List[Dict[str, str]]:
    cleaned = []
    for item in dataset:
        prompt = item['prompt'].strip()
        completion = item['completion'].strip()
        
        # Rule 1: Remove empty prompts or completions
        if not prompt or not completion:
            continue
            
        # Rule 2: Normalize case (example: capitalize completions)
        if 'president' in prompt:
            completion = completion.title()
            
        # Rule 3: Remove unwanted artifacts
        prompt = re.sub(r'[REMOVE].*', '', prompt).strip()
        
        cleaned.append({"prompt": prompt, "completion": completion})
    return cleaned

cleaned_dataset = clean_data(messy_data)
print("--- Cleaned Dataset ---")
print(json.dumps(cleaned_dataset, indent=2))

## Part 3: Data Augmentation

In [None]:
# Augmentation technique: Paraphrasing (mocked)
def paraphrase(text: str) -> str:
    # In a real system, you would use a model for this.
    # Here, we'll just add a synonym or rephrase slightly.
    paraphrases = {
        "What is the capital of France?": "Which city is the capital of France?",
        "Tell me about photosynthesis.": "Explain the process of photosynthesis."
    }
    return paraphrases.get(text, text) # Return original if no paraphrase is defined

# Augmentation technique: Back-Translation (mocked)
def back_translate(text: str, lang_a='en', lang_b='fr') -> str:
    # Mock translation: en -> fr -> en
    # This simulates using a translation service to get a paraphrased version.
    mock_fr = {"Hello world": "Bonjour le monde"}
    mock_en = {"Bonjour le monde": "Hi Earth"}
    translated = mock_fr.get(text, text)
    back_translated = mock_en.get(translated, text)
    return back_translated

original_prompts = [
    {"prompt": "What is the capital of France?", "completion": "Paris"},
    {"prompt": "Hello world", "completion": "Hi!"}
]

augmented_data = []
for item in original_prompts:
    augmented_data.append(item) # Add original
    # Add paraphrased version
    augmented_data.append({"prompt": paraphrase(item['prompt']), "completion": item['completion']})
    # Add back-translated version
    augmented_data.append({"prompt": back_translate(item['prompt']), "completion": item['completion']})

print("--- Augmented Dataset ---")
print(json.dumps(augmented_data, indent=2))

## Part 4: Handling Class Imbalance

In [None]:
# Imbalanced dataset for a classification task
imbalanced_classification_data = [
    {"text": "This movie was amazing!", "label": "positive"}, # 10 positive
    {"text": "I loved it.", "label": "positive"},
    {"text": "Best film ever.", "label": "positive"},
    {"text": "So good.", "label": "positive"},
    {"text": "Incredible.", "label": "positive"},
    {"text": "Fantastic.", "label": "positive"},
    {"text": "A masterpiece.", "label": "positive"},
    {"text": "Highly recommended.", "label": "positive"},
    {"text": "Superb.", "label": "positive"},
    {"text": "Brilliant.", "label": "positive"},
    {"text": "This was terrible.", "label": "negative"}, # 2 negative
    {"text": "I hated it.", "label": "negative"}
]

def balance_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Balances a dataset using over-sampling."""
    label_counts = Counter(item['label'] for item in dataset)
    max_count = max(label_counts.values())
    
    balanced_data = []
    for label, count in label_counts.items():
        if count < max_count:
            # Get all items for this label
            items_for_label = [item for item in dataset if item['label'] == label]
            # Oversample by choosing with replacement
            oversampled_items = random.choices(items_for_label, k=max_count - count)
            balanced_data.extend(oversampled_items)
            
    balanced_data.extend(dataset)
    random.shuffle(balanced_data)
    return balanced_data

balanced_data = balance_dataset(imbalanced_classification_data)

print("--- Class Imbalance Handling ---")
print(f'Original counts: {Counter(item['label'] for item in imbalanced_classification_data)}')
print(f'Balanced counts: {Counter(item['label'] for item in balanced_data)}')

## Exercises

1. **Create a New Data Formatter**: Write a function to format data for a summarization task, with `{"text": "<document>", "summary": "<summary>"}` fields.
2. **Add More Cleaning Rules**: Extend the `clean_data` function to handle other issues, such as removing URLs or standardizing punctuation.
3. **Implement Under-sampling**: Write a function to balance the dataset by under-sampling the majority class instead of over-sampling the minority class. What are the pros and cons of this approach?

## Summary

You learned:
- How to format data for common fine-tuning tasks.
- The importance of cleaning data to remove noise and inconsistencies.
- How to augment your dataset to improve model generalization.
- A simple technique to handle class imbalance in classification datasets.