In [1]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plot

from pathlib import Path
import os
import sys

import torch
from torch import nn
from torch import functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# handling paths
project_path = Path("C:/Users/aleks/git/frankenstein")
data_path = Path("D:\data_brain_to_text\competitionData")
utils_path = project_path / "utils"
sys.path.append(str(utils_path))


from data_utils import process_string, save_sentences_to_txt, load_sentences_from_txt

In [2]:
""" LOAD PRETRAINED MODEL COMPONENTS """

WHISPER_MODEL_NAME = "openai/whisper-small.en"

from transformers import WhisperTokenizer, WhisperFeatureExtractor
from transformers import GenerationConfig
from transformers import WhisperForConditionalGeneration
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# load feature/label processing engines
feature_extractor = WhisperFeatureExtractor.from_pretrained(WHISPER_MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(WHISPER_MODEL_NAME, task="transcribe")
# load model
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME)
#model.generation_config.language = "english"
#model.generation_config.task = "transcribe"
#model.generation_config.forced_decoder_ids = None

from dataclasses import dataclass
from typing import Any, Dict, List, Union

import evaluate
metric = evaluate.load("wer")

from torchsummary import summary
summary(model);

Layer (type:depth-idx)                             Param #
├─WhisperModel: 1-1                                --
|    └─WhisperEncoder: 2-1                         --
|    |    └─Conv1d: 3-1                            185,088
|    |    └─Conv1d: 3-2                            1,770,240
|    |    └─Embedding: 3-3                         (1,152,000)
|    |    └─ModuleList: 3-4                        85,045,248
|    |    └─LayerNorm: 3-5                         1,536
|    └─WhisperDecoder: 2-2                         --
|    |    └─Embedding: 3-6                         39,831,552
|    |    └─WhisperPositionalEmbedding: 3-7        344,064
|    |    └─ModuleList: 3-8                        113,402,880
|    |    └─LayerNorm: 3-9                         1,536
├─Linear: 1-2                                      39,831,552
Total params: 281,565,696
Trainable params: 280,413,696
Non-trainable params: 1,152,000


## Load data and create dataset

In [3]:
%%time
features_train = np.load(data_path / "whisper_brain_arr_train.npy")
features_test = np.load(data_path / "whisper_brain_arr_test.npy")

print("Features train shape", features_train.shape)
print("Features test shape ", features_test.shape)

Features train shape (8800, 80, 3000)
Features test shape  (880, 80, 3000)
CPU times: total: 8.17 s
Wall time: 17.3 s


In [4]:
%%time
sentences_train = load_sentences_from_txt(data_path / "whisper_sentences_train.txt")
sentences_test = load_sentences_from_txt(data_path / "whisper_sentences_test.txt")

sentences_train[0]

CPU times: total: 0 ns
Wall time: 7.14 ms


'nuclear rockets can destroy airfields with ease'

In [5]:
class WhisperBrainDataset(Dataset):
    def __init__(self, brain_features, sentences, tokenizer):
        self.brain_features = brain_features
        self.sentences = sentences
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.brain_features)

    def __getitem__(self, idx):
        input_features = self.brain_features[idx]
        sentence = self.sentences[idx]

        # Tokenize the sentence
        labels = self.tokenizer(sentence, return_tensors="pt").input_ids.squeeze()

        return {
            "input_features": torch.tensor(input_features),
            "labels": labels,
        }

In [6]:
# Create training and evaluation datasets
train_dataset = WhisperBrainDataset(features_train, sentences_train, tokenizer)
eval_dataset = WhisperBrainDataset(features_test, sentences_test, tokenizer)

### Create data collator

In [7]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    
    feature_extractor: Any
    tokenizer: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # first treat the BRAIN INPUTS (already preprocessed)
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [8]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    tokenizer=tokenizer,
    feature_extractor=feature_extractor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

### WER metric

In [9]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

## Training setup

### Setup wandb and hugging face login

In [10]:
import wandb
wandb.login(key="84800673dd80a5eac8bb77b02728e733f806fd10")

[34m[1mwandb[0m: Currently logged in as: [33maltime[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\aleks\.netrc


True

In [11]:
import huggingface_hub
huggingface_hub.login(token="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")

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to C:\Users\aleks\.cache\huggingface\token
Login successful


### Small detour - generation_max_length

Let's find out what is the maximal amount of tokens we will need to desribe our output sentences.

In [12]:
#%%time
# dataset_names = ["Train dataset", "Test dataset"]
# sentences_datasets = [sentences_train, sentences_test]

# for i, dataset in enumerate([train_dataset, eval_dataset]):

#     max_tokens = -np.inf
#     min_tokens = +np.inf
#     max_idx = 0
#     min_idx = 0

#     for idx in range(len(dataset)):
#         batch = dataset[idx]
#         n_tokens = len(batch['labels'])
#         if n_tokens > max_tokens:
#             max_tokens = n_tokens
#             max_idx = idx
#         if n_tokens < min_tokens:
#             min_tokens = n_tokens
#             min_idx = idx
            
#     print(f"{dataset_names[i]}: max_tokens = {max_tokens}, min_tokens = {min_tokens}")
#     print(f"Examples:\n  >{sentences_datasets[i][max_idx]}\n  >{sentences_datasets[i][min_idx]}")


""" 

FOUND OUT 

max_tokens (train) = 23, max_tokens (test) = 21 
min_tokens (train) = 5,  min_tokens (test) = 6 

""";

### Set training parameters

In [13]:
experiment_path = data_path / "experiments" / WHISPER_MODEL_NAME / "experiment-2"
experiment_path.mkdir(parents=True, exist_ok=True)

batch_size = 16
epoch_length = int(len(train_dataset) / batch_size)

training_args = Seq2SeqTrainingArguments(
    output_dir=experiment_path,  # change to a repo name of your choice
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=int(16 / batch_size),  # increase by 2x for every 2x decrease in batch size
    learning_rate=2.5e-5,
    num_train_epochs = 5,
    warmup_steps=epoch_length,
    gradient_checkpointing=True,
    fp16=False,
    evaluation_strategy="steps",
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    generation_max_length=32,
    save_steps=int(epoch_length / 2),
    eval_steps=int(epoch_length / 2),
    logging_steps=25,
    report_to=["wandb"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)


### Initialize trainer

In [15]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [16]:
trainer.train()

  attn_output = torch.nn.functional.scaled_dot_product_attention(
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


Step,Training Loss,Validation Loss,Wer
275,3.2555,3.167503,112.841033
550,3.3362,3.13711,107.0753
825,2.7333,3.049134,105.001819
1100,1.9005,2.067214,87.140778
1375,1.065,1.810285,76.591488
1650,1.0292,1.636373,69.206984


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 1142

KeyboardInterrupt: 

In [33]:
for idx in range(20, 30):
    true_text = sentences_test[idx]
    print(f"True: {true_text}")
    input_tensor = eval_dataset[idx]['input_features'].to('cuda').type(torch.float).reshape((1, 80, 3000))
    ids = model.generate(input_tensor).cpu()
    pred_text = tokenizer.decode(ids[0], skip_special_tokens=True)
    print(f"Pred: {pred_text}\n")

True: to some extent predispositions are shaped by exposure to group environments
Pred: the same act usually depends on whether the prosecutor does not object

True: an adult male baboon's teeth are not suitable for eating shellfish
Pred: he had longed to talk about his own personal personal experience

True: in this context it would do well for us to bear in mind the vision of peace
Pred: in this instance a voice was heard through a voice that had been spoken of occasionally

True: you're boiling milk ain't you
Pred: who was america's daughter

True: rich looked for spotted hyenas and jaguars on the safari
Pred: frontiers should avoid casualties and take care of the employees

True: traffic frequently has failed to measure up to engineers' rosy estimates
Pred: dallas has also avoided major financial problems due to the lack of benefits

True: ralph prepared red snapper with fresh lemon sauce for dinner
Pred: toss pewter with cream cheese with cream cheese

True: did you buy any cordur

In [35]:
checkpoint_path = experiment_path / "checkpoint-2500"
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path)

In [37]:
input_tensor = eval_dataset[idx]['input_features'].type(torch.float).reshape((1, 80, 3000))

ids = model.generate(input_tensor).cpu()

In [28]:
ids

tensor([[50257, 50362,    72,   550,   284,  1414,   616,  5704,  6949, 50256]])

In [29]:
tokenizer.decode(ids[0], skip_special_tokens=True)

'i had to pay my taxes anyway'

In [43]:
out = model.generate(input_tensor, num_beams=5).cpu()

In [46]:
pred_text = tokenizer.decode(ids[0], skip_special_tokens=True)
pred_text

'he had not been taken to visit with her mother'

In [48]:
for idx in range(30, 40):
    true_text = sentences_test[idx]
    print(f"True: {true_text}")
    input_tensor = eval_dataset[idx]['input_features'].type(torch.float).reshape((1, 80, 3000))
    ids = model.generate(input_tensor, num_beams=5).cpu()
    pred_text = tokenizer.decode(ids[0], skip_special_tokens=True)
    print(f"Pred: {pred_text}\n")

True: who authorized the unlimited expense account
Pred: all they knew was that my husband could go out

True: the family requests that flowers be omitted
Pred: the point you get is not from the standpoint

True: the museum hires musicians every evening
Pred: the boy slipped hard to get to the point

True: we'll serve rhubarb pie after rachel's talk
Pred: a lot of people have to do their part

True: they enjoy it when i audition
Pred: that's not how i feel a lot good

True: the avocado should have a give to it as you hold it when it is ripe
Pred: the earthquake ran out of the way and now we are at the end of time

True: energy suppliers usually deal with this by conducting regular inspections and trimming
Pred: in this case usually the first step is to get together and make sure everything is fine

True: the saw is broken so chop the wood instead
Pred: this is a blessing to have the joy and insight

True: the cleaned version is given below
Pred: the cowboys couldn't help but wail

True

In [39]:
model

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 768)
      (layers): ModuleList(
        (0-11): 12 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        

In [18]:
trainer.train(experiment_path / "checkpoint-1650")

There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


Step,Training Loss,Validation Loss


KeyboardInterrupt: 