In [1]:
"""
Linguistic Alignment Metrics for Child-Adult Conversational Analysis

Computes three types of alignment between child and adult turns:
1. Lexical alignment: Overlap of lemmatized words
2. Syntactic alignment: Overlap of POS tag sequences
3. Semantic alignment: Cosine similarity of embeddings

All metrics return normalized scores in [0, 1]
"""

import numpy as np
from typing import List, Tuple, Dict
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from collections import Counter


  from tqdm.autonotebook import tqdm, trange


In [None]:


class LexicalAlignmentCalculator:
    """
    Computes lexical alignment based on lemma overlap rate.
    
    Alignment = |lemmas_child ∩ lemmas_adult| / (|lemmas_child| + |lemmas_adult|)
    """
    
    def __init__(self, model_name: str = "en_core_web_sm", 
                 exclude_stopwords: bool = True,
                 exclude_punctuation: bool = True):
        """
        Args:
            model_name: spaCy model to use
            exclude_stopwords: Whether to exclude stopwords from alignment
            exclude_punctuation: Whether to exclude punctuation
        """
        try:
            self.nlp = spacy.load(model_name)
        except OSError:
            print(f"Downloading spaCy model {model_name}...")
            import subprocess
            subprocess.run(["python", "-m", "spacy", "download", model_name])
            self.nlp = spacy.load(model_name)
        
        self.exclude_stopwords = exclude_stopwords
        self.exclude_punctuation = exclude_punctuation
    
    def _get_lemmas(self, text: str) -> List[str]:
        """Extract lemmatized tokens from text."""
        doc = self.nlp(text.lower())
        lemmas = []
        
        for token in doc:
            # Filter based on settings
            if self.exclude_punctuation and token.is_punct:
                continue
            if self.exclude_stopwords and token.is_stop:
                continue
            
            lemmas.append(token.lemma_)
        
        return lemmas
    
    def compute_alignment(self, child_turn: str, adult_turn: str) -> float:
        """
        Compute lexical alignment between child and adult turns.
        
        Args:
            child_turn: Child's utterance
            adult_turn: Adult's utterance
            
        Returns:
            Alignment score in [0, 1]
        """
        child_lemmas = self._get_lemmas(child_turn)
        adult_lemmas = self._get_lemmas(adult_turn)
        
        # Handle empty cases
        if len(child_lemmas) == 0 and len(adult_lemmas) == 0:
            return 1.0  # Both empty = perfect alignment
        if len(child_lemmas) == 0 or len(adult_lemmas) == 0:
            return 0.0  # One empty = no alignment
        
        # Overlap rate: intersection / total
        child_set = set(child_lemmas)
        adult_set = set(adult_lemmas)
        
        intersection = len(child_set & adult_set)
        total = len(child_lemmas) + len(adult_lemmas)
        
        alignment = intersection / total if total > 0 else 0.0
        
        return float(alignment)
    
    def compute_alignment_detailed(self, child_turn: str, adult_turn: str) -> Dict:
        """
        Compute alignment with detailed breakdown.
        
        Returns:
            Dictionary with alignment score and shared lemmas
        """
        child_lemmas = self._get_lemmas(child_turn)
        adult_lemmas = self._get_lemmas(adult_turn)
        
        child_set = set(child_lemmas)
        adult_set = set(adult_lemmas)
        shared_lemmas = child_set & adult_set
        
        alignment = self.compute_alignment(child_turn, adult_turn)
        
        return {
            'alignment': alignment,
            'child_lemmas': child_lemmas,
            'adult_lemmas': adult_lemmas,
            'shared_lemmas': list(shared_lemmas),
            'num_child_lemmas': len(child_lemmas),
            'num_adult_lemmas': len(adult_lemmas),
            'num_shared': len(shared_lemmas),
        }


class SyntacticAlignmentCalculator:
    """
    Computes syntactic alignment based on POS tag bigram overlap rate.
    
    Uses POS bigrams to capture sequential syntactic structure.
    """
    
    def __init__(self, model_name: str = "en_core_web_sm"):
        """
        Args:
            model_name: spaCy model to use
        """
        try:
            self.nlp = spacy.load(model_name)
        except OSError:
            print(f"Downloading spaCy model {model_name}...")
            import subprocess
            subprocess.run(["python", "-m", "spacy", "download", model_name])
            self.nlp = spacy.load(model_name)
    
    def _get_pos_tags(self, text: str) -> List[str]:
        """Extract POS tags from text."""
        doc = self.nlp(text)
        # Use universal POS tags (coarse-grained)
        pos_tags = [token.pos_ for token in doc if not token.is_punct]
        return pos_tags
    
    def _get_pos_bigrams(self, text: str) -> List[str]:
        """Extract POS tag bigrams to capture sequential structure."""
        pos_tags = self._get_pos_tags(text)
        
        if len(pos_tags) < 2:
            return []
        
        # Create bigrams: "DET_NOUN", "NOUN_VERB", etc.
        bigrams = [f"{pos_tags[i]}_{pos_tags[i+1]}" for i in range(len(pos_tags) - 1)]
        return bigrams
    
    def compute_alignment(self, child_turn: str, adult_turn: str) -> float:
        """
        Compute syntactic alignment between child and adult turns using POS bigrams.
        
        Args:
            child_turn: Child's utterance
            adult_turn: Adult's utterance
            
        Returns:
            Alignment score in [0, 1]
        """
        child_bigrams = self._get_pos_bigrams(child_turn)
        adult_bigrams = self._get_pos_bigrams(adult_turn)
        
        # Handle empty cases
        if len(child_bigrams) == 0 and len(adult_bigrams) == 0:
            return 1.0  # Both empty = perfect alignment
        if len(child_bigrams) == 0 or len(adult_bigrams) == 0:
            return 0.0  # One empty = no alignment
        
        # Overlap rate: intersection / total
        child_set = set(child_bigrams)
        adult_set = set(adult_bigrams)
        
        intersection = len(child_set & adult_set)
        total = len(child_bigrams) + len(adult_bigrams)
        
        alignment = intersection / total if total > 0 else 0.0
        
        return float(alignment)
    
    def compute_alignment_detailed(self, child_turn: str, adult_turn: str) -> Dict:
        """
        Compute alignment with detailed breakdown.
        
        Returns:
            Dictionary with alignment scores and POS tag information
        """
        child_tags = self._get_pos_tags(child_turn)
        adult_tags = self._get_pos_tags(adult_turn)
        
        child_bigrams = self._get_pos_bigrams(child_turn)
        adult_bigrams = self._get_pos_bigrams(adult_turn)
        
        child_bigram_set = set(child_bigrams)
        adult_bigram_set = set(adult_bigrams)
        shared_bigrams = child_bigram_set & adult_bigram_set
        
        alignment = self.compute_alignment(child_turn, adult_turn)
        
        return {
            'alignment': alignment,
            'child_pos_tags': child_tags,
            'adult_pos_tags': adult_tags,
            'child_pos_bigrams': child_bigrams,
            'adult_pos_bigrams': adult_bigrams,
            'shared_bigrams': list(shared_bigrams),
            'num_child_bigrams': len(child_bigrams),
            'num_adult_bigrams': len(adult_bigrams),
            'num_shared_bigrams': len(shared_bigrams),
        }


class SemanticAlignmentCalculator:
    """
    Computes semantic alignment using sentence embeddings.
    
    Alignment = cosine_similarity(embedding_child, embedding_adult)
    """
    
    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        """
        Args:
            model_name: SentenceTransformer model to use
                       Options:
                       - "all-MiniLM-L6-v2" (fast, 384 dim)
                       - "all-mpnet-base-v2" (better quality, 768 dim)
                       - "paraphrase-MiniLM-L6-v2" (paraphrase detection)
        """
        print(f"Loading semantic model: {model_name}...")
        self.model = SentenceTransformer(model_name)
        self.model_name = model_name
    
    def _get_embedding(self, text: str) -> np.ndarray:
        """Get sentence embedding for text."""
        embedding = self.model.encode(text, convert_to_numpy=True)
        return embedding
    
    def compute_alignment(self, child_turn: str, adult_turn: str) -> float:
        """
        Compute semantic alignment between child and adult turns.
        
        Args:
            child_turn: Child's utterance
            adult_turn: Adult's utterance
            
        Returns:
            Alignment score in [0, 1]
            
        Note:
            Cosine similarity ranges from [-1, 1], but in practice
            sentence embeddings rarely give negative values.
            We clip to [0, 1] for consistency.
        """
        # Handle empty inputs
        if not child_turn.strip() or not adult_turn.strip():
            return 0.0
        
        child_emb = self._get_embedding(child_turn)
        adult_emb = self._get_embedding(adult_turn)
        
        # Compute cosine similarity
        similarity = cosine_similarity(
            child_emb.reshape(1, -1),
            adult_emb.reshape(1, -1)
        )[0, 0]
        
        # Clip to [0, 1] range
        # (cosine similarity is theoretically [-1, 1], but usually positive)
        alignment = np.clip(similarity, 0, 1)
        
        return float(alignment)
    
    def compute_alignment_batch(self, child_turns: List[str], 
                                adult_turns: List[str]) -> np.ndarray:
        """
        Compute semantic alignment for multiple turn pairs efficiently.
        
        Args:
            child_turns: List of child utterances
            adult_turns: List of adult utterances
            
        Returns:
            Array of alignment scores
        """
        if len(child_turns) != len(adult_turns):
            raise ValueError("Number of child and adult turns must match")
        
        # Batch encode for efficiency
        child_embs = self.model.encode(child_turns, convert_to_numpy=True)
        adult_embs = self.model.encode(adult_turns, convert_to_numpy=True)
        
        # Compute pairwise cosine similarities
        alignments = []
        for child_emb, adult_emb in zip(child_embs, adult_embs):
            similarity = cosine_similarity(
                child_emb.reshape(1, -1),
                adult_emb.reshape(1, -1)
            )[0, 0]
            alignments.append(np.clip(similarity, 0, 1))
        
        return np.array(alignments)
    
    def compute_alignment_detailed(self, child_turn: str, adult_turn: str) -> Dict:
        """
        Compute alignment with embedding details.
        
        Returns:
            Dictionary with alignment score and embeddings
        """
        child_emb = self._get_embedding(child_turn)
        adult_emb = self._get_embedding(adult_turn)
        
        alignment = self.compute_alignment(child_turn, adult_turn)
        
        return {
            'alignment': alignment,
            'child_embedding': child_emb,
            'adult_embedding': adult_emb,
            'embedding_dim': len(child_emb),
            'model_name': self.model_name,
        }


class LinguisticAlignmentSuite:
    """
    Unified interface for computing all alignment metrics.
    """
    
    def __init__(self, 
                 spacy_model: str = "en_core_web_sm",
                 semantic_model: str = "all-MiniLM-L6-v2",
                 exclude_stopwords: bool = True):
        """
        Initialize all alignment calculators.
        
        Args:
            spacy_model: Model for lexical and syntactic alignment
            semantic_model: Model for semantic alignment
            exclude_stopwords: Whether to exclude stopwords in lexical alignment
        """
        print("Initializing Linguistic Alignment Suite...")
        
        self.lexical_calc = LexicalAlignmentCalculator(
            model_name=spacy_model,
            exclude_stopwords=exclude_stopwords
        )
        
        self.syntactic_calc = SyntacticAlignmentCalculator(
            model_name=spacy_model
        )
        
        self.semantic_calc = SemanticAlignmentCalculator(
            model_name=semantic_model
        )
        
        print("✓ Alignment suite ready!")
    
    def compute_all_alignments(self, child_turn: str, adult_turn: str) -> Dict[str, float]:
        """
        Compute all three alignment types.
        
        Args:
            child_turn: Child's utterance
            adult_turn: Adult's utterance
            
        Returns:
            Dictionary with all alignment scores (all in [0, 1])
        """
        return {
            'lexical_alignment': self.lexical_calc.compute_alignment(child_turn, adult_turn),
            'syntactic_alignment': self.syntactic_calc.compute_alignment(child_turn, adult_turn),
            'semantic_alignment': self.semantic_calc.compute_alignment(child_turn, adult_turn),
        }
    
    def compute_all_alignments_detailed(self, child_turn: str, adult_turn: str) -> Dict:
        """
        Compute all alignments with detailed breakdowns.
        
        Returns:
            Dictionary with detailed information for each alignment type
        """
        return {
            'lexical': self.lexical_calc.compute_alignment_detailed(child_turn, adult_turn),
            'syntactic': self.syntactic_calc.compute_alignment_detailed(child_turn, adult_turn),
            'semantic': self.semantic_calc.compute_alignment_detailed(child_turn, adult_turn),
        }
    
    def compute_batch(self, child_turns: List[str], adult_turns: List[str]) -> Dict[str, np.ndarray]:
        """
        Efficiently compute alignments for multiple turn pairs.
        
        Returns:
            Dictionary of arrays with alignment scores
        """
        n = len(child_turns)
        if n != len(adult_turns):
            raise ValueError("Number of child and adult turns must match")
        
        # Lexical and syntactic need to be computed individually
        lexical_scores = np.array([
            self.lexical_calc.compute_alignment(c, a) 
            for c, a in zip(child_turns, adult_turns)
        ])
        
        syntactic_scores = np.array([
            self.syntactic_calc.compute_alignment(c, a)
            for c, a in zip(child_turns, adult_turns)
        ])
        
        # Semantic can be batched efficiently
        semantic_scores = self.semantic_calc.compute_alignment_batch(
            child_turns, adult_turns
        )
        
        return {
            'lexical_alignment': lexical_scores,
            'syntactic_alignment': syntactic_scores,
            'semantic_alignment': semantic_scores,
        }



EXAMPLE 1: Individual Calculators

--- Lexical Alignment (Overlap Rate) ---

Child: I like dinosaurs! They're big and scary!
Adult: Dinosaurs are fascinating! Which dinosaur is your favorite?
Lexical alignment: 0.125

Child: How do birds fly?
Adult: Great question! Birds fly by flapping their wings to push air down.
Lexical alignment: 0.200

Child: I don't understand fractions.
Adult: Let me help you understand fractions with an example.
Lexical alignment: 0.286

Child: Can we play outside?
Adult: Not right now, but we can after you finish your homework.
Lexical alignment: 0.000

Detailed breakdown:
  Shared lemmas: ['dinosaur']
  Child lemmas: ['like', 'dinosaur', 'big', 'scary']
  Adult lemmas: ['dinosaur', 'fascinating', 'dinosaur', 'favorite']
  Overlap: 1 / (4 + 4) = 0.125

--- Syntactic Alignment (POS Bigrams) ---

Child: I like dinosaurs! They're big and scary!
Adult: Dinosaurs are fascinating! Which dinosaur is your favorite?
Syntactic alignment: 0.071

Child: How do birds fly?

In [None]:

# ============================================================================
# USAGE EXAMPLES
# ============================================================================

if __name__ == "__main__":
    
    # Example child-adult conversation pairs
    examples = [
        {
            'child': "I like dinosaurs! They're big and scary!",
            'adult': "Dinosaurs are fascinating! Which dinosaur is your favorite?",
        },
        {
            'child': "How do birds fly?",
            'adult': "Great question! Birds fly by flapping their wings to push air down.",
        },
        {
            'child': "I don't understand fractions.",
            'adult': "Let me help you understand fractions with an example.",
        },
        {
            'child': "Can we play outside?",
            'adult': "Not right now, but we can after you finish your homework.",
        },
    ]
    
    print("="*70)
    print("EXAMPLE 1: Individual Calculators")
    print("="*70)
    
    # 1. Lexical Alignment (Overlap Rate)
    print("\n--- Lexical Alignment (Overlap Rate) ---")
    lexical_calc = LexicalAlignmentCalculator(exclude_stopwords=True)
    
    for ex in examples:
        score = lexical_calc.compute_alignment(ex['child'], ex['adult'])
        print(f"\nChild: {ex['child']}")
        print(f"Adult: {ex['adult']}")
        print(f"Lexical alignment: {score:.3f}")
    
    # Detailed example
    detailed = lexical_calc.compute_alignment_detailed(
        examples[0]['child'], 
        examples[0]['adult']
    )
    print(f"\nDetailed breakdown:")
    print(f"  Shared lemmas: {detailed['shared_lemmas']}")
    print(f"  Child lemmas: {detailed['child_lemmas']}")
    print(f"  Adult lemmas: {detailed['adult_lemmas']}")
    print(f"  Overlap: {detailed['num_shared']} / ({detailed['num_child_lemmas']} + {detailed['num_adult_lemmas']}) = {detailed['alignment']:.3f}")
    
    # 2. Syntactic Alignment (POS Bigrams)
    print("\n" + "="*70)
    print("--- Syntactic Alignment (POS Bigrams) ---")
    syntactic_calc = SyntacticAlignmentCalculator()
    
    for ex in examples:
        score = syntactic_calc.compute_alignment(ex['child'], ex['adult'])
        print(f"\nChild: {ex['child']}")
        print(f"Adult: {ex['adult']}")
        print(f"Syntactic alignment: {score:.3f}")
    
    # Detailed example
    detailed = syntactic_calc.compute_alignment_detailed(
        examples[1]['child'],
        examples[1]['adult']
    )
    print(f"\nDetailed breakdown:")
    print(f"  Child POS tags: {detailed['child_pos_tags']}")
    print(f"  Adult POS tags: {detailed['adult_pos_tags']}")
    print(f"  Child POS bigrams: {detailed['child_pos_bigrams']}")
    print(f"  Adult POS bigrams: {detailed['adult_pos_bigrams']}")
    print(f"  Shared bigrams: {detailed['shared_bigrams']}")
    print(f"  Overlap: {detailed['num_shared_bigrams']} / ({detailed['num_child_bigrams']} + {detailed['num_adult_bigrams']}) = {detailed['alignment']:.3f}")
    
    # 3. Semantic Alignment
    print("\n" + "="*70)
    print("--- Semantic Alignment ---")
    semantic_calc = SemanticAlignmentCalculator(model_name="all-MiniLM-L6-v2")
    
    for ex in examples:
        score = semantic_calc.compute_alignment(ex['child'], ex['adult'])
        print(f"\nChild: {ex['child']}")
        print(f"Adult: {ex['adult']}")
        print(f"Semantic alignment: {score:.3f}")
    
    # 4. Unified Suite
    print("\n" + "="*70)
    print("EXAMPLE 2: Unified Alignment Suite")
    print("="*70)
    
    suite = LinguisticAlignmentSuite(
        spacy_model="en_core_web_sm",
        semantic_model="all-MiniLM-L6-v2",
        exclude_stopwords=True
    )
    
    for ex in examples:
        alignments = suite.compute_all_alignments(ex['child'], ex['adult'])
        print(f"\nChild: {ex['child']}")
        print(f"Adult: {ex['adult']}")
        print(f"  Lexical:   {alignments['lexical_alignment']:.3f}")
        print(f"  Syntactic: {alignments['syntactic_alignment']:.3f}")
        print(f"  Semantic:  {alignments['semantic_alignment']:.3f}")
    
    # 5. Batch Processing
    print("\n" + "="*70)
    print("EXAMPLE 3: Batch Processing")
    print("="*70)
    
    child_turns = [ex['child'] for ex in examples]
    adult_turns = [ex['adult'] for ex in examples]
    
    batch_results = suite.compute_batch(child_turns, adult_turns)
    
    print("\nBatch results:")
    print(f"Lexical alignments:   {batch_results['lexical_alignment']}")
    print(f"Syntactic alignments: {batch_results['syntactic_alignment']}")
    print(f"Semantic alignments:  {batch_results['semantic_alignment']}")
    
    # 6. Comprehensive Analysis
    print("\n" + "="*70)
    print("EXAMPLE 4: Comprehensive Detailed Analysis")
    print("="*70)
    
    detailed_all = suite.compute_all_alignments_detailed(
        examples[0]['child'],
        examples[0]['adult']
    )

    print(f"\nChild: {examples[0]['child']}")
    print(f"Adult: {examples[0]['adult']}")
    print(f"\nLexical alignment: {detailed_all['lexical']['alignment']:.3f}")
    print(f"  Shared lemmas: {detailed_all['lexical']['shared_lemmas']}")
    print(f"\nSyntactic alignment: {detailed_all['syntactic']['alignment']:.3f}")
    print(f"  Shared POS bigrams: {detailed_all['syntactic']['shared_bigrams']}")
    print(f"\nSemantic alignment: {detailed_all['semantic']['alignment']:.3f}")
    print(f"  Embedding dimension: {detailed_all['semantic']['embedding_dim']}")