In [None]:
from whisper.tokenizer import get_tokenizer
from whisper import audio
import olmoasr as oa

In [None]:
# no timestamps
# trim audio
from typing import Tuple, Optional
import numpy as np
import torch
def preprocess_audio(audio_file: str, norm_end: Optional[str]) -> Tuple[str, torch.Tensor]:
    """Preprocesses the audio data for the model.

    Loads the audio file, pads or trims the audio data, and computes the log mel spectrogram.

    Args:
        audio_file: The path to the audio file

    Returns:
        A tuple containing the name of audio file and the log mel spectrogram
    """
    audio_arr = np.load(audio_file).astype(np.float32) / 32768.0
    if norm_end:
        # number of samples to trim until
        length = oa.utils.convert_to_milliseconds(norm_end) * 16
        # trim until end of text segment
        audio_arr = audio.pad_or_trim(audio_arr, length=length)
        # pad w/ silence
        audio_arr = audio.pad_or_trim(audio_arr)
    else:
        # in case audio_arr isn't exactly 480K samples
        audio_arr = audio.pad_or_trim(audio_arr)
    mel_spec = audio.log_mel_spectrogram(audio_arr)

    return mel_spec, audio_arr

In [None]:
sample = """1
00:00:00,000 --> 00:00:03,090
Men talk more in
male/female interaction

2
00:00:03,090 --> 00:00:06,110
and they interrupt all the time.
It's partly linked to poaer.

3
00:00:06,110 --> 00:00:09,510
In general people who think they're
more poaerful interrupt more.

4
00:00:10,950 --> 00:00:14,110
So that would be something that
a sociolinguist would discover,

5
00:00:14,110 --> 00:00:16,230
different social groups
speak somewhat differently

6
00:00:16,870 --> 00:00:19,810
and there are people who look
at you knoa computational linguistics

7
00:00:19,810 --> 00:00:22,980
and the way you can use language
in artificial intelligence,

8
00:00:23,680 --> 00:00:26,830
teach machines to better
be able to translate.

9
00:00:28,420 --> 00:00:29,200
Many other areas."""

In [None]:
import whisper
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, _, transcript_end = reader.read()
    
    transcript_text = reader.extract_text(transcript=transcript)
    text_tokens = tokenizer.encode(transcript_text)
    
    return text_tokens

In [None]:
tokenizer = get_tokenizer(multilingual=False)
no_timestamps_text_tokens = preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer)

In [None]:
no_timestamps_text_tokens

[10418,
 1561,
 517,
 287,
 198,
 22606,
 14,
 24724,
 10375,
 290,
 484,
 11313,
 477,
 262,
 640,
 13,
 198,
 1026,
 338,
 11476,
 6692,
 284,
 1176,
 13,
 554,
 2276,
 661,
 508,
 892,
 484,
 821,
 198,
 3549,
 3665,
 11313,
 517,
 13,
 1406,
 326,
 561,
 307,
 1223,
 326,
 198,
 64,
 1307,
 1669,
 6680,
 396,
 561,
 7073,
 11,
 1180,
 1919,
 2628,
 198,
 47350,
 6454,
 10338,
 290,
 612,
 389,
 661,
 508,
 804,
 198,
 265,
 345,
 760,
 31350,
 20280,
 3969,
 290,
 262,
 835,
 345,
 460,
 779,
 3303,
 198,
 259,
 11666,
 4430,
 11,
 4545,
 8217,
 284,
 1365,
 198,
 1350,
 1498,
 284,
 15772,
 13,
 4650,
 584,
 3006,
 13]

In [None]:
import whisper
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, _, transcript_end = reader.read()
    
    print(transcript)

In [None]:
preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer)

{('00:00:00.000', '00:00:03.090'): 'Men talk more in\nmale/female interaction', ('00:00:03.090', '00:00:06.110'): "and they interrupt all the time.\nIt's partly linked to power.", ('00:00:06.110', '00:00:09.510'): "In general people who think they're\nmore powerful interrupt more.", ('00:00:10.950', '00:00:14.110'): 'So that would be something that\na sociolinguist would discover,', ('00:00:14.110', '00:00:16.230'): 'different social groups\nspeak somewhat differently', ('00:00:16.870', '00:00:19.810'): 'and there are people who look\nat you know computational linguistics', ('00:00:19.810', '00:00:22.980'): 'and the way you can use language\nin artificial intelligence,', ('00:00:23.680', '00:00:26.830'): 'teach machines to better\nbe able to translate.', ('00:00:28.420', '00:00:29.200'): 'Many other areas.'}


In [None]:
import whisper
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, _, transcript_end = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.extend(tokenizer.encode(text.strip() + " "))
        else:
            tokens.extend(tokenizer.encode(text.strip()))
    return tokens

In [None]:
tokens = preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer)

In [None]:
tokens

[10418,
 1561,
 517,
 287,
 198,
 22606,
 14,
 24724,
 10375,
 220,
 392,
 484,
 11313,
 477,
 262,
 640,
 13,
 198,
 1026,
 338,
 11476,
 6692,
 284,
 1176,
 13,
 220,
 818,
 2276,
 661,
 508,
 892,
 484,
 821,
 198,
 3549,
 3665,
 11313,
 517,
 13,
 220,
 2396,
 326,
 561,
 307,
 1223,
 326,
 198,
 64,
 1307,
 1669,
 6680,
 396,
 561,
 7073,
 11,
 220,
 39799,
 1919,
 2628,
 198,
 47350,
 6454,
 10338,
 220,
 392,
 612,
 389,
 661,
 508,
 804,
 198,
 265,
 345,
 760,
 31350,
 20280,
 3969,
 220,
 392,
 262,
 835,
 345,
 460,
 779,
 3303,
 198,
 259,
 11666,
 4430,
 11,
 220,
 660,
 620,
 8217,
 284,
 1365,
 198,
 1350,
 1498,
 284,
 15772,
 13,
 220,
 7085,
 584,
 3006,
 13]

In [None]:
no_timestamps_text_tokens == tokens

False

In [None]:
len(tokens), len(no_timestamps_text_tokens)

(107, 98)

In [None]:
tokenizer.decode(no_timestamps_text_tokens)

"Men talk more in\nmale/female interaction and they interrupt all the time.\nIt's partly linked to power. In general people who think they're\nmore powerful interrupt more. So that would be something that\na sociolinguist would discover, different social groups\nspeak somewhat differently and there are people who look\nat you know computational linguistics and the way you can use language\nin artificial intelligence, teach machines to better\nbe able to translate. Many other areas."

In [None]:
tokenizer.decode(no_timestamps_text_tokens) == tokenizer.decode(tokens)

True

In [None]:
tokenizer.decode(tokens)

"Men talk more in\nmale/female interaction and they interrupt all the time.\nIt's partly linked to power. In general people who think they're\nmore powerful interrupt more. So that would be something that\na sociolinguist would discover, different social groups\nspeak somewhat differently and there are people who look\nat you know computational linguistics and the way you can use language\nin artificial intelligence, teach machines to better\nbe able to translate. Many other areas."

In [None]:
import whisper
from itertools import chain
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, _, transcript_end = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.append(tokenizer.encode(text.strip() + " "))
        else:
            tokens.append(tokenizer.encode(text.strip()))
            
    num_timestamp_tokens = (len(transcript) * 2) + 1
    print(f"{num_timestamp_tokens=}")
    num_text_tokens = sum([len(token_group) for token_group in tokens])
    print(f"{num_text_tokens=}")
    num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
    print(f"{num_total_tokens=}")
    return tokens

In [None]:
tokens = preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer)

num_timestamp_tokens=19
num_text_tokens=107
num_total_tokens=128


In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.append(tokenizer.encode(text.strip() + " "))
        else:
            tokens.append(tokenizer.encode(text.strip()))
            
    num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
    print(f"{num_timestamp_tokens=}")
    num_text_tokens = sum([len(token_group) for token_group in tokens])
    print(f"{num_text_tokens=}")
    num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
    print(f"{num_total_tokens=}")
    
    if np.random.rand() > 0.5:
        if num_total_tokens <= 448:
            new_tokens = []
            for i, timestamps in enumerate(transcript.keys()):
                start, end = timestamps
                print(f"{start=}")
                print(f"{end=}")
                start_ms = oa.utils.convert_to_milliseconds(start)
                print(f"{start_ms=}")
                print(f"{(start_ms // 20)=}")
                end_ms = oa.utils.convert_to_milliseconds(end)
                print(f"{end_ms}")
                print(f"{(end_ms // 20)=}")
                start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                print(f"{start_token_idx=}")
                end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                print(f"{end_token_idx=}")

                if i == 0:
                    line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                elif i < len(transcript) - 1 and i > 0:
                    line_tokens = start_token_idx + tokens[i] + end_token_idx
                else:
                    line_tokens = start_token_idx + tokens[i] + end_token_idx + [tokenizer.eot]

                new_tokens.extend(line_tokens)
            tokens = new_tokens
        else:
            tokens = list(chain(*tokens))
    else:
        tokens = list(chain(*tokens))
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
tokens = preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer)

num_timestamp_tokens=19
num_text_tokens=107
num_total_tokens=128
start='00:00:00.000'
end='00:00:03.090'
start_ms=0
(start_ms // 20)=0
3090
(end_ms // 20)=154
start_token_idx=[50363]
end_token_idx=[50517]
start='00:00:03.090'
end='00:00:06.110'
start_ms=3090
(start_ms // 20)=154
6110
(end_ms // 20)=305
start_token_idx=[50517]
end_token_idx=[50668]
start='00:00:06.110'
end='00:00:09.510'
start_ms=6110
(start_ms // 20)=305
9510
(end_ms // 20)=475
start_token_idx=[50668]
end_token_idx=[50838]
start='00:00:10.950'
end='00:00:14.110'
start_ms=10950
(start_ms // 20)=547
14110
(end_ms // 20)=705
start_token_idx=[50910]
end_token_idx=[51068]
start='00:00:14.110'
end='00:00:16.230'
start_ms=14110
(start_ms // 20)=705
16230
(end_ms // 20)=811
start_token_idx=[51068]
end_token_idx=[51174]
start='00:00:16.870'
end='00:00:19.810'
start_ms=16870
(start_ms // 20)=843
19810
(end_ms // 20)=990
start_token_idx=[51206]
end_token_idx=[51353]
start='00:00:19.810'
end='00:00:22.980'
start_ms=19810
(start_ms

In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
    text_timestamp: str,
    next_start: str,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.append(tokenizer.encode(text.strip() + " "))
        else:
            tokens.append(tokenizer.encode(text.strip()))
            
    num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
    print(f"{num_timestamp_tokens=}")
    num_text_tokens = sum([len(token_group) for token_group in tokens])
    print(f"{num_text_tokens=}")
    num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
    print(f"{num_total_tokens=}")
    
    if np.random.rand() > 0:
        if num_total_tokens <= 448:
            new_tokens = []
            for i, timestamps in enumerate(transcript.keys()):
                start, end = timestamps
                print(f"{start=}")
                print(f"{end=}")
                start_ms = oa.utils.convert_to_milliseconds(start)
                print(f"{start_ms=}")
                print(f"{(start_ms // 20)=}")
                end_ms = oa.utils.convert_to_milliseconds(end)
                print(f"{end_ms}")
                print(f"{(end_ms // 20)=}")
                start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                print(f"{start_token_idx=}")
                end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                print(f"{end_token_idx=}")

                if i == 0:
                    line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                elif i < len(transcript) - 1 and i > 0:
                    line_tokens = start_token_idx + tokens[i] + end_token_idx

                new_tokens.extend(line_tokens)
            unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
            print(f"{unnorm_start=}")
            norm_next_start = oa.utils.calculate_difference(unnorm_start, next_start)
            print(f"{norm_next_start=}")
            next_start_ms = oa.utils.convert_to_milliseconds(norm_next_start)
            if next_start_ms > 30000:
                next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
            else:
                next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
            new_tokens.append(next_start_token_idx + [tokenizer.eot])
            tokens = new_tokens
        else:
            tokens = list(chain(*tokens))
    else:
        tokens = list(chain(*tokens))
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
    text_timestamp: str,
    next_start: str,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.append(tokenizer.encode(text.strip() + " "))
        else:
            tokens.append(tokenizer.encode(text.strip()))
            
    num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
    print(f"{num_timestamp_tokens=}")
    num_text_tokens = sum([len(token_group) for token_group in tokens])
    print(f"{num_text_tokens=}")
    num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
    print(f"{num_total_tokens=}")
    
    if np.random.rand() > 0:
        if num_total_tokens <= 448:
            new_tokens = []
            for i, timestamps in enumerate(transcript.keys()):
                start, end = timestamps
                print(f"{start=}")
                print(f"{end=}")
                start_ms = oa.utils.convert_to_milliseconds(start)
                print(f"{start_ms=}")
                print(f"{(start_ms // 20)=}")
                end_ms = oa.utils.convert_to_milliseconds(end)
                print(f"{end_ms}")
                print(f"{(end_ms // 20)=}")
                start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                print(f"{start_token_idx=}")
                end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                print(f"{end_token_idx=}")

                if i == 0:
                    line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                elif i < len(transcript) - 1 and i > 0:
                    line_tokens = start_token_idx + tokens[i] + end_token_idx

                new_tokens.extend(line_tokens)
            unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
            print(f"{unnorm_start=}")
            next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
            print(f"{next_start_ms=}")
            if next_start_ms > 30000:
                next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
            else:
                next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
            new_tokens.append(next_start_token_idx + [tokenizer.eot])
            tokens = new_tokens
        else:
            tokens = list(chain(*tokens))
    else:
        tokens = list(chain(*tokens))
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
    text_timestamp: str,
    next_start: str,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.append(tokenizer.encode(text.strip() + " "))
        else:
            tokens.append(tokenizer.encode(text.strip()))
            
    num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
    print(f"{num_timestamp_tokens=}")
    num_text_tokens = sum([len(token_group) for token_group in tokens])
    print(f"{num_text_tokens=}")
    num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
    print(f"{num_total_tokens=}\n")
    
    if np.random.rand() > 0:
        if num_total_tokens <= 448:
            new_tokens = []
            for i, timestamps in enumerate(transcript.keys()):
                start, end = timestamps
                print(f"{start=}")
                print(f"{end=}")
                start_ms = oa.utils.convert_to_milliseconds(start)
                print(f"{start_ms=}")
                print(f"{(start_ms // 20)=}")
                end_ms = oa.utils.convert_to_milliseconds(end)
                print(f"{end_ms=}")
                print(f"{(end_ms // 20)=}")
                start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                print(f"{start_token_idx=}")
                end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                print(f"{end_token_idx=}")

                if i == 0:
                    line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                elif i < len(transcript) - 1 and i > 0:
                    line_tokens = start_token_idx + tokens[i] + end_token_idx
                    
                print(f"{len(line_tokens)=}")
                print(f"{line_tokens=}")
                new_tokens.extend(line_tokens)
                print(f"{len(new_tokens)=}\n")
            unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
            print(f"{unnorm_start=}")
            next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
            print(f"{next_start_ms=}")
            if next_start_ms > 30000:
                next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
            else:
                next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
            new_tokens.extend(next_start_token_idx + [tokenizer.eot])
            tokens = new_tokens
        else:
            tokens = list(chain(*tokens))
    else:
        tokens = list(chain(*tokens))
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
tokens = preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer, text_timestamp="00:00:29,100_00:00:58,300", next_start="00:00:58.300")

num_timestamp_tokens=19
num_text_tokens=107
num_total_tokens=128

start='00:00:00.000'
end='00:00:03.090'
start_ms=0
(start_ms // 20)=0
end_ms=3090
(end_ms // 20)=154
start_token_idx=[50363]
end_token_idx=[50517]
len(line_tokens)=13
line_tokens=[50257, 50363, 10418, 1561, 517, 287, 198, 22606, 14, 24724, 10375, 220, 50517]
len(new_tokens)=13

start='00:00:03.090'
end='00:00:06.110'
start_ms=3090
(start_ms // 20)=154
end_ms=6110
(end_ms // 20)=305
start_token_idx=[50517]
end_token_idx=[50668]
len(line_tokens)=18
line_tokens=[50517, 392, 484, 11313, 477, 262, 640, 13, 198, 1026, 338, 11476, 6692, 284, 1176, 13, 220, 50668]
len(new_tokens)=31

start='00:00:06.110'
end='00:00:09.510'
start_ms=6110
(start_ms // 20)=305
end_ms=9510
(end_ms // 20)=475
start_token_idx=[50668]
end_token_idx=[50838]
len(line_tokens)=16
line_tokens=[50668, 818, 2276, 661, 508, 892, 484, 821, 198, 3549, 3665, 11313, 517, 13, 220, 50838]
len(new_tokens)=47

start='00:00:10.950'
end='00:00:14.110'
start_ms=10950
(st

In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
    text_timestamp: str,
    next_start: str,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.append(tokenizer.encode(text.strip() + " "))
        else:
            tokens.append(tokenizer.encode(text.strip()))
            
    num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
    print(f"{num_timestamp_tokens=}")
    num_text_tokens = sum([len(token_group) for token_group in tokens])
    print(f"{num_text_tokens=}")
    num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
    print(f"{num_total_tokens=}\n")
    print(f"{([len(token_group) for token_group in tokens])=}")
    
    if np.random.rand() > 0:
        if num_total_tokens <= 448:
            new_tokens = []
            for i, timestamps in enumerate(transcript.keys()):
                start, end = timestamps
                print(f"{start=}")
                print(f"{end=}")
                start_ms = oa.utils.convert_to_milliseconds(start)
                print(f"{start_ms=}")
                print(f"{(start_ms // 20)=}")
                end_ms = oa.utils.convert_to_milliseconds(end)
                print(f"{end_ms=}")
                print(f"{(end_ms // 20)=}")
                start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                print(f"{start_token_idx=}")
                end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                print(f"{end_token_idx=}")

                if i == 0:
                    line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                else:
                    line_tokens = start_token_idx + tokens[i] + end_token_idx
                    
                print(f"{len(line_tokens)=}")
                print(f"{line_tokens=}")
                new_tokens.extend(line_tokens)
                print(f"{len(new_tokens)=}\n")
            unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
            print(f"{unnorm_start=}")
            next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
            print(f"{next_start_ms=}")
            if next_start_ms > 30000:
                next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
            else:
                next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
            new_tokens.extend(next_start_token_idx + [tokenizer.eot])
            tokens = new_tokens
        else:
            tokens = list(chain(*tokens))
    else:
        tokens = list(chain(*tokens))
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
tokens = preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer, text_timestamp="00:00:29,100_00:00:58,300", next_start="00:00:58.300")

num_timestamp_tokens=19
num_text_tokens=107
num_total_tokens=128

([len(token_group) for token_group in tokens])=[10, 16, 14, 16, 8, 14, 13, 12, 4]
start='00:00:00.000'
end='00:00:03.090'
start_ms=0
(start_ms // 20)=0
end_ms=3090
(end_ms // 20)=154
start_token_idx=[50363]
end_token_idx=[50517]
len(line_tokens)=13
line_tokens=[50257, 50363, 10418, 1561, 517, 287, 198, 22606, 14, 24724, 10375, 220, 50517]
len(new_tokens)=13

start='00:00:03.090'
end='00:00:06.110'
start_ms=3090
(start_ms // 20)=154
end_ms=6110
(end_ms // 20)=305
start_token_idx=[50517]
end_token_idx=[50668]
len(line_tokens)=18
line_tokens=[50517, 392, 484, 11313, 477, 262, 640, 13, 198, 1026, 338, 11476, 6692, 284, 1176, 13, 220, 50668]
len(new_tokens)=31

start='00:00:06.110'
end='00:00:09.510'
start_ms=6110
(start_ms // 20)=305
end_ms=9510
(end_ms // 20)=475
start_token_idx=[50668]
end_token_idx=[50838]
len(line_tokens)=16
line_tokens=[50668, 818, 2276, 661, 508, 892, 484, 821, 198, 3549, 3665, 11313, 517, 13, 220, 508

In [None]:
print(tokenizer.decode_with_timestamps(tokens))

<|startoftranscript|><|0.00|>Men talk more in
male/female interaction <|3.08|><|3.08|>and they interrupt all the time.
It's partly linked to power. <|6.10|><|6.10|>In general people who think they're
more powerful interrupt more. <|9.50|><|10.94|>So that would be something that
a sociolinguist would discover, <|14.10|><|14.10|>different social groups
speak somewhat differently <|16.22|><|16.86|>and there are people who look
at you know computational linguistics <|19.80|><|19.80|>and the way you can use language
in artificial intelligence, <|22.98|><|23.68|>teach machines to better
be able to translate. <|26.82|><|28.42|>Many other areas.<|29.20|><|29.20|><|endoftext|>


In [None]:
print(tokenizer.decode(tokens))

<|startoftranscript|>Men talk more in
male/female interaction and they interrupt all the time.
It's partly linked to power. In general people who think they're
more powerful interrupt more. So that would be something that
a sociolinguist would discover, different social groups
speak somewhat differently and there are people who look
at you know computational linguistics and the way you can use language
in artificial intelligence, teach machines to better
be able to translate. Many other areas.<|endoftext|>


In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
    text_timestamp: str,
    next_start: str,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.append(tokenizer.encode(text.strip() + " "))
        else:
            tokens.append(tokenizer.encode(text.strip()))
            
    num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
    print(f"{num_timestamp_tokens=}")
    num_text_tokens = sum([len(token_group) for token_group in tokens])
    print(f"{num_text_tokens=}")
    num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
    print(f"{num_total_tokens=}\n")
    print(f"{([len(token_group) for token_group in tokens])=}")
    
    if np.random.rand() > 2:
        if num_total_tokens <= 448:
            new_tokens = []
            for i, timestamps in enumerate(transcript.keys()):
                start, end = timestamps
                print(f"{start=}")
                print(f"{end=}")
                start_ms = oa.utils.convert_to_milliseconds(start)
                print(f"{start_ms=}")
                print(f"{(start_ms // 20)=}")
                end_ms = oa.utils.convert_to_milliseconds(end)
                print(f"{end_ms=}")
                print(f"{(end_ms // 20)=}")
                start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                print(f"{start_token_idx=}")
                end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                print(f"{end_token_idx=}")

                if i == 0:
                    line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                else:
                    line_tokens = start_token_idx + tokens[i] + end_token_idx
                    
                print(f"{len(line_tokens)=}")
                print(f"{line_tokens=}")
                new_tokens.extend(line_tokens)
                print(f"{len(new_tokens)=}\n")
            unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
            print(f"{unnorm_start=}")
            next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
            print(f"{next_start_ms=}")
            if next_start_ms > 30000:
                next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
            else:
                next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
            new_tokens.extend(next_start_token_idx + [tokenizer.eot])
            tokens = new_tokens
        else:
            tokens = list(chain(*tokens))
    else:
        tokens = list(chain(*tokens))
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
tokens = preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer, text_timestamp="00:00:29,100_00:00:58,300", next_start="00:00:58.300")

num_timestamp_tokens=19
num_text_tokens=107
num_total_tokens=128

([len(token_group) for token_group in tokens])=[10, 16, 14, 16, 8, 14, 13, 12, 4]
tokens=[10418, 1561, 517, 287, 198, 22606, 14, 24724, 10375, 220, 392, 484, 11313, 477, 262, 640, 13, 198, 1026, 338, 11476, 6692, 284, 1176, 13, 220, 818, 2276, 661, 508, 892, 484, 821, 198, 3549, 3665, 11313, 517, 13, 220, 2396, 326, 561, 307, 1223, 326, 198, 64, 1307, 1669, 6680, 396, 561, 7073, 11, 220, 39799, 1919, 2628, 198, 47350, 6454, 10338, 220, 392, 612, 389, 661, 508, 804, 198, 265, 345, 760, 31350, 20280, 3969, 220, 392, 262, 835, 345, 460, 779, 3303, 198, 259, 11666, 4430, 11, 220, 660, 620, 8217, 284, 1365, 198, 1350, 1498, 284, 15772, 13, 220, 7085, 584, 3006, 13]
len(tokens)=107


In [None]:
print(tokenizer.decode(tokens))

Men talk more in
male/female interaction and they interrupt all the time.
It's partly linked to power. In general people who think they're
more powerful interrupt more. So that would be something that
a sociolinguist would discover, different social groups
speak somewhat differently and there are people who look
at you know computational linguistics and the way you can use language
in artificial intelligence, teach machines to better
be able to translate. Many other areas.


In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
    text_timestamp: str,
    next_start: str,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = []
    for i, (timestamps, text) in enumerate(transcript.items()):
        if i < len(transcript) - 1:
            tokens.append(tokenizer.encode(text.strip() + " "))
        else:
            tokens.append(tokenizer.encode(text.strip()))
            
    num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
    print(f"{num_timestamp_tokens=}")
    num_text_tokens = sum([len(token_group) for token_group in tokens])
    print(f"{num_text_tokens=}")
    num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
    print(f"{num_total_tokens=}\n")
    print(f"{([len(token_group) for token_group in tokens])=}")
    
    if np.random.rand() > 2:
        if num_total_tokens <= 448:
            new_tokens = []
            for i, timestamps in enumerate(transcript.keys()):
                start, end = timestamps
                print(f"{start=}")
                print(f"{end=}")
                start_ms = oa.utils.convert_to_milliseconds(start)
                print(f"{start_ms=}")
                print(f"{(start_ms // 20)=}")
                end_ms = oa.utils.convert_to_milliseconds(end)
                print(f"{end_ms=}")
                print(f"{(end_ms // 20)=}")
                start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                print(f"{start_token_idx=}")
                end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                print(f"{end_token_idx=}")

                if i == 0:
                    line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                else:
                    line_tokens = start_token_idx + tokens[i] + end_token_idx
                    
                print(f"{len(line_tokens)=}")
                print(f"{line_tokens=}")
                new_tokens.extend(line_tokens)
                print(f"{len(new_tokens)=}\n")
            unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
            print(f"{unnorm_start=}")
            next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
            print(f"{next_start_ms=}")
            if next_start_ms > 30000:
                next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
            else:
                next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
            new_tokens.extend(next_start_token_idx + [tokenizer.eot])
            tokens = new_tokens
        else:
            tokens = list(tokenizer.sot_sequence_including_notimestamps) + list(chain(*tokens)) + [tokenizer.eot]
    else:
        tokens = list(tokenizer.sot_sequence_including_notimestamps) + list(chain(*tokens)) + [tokenizer.eot]
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
tokens = preprocess_text(transcript_string=sample, transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt", tokenizer=tokenizer, text_timestamp="00:00:29,100_00:00:58,300", next_start="00:00:58.300")

num_timestamp_tokens=19
num_text_tokens=107
num_total_tokens=128

([len(token_group) for token_group in tokens])=[10, 16, 14, 16, 8, 14, 13, 12, 4]
tokens=[50257, 50362, 10418, 1561, 517, 287, 198, 22606, 14, 24724, 10375, 220, 392, 484, 11313, 477, 262, 640, 13, 198, 1026, 338, 11476, 6692, 284, 1176, 13, 220, 818, 2276, 661, 508, 892, 484, 821, 198, 3549, 3665, 11313, 517, 13, 220, 2396, 326, 561, 307, 1223, 326, 198, 64, 1307, 1669, 6680, 396, 561, 7073, 11, 220, 39799, 1919, 2628, 198, 47350, 6454, 10338, 220, 392, 612, 389, 661, 508, 804, 198, 265, 345, 760, 31350, 20280, 3969, 220, 392, 262, 835, 345, 460, 779, 3303, 198, 259, 11666, 4430, 11, 220, 660, 620, 8217, 284, 1365, 198, 1350, 1498, 284, 15772, 13, 220, 7085, 584, 3006, 13, 50256]
len(tokens)=110


In [None]:
print(tokenizer.decode(tokens))

<|startoftranscript|><|notimestamps|>Men talk more in
male/female interaction and they interrupt all the time.
It's partly linked to power. In general people who think they're
more powerful interrupt more. So that would be something that
a sociolinguist would discover, different social groups
speak somewhat differently and there are people who look
at you know computational linguistics and the way you can use language
in artificial intelligence, teach machines to better
be able to translate. Many other areas.<|endoftext|>


In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
    text_timestamp: str,
    next_start: str,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = [] 
    if not transcript:
        unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
        print(f"{unnorm_start=}")
        next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
        print(f"{next_start_ms=}")
        if next_start_ms > 30000:
            next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
        else:
            next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
            
        if np.random.rand() > 0:
            tokens = [tokenizer.sot_sequence[0]] + [tokenizer.timestamp_begin] + [tokenizer.no_speech] + next_start_token_idx + [tokenizer.eot]
        else:
            tokens = list(tokenizer.sot_sequence_including_notimestamps) + [tokenizer.no_speech] + [tokenizer.eot]
    else:        
        for i, (timestamps, text) in enumerate(transcript.items()):
            if i < len(transcript) - 1:
                tokens.append(tokenizer.encode(text.strip() + " "))
            else:
                tokens.append(tokenizer.encode(text.strip()))
                
        num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
        print(f"{num_timestamp_tokens=}")
        num_text_tokens = sum([len(token_group) for token_group in tokens])
        print(f"{num_text_tokens=}")
        num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
        print(f"{num_total_tokens=}\n")
        print(f"{([len(token_group) for token_group in tokens])=}")
        
        if np.random.rand() > 0:
            if num_total_tokens <= 448:
                new_tokens = []
                for i, timestamps in enumerate(transcript.keys()):
                    start, end = timestamps
                    print(f"{start=}")
                    print(f"{end=}")
                    start_ms = oa.utils.convert_to_milliseconds(start)
                    print(f"{start_ms=}")
                    print(f"{(start_ms // 20)=}")
                    end_ms = oa.utils.convert_to_milliseconds(end)
                    print(f"{end_ms=}")
                    print(f"{(end_ms // 20)=}")
                    start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                    print(f"{start_token_idx=}")
                    end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                    print(f"{end_token_idx=}")

                    if i == 0:
                        line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                    else:
                        line_tokens = start_token_idx + tokens[i] + end_token_idx
                        
                    print(f"{len(line_tokens)=}")
                    print(f"{line_tokens=}")
                    new_tokens.extend(line_tokens)
                    print(f"{len(new_tokens)=}\n")
                unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
                print(f"{unnorm_start=}")
                next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
                print(f"{next_start_ms=}")
                if next_start_ms > 30000:
                    next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
                else:
                    next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
                new_tokens.extend(next_start_token_idx + [tokenizer.eot])
                tokens = new_tokens
            else:
                tokens = list(tokenizer.sot_sequence_including_notimestamps) + list(chain(*tokens)) + [tokenizer.eot]
        else:
            tokens = list(tokenizer.sot_sequence_including_notimestamps) + list(chain(*tokens)) + [tokenizer.eot]
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
tokens = preprocess_text(transcript_string="", transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:28,050_00:00:29,100.srt", tokenizer=tokenizer, text_timestamp="00:00:28,050_00:00:29,100", next_start="00:00:29.100")

unnorm_start='00:00:28.050'
next_start_ms=1050
tokens=[50257, 50363, 50361, 50415, 50256]
len(tokens)=5


In [None]:
print(tokenizer.decode_with_timestamps(tokens))

<|startoftranscript|><|0.00|><|nospeech|><|1.04|><|endoftext|>


In [None]:
import whisper
from itertools import chain
n_text_ctx = 448
def preprocess_text(
    transcript_string: str,
    transcript_file: str,
    tokenizer: whisper.tokenizer.Tokenizer,
    text_timestamp: str,
    next_start: str,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Preprocesses the text data for the model.

    Reads in the transcript file and extracts the text data. Tokenizes the text data and pads it to the context length.

    Args:
        transcript_file: The path to the transcript file
        tokenizer: The tokenizer to use for encoding the text data

    Returns:
        A tuple containing the transcript file, the input text tensor, the target text tensor, and the padding mask
    """
    # transcript -> text
    reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
    transcript, *_ = reader.read()
    tokens = [] 
    if not transcript:
        unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
        print(f"{unnorm_start=}")
        next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
        print(f"{next_start_ms=}")
        if next_start_ms > 30000:
            next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
        else:
            next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
            
        if np.random.rand() > 1:
            tokens = [tokenizer.sot_sequence[0]] + [tokenizer.timestamp_begin] + [tokenizer.no_speech] + next_start_token_idx + [tokenizer.eot]
        else:
            tokens = list(tokenizer.sot_sequence_including_notimestamps) + [tokenizer.no_speech] + [tokenizer.eot]
    else:        
        for i, (timestamps, text) in enumerate(transcript.items()):
            if i < len(transcript) - 1:
                tokens.append(tokenizer.encode(text.strip() + " "))
            else:
                tokens.append(tokenizer.encode(text.strip()))
                
        num_timestamp_tokens = (len(transcript) * 2) + 1 # next_start timestamp
        print(f"{num_timestamp_tokens=}")
        num_text_tokens = sum([len(token_group) for token_group in tokens])
        print(f"{num_text_tokens=}")
        num_total_tokens = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
        print(f"{num_total_tokens=}\n")
        print(f"{([len(token_group) for token_group in tokens])=}")
        
        if np.random.rand() > 0:
            if num_total_tokens <= 448:
                new_tokens = []
                for i, timestamps in enumerate(transcript.keys()):
                    start, end = timestamps
                    print(f"{start=}")
                    print(f"{end=}")
                    start_ms = oa.utils.convert_to_milliseconds(start)
                    print(f"{start_ms=}")
                    print(f"{(start_ms // 20)=}")
                    end_ms = oa.utils.convert_to_milliseconds(end)
                    print(f"{end_ms=}")
                    print(f"{(end_ms // 20)=}")
                    start_token_idx = [tokenizer.timestamp_begin + (start_ms // 20)]
                    print(f"{start_token_idx=}")
                    end_token_idx = [tokenizer.timestamp_begin + (end_ms // 20)]
                    print(f"{end_token_idx=}")

                    if i == 0:
                        line_tokens = [tokenizer.sot_sequence[0]] + start_token_idx + tokens[i] + end_token_idx
                    else:
                        line_tokens = start_token_idx + tokens[i] + end_token_idx
                        
                    print(f"{len(line_tokens)=}")
                    print(f"{line_tokens=}")
                    new_tokens.extend(line_tokens)
                    print(f"{len(new_tokens)=}\n")
                unnorm_start = text_timestamp.split("_")[0].replace(",", ".")
                print(f"{unnorm_start=}")
                next_start_ms = oa.utils.calculate_difference(unnorm_start, next_start)
                print(f"{next_start_ms=}")
                if next_start_ms > 30000:
                    next_start_token_idx = [tokenizer.timestamp_begin + (30000 // 20)]
                else:
                    next_start_token_idx = [tokenizer.timestamp_begin + (next_start_ms // 20)]
                new_tokens.extend(next_start_token_idx + [tokenizer.eot])
                tokens = new_tokens
            else:
                tokens = list(tokenizer.sot_sequence_including_notimestamps) + list(chain(*tokens)) + [tokenizer.eot]
        else:
            tokens = list(tokenizer.sot_sequence_including_notimestamps) + list(chain(*tokens)) + [tokenizer.eot]
    print(f"{tokens=}")
    print(f"{len(tokens)=}")
        
    return tokens

In [None]:
tokens = preprocess_text(transcript_string="", transcript_file="/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:28,050_00:00:29,100.srt", tokenizer=tokenizer, text_timestamp="00:00:28,050_00:00:29,100", next_start="00:00:29.100")

unnorm_start='00:00:28.050'
next_start_ms=1050
tokens=[50257, 50362, 50361, 50256]
len(tokens)=4


In [None]:
print(tokenizer.decode(tokens))

<|startoftranscript|><|notimestamps|><|nospeech|><|endoftext|>


In [None]:
from typing import List, Dict, Optional
def over_ctx_len(
    timestamps: List, transcript: Optional[Dict], language: Optional[str]
) -> Tuple[bool, Optional[str]]:
    """Check if transcript text exceeds model context length

    Check if the total number of tokens in the transcript text exceeds the model context length

    Args:
        timestamps: List of timestamps
        transcript: Transcript as a dictionary

    Returns:
        True if the transcript text exceeds the model context length, False otherwise
    """
    try:
        if language is None:
            tokenizer = get_tokenizer(multilingual=False)
        else:
            tokenizer = get_tokenizer(language=language, multilingual=True)

        text_tokens = [
            (
                tokenizer.encode(transcript[timestamp].strip() + " ")
                if i < len(timestamps) - 1
                else tokenizer.encode(transcript[timestamp].strip())
            )
            for i, timestamp in enumerate(timestamps)
        ]
        
        # text_tokens = list(chain(*text_tokens))
        
        num_timestamp_tokens = (len(timestamps) * 2) + 1 # next_start timestamp
        print(f"{num_timestamp_tokens=}")
        num_text_tokens = sum([len(token_group) for token_group in text_tokens])
        print(f"{num_text_tokens=}")
        num_tokens_ts_mode = num_timestamp_tokens + num_text_tokens + 2 # sot + eot
        print(f"{num_tokens_ts_mode=}")
        num_tokens_no_ts_mode = num_text_tokens + 3 # sot + notimestamps + eot
        print(f"{num_tokens_no_ts_mode=}")
        
        if num_tokens_ts_mode > 448 and num_tokens_no_ts_mode > 448:
            return True, None
        elif num_tokens_ts_mode > 448 and num_tokens_no_ts_mode <= 448:
            return False, {"ts_mode": False, "no_ts_mode": True}
        elif num_tokens_ts_mode <= 448 and num_tokens_no_ts_mode > 448:
            return False, {"ts_mode": True, "no_ts_mode": False}
        else:
            return False, {"ts_mode": True, "no_ts_mode": True}
    except RuntimeError:
        return True, "error"
    except Exception as e:
        return True, "error"

In [None]:
transcript_string = sample
transcript_file = "/weka/huongn/oa_seg/00000000/7HOsQDD1Res/00:00:29,100_00:00:58,300.srt"
reader = oa.utils.TranscriptReader(
        transcript_string=transcript_string,
        file_path=None,
        ext=transcript_file.split(".")[-1],
    )
transcript, *_ = reader.read()
temp, res = over_ctx_len(timestamps=list(transcript.keys()), transcript=transcript, language=None)

num_timestamp_tokens=19
num_text_tokens=107
num_tokens_ts_mode=128
num_tokens_no_ts_mode=110
