In [1]:
# clone repo, install requirements, and load hubert checkpoint
!git clone https://github.com/0nutation/SpeechGPT.git
%cd SpeechGPT/speechgpt

!wget -P utils/speech2unit https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3.pt
!wget -P utils/speech2unit https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3_L11_km1000.bin

Cloning into 'SpeechGPT'...
remote: Enumerating objects: 236, done.[K
remote: Counting objects: 100% (197/197), done.[K
remote: Compressing objects: 100% (166/166), done.[K
remote: Total 236 (delta 69), reused 106 (delta 24), pack-reused 39[K
Receiving objects: 100% (236/236), 3.50 MiB | 32.87 MiB/s, done.
Resolving deltas: 100% (75/75), done.
/kaggle/working/SpeechGPT/speechgpt
--2024-05-29 14:12:03--  https://dl.fbaipublicfiles.com/hubert/mhubert_base_vp_en_es_fr_it3.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.96, 3.163.189.108, 3.163.189.14, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.96|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1136474383 (1.1G) [binary/octet-stream]
Saving to: 'utils/speech2unit/mhubert_base_vp_en_es_fr_it3.pt'


2024-05-29 14:12:07 (256 MB/s) - 'utils/speech2unit/mhubert_base_vp_en_es_fr_it3.pt' saved [1136474383/1136474383]

--2024-05-29 14:12:08-- 

In [2]:
!pip install -qqq fairseq==0.12.2
!pip install -qqq fire
!pip install -qqq einops

In [3]:
from pathlib import Path

import json
import torch


DATA_PATH = Path("/kaggle/input/fastspeech-audio")
device = "cuda" if torch.cuda.is_available() else "cpu"


with open(str(DATA_PATH / "audio" / "train.json")) as f:
    train_metadata = json.load(f)
with open(str(DATA_PATH / "audio" / "test.json")) as f:
    test_metadata = json.load(f)

In [4]:
from torch.utils.data import Dataset

import torch
import torchaudio


class FastSpeechAudioDataset(Dataset):
    def __init__(self, metadata, audio_tokenizer, text_tokenizer):
        self.data = metadata
        self.audio_tokenizer = audio_tokenizer
        self.text_tokenizer = text_tokenizer
        
        self._prepare_data()
        
    def _prepare_data(self):
        for elem in self.data:
            clean_audio_tokens = self.audio_tokenizer(
                str(DATA_PATH / elem["clean path"]),
            )
            aug_audio_tokens = self.audio_tokenizer(
                str(DATA_PATH / elem["aug path"]),
            )
            
            elem["clean_audio_tokens"] = self.text_tokenizer(
                clean_audio_tokens, return_tensors="pt",
            )["input_ids"].squeeze()
            elem["aug_audio_tokens"] = self.text_tokenizer(
                aug_audio_tokens, return_tensors="pt",
            )["input_ids"].squeeze()
            elem["text_tokens"] = self.text_tokenizer(
                elem["text"], return_tensors="pt",
            )["input_ids"].squeeze()
     
            elem["clean_tokens"] = torch.cat([
                elem["text_tokens"], elem["clean_audio_tokens"],
            ])
            elem["aug_tokens"] = torch.cat([
                elem["text_tokens"], elem["aug_audio_tokens"],
            ])

    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

In [5]:
from torch.nn.utils.rnn import pad_sequence


def collate_fn(batch):
    all_clean_tokens, all_aug_tokens = [], []
    all_clean_mask, all_aug_mask = [], []
    for elem in batch:
        all_clean_tokens.append(elem["clean_tokens"])
        all_aug_tokens.append(elem["aug_tokens"])
        all_clean_mask.append(
            torch.cat([
                torch.zeros(len(elem["text_tokens"])),
                torch.ones(len(elem["clean_audio_tokens"])),
            ])
        )
        all_aug_mask.append(
            torch.cat([
                torch.zeros(len(elem["text_tokens"])),
                torch.ones(len(elem["aug_audio_tokens"])),
            ])
        )
    
    all_clean_tokens = pad_sequence(
        all_clean_tokens, batch_first=True, padding_value=0,
    )
    all_aug_tokens = pad_sequence(
        all_aug_tokens, batch_first=True, padding_value=0,
    )
    all_clean_mask = pad_sequence(
        all_clean_mask, batch_first=True, padding_value=0,
    )
    all_aug_mask = pad_sequence(
        all_aug_mask, batch_first=True, padding_value=0,
    )
    
    return (
        all_clean_tokens, all_clean_mask,
        all_aug_tokens, all_aug_mask,
    )

In [6]:
from utils.speech2unit.speech2unit import Speech2Unit

from transformers import LlamaTokenizer
    

checkpoint_dir = "utils/speech2unit/"
audio_tokenizer = Speech2Unit(ckpt_dir=checkpoint_dir)

text_tokenizer = LlamaTokenizer.from_pretrained("fnlp/SpeechGPT-7B-cm")
text_tokenizer.model_max_length = 2048

2024-05-29 14:13:57.026940: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-29 14:13:57.027047: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-29 14:13:57.154962: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


tokenizer_config.json:   0%|          | 0.00/747 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/18.0k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
train_dataset = FastSpeechAudioDataset(
    train_metadata[:256], audio_tokenizer, text_tokenizer,
)
test_dataset = FastSpeechAudioDataset(
    test_metadata[:256], audio_tokenizer, text_tokenizer,
)

In [8]:
from torch.utils.data import DataLoader


train_loader = DataLoader(
    train_dataset,
    collate_fn=collate_fn,
    batch_size=16,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
)
test_loader = DataLoader(
    test_dataset,
    collate_fn=collate_fn,
    batch_size=16,
    shuffle=False,
    pin_memory=True,
    num_workers=4,
)

In [9]:
from transformers import LlamaForCausalLM


NUM_LAYERS = 12

feature_model = LlamaForCausalLM.from_pretrained("fnlp/SpeechGPT-7B-cm", use_cache=False)
feature_model = feature_model.model
feature_model.norm = feature_model.layers[NUM_LAYERS].input_layernorm
feature_model.layers = feature_model.layers[:NUM_LAYERS]

config.json:   0%|          | 0.00/507 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/27.0G [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


In [10]:
from torch import nn


class SpeechGPTMOSPC(nn.Module):
    def __init__(self, feature_model, freeze=True):
        super().__init__()

        self.feature_model = feature_model
        self.freeze = freeze
        self.head = nn.Sequential(
            nn.Linear(4096, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1),
        )
        
        if freeze:
            self.feature_model.eval()
            for p in self.feature_model.parameters():
                p.requires_grad_(False)
        
    def forward(self, tokens, mask):
        if self.freeze:
            with torch.no_grad():
                x = self.feature_model(tokens)
        else:
            x = self.feature_model(tokens)
        x = x["last_hidden_state"]
        x = self.head(x).squeeze(-1)
        x = (x * mask).sum(dim=1) / mask.sum(dim=1)
        return x

In [11]:
def calc_rank_loss(clean_score, aug_score):
    p = torch.exp(clean_score - aug_score) / (1 + torch.exp(clean_score - aug_score))
    L_rank = - torch.log(p)
    return L_rank.mean()

In [12]:
model = SpeechGPTMOSPC(feature_model)
model.to(device)

SpeechGPTMOSPC(
  (feature_model): LlamaModel(
    (embed_tokens): Embedding(33006, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [13]:
from kaggle_secrets import UserSecretsClient

import os
import wandb


user_secrets = UserSecretsClient()
wandb_token = user_secrets.get_secret("wandb-token")
os.environ["WANDB_API_KEY"] = user_secrets.get_secret("wandb-token")

In [14]:
EPOCHS = 50

wandb.init(
    project="critic_project",
    name="Vanilla baseline model training",
)

optimizer = torch.optim.Adam(model.parameters())

for epoch in range(EPOCHS):
    for clean_tokens, clean_mask, aug_tokens, aug_mask in train_loader:
        clean_tokens = clean_tokens.to(device)
        clean_mask = clean_mask.to(device)
        aug_tokens = aug_tokens.to(device)
        aug_mask = aug_mask.to(device)

        clean_score = model(clean_tokens, clean_mask)
        aug_score = model(aug_tokens, aug_mask)

        rank_loss = calc_rank_loss(clean_score, aug_score)
        clean_loss = ((5 - clean_score) ** 2).mean()
        aug_loss = ((1 - aug_score) ** 2).mean()
        
        loss = clean_loss + aug_loss + rank_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        wandb.log(
            {
                "train_loss": loss.item(),
                "train_rank_loss": rank_loss.item(),
                "train_clean_score": clean_score.mean().item(),
                "train_clean_loss": clean_loss.item(),
                "train_aug_score": aug_score.mean().item(),
                "train_aug_loss": aug_loss.item(),
            }
        )
    
    with torch.inference_mode():
        total_loss = 0
        total_rank_loss = 0
        total_clean_score = 0
        total_clean_loss = 0
        total_aug_score = 0
        total_aug_loss = 0
        for clean_tokens, clean_mask, aug_tokens, aug_mask in test_loader:
            clean_tokens = clean_tokens.to(device)
            clean_mask = clean_mask.to(device)
            aug_tokens = aug_tokens.to(device)
            aug_mask = aug_mask.to(device)

            clean_score = model(clean_tokens, clean_mask)
            aug_score = model(aug_tokens, aug_mask)
            
            rank_loss = calc_rank_loss(clean_score, aug_score)
            clean_loss = ((5 - clean_score) ** 2).mean()
            aug_loss = ((1 - aug_score) ** 2).mean()
            loss = clean_loss + aug_loss + rank_loss

            bs = clean_tokens.shape[0]
            total_loss += loss.item() * bs
            total_rank_loss += rank_loss * bs
            total_clean_score += clean_score.sum().item()
            total_clean_loss += clean_loss.item() * bs
            total_aug_score += aug_score.sum().item()
            total_aug_loss += aug_loss.item() * bs
        
        wandb.log(
            {
                "valid_loss": total_loss / len(test_dataset),
                "valid_rank_loss": total_rank_loss / len(test_dataset),
                "valid_clean_score": total_clean_score / len(test_dataset),
                "valid_clean_loss": total_clean_loss / len(test_dataset),
                "valid_aug_score": total_aug_score / len(test_dataset),
                "valid_aug_loss": total_aug_loss / len(test_dataset),
            }
        )

[34m[1mwandb[0m: Currently logged in as: [33msitff_subset[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.16.6
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/SpeechGPT/speechgpt/wandb/run-20240529_141909-ndlbis1i[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mVanilla baseline model training[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/sitff_subset/critic_project[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/sitff_subset/critic_project/runs/ndlbis1i[0m
