In [3]:
# Import libraries
import logging
import os
import sys

# Setup 
os.system("pip install -q transformers librosa datasets==2.14.6 evaluate jiwer gradio bitsandbytes==0.37 accelerate geomloss gradio torchaudio")
os.system("pip install -q git+https://github.com/huggingface/peft.git@main")

import json
import random
import numpy as np
from dataclasses import dataclass, field
from typing import Optional
from huggingface_hub import notebook_login

from datasets import load_dataset, DatasetDict
from transformers import (WhisperFeatureExtractor, 
                          WhisperTokenizer, 
                          WhisperProcessor,
                          WhisperModel,
                          WhisperForConditionalGeneration, 
                          Seq2SeqTrainingArguments, 
                          Seq2SeqTrainer, 
                          TrainerCallback, 
                          TrainingArguments, 
                          TrainerState, 
                          TrainerControl)
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from peft import (prepare_model_for_int8_training,
                  LoraConfig, 
                  PeftModel, 
                  LoraModel, 
                  LoraConfig, 
                  TaskType,
                  get_peft_model)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import check_min_version
import re

from trainer_utils import AlignmentSeq2SeqTrainer
from data_utils import (DataCollatorSpeechSeq2SeqWithPadding, 
                        load_sd_qa_dataset, 
                        filter_data)

import csv


os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # use first gpu on machine




In [None]:
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

# load whisper feature extractor, tokenizer, processor
model_path = "openai/whisper-base"
task = "transcribe"
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path)
tokenizer = WhisperTokenizer.from_pretrained(model_path, task=task)
processor = WhisperProcessor.from_pretrained(model_path, task=task)

    # load pre-trained model checkpoint
model = WhisperForConditionalGeneration.from_pretrained(model_path)
    # model.hf_device_map = {" ":0}  # not super sure what to map to here
model.config.forced_decoder_ids = None  # no tokens forced for decoder outputs
model.config.suppress_tokens = []
model = model.to(device)
    
    # load data
target_dialect = 'usa'
source_dialect = 'ind_n'
sd_qa = filter_data(load_sd_qa_dataset(), source=source_dialect, target=target_dialect)
    
print(sd_qa['dev'][0])

    # prepare data
def prepare_source_data(data):
        # compute log-Mel input features from audio arrays
        data["source_input_features"] = feature_extractor(data[source_dialect]["array"], sampling_rate=data[source_dialect]["sampling_rate"]).input_features[0]
        data["target_input_features"] = feature_extractor(data[target_dialect]["array"], sampling_rate=data[target_dialect]["sampling_rate"]).input_features[0]
        return data

    # move to gpu
    # run everything at once -> no for loop
def prepare_target_embeddings(data):
        # compute log-Mel input features from target audio array
        # batch_size = 128
        # target_embeddings = []
        decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
        # for i in range(0, len(data["target_input_features"]), batch_size):
        input_features = torch.tensor(data["target_input_features"])
        with torch.no_grad():
            outputs = model(input_features, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
        last_hidden_state = outputs.encoder_hidden_states[-1]
        target_embeddings = [embedding for embedding in last_hidden_state]
        data["target_embeddings"] = target_embeddings
        return data
sample = sd_qa['dev'].select([10,11,12,13,14,15])
sample = sample.map(prepare_source_data, num_proc=2, desc="Extract features for source dialect"
                      ).map(prepare_target_embeddings, desc="Original hidden embeddings for target dialect")
