In [1]:
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import math
import random
import wandb

from tqdm import tqdm
from multiprocess import set_start_method

import torch
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from datasets import load_dataset, load_from_disk, Audio, concatenate_datasets
from datasets import Value, DatasetDict

import transformers
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,  
    TrainingArguments, 
    default_data_collator, 
    get_scheduler
)
from accelerate import Accelerator

from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal

In [3]:
base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
save_dir = "./results_asr_mixed_different_tokens"

data = "librispeech" # ["parler-tts", "tedlium", "librispeech"]

start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"
n_special_tokens = 3

n_codebooks = 3
max_seq_length = 2048

device = 0
load_processed = False
path_to_processed = "./data/processed/"
path_to_cache = "./data/cache/"
quantize_before_training = False


torch.cuda.set_device(f"cuda:{device}")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [4]:
def freeze_entire_model(model):
    for n, p in model.named_parameters():
        p.requires_grad = False
    return model


def freeze(
    model,
    freeze_emb=False,
    freeze_ln=False,
    freeze_attn=False,
    freeze_ff=True,
    freeze_ff_layers=[5,6,7,8,9,12,23,14,18,19,20,0,25],  # None means all or no layers, depending on freeze_ff
    freeze_other=False,
):
    if freeze_ff_layers is not None and not isinstance(freeze_ff_layers, (list, set)):
        raise ValueError("freeze_ff_layers must be a list or set of layer indices")

    for name, p in model.named_parameters():
        name = name.lower()
        layer_index = None
        if 'mlp' in name:
            # Parse the layer index from the parameter name if possible
            tokens = name.split('.')
            for token in tokens:
                if token.isdigit():
                    layer_index = int(token)
                    break
        
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            if freeze_ff_layers is None:
                # Apply general freeze_ff setting
                p.requires_grad = not freeze_ff
            else:
                # Apply specific layer freeze setting
                p.requires_grad = not (freeze_ff and layer_index in freeze_ff_layers)
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other
    return model

In [5]:
class Vikhr4oDataset(Dataset):
    def __init__(self, dataset, tokenizer, quantizer, asr: bool = False):
        self.dataset = dataset
        self.tokenizer = tokenizer
        # if true, sequences of type speech to text
        self.asr = asr 

        self.soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
        self.eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
        self.eos = tokenizer(end_sequence_token, return_tensors="pt")["input_ids"][:, -1:].to(device)

        self.n_original_tokens = len(tokenizer) - 1024
        self.quantizer = quantizer 

    def __len__(self):
        return len(self.dataset)

    def quantize(self, example):
        audio_data, sample_rate = example["audio"]["array"], example["audio"]["sampling_rate"]
    
        # audio -> discrete codes
        audio = torch.tensor(audio_data).view(1, 1, len(audio_data)).float()
        audio = audio.to(device)
        codes = self.quantizer.encode(audio)
        codes = codes.squeeze(1)
    
        # Move tensor back to CPU and delete it to free GPU memory
        del audio
        torch.cuda.empty_cache()
    
        # increment tokens' ids 
        return codes + self.n_original_tokens

    def __getitem__(self, idx):
        row = self.dataset[idx]

        # get text tokens 
        text = row["text"]
        text_tokenized = self.tokenizer(text, return_tensors="pt")
        text_input_tokens = text_tokenized["input_ids"].to(device)

        # quantize audio 
        codes = self.quantize(row)

        if self.asr:
            raw_audio_tokens = codes[:1]
            audio_input_tokens = raw_audio_tokens.contiguous().view(1, -1)
        else:
            raw_audio_tokens = codes[:n_codebooks]
            audio_input_tokens = raw_audio_tokens.t().contiguous().view(1, -1)

        # determine number of audio tokens given max_seq_length 
        audio_length = min(max_seq_length - text_input_tokens.shape[-1] - n_special_tokens, audio_input_tokens.shape[-1])

        if not self.asr:
            audio_length -= audio_length % n_codebooks

        padding_size = max_seq_length - text_input_tokens.shape[-1] - audio_length - n_special_tokens
        padding = torch.zeros((1, padding_size), dtype=torch.int64, device=device)

        if self.asr:
            tokens = torch.cat([padding, self.soa, audio_input_tokens[:, :audio_length], self.eoa, text_input_tokens, self.eos], dim=1).squeeze(0)
        else:
            tokens = torch.cat([padding, text_input_tokens, self.soa, audio_input_tokens[:, :audio_length], self.eoa, self.eos], dim=1).squeeze(0)
            
        attention_mask = torch.cat([padding, torch.ones((1, max_seq_length - padding_size), device=device)], dim=1).squeeze(0)

        return {
            "input_ids": tokens, 
            "attention_mask": attention_mask, 
            "labels": tokens.clone(),
        }

In [6]:
tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(base_model,
                                             attn_implementation="sdpa",
                                             cache_dir=".")

tokenizer.add_special_tokens({'additional_special_tokens': [start_audio_token, end_audio_token]})
n_tokens = len(tokenizer)

start_audio_token_id = tokenizer(start_audio_token)["input_ids"][-1]
end_audio_token_id = tokenizer(end_audio_token)["input_ids"][-1]

config_path = "./audiotokenizer/speechtokenizer_hubert_avg_config.json"
ckpt_path = "./audiotokenizer/SpeechTokenizer.pt"
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
quantizer.eval()

for n, child in quantizer.named_children():
    child.to(device)
    child = freeze_entire_model(child)

codebook_size = quantizer.quantizer.bins

tokenizer.add_tokens([f"<audio_token_{i}>" for i in range(codebook_size)])

assert len(tokenizer) == n_tokens + codebook_size

model.resize_token_embeddings(len(tokenizer))

  WeightNorm.apply(module, name, dim)
  params = torch.load(ckpt_path, map_location='cpu')


Embedding(33026, 2048)

In [7]:
def get_audio_padding_tokens(quantizer):
    # create audio without any sounds 
    # seems to work better than radom padding if 
    # length of generated audio is not devisible by n_codebooks
    audio = torch.zeros((1, 1, 1))
    audio = audio.to(device)
    
    codes = quantizer.encode(audio)

    # Move tensor back to CPU and delete it to free GPU memory
    del audio
    torch.cuda.empty_cache()
    
    return {"audio_tokens": codes.squeeze(1)}
    


def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens, n_codebooks):
    # find start and end indices of audio tokens 
    start = torch.nonzero(tokens == start_audio_token_id)
    end = torch.nonzero(tokens == end_audio_token_id)
    
    start = start[0, -1] + 1 if len(start) else 0
    end = end[0, -1] if len(end) else tokens.shape[-1]
    
    # substract length of original vocabulary -> tokens in range [0, 1024)
    audio_tokens = tokens[start:end] % n_original_tokens
    reminder = audio_tokens.shape[-1] % n_codebooks
    
    if reminder:
        # pad if last frame is incomplete 
        audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:]], dim=0)

    if n_codebooks > 1:
        transposed = audio_tokens.view(-1, n_codebooks).t()
    else:
        transposed = audio_tokens
    codes = transposed.view(n_codebooks, 1, -1).to(device)

    audio = quantizer.decode(codes).squeeze(0)

    del tokens 
    del audio_tokens 
    torch.cuda.empty_cache()
    
    return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)

    

In [8]:
def prepare_librispeech():
    raw = load_dataset("openslr/librispeech_asr", "clean", cache_dir=".")
    processed = raw.remove_columns(["chapter_id"])
    processed = processed.cast_column('speaker_id', Value('string'))
    return processed 


def prepare_tedlium():
    raw = load_dataset("LIUM/tedlium", "release1", cache_dir=".")
    processed = raw.remove_columns(["gender"])
    return processed


def prepare_parler_tts():    
    raw_mls = load_dataset("parler-tts/mls_eng", cache_dir="/mnt/storage")
    processed_mls = raw_mls.remove_columns(["begin_time", "end_time", "speaker_id", "book_id", "audio_duration"])
    processed_mls = processed_mls.rename_column('transcript', 'text')

    return processed_mls


def prepare_synthetic():
    raw = load_dataset("homebrewltd/instruction-speech-encodec-v1", cache_dir=".")
    processed = raw.remove_columns(["answer", "length"])
    processed = processed.rename_column('prompt', 'text')

    return processed
    

load_processed = False

if not load_processed:
    print("Loadiing data")
    if data == "tedlium":
        dataset = prepare_tedlium()

        train_data = dataset["train"]
        val_data = dataset["validation"]
        
    elif data == "parler-tts":
        dataset = prepare_parler_tts()

        train_data = dataset["train"]
        val_data = dataset["dev"]
        
    elif data == "librispeech":
        dataset = prepare_librispeech()

        train_data = dataset["train.100"]
        val_data = dataset["validation"]

    elif data == "synthetic":
        dataset = prepare_synthetic()["train"]

        splits = dataset.train_test_split(test_size=0.1)
        train_data = splits["train"]
        val_data = splits["test"]
else:
    train_data = load_from_disk(os.path.join(path_to_processed, "train"))
    val_data = load_from_disk(os.path.join(path_to_processed, "val"))

Loadiing data


Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

In [9]:
train_dataset_tts = Vikhr4oDataset(train_data, tokenizer, quantizer)
train_dataset_asr = Vikhr4oDataset(train_data, tokenizer, quantizer, asr=True)

val_dataset_tts = Vikhr4oDataset(val_data, tokenizer, quantizer)
val_dataset_asr = Vikhr4oDataset(val_data, tokenizer, quantizer, asr=True)

train_dataset = ConcatDataset([train_dataset_tts, train_dataset_asr])
val_dataset = ConcatDataset([val_dataset_tts, val_dataset_asr])

padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"]

In [10]:
assert len(val_dataset) == len(val_dataset_asr) + len(val_dataset_tts)

In [11]:
val_dataset[-1]['input_ids']

tensor([    0,     0,     0,  ..., 16999,  1525, 29958], device='cuda:0')

In [12]:
# test audio decoding

input_ids_test = val_dataset[0]["input_ids"]
decoded = decode_audio(input_ids_test, quantizer, padding_tokens, n_tokens, n_codebooks)

input_ids_test2 = val_dataset[-1]["input_ids"]
decoded2 = decode_audio(input_ids_test2, quantizer, padding_tokens, n_tokens, 1)

decoded.write("tests/test.wav")
decoded2.write("tests/test2.wav")

<audiotools.core.audio_signal.AudioSignal at 0x75fb3090cbf0>

## No SFT Trainer

In [None]:
train_batch_size = 1
eval_batch_size = 2
learning_rate = 5e-4
gradient_accumulation_steps = 8
lr_scheduler_type = "cosine"
num_train_epochs = 5
num_warmup_steps = 10
checkpointing_steps = 1000
logging_steps = 20
weight_decay = 0.1
max_grad_norm = 0.25



def test_audio_generation(model, batch, n, quantizer, pad_tokens, n_original_tokens):
    inds = random.choices(range(len(batch)), k=n)
    audios = []
    
    for input_ids, attn in batch["input_ids"], batch["attention_mask"]:
        with torch.no_grad():
            ind = torch.nonzero(input_ids == start_audio_token_id)[0, -1]
            input_ids = input_ids[:ind+1].unsqueeze(0)
            attn = attn[:ind+1].unsqueeze(0).to(torch.float16)
            output = model.generate(input_ids=input_ids, attention_mask=attn ,max_length=max_seq_length)

        try:
            audio = decode_audio(output, quantizer, pad_tokens, n_original_tokens)
            audio_file = os.path.join(save_dir, "audio")
            os.makedirs(audio_file, exists_ok=True)
            audio_file = os.path.join(audio_file, f"audio_{ind + 1}.wav")
            aduio.write(audio_file)
            audios.append(audio_file)
        except:
            print("No audio generated.")
            pass

    return audios


def get_last_checkpoint():
    n_checkpoints = len(list(filter(lambda x: x.startswith("checkpoint"), os.listdir(save_dir))))
    return n_checkpoints + 1


def save_checkpoint(model, accelerator, tokenizer, optimizer, scheduler):
    accelerator.wait_for_everyone()
    state = model.state_dict()

    path = os.path.join(save_dir, f"checkpoint-{get_last_checkpoint() * checkpointing_steps}")
    
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        path, 
        state_dict=state, 
        is_main_process=accelerator.is_main_process, 
        save_function=accelerator.save, 
        save_embedding_layers=True
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(path)
        torch.save(optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
        torch.save(scheduler.state_dict(), os.path.join(path, "scheduler.pt"))


def train(model, dataloader, accelerator, optimizer, lr_scheduler, completed_steps, progress_bar, max_train_steps):
    model.gradient_checkpointing_enable()
    model.train()
    # model = freeze(model, freeze_ff_layers=None)
    total_loss = 0
    acc_loss = 0
    
    for step, batch in enumerate(dataloader): 
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
    
            last_loss = loss.detach().float()
            total_loss += last_loss
            acc_loss += last_loss 
            
            accelerator.backward(loss)
    
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
    
            progress_bar.update(1)
            completed_steps += 1

            acc_loss = acc_loss / gradient_accumulation_steps

            accelerator.log({"loss": acc_loss.item()})
            acc_loss = 0
    
            if completed_steps % checkpointing_steps == 0:
                save_checkpoint(model, accelerator, tokenizer, optimizer, lr_scheduler)

            torch.cuda.empty_cache()

        if completed_steps >= max_train_steps:
                break

    return total_loss / len(dataloader), completed_steps


def eval(model, dataloader, accelerator, epoch, completed_steps, train_loss, quantizer, pad_tokens, n_original_tokens):
    model.eval()
    losses = []

    eval_progress_bar = tqdm(dataloader, desc=f"Evaluating Epoch {epoch}", leave=False)
    
    for batch in eval_progress_bar:
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(accelerator.gather_for_metrics(loss.repeat(eval_batch_size)))

    losses = torch.cat(losses)
    try:
        eval_loss = torch.mean(losses)
        perplexity = math.exp(eval_loss)
    except OverflowError:
        perplexity = float("inf")

    print(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
    # audios = test_audio_generation(model, batch, 2, quantizer, pad_tokens, n_original_tokens)

    base_log = {
        "perplexity": perplexity,
        "eval_loss": eval_loss,
        "train_loss": train_loss.item() / len(train_dataloader),
        "epoch": epoch,
        "step": completed_steps,
    }
    # base_log.update({f"audio_{i+1}": audios[i] for i in range(len(audios))})

    accelerator.log(base_log, step=completed_steps)


if __name__ == "__main__":
    accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, 
                              mixed_precision='no', 
                              log_with="wandb")

    os.makedirs(save_dir, exist_ok=True)

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=train_batch_size
    )
    eval_dataloader = DataLoader(
        val_dataset, collate_fn=default_data_collator, batch_size=eval_batch_size
    )

    no_decay = ["bias", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, fused=True)
    # optimizer = bnb.optim.Adam8bit(optimizer_grouped_parameters, min_8bit_size=16384)

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    max_train_steps = num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = get_scheduler(
        name=lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps * accelerator.num_processes,
        num_training_steps=max_train_steps * accelerator.num_processes,
    )

    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    max_train_steps = num_train_epochs * num_update_steps_per_epoch
    
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    accelerator.init_trackers("vikhr4o-llama-tiny", {"lr_scheduler_type": lr_scheduler_type})

    total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps

    print("***** Running training *****")
    print(f"  Num examples = {len(train_dataset)}")
    print(f"  Num Epochs = {num_train_epochs}")
    print(f"  Instantaneous batch size per device = {train_batch_size}")
    print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    print(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
    print(f"  Total optimization steps = {max_train_steps}")
    
    progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
    completed_steps = 0
    starting_epoch = 0

    for epoch in range(starting_epoch, num_train_epochs):
        train_loss, completed_steps = train(model, train_dataloader, accelerator, optimizer, lr_scheduler, completed_steps, progress_bar, max_train_steps)
        print(f"EPOCH {epoch + 1} train loss:", train_loss)
        eval(model, eval_dataloader, accelerator, epoch, completed_steps, train_loss, quantizer, padding_tokens, n_tokens + 1)

    save_checkpoint(model, accelerator, tokenizer, optimizer, lr_scheduler)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mksycheva[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112527088779542, max=1.0…

***** Running training *****
  Num examples = 57078
  Num Epochs = 5
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 8
  Total optimization steps = 35675


  0%|          | 0/35675 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
