# 02 - Prepare Messages for Training

This notebook handles loading CSV messages and preparing them for model fine-tuning.

## Overview
- **Load standardized CSV messages** from notebook 01
- **Normalize speakers** using user configuration (user → A:, others → B:)
- **Merge consecutive messages** from same speaker (< 3 min apart)
- **Segment conversations** (≥ 30 min gaps)
- **Generate rolling windows** for training with proper B: → A: patterns
- **Format as JSONL** for fine-tuning following PREPARE.md specifications
- **Split into train/validation sets** chronologically

## Input Data
- **CSV files** with columns: `timestamp`, `sender`, `message`
- From `data/cleaned/` folder (generated by notebook 01)
- **User configuration** from `config/user_config.json`

## Output Data
- **Training dataset**: `data/processed/train.jsonl`
- **Validation dataset**: `data/processed/val.jsonl`
- **Dataset statistics** and metrics
- **PREPARE.md compliant** training samples

## Key Features
- **User-configurable** speaker identification
- **PREPARE.md compliant** windowing (context ends with B:, target is A:)
- **Chronological data split** (90% train, 10% validation)


In [None]:
import os
import json
import re
import random
from datetime import datetime, timedelta
from typing import List, Dict, Any, Tuple
from collections import Counter
import pandas as pd

def load_user_config(config_path: str = 'config/user_config.json') -> Dict[str, Any]:
    """Load user configuration for identifying the user's messages."""
    if not os.path.exists(config_path):
        # Create default config if it doesn't exist
        default_config = {
            "user_identifiers": ["Antonin", "antonin", "Anto", "anto"],
            "user_label": "A:",
            "other_label": "B:",
            "description": "Configuration for identifying the user's messages in chat data."
        }
        os.makedirs(os.path.dirname(config_path), exist_ok=True)
        with open(config_path, 'w', encoding='utf-8') as f:
            json.dump(default_config, f, indent=2)
        print(f"Created default user config at {config_path}")
        return default_config
    
    with open(config_path, 'r', encoding='utf-8') as f:
        config = json.load(f)
    
    print(f"Loaded user config: {config['user_identifiers']} -> {config['user_label']}")
    return config

def load_csv_messages(data_dir: str = 'data/cleaned') -> Tuple[Dict[str, List[Dict[str, Any]]], Dict[str, str]]:
    """Load CSV messages from the cleaned data directory, keeping conversations separate."""
    conversations = {}
    conversation_languages = {}
    
    # Load language mapping
    language_file = os.path.join(data_dir, 'conversation_languages.csv')
    if os.path.exists(language_file):
        lang_df = pd.read_csv(language_file)
        conversation_languages = dict(zip(lang_df['conversation'], lang_df['language']))
    
    # Load individual conversation CSV files
    for filename in os.listdir(data_dir):
        if filename.endswith('_messages.csv'):
            file_path = os.path.join(data_dir, filename)
            df = pd.read_csv(file_path)
            
            # Convert to list of dictionaries
            messages = df.to_dict('records')
            
            # Add conversation name and language
            conv_name = filename.replace('_messages.csv', '')
            for msg in messages:
                msg['conversation'] = conv_name
                msg['language'] = conversation_languages.get(conv_name, 'unknown')
            
            # Sort messages within this conversation by timestamp
            messages.sort(key=lambda x: x['timestamp'])
            conversations[conv_name] = messages
    
    return conversations, conversation_languages

def normalize_speakers(messages: List[Dict[str, Any]], user_config: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Normalize speaker names using user configuration."""
    normalized_messages = []
    user_identifiers = user_config['user_identifiers']
    user_label = user_config['user_label']
    other_label = user_config['other_label']
    
    for msg in messages:
        normalized_msg = msg.copy()
        sender = msg['sender'].strip()
        
        # Check if sender matches any user identifier
        is_user = any(identifier.lower() in sender.lower() for identifier in user_identifiers)
        
        if is_user:
            normalized_msg['sender'] = user_label
        else:
            normalized_msg['sender'] = other_label
        
        normalized_messages.append(normalized_msg)
    
    return normalized_messages

def merge_consecutive_messages(messages: List[Dict[str, Any]], max_gap_minutes: int = 3) -> List[Dict[str, Any]]:
    """Merge consecutive messages from the same speaker if they are less than max_gap_minutes apart."""
    if not messages:
        return messages
    
    merged_messages = []
    current_message = None
    
    for msg in messages:
        if current_message is None:
            current_message = msg.copy()
            # Handle NaN values in message field
            if pd.isna(current_message['message']):
                current_message['message'] = ''
            continue
        
        # Check if same speaker and within time gap
        same_speaker = current_message['sender'] == msg['sender']
        within_gap = False
        
        try:
            current_time = datetime.fromisoformat(current_message['timestamp'].replace('Z', '+00:00'))
            msg_time = datetime.fromisoformat(msg['timestamp'].replace('Z', '+00:00'))
            time_diff = msg_time - current_time
            within_gap = time_diff.total_seconds() < (max_gap_minutes * 60)
        except:
            within_gap = False
        
        if same_speaker and within_gap:
            # Merge messages - handle NaN values
            current_msg_text = current_message['message'] if not pd.isna(current_message['message']) else ''
            msg_text = msg['message'] if not pd.isna(msg['message']) else ''
            
            if current_msg_text and msg_text:
                current_message['message'] = current_msg_text + ' ' + msg_text
            elif msg_text:  # Only add if msg has content
                current_message['message'] = msg_text
            # Keep the latest timestamp
            current_message['timestamp'] = msg['timestamp']
        else:
            # Save current message and start new one
            merged_messages.append(current_message)
            current_message = msg.copy()
            # Handle NaN values in message field
            if pd.isna(current_message['message']):
                current_message['message'] = ''
    
    # Don't forget the last message
    if current_message:
        merged_messages.append(current_message)
    
    return merged_messages

def segment_conversations(messages: List[Dict[str, Any]], gap_minutes: int = 30) -> List[List[Dict[str, Any]]]:
    """Segment conversations based on time gaps. Start new segment if gap >= gap_minutes."""
    if not messages:
        return []
    
    segments = []
    current_segment = [messages[0]]
    
    for i in range(1, len(messages)):
        try:
            prev_time = datetime.fromisoformat(messages[i-1]['timestamp'].replace('Z', '+00:00'))
            curr_time = datetime.fromisoformat(messages[i]['timestamp'].replace('Z', '+00:00'))
            time_diff = curr_time - prev_time
            
            if time_diff.total_seconds() >= (gap_minutes * 60):
                # Start new segment
                segments.append(current_segment)
                current_segment = [messages[i]]
            else:
                current_segment.append(messages[i])
        except:
            # If timestamp parsing fails, continue in same segment
            current_segment.append(messages[i])
    
    # Add the last segment
    if current_segment:
        segments.append(current_segment)
    
    return segments

def process_single_conversation(messages: List[Dict[str, Any]], user_config: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Process a single conversation: normalize speakers, merge consecutive, segment."""
    # Normalize speakers
    normalized = normalize_speakers(messages, user_config)
    
    # Merge consecutive messages
    merged = merge_consecutive_messages(normalized, max_gap_minutes=3)
    
    # Segment by time gaps
    segments = segment_conversations(merged, gap_minutes=30)
    
    return segments

def generate_rolling_windows(segments: List[List[Dict[str, Any]]], 
                           window_sizes: List[int] = [6, 8, 10, 12], 
                           stride: int = 3) -> List[Dict[str, Any]]:
    """Generate rolling windows from conversation segments."""
    windows = []
    
    for segment in segments:
        if len(segment) < 2:  # Need at least 2 messages for a window
            continue
        
        # Try different window sizes
        for window_size in window_sizes:
            for i in range(0, len(segment) - window_size, stride):
                window = segment[i:i + window_size]
                
                # Check if window ends with B: and next message is A:
                if (i + window_size < len(segment) and 
                    window[-1]['sender'] == 'B:' and 
                    segment[i + window_size]['sender'] == 'A:'):
                    
                    # Create window data
                    window_data = {
                        'window': window,
                        'next_message': segment[i + window_size],
                        'window_size': window_size,
                        'segment_id': id(segment),
                        'conversation': window[0]['conversation']  # Add conversation ID
                    }
                    windows.append(window_data)
    
    return windows

def format_window_as_jsonl(window_data: Dict[str, Any]) -> str:
    """Format a window as JSONL training sample."""
    window = window_data['window']
    next_message = window_data['next_message']
    
    # Build chat text
    chat_lines = []
    for msg in window:
        chat_lines.append(f"{msg['sender']} {msg['message']}")
    
    chat_text = '\n'.join(chat_lines)
    
    # Create training sample for base model completion format
    sample = {
        "text": f"{chat_text}\nA: {next_message['message']}</s>"
    }
    
    return json.dumps(sample, ensure_ascii=False)

def create_instruct_seeds(messages: List[Dict[str, Any]], user_config: Dict[str, Any], sample_rate: float = 0.1) -> List[str]:
    """Create instruct seed samples from user's replies."""
    user_label = user_config['user_label']
    other_label = user_config['other_label']
    user_messages = [msg for msg in messages if msg['sender'] == user_label]
    
    # Sample 10% of user's messages
    num_samples = max(1, int(len(user_messages) * sample_rate))
    sampled_messages = random.sample(user_messages, min(num_samples, len(user_messages)))
    
    seeds = []
    for msg in sampled_messages:
        # Find the previous other speaker message for context
        msg_index = messages.index(msg)
        prev_other_message = None
        
        for i in range(msg_index - 1, -1, -1):
            if messages[i]['sender'] == other_label:
                prev_other_message = messages[i]
                break
        
        if prev_other_message:
            seed_text = f"<sys>Write a realistic text chat. Keep it short.</sys>\n<seed>{other_label} {prev_other_message['message']}</seed>\n{user_label}"
            sample = {"text": seed_text}
            seeds.append(json.dumps(sample, ensure_ascii=False))
    
    return seeds

def split_dataset(messages: List[Dict[str, Any]], train_ratio: float = 0.9) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """Split dataset chronologically: 90% training, 10% validation."""
    split_index = int(len(messages) * train_ratio)
    train_messages = messages[:split_index]
    val_messages = messages[split_index:]
    
    return train_messages, val_messages

def save_jsonl_dataset(samples: List[str], filename: str) -> None:
    """Save samples to JSONL file."""
    os.makedirs('data/processed', exist_ok=True)
    filepath = os.path.join('data/processed', filename)
    
    with open(filepath, 'w', encoding='utf-8') as f:
        for sample in samples:
            f.write(sample + '\n')
    
    print(f"Saved {len(samples)} samples to {filepath}")

def calculate_dataset_stats(windows: List[Dict[str, Any]], train_samples: List[str], val_samples: List[str]) -> None:
    """Calculate and print dataset statistics."""
    print(f"\n=== DATASET STATISTICS ===")
    print(f"Number of segments: {len(set(w['segment_id'] for w in windows))}")
    print(f"Number of training samples: {len(train_samples)}")
    print(f"Number of validation samples: {len(val_samples)}")
    
    # Calculate median context length
    context_lengths = []
    for window_data in windows:
        window = window_data['window']
        total_length = sum(len(msg['message']) for msg in window)
        context_lengths.append(total_length)
    
    if context_lengths:
        context_lengths.sort()
        median_length = context_lengths[len(context_lengths) // 2]
        print(f"Median context length: {median_length} characters")
    
    print(f"Total samples: {len(train_samples) + len(val_samples)}")
    print(f"Train/Val ratio: {len(train_samples)}/{len(val_samples)} ({len(train_samples)/(len(train_samples)+len(val_samples))*100:.1f}%/{len(val_samples)/(len(train_samples)+len(val_samples))*100:.1f}%)")


## Step 1: Load CSV Messages

Load the standardized CSV messages from notebook 01 and prepare them for training.


In [None]:
# Load CSV messages from notebook 01 - keeping conversations separate
print("Loading CSV messages from notebook 01...")
conversations, conversation_languages = load_csv_messages()

print(f"\n=== LOADED DATA SUMMARY ===")
total_messages = sum(len(messages) for messages in conversations.values())
print(f"Total conversations: {len(conversations)}")
print(f"Total messages: {total_messages:,}")
print(f"Languages detected: {dict(Counter(conversation_languages.values()))}")

# Show conversation breakdown
print(f"\n=== CONVERSATION BREAKDOWN ===")
for conv_name, messages in conversations.items():
    language = conversation_languages.get(conv_name, 'unknown')
    print(f"  {conv_name}: {len(messages):,} messages ({language})")

# Show sample of loaded data
print(f"\n=== SAMPLE MESSAGES ===")
first_conv = next(iter(conversations.values()))
for i, msg in enumerate(first_conv[:5]):
    print(f"{i+1}. [{msg['sender']}]: {msg['message'][:100]}...")


## Step 2: Normalize Speakers and Process Messages

Normalize speaker names, merge consecutive messages, segment conversations.


In [None]:
# Load user configuration
print("Loading user configuration...")
user_config = load_user_config()

# Process each conversation separately
print("\nProcessing conversations separately...")
all_segments = []
total_original = 0
total_processed = 0

for conv_name, messages in conversations.items():
    print(f"\nProcessing {conv_name}...")
    
    # Process this conversation
    segments = process_single_conversation(messages, user_config)
    
    # Flatten segments to get message count
    processed_messages = [msg for segment in segments for msg in segment]
    
    print(f"  {conv_name}: {len(messages)} → {len(processed_messages)} messages, {len(segments)} segments")
    
    total_original += len(messages)
    total_processed += len(processed_messages)
    all_segments.extend(segments)

print(f"\n=== PROCESSING SUMMARY ===")
print(f"Total messages: {total_original} → {total_processed}")
print(f"Total segments: {len(all_segments)}")

# Show segment statistics
segment_lengths = [len(seg) for seg in all_segments]
if segment_lengths:
    print(f"Segment lengths: min={min(segment_lengths)}, max={max(segment_lengths)}, avg={sum(segment_lengths)/len(segment_lengths):.1f}")

# Show sample of processed messages from first segment
print(f"\n=== SAMPLE PROCESSED MESSAGES ===")
if all_segments:
    first_segment = all_segments[0]
    print(f"Sample from conversation: {first_segment[0]['conversation']}")
    for i, msg in enumerate(first_segment[:5]):
        print(f"{i+1}. [{msg['sender']}]: {msg['message'][:100]}...")


## Step 3: Generate Training Windows

Generate rolling windows and create training samples in JSONL format.


In [None]:
# Generate rolling windows from segments
print("Generating rolling windows...")
windows = generate_rolling_windows(all_segments, window_sizes=[6, 8, 10, 12], stride=3)
print(f"✓ Generated {len(windows)} training windows")

# Show window statistics
window_sizes = [w['window_size'] for w in windows]
window_size_counts = Counter(window_sizes)
print(f"Window size distribution: {dict(window_size_counts)}")

# Show conversation distribution in windows
conv_counts = Counter(w['conversation'] for w in windows)
print(f"Windows per conversation: {dict(conv_counts)}")

# Format windows as JSONL training samples
print("\nFormatting training samples...")
training_samples = []
for window_data in windows:
    sample = format_window_as_jsonl(window_data)
    training_samples.append(sample)

print(f"✓ Created {len(training_samples)} training samples")

# Create instruct seed samples (10% of user's replies)
print("\nCreating instruct seed samples...")
# Flatten all processed messages for instruct seeds
all_processed_messages = [msg for segment in all_segments for msg in segment]
instruct_seeds = create_instruct_seeds(all_processed_messages, user_config, sample_rate=0.1)
print(f"✓ Created {len(instruct_seeds)} instruct seed samples")

# Combine regular training samples with instruct seeds
all_training_samples = training_samples + instruct_seeds
print(f"✓ Total training samples: {len(all_training_samples)}")

# Show sample training data
print(f"\n=== SAMPLE TRAINING DATA ===")
print("Regular training sample:")
print(training_samples[0] if training_samples else "No samples")
print("\nInstruct seed sample:")
print(instruct_seeds[0] if instruct_seeds else "No seeds")


## Step 4: Split Dataset and Save

Split dataset chronologically and save training/validation files.


In [None]:
# Split windows by conversation (90% train, 10% val)
print("\nSplitting windows by conversation...")

# Split windows maintaining conversation boundaries
train_windows = []
val_windows = []

# Group windows by conversation
conv_windows = {}
for window in windows:
    conv_name = window['conversation']
    if conv_name not in conv_windows:
        conv_windows[conv_name] = []
    conv_windows[conv_name].append(window)

# Split each conversation's windows
for conv_name, conv_window_list in conv_windows.items():
    split_index = int(len(conv_window_list) * 0.9)
    train_windows.extend(conv_window_list[:split_index])
    val_windows.extend(conv_window_list[split_index:])

# Generate training samples
print("\nFormatting samples...")
train_samples = [format_window_as_jsonl(w) for w in train_windows]
val_samples = [format_window_as_jsonl(w) for w in val_windows]

# Create instruct seeds from processed messages
all_processed_messages = [msg for segment in all_segments for msg in segment]
all_instruct_seeds = create_instruct_seeds(all_processed_messages, user_config, sample_rate=0.1)

# Split instruct seeds similarly
train_instruct_count = int(len(all_instruct_seeds) * 0.9)
train_instruct_seeds = all_instruct_seeds[:train_instruct_count]
val_instruct_seeds = all_instruct_seeds[train_instruct_count:]

# Combine samples
train_all_samples = train_samples + train_instruct_seeds
val_all_samples = val_samples + val_instruct_seeds

print(f"✓ Training samples: {len(train_all_samples):,} ({len(train_samples):,} windows + {len(train_instruct_seeds):,} seeds)")
print(f"✓ Validation samples: {len(val_all_samples):,} ({len(val_samples):,} windows + {len(val_instruct_seeds):,} seeds)")

# Show conversation distribution in splits
train_conv_counts = Counter(w['conversation'] for w in train_windows)
val_conv_counts = Counter(w['conversation'] for w in val_windows)
print(f"Train conversations: {dict(train_conv_counts)}")
print(f"Val conversations: {dict(val_conv_counts)}")

# Save datasets
print("\nSaving datasets...")
save_jsonl_dataset(train_all_samples, 'train.jsonl')
save_jsonl_dataset(val_all_samples, 'val.jsonl')

# Calculate and display statistics
calculate_dataset_stats(windows, train_all_samples, val_all_samples)

print(f"\n=== PROCESSING COMPLETE ===")
total_messages = sum(len(messages) for messages in conversations.values())
print(f"✓ Processed {total_messages:,} total messages from {len(conversations)} conversations")
print(f"✓ Maintained conversation boundaries throughout processing")
print(f"✓ Generated rolling windows per conversation")
print(f"✓ Split dataset maintaining conversation integrity")
print(f"✓ Saved train.jsonl and val.jsonl with proper target format")
print(f"✓ Ready for model fine-tuning")