In [None]:
"""
SMS conversation generator using personas and relationships.
Uses CAMEL-style multi-agent simulation with distilabel + vLLM (Qwen3).

This notebook is for running in Google Colab.
The code has been refactored into modular files:
- config.py: Configuration constants
- text_utils.py: Message cleaning and validation  
- generator.py: Core generation logic

For local usage, run: python -m src.synthetic_data.generator
"""

# Install dependencies (for Colab)
# !pip install -q pydantic tqdm distilabel[vllm]

import json
import random
import re
from collections import Counter
from pathlib import Path

from pydantic import BaseModel
from tqdm import tqdm

# =============================================================================
# Configuration
# =============================================================================

DATA_DIR = Path("data")
PERSONAS_FILE = DATA_DIR / "personas.json"
RELATIONSHIPS_FILE = DATA_DIR / "relationships.json"
OUTPUT_FILE = DATA_DIR / "conversations.json"

MODEL_NAME = "Qwen/Qwen3-4B-Instruct-2507"

CONVERSATIONS_PER_RELATIONSHIP = {
    "partner": 8,
    "close_friends": 5,
    "friends": 3,
    "family": 4,
    "colleagues": 3,
    "professionals": 2,
    "businesses": 2,
    "neighbors": 2,
    "casual": 2,
    "other": 1,
}

TURN_RANGES = {
    "partner": (8, 20),
    "close_friends": (6, 15),
    "friends": (4, 12),
    "family": (4, 10),
    "colleagues": (3, 8),
    "professionals": (2, 6),
    "businesses": (2, 4),
    "neighbors": (2, 6),
    "casual": (3, 8),
    "other": (2, 6),
}

SCENARIO_TEMPLATES = {
    "partner": [
        "Asking for help with something",
        "Planning dinner tonight",
        "Checking in during work day",
        "Discussing weekend plans",
        "Small argument about chores",
    ],
    "family": [
        "Asking for help with something",
        "Catching up after not talking for a few days",
        "Planning a family gathering",
        "Sharing news about a relative",
    ],
    "close_friends": [
        "Making plans to hang out",
        "Sharing gossip or news",
        "Venting about work or life",
        "Asking for help with something",
    ],
    "friends": [
        "Casual catch-up",
        "Sharing a meme or link",
        "Making loose plans",
        "Asking for advice"
    ],
    "colleagues": [
        "Quick work question",
        "Coordinating on a project",
        "Office gossip",
        "Asking for help with something",
    ],
    "professionals": [
        "Scheduling an appointment",
        "Following up on a service",
        "Asking for a quote",
        "Confirming arrival time",
    ],
    "businesses": [
        "Confirming a reservation",
        "Checking order status",
        "Asking about hours or availability",
    ],
    "neighbors": [
        "Asking about a package delivery",
        "Noise complaint (polite)",
        "Borrowing something",
    ],
    "casual": [
        "Planning a group activity",
        "Sharing hobby-related info",
    ],
}

RELATIONSHIP_IDENTITIES = {
    "partner": "spouse/romantic partner (you live together, you're a couple)",
    "close_friends": "close friend (you hang out often, very casual)",
    "friends": "friend",
    "family": "family member",
    "colleagues": "coworker",
    "professionals": "professional contact",
    "businesses": "business/service",
    "neighbors": "neighbor",
    "casual": "acquaintance",
    "other": "contact",
}

# =============================================================================
# Data Models
# =============================================================================

class Message(BaseModel):
    sender_uuid: str
    text: str
    timestamp_offset_minutes: int


class Conversation(BaseModel):
    main_uuid: str
    contact_uuid: str
    relationship_type: str
    messages: list[Message]
    scenario: str


# =============================================================================
# Text Utilities
# =============================================================================

def clean_message(text: str) -> str:
    """Remove artifacts and fix common generation issues."""
    text = re.sub(r'(?:Best regards|Regards|Warm regards|Sincerely|With (?:love|gratitude|thanks)|Cheers),?\s*[-‚Äì‚Äî]?\s*\w+.*$', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\s*[-‚Äì‚Äî]\s*\w+\s*$', '', text)
    text = re.sub(r'\bName\b', '', text)
    text = re.sub(r'\[.*?(?:Name|name).*?\]', '', text)
    text = re.sub(r'^(?:Dear|Hi|Hello|Hey)\s+[\w\s]+,\s*', '', text, flags=re.IGNORECASE)
    text = ''.join(c for c in text if ord(c) < 128 or ord(c) > 0x1F600)
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def check_message_quality(text: str, history: list[dict], rel_type: str) -> tuple[bool, str]:
    """Check if a generated message is acceptable."""
    text_lower = text.lower()
    
    if len(text.split()) > 35:
        return False, "too_long"
    if len(text.split()) < 2:
        return False, "too_short"
    
    if rel_type in ("partner", "close_friends", "friends", "family"):
        formal = ["best regards", "sincerely", "warm regards", "respectfully"]
        if any(f in text_lower for f in formal):
            return False, "too_formal"
    
    if history:
        last_text = history[-1]["text"].lower()
        
        def get_phrases(t):
            words = t.split()
            return set(' '.join(words[i:i+3]) for i in range(len(words)-2))
        
        text_phrases = get_phrases(text_lower)
        last_phrases = get_phrases(last_text)
        
        if text_phrases and last_phrases:
            overlap = len(text_phrases & last_phrases)
            if overlap >= 2:
                return False, "too_similar"
        
        time_pattern = r'(?:at|in|around)\s+\d+'
        text_times = re.findall(time_pattern, text_lower)
        last_times = re.findall(time_pattern, last_text)
        if text_times and text_times == last_times:
            return False, "repeated_time"
    
    return True, ""


def extract_key_phrases(text: str) -> set[str]:
    """Extract meaningful phrases from a message."""
    text = text.lower()
    phrases = set(re.findall(r'(?:in|at|around)\s+\d+', text))
    words = text.split()
    for i in range(len(words) - 2):
        phrases.add(' '.join(words[i:i+3]))
    return phrases


def should_end_conversation(history: list[dict]) -> bool:
    """Detect if conversation has naturally ended."""
    if len(history) < 3:
        return False
    
    last_msgs = [m['text'].lower() for m in history[-4:]]
    
    bye_patterns = ['bye', 'see you', 'see ya', 'ttyl', 'later', 'talk soon', 'cya', 'gotta go']
    if sum(1 for msg in last_msgs if any(p in msg for p in bye_patterns)) >= 2:
        return True
    
    confirm_patterns = ['sounds good', 'perfect', 'got it', 'will do', 'see you then', 'üëç', 'üëå']
    if sum(1 for msg in last_msgs if any(p in msg for p in confirm_patterns)) >= 2:
        return True
    
    recent_text = ' '.join(last_msgs)
    if len(re.findall(r'\d+(?::\d+)?\s*(?:am|pm|AM|PM)', recent_text)) >= 3:
        return True
    
    if len(history) >= 3:
        all_phrases = []
        for msg in history[-4:]:
            all_phrases.extend(extract_key_phrases(msg['text']))
        phrase_counts = Counter(all_phrases)
        if any(count >= 3 for phrase, count in phrase_counts.items() if len(phrase) > 5):
            return True
    
    if len(history) >= 4:
        def get_content_words(text):
            stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'to', 'in', 'on', 'at', 'and', 'or', 'i', 'you', 'your', 'my'}
            return set(w for w in text.lower().split() if w not in stopwords and len(w) > 2)
        
        recent_words = [get_content_words(m['text']) for m in history[-4:]]
        for i in range(len(recent_words) - 1):
            if recent_words[i] and recent_words[i+1]:
                overlap = len(recent_words[i] & recent_words[i+1]) / max(len(recent_words[i] | recent_words[i+1]), 1)
                if overlap > 0.6 and i >= 1:
                    prev_overlap = len(recent_words[i-1] & recent_words[i]) / max(len(recent_words[i-1] | recent_words[i]), 1)
                    if prev_overlap > 0.5:
                        return True
    
    return False


def infer_sms_style(persona: dict) -> str:
    """Infer texting style from persona demographics."""
    age = persona.get("age", 40)
    education = persona.get("education_level", "")
    
    traits = []
    if age < 25:
        traits.extend(["lowercase", "heavy emoji use", "abbreviations (u, ur, rn, ngl)", "no punctuation"])
    elif age < 40:
        traits.extend(["casual", "occasional emoji", "some abbreviations"])
    elif age < 60:
        traits.extend(["proper sentences", "minimal emoji", "full words"])
    else:
        traits.extend(["formal", "complete sentences", "no emoji", "may sign off with name"])
    
    if "doctorate" in education or "masters" in education:
        traits.append("articulate vocabulary")
    
    return ", ".join(traits)


# =============================================================================
# Helper Functions
# =============================================================================

def get_turn_count(relationship_type: str) -> int:
    """Get random turn count based on relationship type."""
    min_t, max_t = TURN_RANGES.get(relationship_type, (3, 8))
    return random.randint(min_t, max_t)


def get_contact_identity(rel_type: str, service_type: str | None = None) -> str:
    """Get relationship description for prompt."""
    if service_type:
        return f"{service_type} (professional service)"
    return RELATIONSHIP_IDENTITIES.get(rel_type, "contact")


# =============================================================================
# Prompt Building
# =============================================================================

def build_prompt(persona: dict, is_main: bool, scenario: str, 
                 contact_identity: str, service_type: str | None, history: list[dict],
                 turn_number: int = 0, total_turns: int = 10) -> tuple[list[dict], str]:
    """Build chat messages for LLM with prefix caching optimization."""
    style = infer_sms_style(persona)
    persona_text = persona.get('professional_persona') if service_type and not is_main else persona.get('persona', '')
    
    name_match = re.match(r'^([A-Z][a-z]+)', persona_text)
    first_name = name_match.group(1) if name_match else "Person"
    
    cache_key = f"{persona.get('uuid', '')}_{is_main}"
    
    near_end = turn_number >= total_turns - 2
    ending_instruction = "\n6. WRAP UP: Send a brief closer and END the conversation." if near_end else ""
    
    recent_topics = ""
    if history:
        last_msg = history[-1]['text'].lower()
        time_match = re.search(r'(?:in|at|around)\s+\d+', last_msg)
        if time_match:
            recent_topics = f"\nDO NOT REPEAT: '{time_match.group()}' - this was already said."
    
    system = f"""You are {first_name}.

BACKGROUND: {persona_text[:200]}...

TEXTING STYLE: {style}

---
You're texting your {contact_identity}.
SITUATION: {scenario}

CRITICAL RULES:
1. MAX 20 words. Real texts are SHORT.
2. NO greetings or sign-offs
3. NEVER repeat what was just said - move the conversation forward
4. Use contractions (I'm, don't, gonna)
5. Don't sign your name{ending_instruction}{recent_topics}

Write ONLY the text message. Nothing else."""

    history_text = "\n".join(
        f"{'You' if m['is_main'] == is_main else 'Them'}: {m['text']}" 
        for m in history[-4:]
    )
    
    return [
        {"role": "system", "content": system},
        {"role": "user", "content": f"{history_text or '(Start the conversation)'}\n\nYour text:"}
    ], cache_key


# =============================================================================
# Conversation Generation
# =============================================================================

def generate_conversations(personas: dict, relationships: list, llm, limit: int | None = None) -> list[Conversation]:
    """Generate all conversations with batched LLM calls."""
    convs = []
    for rel in relationships[:limit]:
        main = personas.get(rel["from_uuid"])
        contact = personas.get(rel["to_uuid"])
        if not main or not contact:
            continue
        
        rel_type = rel["relationship_type"]
        service_type = rel.get("service_type")
        contact_identity = get_contact_identity(rel_type, service_type)
        pair_key = f"{main['uuid']}_{contact['uuid']}"
        
        num_convs = CONVERSATIONS_PER_RELATIONSHIP.get(rel_type, 1)
        available_scenarios = SCENARIO_TEMPLATES.get(rel_type, ["General conversation"])
        
        for conv_idx in range(num_convs):
            scenario = available_scenarios[conv_idx % len(available_scenarios)]
            if service_type:
                scenario = f"Contacting {service_type} - {scenario}"
            
            convs.append({
                "main": main,
                "contact": contact,
                "rel_type": rel_type,
                "scenario": scenario,
                "service_type": service_type,
                "contact_identity": contact_identity,
                "turns": get_turn_count(rel_type),
                "current_turn": 0,
                "history": [],
                "time_offset": 0,
                "pair_key": pair_key,
            })
    
    convs.sort(key=lambda c: c["pair_key"])
    
    total_messages = sum(c["turns"] for c in convs)
    max_turns = max(c["turns"] for c in convs) if convs else 0
    
    print(f"Total conversations: {len(convs)}")
    print(f"Total messages to generate: {total_messages}")
    print(f"Max turns per conversation: {max_turns}")
    print(f"Unique persona pairs: {len(set(c['pair_key'] for c in convs))} (sorted for cache efficiency)")
    
    pbar = tqdm(total=total_messages, desc="Generating SMS", unit="msg")
    
    for turn in range(max_turns):
        active_convs = [(i, c) for i, c in enumerate(convs) if c["current_turn"] < c["turns"]]
        if not active_convs:
            break
        
        batch_data = []
        for i, conv in active_convs:
            if should_end_conversation(conv["history"]):
                conv["current_turn"] = conv["turns"]
                continue
            
            is_main_turn = (conv["current_turn"] % 2 == 0)
            persona = conv["main"] if is_main_turn else conv["contact"]
            
            messages, cache_key = build_prompt(
                persona, is_main_turn, conv["scenario"],
                conv["contact_identity"], conv["service_type"], conv["history"],
                turn_number=conv["current_turn"], total_turns=conv["turns"]
            )
            batch_data.append((messages, cache_key, i, is_main_turn))
        
        if not batch_data:
            continue
        
        batch_data.sort(key=lambda x: x[1])
        batch_inputs = [d[0] for d in batch_data]
        batch_meta = [(d[2], d[3]) for d in batch_data]
        
        results = llm.generate(batch_inputs)
        
        retry_needed = []
        for (i, is_main_turn), result in zip(batch_meta, results):
            conv = convs[i]
            text = result['generations'][0].strip().strip('"').strip("'")
            text = clean_message(text)
            
            is_valid, issue = check_message_quality(text, conv["history"], conv["rel_type"])
            if not is_valid:
                retry_needed.append((i, is_main_turn, issue))
                continue
            
            gap = random.randint(1, 15) if conv["rel_type"] in ("partner", "close_friends") else random.randint(2, 60)
            conv["time_offset"] += gap
            conv["history"].append({
                "is_main": is_main_turn,
                "sender_uuid": conv["main"]["uuid"] if is_main_turn else conv["contact"]["uuid"],
                "text": text,
                "timestamp_offset_minutes": conv["time_offset"],
            })
            conv["current_turn"] += 1
            pbar.update(1)
        
        if retry_needed:
            retry_inputs = []
            retry_meta = []
            
            for i, is_main_turn, issue in retry_needed:
                conv = convs[i]
                persona = conv["main"] if is_main_turn else conv["contact"]
                
                messages, _ = build_prompt(
                    persona, is_main_turn, conv["scenario"],
                    conv["contact_identity"], conv["service_type"], conv["history"],
                    turn_number=conv["current_turn"], total_turns=conv["turns"]
                )
                messages[-1]["content"] += f"\n\n(Be different. Previous was {issue}. Keep it SHORT and FRESH.)"
                retry_inputs.append(messages)
                retry_meta.append((i, is_main_turn))
            
            if retry_inputs:
                retry_results = llm.generate(retry_inputs)
                
                for (i, is_main_turn), result in zip(retry_meta, retry_results):
                    conv = convs[i]
                    text = result['generations'][0].strip().strip('"').strip("'")
                    text = clean_message(text)
                    
                    gap = random.randint(1, 15) if conv["rel_type"] in ("partner", "close_friends") else random.randint(2, 60)
                    conv["time_offset"] += gap
                    conv["history"].append({
                        "is_main": is_main_turn,
                        "sender_uuid": conv["main"]["uuid"] if is_main_turn else conv["contact"]["uuid"],
                        "text": text,
                        "timestamp_offset_minutes": conv["time_offset"],
                    })
                    conv["current_turn"] += 1
                    pbar.update(1)
    
    pbar.close()
    
    return [
        Conversation(
            main_uuid=c["main"]["uuid"],
            contact_uuid=c["contact"]["uuid"],
            relationship_type=c["rel_type"],
            scenario=c["scenario"],
            messages=[
                Message(
                    sender_uuid=m["sender_uuid"],
                    text=m["text"],
                    timestamp_offset_minutes=m["timestamp_offset_minutes"]
                )
                for m in c["history"]
            ]
        )
        for c in convs
    ]


# =============================================================================
# Data I/O
# =============================================================================

def load_data() -> tuple[dict, list]:
    """Load personas and relationships from JSON files."""
    with open(PERSONAS_FILE) as f:
        personas = json.load(f)
    with open(RELATIONSHIPS_FILE) as f:
        relationships = json.load(f)
    return personas, relationships


def save_conversations(conversations: list[Conversation], path: Path = OUTPUT_FILE):
    """Save conversations to JSON."""
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump([c.model_dump() for c in conversations], f, indent=2, ensure_ascii=False)
    print(f"Saved {len(conversations)} conversations to {path}")

In [None]:
personas, relationships = load_data()
print(f"Loaded {len(personas)} personas and {len(relationships)} relationships")


In [None]:
from distilabel.models import vLLM


# Initialize vLLM with Qwen3 - optimized for T4 (16GB VRAM)
llm = vLLM(
    model=MODEL_NAME,
    dtype="float16",  # T4 doesn't support bfloat16 - must be top-level param
    extra_kwargs={
        # GPU settings
        "tensor_parallel_size": 1,      # Single T4
        "gpu_memory_utilization": 0.92, # Leave headroom for CUDA kernels
        
        # Memory optimization
        "max_model_len": 2048,          # Reduce from 4096 - SMS are short
        "swap_space": 0,                # Disable CPU swap
        "enforce_eager": False,         # Use CUDA graphs (faster)
        
        # Batching optimization
        "max_num_seqs": 64,             # Max concurrent sequences
        "enable_chunked_prefill": True, # Better memory efficiency
        "enable_prefix_caching": True,  # Cache common prefixes (system prompts)
    },
    generation_kwargs={
        "max_tokens": 60,               # SMS are short - enforce brevity
        "temperature": 0.9,             # Higher for more natural variation
        "top_p": 0.95,
        "stop": ["\n\n", "Them:", "You:", "THEM:", "ME:", "Best regards", "Regards,", "Sincerely"],
    },
)
llm.load()



In [None]:
print(f"Generating conversations for {len(relationships)} relationships...")
conversations = generate_conversations(personas, relationships, llm)
save_conversations(conversations)

In [None]:
# Print sample conversations
print("\n" + "="*60)
print("SAMPLE CONVERSATIONS")
print("="*60)
for conv in conversations[:3]:
    print(f"\n{'‚îÄ'*50}")
    print(f"üì± {conv.relationship_type.upper()} | {conv.scenario}")
    print("‚îÄ" * 50)
    for msg in conv.messages:
        sender = "‚Üí" if msg.sender_uuid == conv.main_uuid else "‚Üê"
        print(f"  {sender} [{msg.timestamp_offset_minutes:3d}m] {msg.text}")