In [1]:
import os
import sys
from pathlib import Path

src_path = Path(os.getcwd()).parent.parent.joinpath('src').absolute()
sys.path.append(str(src_path))

In [None]:
from huggingsound import TrainingArguments, ModelArguments, SpeechRecognitionModel, TokenSet
from typing import NamedTuple, TypedDict
import numpy as np
from persistence.model import FileMetadata
from pathlib import Path
from persistence.db import Database
from persistence.file_metadata_repository import FileMetadataRepository
from features.transcriber import Transcriber

In [3]:
OUTPUT_DIR = src_path.joinpath('finetuning').joinpath('models')
PREPROCESSED_DATA_DIR = src_path.joinpath('finetuning').joinpath('data')

In [None]:
ROOT_DIR = Path('/home/flok3n/konrad'); DB_DIR = ROOT_DIR
db = Database(DB_DIR)
await db.init_db()

repo = FileMetadataRepository(db)

files = await repo.load_all_files()

In [None]:
files_with_fixed_transcript = [f for f in files if f.is_transcript_analyzed and f.is_transcript_fixed]
len(files_with_fixed_transcript)

In [6]:
async def preprocess_data(files: list[FileMetadata]):
    transcriber = Transcriber(None)
    for f in files:
        src_path = ROOT_DIR.joinpath(f.name)
        target_path = PREPROCESSED_DATA_DIR.joinpath(f.name + '.wav')
        file_bytes = await transcriber._get_preprocessed_audio_file(src_path)
        with open(target_path, 'wb') as target:
            file_bytes.seek(0)
            target.write(file_bytes.getvalue())

In [7]:
await preprocess_data(files_with_fixed_transcript)

In [8]:
np.random.seed(1234)

class TrainItem(TypedDict):
    path: str
    transcription: str

class Dataset(NamedTuple):
    train: list[TrainItem]
    eval: list[TrainItem]


def get_train_eval_dataset(files: list[FileMetadata], eval_split=0.2) -> Dataset:
    items_with_not_empty_trancript = [f for f in files if f.transcript != '']
    items_with_empty_trancript = [f for f in files if f.transcript == '']

    assert len(items_with_empty_trancript) > 1 and len(items_with_not_empty_trancript) > 1

    train, eval = [], [] 
    for src in (items_with_not_empty_trancript, items_with_empty_trancript):
        items = [TrainItem(path=str(PREPROCESSED_DATA_DIR.joinpath(f.name + '.wav').absolute()), transcription=str(f.transcript)) for f in src]
        num_eval_items = max(1, int(eval_split * len(items)))
        eval_idxs = set(np.random.choice(range(len(items)), num_eval_items, replace=False))
        for i, item in enumerate(items):
            if i in eval_idxs:
                eval.append(item)
            else:
                train.append(item)

    np.random.shuffle(train)
    np.random.shuffle(eval)
    return Dataset(train=train, eval=eval)

In [9]:
dataset = get_train_eval_dataset(files_with_fixed_transcript)

In [None]:
raise

In [None]:
model = SpeechRecognitionModel(str(OUTPUT_DIR), device='cuda')

In [12]:
tb = Transcriber(None)
engine = tb.Engine(tb, lambda: model)

In [13]:
firany_files = []
for f in files:
    if 'firany' in f.name:
        firany_files.append(f)

In [None]:
transcription_results = []
for file in firany_files:
    transcription = await engine.transcribe(ROOT_DIR.joinpath(file.name))
    transcription_results.append(transcription)

In [None]:
for f, t in zip(firany_files, transcription_results):
    print(f'before: {f.transcript}')
    print(f'after: {t}')
    print()

In [None]:
raise

In [None]:
model = SpeechRecognitionModel("jonatasgrosman/wav2vec2-large-xlsr-53-polish", device='cuda')

In [18]:
args = TrainingArguments(
    overwrite_output_dir=False,
    per_device_train_batch_size=4,
    learning_rate=3e-4,
    num_train_epochs=8,
    fp16=False,
)

In [None]:
finetuned = model.finetune(
    str(OUTPUT_DIR),
    train_data=[*dataset.train, *dataset.eval],
    # eval_data=dataset.eval,
    eval_data=None,
    token_set=model.token_set,
    training_args=args
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir finetuning/models