In [None]:
from transformers import pipeline
from cdp_data import datasets, CDPInstances
import numpy as np
from pathlib import Path
import pandas as pd
from tqdm import tqdm

# Set randomness
np.random.seed(60)

MODEL_DIR = "../whole-pc-section-window-classifier"

########################################################################################

# Load the model and tokenizer
clf = pipeline(
    "text-classification",
    model=MODEL_DIR,
    tokenizer=MODEL_DIR,
    padding=True,
    truncation=True,
)

########################################################################################
# Functions to handle segmentation evaluation

def create_positions_from_indices(session_annotations: pd.DataFrame, total_transcript_length: int) -> list[int]:
    # Convert from list of start and end indicies to "positions"
    # positions format is a list of integers where each integer is an id for the unique section of the meeting
    # i.e. [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] has three sections
    # First place 1's until the first start index, then 2's until the first end index,
    # then 3's until the second start index, etc.

    # Process positions
    positions = []

    # Get all the start and end indicies in a single ordered list
    start_and_end_indicies = list(
        session_annotations["period_start_sentence_index"].values
    ) + list(session_annotations["period_end_sentence_index"].values)
    start_and_end_indicies.sort()

    # If all values are -1 then return a list of 1's for the entire length
    if all(x == -1 for x in start_and_end_indicies):
        return [1] * total_transcript_length
    
    # Add the total length of the transcript to the end
    start_and_end_indicies.append(total_transcript_length)

    # Iterate over the start and end indicies and add the positions
    previous_boundary_index = 0
    for section_index, sentence_index in enumerate(start_and_end_indicies, start=1):
        length_of_section = sentence_index - previous_boundary_index
        positions.extend([section_index] * length_of_section)
        previous_boundary_index = sentence_index

    return positions


########################################################################################
# Load the basic dataset
ds = datasets.get_session_dataset(
    CDPInstances.Seattle,
    store_transcript=True,
    store_transcript_as_csv=True,
    start_datetime="2020-01-01",
    end_datetime="2024-01-01",
    sample=5,
)

# Overall directory for saving
storage_dir = Path("seattle-transcripts/")
storage_dir.mkdir(exist_ok=True)

# Iter sessions 
for _, row in ds.iterrows():
    transcript_copy_path = storage_dir / f"{row['id']}.csv"
    transcript = pd.read_csv(row.transcript_as_csv_path)
    transcript = transcript[["index", "text"]]
    transcript = transcript.rename(columns={"index": "sentence_index"})
    transcript["session_id"] = row["id"]
    transcript["council"] = CDPInstances.Seattle
    transcript.to_csv(transcript_copy_path, index=False)

# Load the resolved annotation data
annotations = pd.read_csv("training-data/whole-period-seg-seattle.csv")

def get_context_window_text(transcript, center_index) -> str:
    return " ".join(
        transcript.iloc[
            max(center_index - 1, 0):
            min(center_index + 2, len(transcript) - 1)
        ]["text"]
    ).strip()

# Iter over the annotations set, load the session transcript CSV
# Iter over session transcript and create context windows
# Classify the context windows
# Mark the start and stop of comment periods in a dataframe with the same columns as the annotations
# Construct positions strings for the original annotations and the newly classified start and stop
# Convert to masses and then compare via seg eval
classified_windows = []
for _, session in tqdm(
    annotations.sample(5).iterrows(),
    desc="Processing sessions",
    total=5,
):
    # Load the session transcript csv
    transcript = pd.read_csv(f"seattle-transcripts/{session.session_id.strip()}.csv")
    transcript["text"] = transcript["text"].fillna("")

    # Construct all context windows
    context_windows = [
        get_context_window_text(transcript, sentence.sentence_index)
        for _, sentence in transcript.iterrows()
    ]

    # Classify all context windows
    results = clf(context_windows)

    # Iter over context windows
    current_pc_start = None
    for i, result in enumerate(results):
        # If the result is a comment period start then update the current_pc_start
        if result["label"] == "comment-period-start":
            current_pc_start = i
        
        # If the result is a comment period end and we have already started, then add the start and end to the classified windows
        if result["label"] == "comment-period-end" and current_pc_start is not None:
            classified_windows.append({
                "session_id": session.session_id,
                "period_start_sentence_index": current_pc_start,
                "period_end_sentence_index": i,
            })
            current_pc_start = None

    # If we have not found any comment periods then add a single row with -1's
    if len(classified_windows) == 0:
        classified_windows.append({
            "session_id": session.session_id,
            "period_start_sentence_index": -1,
            "period_end_sentence_index": -1,
        })

# Convert to dataframe
classified_windows = pd.DataFrame(classified_windows)
classified_windows.to_csv("classified-windows.csv", index=False)