# Testing sample for the Pallier2023Listen code

In [None]:
from jra_utils import approx_match_samples
from jr_utils import match_list


CHAPTER_PATHS = [
    "ch1-3.wav",
    "ch4-6.wav",
    "ch7-9.wav",
    "ch10-12.wav",
    "ch13-14.wav",
    "ch15-19.wav",
    "ch20-22.wav",
    "ch23-25.wav",
    "ch26-27.wav",
]

# Handle the particular runs for which we need a higher tolerance
# To handle the default case, we'll use this in the later code
# abs_tol, max_missing = TOL_MISSING_DICT.get((subject, run), (10, 5))
TOL_MISSING_DICT = {
    (9, 6): (30, 5),
    (10, 6): (30, 5),
    (12, 5): (30, 5),
    (13, 3): (30, 5),
    (13, 7): (30, 5),
    (14, 9): (30, 5),
    (21, 6): (30, 5),
    (21, 8): (30, 5),
    (22, 4): (30, 5),
    (33, 2): (30, 5),
    (39, 5): (30, 5),
    (40, 2): (30, 5),
    (41, 1): (30, 5),
    (43, 4): (30, 5),
    (43, 5): (30, 5),
    (44, 9): (30, 5),
    (24, 2): (10, 20),
}


def _get_seq_id_path(self):
    return self.path / f"sourcedata/task-{self.task}_run-{self.run}_extra_info.tsv"

def _get_syntax_path(self):
    return (
        self.path
        / f"sourcedata/stimuli/run{self.run}_v2_0.25_0.5-tokenized.syntax.txt"
    )

def _get_word_info_path(self):
    return str(self._get_bids_path()).replace("meg.fif", "events.tsv")

def _load_events(self) -> pd.DataFrame:
    """
    Redefine this method in the subclasses
    for both listen and read
    """
    raise NotImplementedError


# # # # # actual studies # # # # #


class PallierListen2023(_Pallier2023Base):

    task: tp.ClassVar[str] = "listen"

    def _load_events(self) -> pd.DataFrame:
        """
        in this particular data, I'm transforming our original rich dataframe
        into mne use a Annotation class in order to save the whole thing into
        a *.fif file, At reading time, I'm converting it back to a DataFrame
        """

        error_msg_prefix = (
            f"subject {self.subject}, session {self.session}, run {self.run}\n"
        )

        raw = self._load_raw(self.timeline)
        # Get the start and stop triggers from STI101
        sound_triggers = mne.find_events(raw, stim_channel="STI101", shortest_event=1)

        # extract annotations
        events = []
        for annot in raw.annotations:
            description = annot.pop("description")
            if "BAD_ACQ_SKIP" in description:
                continue
            event = eval(description)
            event["condition"] = "sentence"
            event["type"] = event.pop("kind").capitalize()
            event["start"] = annot["onset"]
            event["duration"] = annot["duration"]
            event["stop"] = annot["onset"] + annot["duration"]
            event["language"] = "french"
            events.append(event)

        # extract sound annotation
        try:
            sound_triggers = sound_triggers[sound_triggers[:, 2] == 1]  # get the triggers
            start, stop = sound_triggers[:, 0] / raw.info["sfreq"]
            events.append(
                dict(
                    type="Sound",
                    start=start,
                    duration=stop - start,
                    filepath=Path(self.path)
                    / "sourcedata/stimuli/audio"
                    / CHAPTER_PATHS[int(self.run) - 1],
                )
            )
        except Exception as e:
            warnings.warn(
                f"No sound triggers found for subject {self.subject}, run {self.run}: {e}"
            )

        events_df = pd.DataFrame(events).rename(columns=dict(word="text"))

        # Remove empty words that were included in the metadata files...
        events_df.loc[events_df["text"] == " ", "text"] = None

        # Drop the rows containing NaN values in the text column
        events_df = events_df.dropna(subset=["text"])
        events_df.reset_index(drop=True, inplace=True)

        metadata = pd.read_csv(self._get_seq_id_path())
        rows_events, rows_metadata = match_list(
            [str(word) for word in events_df["text"].values],
            [str(word) for word in metadata["word"].values],
        )

        assert len(rows_events) / len(events_df) > 0.95, (
            error_msg_prefix
            + f"only {len(rows_events) / len(events_df)} of the words were found in the metadata"
        )
        events_idx, metadata_idx = (
            events_df.index[rows_events],
            metadata.index[rows_metadata],
        )

        # Adding the information about sequence_id and n_closing
        events_df["word"] = events_df["text"]
        for col in ["sequence_id", "n_closing", "is_last_word", "pos"]:
            events_df.loc[events_idx, col] = metadata.loc[metadata_idx, col]

        # add train/test/val splits
        events_df = set_sentence_split(events_df)  # TODO

        # Handling the alignment issue with audio shift in the MEG recordings
        starts = mne.find_events(raw, output="step", shortest_event=1)[:, 0]
        meg_times = np.copy(raw.times)
        meg_triggers = np.zeros_like(meg_times)
        meg_triggers[starts - raw.first_samp] = 1

        words = events_df.loc[events_df.type == "Word"]

        # Get the word triggers from STI008, as a step so we can get the offset
        word_triggers = mne.find_stim_steps(raw, stim_channel="STI008")
        # Offsets of the step function: allows us to match
        word_triggers = word_triggers[word_triggers[:, 2] == 0]

        # New match
        abs_tol, max_missing = TOL_MISSING_DICT.get(
            (int(self.subject), int(self.run)), (10, 5)
        )
        i, j = approx_match_samples(
            (words.start * 1000).tolist(),
            word_triggers[:, 0],
            abs_tol=abs_tol,
            max_missing=max_missing,
        )
        print(f"Found {len(i)/len(words)} of the words in the triggers")

        words = words.iloc[i, :]

        events_df.loc[:, "unaligned_start"] = events_df.loc[:, "start"]
        events_df.loc[words.index, "start"] = word_triggers[j, 0] / raw.info["sfreq"]

        # Add sentence / constituent info
        events_df = enrich_metadata(events_df)

        # add raw event
        uri = f"method:_load_raw?timeline={self.timeline}"
        meg = {"filepath": uri, "type": "Meg", "start": 0}
        events_df = pd.concat([pd.DataFrame([meg]), events_df])

        # sort by start
        events_df = events_df.sort_values(by="start").reset_index(drop=True)

        return events_df

