In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import math
import random

from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from datasets import load_dataset, load_from_disk
from datasets import Value

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,  
    default_data_collator, 
    get_scheduler
)
from accelerate import Accelerator

from peft import LoraConfig

import dac
from audiotools import AudioSignal

In [3]:
!nvidia-smi

Thu Aug 22 16:06:10 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A40                     On  | 00000000:53:00.0 Off |                    0 |
|  0%   33C    P0              58W / 300W |      7MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [4]:
base_model = "google/gemma-2-2b"
save_dir = "./results_asr"

data = "librispeech" # ["parler-tts", "tedlium", "librispeech"]

start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"
end_frame_token = "<eof>"

n_codebooks = 8
max_seq_length = 3072

device = 0
load_processed = False
path_to_processed = "./data/processed_2/"
path_to_cache = "./data/cache/"
quantize_before_training = False


torch.cuda.set_device(f"cuda:{device}")

In [5]:
lora_config = LoraConfig(
    r=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

In [6]:
class Vikhr4oDataset(Dataset):
    def __init__(self, dataset, tokenizer, quantizer, asr: bool = False):
        self.dataset = dataset
        self.tokenizer = tokenizer
        # if true, sequences of type speech to text
        self.asr = asr 

        self.soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:]
        self.eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:]
        self.eos = tokenizer(end_sequence_token, return_tensors="pt")["input_ids"][:, -1:]
        self.eof = tokenizer(end_frame_token)["input_ids"][-1]

        self.n_original_tokens = len(tokenizer) - 1024
        self.quantizer = quantizer 

    def __len__(self):
        return len(self.dataset)

    def quantize(self, example):
        audio_data, sample_rate = example["audio"]["array"], example["audio"]["sampling_rate"]
    
        # audio -> discrete codes
        audio = torch.tensor(audio_data).view(1, 1, len(audio_data)).float()
        audio = audio.to(quantizer.device)
        x = self.quantizer.preprocess(audio, sample_rate)
        _, codes, _, _, _ = self.quantizer.encode(x)
        codes = codes.to("cpu")
    
        # Move tensor back to CPU and delete it to free GPU memory
        del audio
        del x
        torch.cuda.empty_cache()
    
        # increment tokens' ids 
        return codes + self.n_original_tokens

    def __getitem__(self, idx):
        row = self.dataset[idx]

        # get text tokens 
        text = row["text"]
        text_tokenized = self.tokenizer(text, return_tensors="pt")
        text_input_tokens = text_tokenized["input_ids"]

        # quantize audio 
        codes = self.quantize(row)
        raw_audio_tokens = codes[:, :n_codebooks]

        # add special tokens at the end of each frame 
        n_frames = raw_audio_tokens.shape[-1]
        raw_audio_tokens = torch.cat([raw_audio_tokens, torch.full((1, 1, n_frames), self.eof)], dim=1)
        
        # permute: (n_codebooks, n_frames) -> (n_frames, n_codebooks)
        audio_input_tokens = raw_audio_tokens.permute(2, 0, 1).contiguous().view(1, -1)

        # determine number of audio tokens given max_seq_length 
        audio_length = min(max_seq_length - text_input_tokens.shape[-1] - 3, audio_input_tokens.shape[-1])
        audio_length -= audio_length % (n_codebooks + 1)

        padding_size = max_seq_length - text_input_tokens.shape[-1] - audio_length - 3
        padding = torch.zeros((1, padding_size), dtype=torch.int64)

        if self.asr:
            tokens = torch.cat([padding, self.soa, audio_input_tokens[:, :audio_length], self.eoa, text_input_tokens, self.eos], dim=1).squeeze(0)
        else:
            tokens = torch.cat([padding, text_input_tokens, self.soa, audio_input_tokens[:, :audio_length], self.eoa, self.eos], dim=1).squeeze(0)
            
        attention_mask = torch.cat([padding, torch.ones((1, max_seq_length - padding_size))], dim=1).squeeze(0)

        return {
            "input_ids": tokens, 
            "attention_mask": attention_mask, 
            "labels": tokens.clone(),
        }

In [7]:
tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(base_model,
                                             device_map={"":0}, 
                                             attn_implementation="eager",
                                             cache_dir=".")

tokenizer.add_special_tokens({'additional_special_tokens': [start_audio_token, end_audio_token, end_frame_token]})
n_tokens = len(tokenizer)

start_audio_token_id = tokenizer(start_audio_token)["input_ids"][-1]
end_audio_token_id = tokenizer(end_audio_token)["input_ids"][-1]
end_frame_token_id = tokenizer(end_frame_token)["input_ids"][-1]

quant_path = dac.utils.download(model_type="16khz")
quantizer = dac.DAC.load(quant_path, n_codebooks=n_codebooks).to("cpu")

tokenizer.add_tokens([f"<audio_token_{i}>" for i in range(quantizer.codebook_size)])

assert len(tokenizer) == n_tokens + quantizer.codebook_size

model.resize_token_embeddings(len(tokenizer))

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Embedding(257027, 2304, padding_idx=0)

In [8]:
def freeze(
    model,
    freeze_emb=False,
    freeze_ln=False,
    freeze_attn=False,
    freeze_ff=True,
    freeze_ff_layers=[5,6,7,8,9,12,23,14,18,19,20,0,25],  # None means all or no layers, depending on freeze_ff
    freeze_other=False,
):
    if freeze_ff_layers is not None and not isinstance(freeze_ff_layers, (list, set)):
        raise ValueError("freeze_ff_layers must be a list or set of layer indices")

    for name, p in model.named_parameters():
        name = name.lower()
        layer_index = None
        if 'mlp' in name:
            # Parse the layer index from the parameter name if possible
            tokens = name.split('.')
            for token in tokens:
                if token.isdigit():
                    layer_index = int(token)
                    break
        
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            if freeze_ff_layers is None:
                # Apply general freeze_ff setting
                p.requires_grad = not freeze_ff
            else:
                # Apply specific layer freeze setting
                p.requires_grad = not (freeze_ff and layer_index in freeze_ff_layers)
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other
    return model


def quantize(example, rank, quantizer, n_original_tokens):
    # repeat import to work with multiprocessing
    import torch

    quantizer.to(f"cuda:{(rank or 0) % torch.cuda.device_count()}")
    audio_data, sample_rate = example["audio"]["array"], example["audio"]["sampling_rate"]

    # audio -> discrete codes
    audio = torch.tensor(audio_data).view(1, 1, len(audio_data)).float()
    audio = audio.to(quantizer.device)
    x = quantizer.preprocess(audio, sample_rate)
    _, codes, _, _, _ = quantizer.encode(x)

    # Move tensor back to CPU and delete it to free GPU memory
    del audio
    del x
    torch.cuda.empty_cache()

    # increment tokens' ids 
    return {"audio_tokens": codes + n_original_tokens}


def get_audio_padding_tokens(quantizer):
    # create audio without any sounds 
    # seems to work better than radom padding if 
    # length of generated audio is not devisible by n_codebooks
    audio = torch.zeros((1, 1, 1))
    audio = audio.to(quantizer.device)
    
    x = quantizer.preprocess(audio, quantizer.sample_rate)
    _, codes, _, _, _ = quantizer.encode(x)

    # Move tensor back to CPU and delete it to free GPU memory
    del audio
    del x
    torch.cuda.empty_cache()
    
    return {"audio_tokens": codes[0].t().squeeze(0)}
    


def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
    # find start and end indices of audio tokens 
    tokens = tokens[tokens != end_frame_token_id]
    start = torch.nonzero(tokens == start_audio_token_id)
    end = torch.nonzero(tokens == end_audio_token_id)
    
    start = start[0, -1] + 1 if len(start) else 0
    end = end[0, -1] if len(end) else tokens.shape[-1]
    
    # substract length of original vocabulary -> tokens in range [0, 1024)
    audio_tokens = tokens[start:end] % n_original_tokens
    reminder = audio_tokens.shape[-1] % n_codebooks
    
    if reminder:
        # pad if last frame is incomplete 
        audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:]], dim=0)

    codes = audio_tokens.view(1, -1, n_codebooks).permute(0, 2, 1).to(quantizer.device)
    z = quantizer.quantizer.from_codes(codes)[0]
    audio = quantizer.decode(z)

    del tokens 
    del audio_tokens 
    torch.cuda.empty_cache()
    
    return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)

    

In [9]:
def prepare_librispeech():
    raw = load_dataset("openslr/librispeech_asr", "clean", cache_dir=".")
    processed = raw.remove_columns(["chapter_id"])
    processed = processed.cast_column('speaker_id', Value('string'))
    return processed 


def prepare_tedlium():
    raw = load_dataset("LIUM/tedlium", "release1", cache_dir=".")
    processed = raw.remove_columns(["gender"])
    return processed


def prepare_parler_tts():
    # raw_libr =  load_dataset("parler-tts/libritts_r_filtered", "clean", cache_dir="/mnt/storage")
    # processed_libr = raw_libr.remove_columns(["chapter_id", "text_original", "speaker_id"])
    # processed_libr = processed_libr.rename_column('text_normalized', 'text')
    
    raw_mls = load_dataset("parler-tts/mls_eng", cache_dir="/mnt/storage")
    processed_mls = raw_mls.remove_columns(["begin_time", "end_time", "speaker_id", "book_id", "audio_duration"])
    processed_mls = processed_mls.rename_column('transcript', 'text')

    return processed_mls
    

if not load_processed:
    print("Loadiing data")
    if data == "tedlium":
        dataset = prepare_tedlium()

        train_data = dataset["train"]
        val_data = dataset["validation"]
        
    elif data == "parler-tts":
        dataset = prepare_parler_tts()

        train_data = dataset["train"]
        val_data = dataset["dev"]
        
    elif data == "librispeech":
        dataset = prepare_librispeech()

        train_data = dataset["train.100"]
        val_data = dataset["validation"]
else:
    train_data = load_from_disk(os.path.join(path_to_processed, "train"))
    val_data = load_from_disk(os.path.join(path_to_processed, "val"))

Loadiing data


Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

In [10]:
train_dataset_tts = Vikhr4oDataset(train_data, tokenizer, quantizer)
train_dataset_asr = Vikhr4oDataset(train_data, tokenizer, quantizer, asr=True)

val_dataset_tts = Vikhr4oDataset(val_data, tokenizer, quantizer)
val_dataset_asr = Vikhr4oDataset(val_data, tokenizer, quantizer, asr=True)

train_dataset = ConcatDataset([train_dataset_tts, train_dataset_asr])
val_dataset = ConcatDataset([val_dataset_tts, val_dataset_asr])

padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"]

In [11]:
assert len(val_dataset) == len(val_dataset_asr) + len(val_dataset_tts)

In [12]:
# test audio decoding

input_ids_test = val_dataset[1]["input_ids"].unsqueeze(0)
decoded = decode_audio(input_ids_test, quantizer, padding_tokens, n_tokens)
noise = decode_audio(torch.zeros(input_ids_test.size(), dtype=torch.int64), quantizer, padding_tokens, n_tokens)

decoded.write("tests/test.wav")
noise.write("tests/noise.wav")

<audiotools.core.audio_signal.AudioSignal at 0x7f5480883a30>

## No SFT Trainer

In [None]:
train_batch_size = 1
eval_batch_size = 2
learning_rate = 5e-4
gradient_accumulation_steps = 8
lr_scheduler_type = "cosine"
num_train_epochs = 5
num_warmup_steps = 10
checkpointing_steps = 1000
logging_steps = 20
weight_decay = 0.1
max_grad_norm = 0.5



def test_audio_generation(model, batch, n, quantizer, pad_tokens, n_original_tokens):
    inds = random.choices(range(len(batch)), k=n)
    audios = []
    
    for input_ids, attn in batch["input_ids"], batch["attention_mask"]:
        with torch.no_grad():
            ind = torch.nonzero(input_ids == start_audio_token_id)[0, -1]
            input_ids = input_ids[:ind+1].unsqueeze(0)
            attn = attn[:ind+1].unsqueeze(0).to(torch.float16)
            output = model.generate(input_ids=input_ids, attention_mask=attn ,max_length=max_seq_length)

        try:
            audio = decode_audio(output, quantizer, pad_tokens, n_original_tokens)
            audio_file = os.path.join(save_dir, "audio")
            os.makedirs(audio_file, exists_ok=True)
            audio_file = os.path.join(audio_file, f"audio_{ind + 1}.wav")
            aduio.write(audio_file)
            audios.append(audio_file)
        except:
            print("No audio generated.")
            pass

    return audios


def get_last_checkpoint():
    n_checkpoints = len(list(filter(lambda x: x.startswith("checkpoint"), os.listdir(save_dir))))
    return n_checkpoints + 1


def save_checkpoint(model, accelerator, tokenizer, optimizer, scheduler):
    accelerator.wait_for_everyone()
    state = model.state_dict()

    path = os.path.join(save_dir, f"checkpoint-{get_last_checkpoint() * checkpointing_steps}")
    
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        path, 
        state_dict=state, 
        is_main_process=accelerator.is_main_process, 
        save_function=accelerator.save, 
        save_embedding_layers=True
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(path)
        torch.save(optimizer.state_dict(), os.path.join(path, "optimizer.pt"))
        torch.save(scheduler.state_dict(), os.path.join(path, "scheduler.pt"))


def train(model, dataloader, accelerator, optimizer, lr_scheduler, completed_steps, progress_bar, max_train_steps):
    model.gradient_checkpointing_enable()
    model.train()
    model = freeze(model, freeze_ff_layers=None)
    total_loss = 0
    acc_loss = 0
    
    for step, batch in enumerate(dataloader): 
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss
    
            last_loss = loss.detach().float()
            total_loss += last_loss
            acc_loss += last_loss 
            
            accelerator.backward(loss)
    
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
    
            progress_bar.update(1)
            completed_steps += 1

            acc_loss = acc_loss / gradient_accumulation_steps

            accelerator.log({"loss": acc_loss.item()})
            acc_loss = 0
    
            if completed_steps % checkpointing_steps == 0:
                save_checkpoint(model, accelerator, tokenizer, optimizer, lr_scheduler)

            torch.cuda.empty_cache()

        if completed_steps >= max_train_steps:
                break

    return total_loss / len(dataloader), completed_steps


def eval(model, dataloader, accelerator, epoch, completed_steps, train_loss, quantizer, pad_tokens, n_original_tokens):
    model.eval()
    losses = []

    eval_progress_bar = tqdm(dataloader, desc=f"Evaluating Epoch {epoch}", leave=False)
    
    for batch in eval_progress_bar:
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(accelerator.gather_for_metrics(loss.repeat(eval_batch_size)))

    losses = torch.cat(losses)
    try:
        eval_loss = torch.mean(losses)
        perplexity = math.exp(eval_loss)
    except OverflowError:
        perplexity = float("inf")

    print(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
    # audios = test_audio_generation(model, batch, 2, quantizer, pad_tokens, n_original_tokens)

    base_log = {
        "perplexity": perplexity,
        "eval_loss": eval_loss,
        "train_loss": train_loss.item() / len(train_dataloader),
        "epoch": epoch,
        "step": completed_steps,
    }
    # base_log.update({f"audio_{i+1}": audios[i] for i in range(len(audios))})

    accelerator.log(base_log, step=completed_steps)


if __name__ == "__main__":
    accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, 
                              mixed_precision='bf16', 
                              log_with="wandb")
    quantizer.to(model.device)
    os.makedirs(save_dir, exist_ok=True)

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=train_batch_size
    )
    eval_dataloader = DataLoader(
        val_dataset, collate_fn=default_data_collator, batch_size=eval_batch_size
    )

    no_decay = ["bias", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, fused=True)
    # optimizer = bnb.optim.Adam8bit(optimizer_grouped_parameters, min_8bit_size=16384)

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    max_train_steps = num_train_epochs * num_update_steps_per_epoch

    lr_scheduler = get_scheduler(
        name=lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps * accelerator.num_processes,
        num_training_steps=max_train_steps * accelerator.num_processes,
    )

    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    max_train_steps = num_train_epochs * num_update_steps_per_epoch
    
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    accelerator.init_trackers("vikhr4o-gemma-2", {"lr_scheduler_type": lr_scheduler_type})

    total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps

    print("***** Running training *****")
    print(f"  Num examples = {len(train_dataset)}")
    print(f"  Num Epochs = {num_train_epochs}")
    print(f"  Instantaneous batch size per device = {train_batch_size}")
    print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    print(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
    print(f"  Total optimization steps = {max_train_steps}")
    
    progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
    completed_steps = 0
    starting_epoch = 0

    for epoch in range(starting_epoch, num_train_epochs):
        train_loss, completed_steps = train(model, train_dataloader, accelerator, optimizer, lr_scheduler, completed_steps, progress_bar, max_train_steps)
        print(f"EPOCH {epoch + 1} train loss:", train_loss)
        eval(model, eval_dataloader, accelerator, epoch, completed_steps, train_loss, quantizer, padding_tokens, n_tokens + 1)

    save_checkpoint(model, accelerator, tokenizer, optimizer, lr_scheduler)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
[34m[1mwandb[0m: Currently logged in as: [33mksycheva[0m. Use [1m`wandb login --relogin`[0m to force relogin


***** Running training *****
  Num examples = 57078
  Num Epochs = 5
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 8
  Total optimization steps = 35675


  0%|          | 0/35675 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
