In [None]:
import pandas as pd
import numpy as np
import ast
import random
from typing import List, Tuple, Optional
from tqdm import tqdm

In [None]:
def safe_eval(val):
    """Safely evaluate string representations of lists."""
    if isinstance(val, str):
        return ast.literal_eval(val)
    return val

In [None]:
import pandas as pd
import numpy as np
import ast
import random
from typing import List, Tuple, Optional
from tqdm import tqdm

def safe_eval(val):
    """Safely evaluate string representations of lists."""
    if isinstance(val, str):
        return ast.literal_eval(val)
    return val

def find_valid_segment(docs: List[str], actions: List[str], segment_length: int) -> Optional[Tuple[List[str], List[str]]]:
    """
    Find a valid segment of specified length where 'gen_summ' and 'summ_gen' appear consecutively.
    
    Args:
        docs: List of document IDs
        actions: List of corresponding actions
        segment_length: Desired length of the segment
        
    Returns:
        Tuple of (segment_docs, segment_actions) or None if no valid segment found
    """
    if len(docs) < segment_length:
        return None
        
    valid_segments = []
    for start in range(len(docs) - segment_length + 1):
        segment_docs = docs[start:start + segment_length]
        segment_actions = actions[start:start + segment_length]
        
        # Check for the pattern anywhere in the segment
        for i in range(len(segment_actions) - 1):
            if (segment_actions[i:i + 2] == ['gen_summ', 'summ_gen'] and
                segment_actions[0] != 'summ_gen' and  # First action shouldn't be summ_gen
                segment_actions[-1] != 'gen_summ'):   # Last action shouldn't be gen_summ
                valid_segments.append((segment_docs, segment_actions))
                break
                
    # Return a random valid segment if any found
    return random.choice(valid_segments) if valid_segments else None

def augment_synthetic_data(
    input_csv: str,
    output_csv: str,
    beta: int = 1000,           # Maximum trajectory length
    num_sources: int = 4,       # Number of source trajectories to sample from
    min_segment_length: int = 50,
    max_segment_length: int = 100,
    gap: int = 3               # Gap between segment insertions
) -> None:
    """
    Augment trajectories with one segment from each source trajectory.
    
    Args:
        input_csv: Path to input CSV file
        output_csv: Path to output CSV file
        beta: Maximum allowed trajectory length
        num_sources: Number of source trajectories to sample from (each provides one segment)
        min_segment_length: Minimum length of sampled segments
        max_segment_length: Maximum length of sampled segments
        gap: Gap between segment insertions
    """
    print(f"Loading data from `{input_csv}`")
    synthetic_df = pd.read_csv(input_csv)
    print("Data loaded successfully.")

    for i, row in tqdm(synthetic_df.iterrows(), total=synthetic_df.shape[0]):
        # Extract the target trajectory
        target_docs = safe_eval(row['Docs'])
        target_actions = safe_eval(row['Action'])
        
        # Sample source trajectories
        source_trajectories = synthetic_df.drop(index=i).sample(n=num_sources)
        segments = []
        
        # Extract one segment from each source trajectory
        for _, srow in source_trajectories.iterrows():
            source_docs = safe_eval(srow['Docs'])
            source_actions = safe_eval(srow['Action'])
            
            # Generate random length for this segment
            segment_length = np.random.randint(min_segment_length, max_segment_length + 1)
            
            # Find a valid segment of this length
            segment = find_valid_segment(source_docs, source_actions, segment_length)
            if segment:
                segments.append(segment)
            
        if not segments:  # Skip if no valid segments found
            continue

        # Choose random starting position
        starting_position = np.random.randint(0, max(1, len(target_docs)))
        
        # Initialize augmented trajectory
        augmented_docs = target_docs[:starting_position]
        augmented_actions = target_actions[:starting_position]
        current_position = starting_position
        
        # Insert segments with periodic gaps
        for segment_docs, segment_actions in segments:
            # Check if we've exceeded beta
            if current_position >= beta:
                break
                
            # Calculate how much of the segment we can add
            remaining_space = beta - current_position
            segment_length = min(len(segment_docs), remaining_space)
            
            # Add (possibly truncated) segment
            augmented_docs.extend(segment_docs[:segment_length])
            augmented_actions.extend(segment_actions[:segment_length])
            current_position += segment_length
            
            # Add gap if there's room
            if current_position + gap <= beta and current_position < len(target_docs):
                gap_end = min(current_position + gap, len(target_docs))
                gap_docs = target_docs[current_position:gap_end]
                gap_actions = target_actions[current_position:gap_end]
                augmented_docs.extend(gap_docs)
                augmented_actions.extend(gap_actions)
                current_position = gap_end
            
        # Ensure we don't exceed beta
        augmented_docs = augmented_docs[:beta]
        augmented_actions = augmented_actions[:beta]
        #print(len(augmented_docs))
        # Update the row with augmented trajectory
        synthetic_df.at[i, 'Docs'] = str(augmented_docs)
        synthetic_df.at[i, 'Action'] = str(augmented_actions)
        synthetic_df.at[i, 'Summaries'] = augmented_actions.count('summ_gen')

    # Save augmented dataset
    synthetic_df.to_csv(output_csv, index=False)
    print(f"Augmented dataset saved to `{output_csv}`")

# Example usage:
if __name__ == "__main__":
    augment_synthetic_data(
        "synthetic-original.csv",
        "synthetic-original_D2_25_100.csv",
        beta=100,
        num_sources=4,
        min_segment_length=50,
        max_segment_length=100,
        gap=25
    )