5/11 (Sat) | UF Measures

# Automatic Annotation of Temporal Features

## 1. Introduction

This notebook annotated temporal features using the proposed pipeline.
Before starting the automatic annotation, the following code block loads required packages and defines global variabls.
Note that instead of RoBERTa based disfluency detector, this notebook utilized the conventional BERT based detector.

In [1]:
from typing import List, Tuple, Generator
import sys, json, traceback
from pathlib import Path

import pandas as pd
import pickle as pkl
from textgrids import TextGrid
from rev_ai import Transcript

sys.path.append( # TODO: PYTHONPATH 環境変数によるモジュール読み込みを修正する
    "/home/matsuura/Development/app/feature_extraction_api"
)
sys.path.append(
    "/home/matsuura/Development/app/feature_extraction_api/app/modules"
)

from app.utils.rev_utils import transcript_2_df, FILLER
from fluency import Turn, DisfluencyEnum, Annotator
from fluency.pipeline.utils.pause_location import PauseLocation

DATA_DIR = Path("/home/matsuura/Development/app/feature_extraction_api/experiment/data")

TASK = ["WoZ_Interview"]

---

## 2. Define Functions

This section defines functions to conduct automatic annotation.
The following code block defines a generator to yield rev transcript json path.

In [2]:
def json_path_generator(task: str) -> Generator[Path, None, None]:
    load_dir = DATA_DIR / f"{task}/07_Rev_Json"

    for json_path in load_dir.glob("*.json"):
        yield json_path

The following code block defines a function for the preprocess of rev transcripts.

In [3]:
def preprocess_for_turn(
        rev_transcript: Transcript
) -> Tuple[pd.DataFrame, List[int], List[int]]:
    df_rev = transcript_2_df(rev_transcript)

    period_locations = [] # ピリオドが挿入される直前の word id のリスト
    filler_locations = [] # フィラーの word id のリスト
    idx = -1
    for i in df_rev.index:
        w = df_rev.at[i, "text"]
        t = df_rev.at[i, "type"]

        if t == "text":
            idx += 1
            if w.lower() in FILLER:
                filler_locations.append(idx)
            if " " in w:
                w = w.replace(" ", "_")
                df_rev.at[i, "text"] = w
        else:
            if w == ".":
                period_locations.append(idx - len(filler_locations))

    df_text = df_rev[df_rev["type"] != "punct"].reset_index()
    df_text["text"] = df_text["text"].str.lower()

    if -1 in period_locations:
        period_locations.remove(-1)
        
    period_locations = sorted(set(period_locations))

    return df_text, period_locations, filler_locations

The following code block defines a fuction to convert from pandas' DataFrame to Turn object.

In [4]:
def df_2_turn(
        df_rev: pd.DataFrame, 
        period_locations: List[int], 
        filler_locations: List[int]
) -> Turn:
    turn = Turn.from_DataFrame(df_rev, 0, word_col="text")

    disfluency_list = [DisfluencyEnum.FILLER for _ in filler_locations]
    turn.clauses[0].annotate_disfluency(filler_locations, disfluency_list)
    turn.reset_words()

    if len(period_locations) != 0:
        turn.separate_clause(0, period_locations[:-1])

    return turn

The following code block defines to find pauses from FA timestamp information.

In [5]:
def find_pauses(turn: Turn) -> List[dict]:
    pauses = []
    prev_clause_end = turn.start_time
    for clause in turn.clauses:
        if clause.start_time - prev_clause_end >= 0.25:
            p = {
                "location": PauseLocation.CLAUSE_EXTERNAL,
                "start_time": prev_clause_end,
                "end_time": clause.start_time
            }
            pauses.append(p)

        prev_word_end = clause.start_time
        for wid, word in enumerate(clause.words):
            if clause.idx == 0 and wid == 0: # 最初の節の最初の単語の場合
                if len(clause) == 1: # 1単語のみからなる clause の場合，次の節へ
                        continue
                
                if word.idx == -1 and word.disfluency.name == "FILLER":                
                    prev_word_end = clause.words[wid + 1].end_time
                    continue

            if word.start_time - prev_word_end >= 0.25:
                p = {
                    "location": PauseLocation.CLAUSE_INTERNAL,
                    "start_time": prev_word_end,
                    "end_time": word.start_time
                }
                pauses.append(p)

            prev_word_end = word.end_time

        prev_clause_end = clause.end_time

    return pauses

The following code block defines a function to annotate temporal features.

In [6]:
def annotate(
        transcript: dict, 
        save_path: str, 
        annotator: Annotator
) -> Tuple[Turn, TextGrid]:
    df_transcript, period_locations, filler_locations = preprocess_for_turn(transcript)
    turn = df_2_turn(df_transcript, period_locations, filler_locations)

    turn.ignore_disfluency()
    turn = annotator(turn=turn)
    turn.show_disfluency()

    pauses = find_pauses(turn)
    grid = annotator.to_textgrid(turn, pauses, save_path=save_path)

    return turn, grid

---

## 3. Annotate Temporal Features

This section annotates temporal features related to utterance fluency.
The following code block constructs an annotator object.

In [7]:
annotator = Annotator(process=["eos_detect", "pruning", "clause_detect"], disfluency_detector="bert")

Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForTokenClassification: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this 

BERT model was selected!


Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForTokenClassification: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this 

The following code block annotates temporal features.

In [8]:
for task in TASK:
    save_dir = DATA_DIR / f"{task}/08_Auto_Annotation"

    for json_path in json_path_generator(task):
        with open(json_path, "r") as f:
            transcript = Transcript.from_json(json.load(f))

        turn_path = save_dir / f"{json_path.stem}_bert.pkl"    
        textgrid_path = save_dir / f"{json_path.stem}_bert.TextGrid"

        # if turn_path.exists() and textgrid_path.exists():
        #     continue

        try:
            turn, textgrid = annotate(transcript, str(textgrid_path), annotator)
        except Exception as e:
            print(f"Unexpected error in {str(json_path)}:\n{traceback.format_exc(e)}")
            continue

        with open(turn_path, "wb") as f:
            pkl.dump(turn, f)

  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
  warn(f"specified idx {idx} is end word of clause")
