In [None]:
import pandas as pd
import numpy as np
import ast
import random
from typing import List, Tuple, Optional
from tqdm import tqdm

def safe_eval(val):
    """Safely evaluate string representations of lists."""
    if isinstance(val, str):
        return ast.literal_eval(val)
    return val

def find_valid_segment(docs: List[str], actions: List[str], seg_len: int) -> Optional[Tuple[List[str], List[str]]]:
    """
    Find a valid segment of the specified length that contains the consecutive pair
    ['gen_summ', 'summ_gen'] (with some extra constraints) somewhere in the segment.
    """
    if len(docs) < seg_len:
        return None

    valid_seg = []
    for start in range(len(docs) - seg_len + 1):
        seg_docs = docs[start:start + seg_len]
        seg_actions = actions[start:start + seg_len]
        
        # look for at least one occurrence of the pair in the segment.
        for i in range(len(seg_actions) - 1):
            if (seg_actions[i:i + 2] == ['gen_summ', 'summ_gen'] and
                seg_actions[0] != 'summ_gen' and  # first action shouldn't be summ_gen
                seg_actions[-1] != 'gen_summ'):   # last action shouldn't be gen_summ
                valid_seg.append((seg_docs, seg_actions))
                break
                
    return random.choice(valid_seg) if valid_seg else None

def augment_synthetic_data(
    input_csv: str,
    output_csv: str,
    beta: int = 150,           # maximum trajectory length (for short trajectories)
    num_sources: int = 2,       # number of source trajectories to sample from
    min_seg_len: int = 3,
    max_seg_len: int = 6,
    gap: int = 2              # gap between segment insertions
) -> None:
    """
    Augment trajectories by inserting one segment from each of a few source trajectories.
    This updated version ensures that if a segment must be truncated to avoid exceeding
    beta, it will not break a 'gen_summ' followed immediately by 'summ_gen'.
    """
    #print(f"Loading data from `{input_csv}`")
    synthetic_df = pd.read_csv(input_csv)

    for i, row in tqdm(synthetic_df.iterrows(), total=synthetic_df.shape[0]):
        # parse the target trajectory lists.
        target_docs = safe_eval(row['Docs'])
        target_actions = safe_eval(row['Action'])
        
        # sample source trajectories (excluding the current row).
        source_trajectories = synthetic_df.drop(index=i).sample(n=num_sources)
        segments = []
        
        for _, srow in source_trajectories.iterrows():
            source_docs = safe_eval(srow['Docs'])
            source_actions = safe_eval(srow['Action'])
            
            # choose a random segment length between min_seg_len and max_seg_len.
            seg_len = np.random.randint(min_seg_len, max_seg_len + 1)
            seg = find_valid_segment(source_docs, source_actions, seg_len)
            if seg:
                segments.append(seg)
            
        if not segments:
            # if no valid segments were found, skip this row.
            continue

        # choose a random starting position in the target trajectory.
        start_pos = np.random.randint(0, max(1, len(target_docs)))
        aug_docs = target_docs[:start_pos]
        aug_actions = target_actions[:start_pos]
        curr_pos = start_pos
        
        for seg_docs, seg_actions in segments:
            if curr_pos >= beta:
                break
                
            remain_space = beta - curr_pos
            seg_len = len(seg_docs)
            
            # if the entire segment fits, add it in full.
            if seg_len <= remain_space:
                aug_docs.extend(seg_docs)
                aug_actions.extend(seg_actions)
                curr_pos += seg_len
            else:
                # otherwise, add as much as possible—but avoid splitting a 'gen_summ' / 'summ_gen' pair.
                cutoff = remain_space
                # if the segment would be cut between a 'gen_summ' and a following 'summ_gen'
                if cutoff > 0 and cutoff < seg_len:
                    if seg_actions[cutoff - 1] == 'gen_summ' and seg_actions[cutoff] == 'summ_gen':
                        cutoff = cutoff - 1
                aug_docs.extend(seg_docs[:cutoff])
                aug_actions.extend(seg_actions[:cutoff])
                curr_pos += cutoff
                
            # adding a gap from the target trajectory if room remains.
            if curr_pos + gap <= beta and curr_pos < len(target_docs):
                gap_end = min(curr_pos + gap, len(target_docs))
                gap_docs = target_docs[curr_pos:gap_end]
                gap_actions = target_actions[curr_pos:gap_end]
                aug_docs.extend(gap_docs)
                aug_actions.extend(gap_actions)
                curr_pos = gap_end

        # Make sure we do not exceed beta.
        aug_docs = aug_docs[:beta]
        print(len(aug_docs))
        aug_actions = aug_actions[:beta]
        synthetic_df.at[i, 'Docs'] = str(aug_docs)
        synthetic_df.at[i, 'Actions'] = str(aug_actions)
        synthetic_df.at[i, 'Summaries'] = aug_actions.count('summ_gen')

    synthetic_df.to_csv(output_csv, index=False)
    print(f"Augmented dataset saved to `{output_csv}`")


if __name__ == "__main__":
    augment_synthetic_data(
        input_csv="input.csv",
        output_csv="output.csv",
        beta=50,
        num_sources=4,
        min_seg_len=2,
        max_seg_len=6,
        gap=25
    )