In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import os
import re
import random
import numpy as np

from tqdm import tqdm
from IPython.display import Audio
from scipy.io import wavfile

from bark.bark import SAMPLE_RATE, generate_audio, preload_models

In [2]:
preload_models(use_smaller_models=True)

In [2]:
def save_history(basepath: str, audio_array: np.ndarray, prompt: str, generation_data: list):
  np.savez(f"{basepath}.npz", semantic_prompt=generation_data[0], coarse_prompt=generation_data[1], fine_prompt=generation_data[2])
  with open(f"{basepath}.txt", "w", encoding="utf-8") as fp: fp.write(prompt)
  wavfile.write(f"{basepath}.wav", SAMPLE_RATE, audio_array)

def load_history(filepath: str):
  return np.load(filepath)

def load_voice(filepath: str):
  sample_rate, audio_array = wavfile.read(filepath)
  return audio_array

In [7]:
class Encoder(nn.Module):
  def __init__(self, inputs_dim, hidden_dim):
    super().__init__()
    self.num_layers = 4
    self.inputs_dim = inputs_dim
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(self.inputs_dim, self.hidden_dim, self.num_layers // 2, dropout=0.1, bidirectional=True, batch_first=True)
    self.h0 = nn.Parameter(torch.zeros(1), requires_grad=False)
    self.c0 = nn.Parameter(torch.zeros(1), requires_grad=False)

  def init_hidden(self, inputs):
    batch_size = inputs.size(0)
    h0 = self.h0.reshape((1, 1, 1)).repeat(self.num_layers, batch_size, self.hidden_dim)
    c0 = self.c0.reshape((1, 1, 1)).repeat(self.num_layers, batch_size, self.hidden_dim)
    return h0, c0

  def forward(self, inputs, hidden):
    outputs, hidden = self.lstm(inputs, hidden)
    return outputs, hidden

class Attention(nn.Module):
  def __init__(self, inputs_dim, hidden_dim):
    super().__init__()
    self.inputs_dim = inputs_dim
    self.hidden_dim = hidden_dim
    self.inputs_fc = nn.Linear(inputs_dim, hidden_dim)
    self.context_fc = nn.Conv1d(inputs_dim, hidden_dim, 1, 1)
    self.V = nn.Parameter(torch.FloatTensor(hidden_dim), requires_grad=True)
    self._inf = nn.Parameter(torch.FloatTensor([float("-inf")]), requires_grad=False)
    self.tanh = nn.Tanh()
    self.softmax = nn.Softmax(dim=1)
    nn.init.uniform_(self.V, -1, 1)

  def init_inf(self, mask_size):
    self.inf = self._inf.unsqueeze(1).expand(*mask_size)

  def forward(self, inputs, context, mask):
    i = self.inputs_fc(inputs).unsqueeze(2).expand(-1, -1, context.size(1))
    context = context.permute(0, 2, 1)
    c = self.context_fc(context)
    V = self.V.unsqueeze(0).expand(context.size(0), -1).unsqueeze(1)
    attention = torch.bmm(V, self.tanh(i + c)).squeeze(1)
    if len(attention[mask]) > 0: attention[mask] = self.inf[mask]
    alpha = self.softmax(attention)
    hidden_state = torch.bmm(c, alpha.unsqueeze(2)).squeeze(2)
    return hidden_state, alpha

class Decoder(nn.Module):
  def __init__(self, inputs_dim, hidden_dim):
    super().__init__()
    self.inputs_dim = inputs_dim
    self.hidden_dim = hidden_dim
    self.input_to_hidden = nn.Linear(self.inputs_dim, 4 * self.hidden_dim)
    self.hidden_to_hidden = nn.Linear(self.hidden_dim, 4 * self.hidden_dim)
    self.hidden_out = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
    self.attention = Attention(self.hidden_dim, self.hidden_dim)
    self.mask = nn.Parameter(torch.ones(1), requires_grad=False)
    self.runner = nn.Parameter(torch.zeros(1), requires_grad=False)

  def forward(self, inputs, decoder_input, hidden, context):
    batch_size = inputs.size(0)
    inputs_len = inputs.size(1)
    mask = self.mask.repeat(inputs_len).unsqueeze(0).repeat(batch_size, 1)
    self.attention.init_inf(mask.size())
    runner = self.runner.repeat(inputs_len)
    for i in range(inputs_len): runner.data[i] = i
    runner = runner.unsqueeze(0).expand(batch_size, -1).long()
    outputs = []
    pointers = []

    def step(x, hidden):
      h, c = hidden
      gates = self.input_to_hidden(x) + self.hidden_to_hidden(h)
      input, forget, cell, out = gates.chunk(4, 1)
      input = F.sigmoid(input)
      forget = F.sigmoid(forget)
      cell = F.tanh(cell)
      out = F.sigmoid(out)
      c_t = (forget * c) + (input * cell)
      h_t = out * F.tanh(c_t)
      hidden_t, output = self.attention(h_t, context, torch.eq(mask, 0))
      hidden_t = F.tanh(self.hidden_out(torch.cat((hidden_t, h_t), 1)))
      return hidden_t, c_t, output

    for _ in range(inputs_len):
      h_t, c_t, outs = step(decoder_input, hidden)
      hidden = (h_t, c_t)
      masked_outs = outs * mask
      max_probs, indices = masked_outs.max(1)
      one_hot_pointers = (runner == indices.unsqueeze(1).expand(-1, outs.size()[1])).float()
      mask *= 1 - one_hot_pointers
      embedding_mask = one_hot_pointers.unsqueeze(2).expand(-1, -1, self.inputs_dim).byte()
      decoder_input = inputs[embedding_mask.data.bool()].view(batch_size, self.inputs_dim)
      outputs.append(outs.unsqueeze(0))
      pointers.append(indices.unsqueeze(1))

    outputs = torch.cat(outputs).permute(1, 0, 2)
    pointers = torch.cat(pointers, 1)
    return (outputs, pointers), hidden

class PointerNet(nn.Module):
  def __init__(self, inputs_dim):
    super().__init__()
    self.inputs_dim = inputs_dim
    self.embedding_dim = 256
    self.hidden_dim = 128
    self.embedding = nn.Sequential(
      nn.Conv1d(1, 16, 4),
      nn.MaxPool1d(2),
      nn.Conv1d(16, 64, 4),
      nn.MaxPool1d(2),
      nn.Conv1d(64, 256, 4),
      nn.MaxPool1d(2),
      nn.Flatten(),
      nn.Linear(8997, 1),
    )
    self.encoder = Encoder(self.embedding_dim, self.hidden_dim // 2)
    self.decoder = Decoder(self.embedding_dim, self.hidden_dim)
    self.decoder_input0 = nn.Parameter(torch.FloatTensor(self.embedding_dim), requires_grad=False)
    nn.init.uniform_(self.decoder_input0, -1, 1)

  def forward(self, inputs):
    batch_size = inputs.size(0)
    inputs_len = inputs.size(1)

    decoder_input0 = self.decoder_input0.unsqueeze(0).expand(batch_size, -1)

    inputs = inputs.view(batch_size * inputs_len, 1, -1)
    embedded_inputs = torch.cat([self.embedding(inp) for inp in inputs]).view(batch_size, inputs_len, -1)

    encoder_hidden0 = self.encoder.init_hidden(embedded_inputs)
    encoder_outputs, encoder_hidden = self.encoder(embedded_inputs, encoder_hidden0)
    decoder_hidden0 = (torch.cat(tuple(encoder_hidden[0][-2:]), dim=-1), torch.cat(tuple(encoder_hidden[1][-2:]), dim=-1))
    (outputs, pointers), decoder_hidden = self.decoder(embedded_inputs, decoder_input0, decoder_hidden0, encoder_outputs)
    return outputs, pointers

In [4]:
NUM_EPOCHS = 1
NUM_STEPS = 1000

BATCH_SIZE = 1

MAX_DURATION = 3000
DURATION = 3000
SHIFT = 10

EMBEDDING_DIM = int(DURATION * SAMPLE_RATE / 1000)

In [5]:
def decompose_voice(audio_array, period, max_period):
  return [audio_array[start:start + period].tolist() for start in range(0, max_period, period)]

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

originals = {}
voice_map = {}
for filename in os.listdir("./data/bark"):
  if not filename.endswith(".wav"): continue
  folder_name = filename.split(".wav")[0]
  originals[folder_name] = load_voice(f"./data/bark/{filename}")
  if not folder_name in voice_map: voice_map[folder_name] = []
  for fn in os.listdir(f"./data/bark/{folder_name}"):
    if not fn.endswith(".wav"): continue
    voice_map[folder_name].append(load_voice(f"./data/bark/{folder_name}/{fn}"))

num_frames = MAX_DURATION // DURATION
num_voices = sum([len(voices) for voices in voice_map.values()])
num_samples = 5
ptr_net = PointerNet(len(voice_map) * num_samples * num_frames)
optimizer = optim.Adam(ptr_net.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()
losses = []

ptr_net.train()

if device == "cuda": ptr_net.cuda()
for step in range(1, NUM_STEPS + 1):
  optimizer.zero_grad()
  processed = []
  labels = torch.zeros((1, len(voice_map) * num_samples * num_frames)).to(device)
  for i, voices in enumerate(voice_map.values()):
    for voice in random.choices(voices, k=num_samples):
      start = np.random.randint(0, len(voice) - int(MAX_DURATION * SAMPLE_RATE / 1000))
      processed.extend(decompose_voice(voice[start:], EMBEDDING_DIM, int(MAX_DURATION * SAMPLE_RATE / 1000)))
    labels[:,i * num_frames * num_samples:] += 1
  labels -= 1
  outputs, pointers = ptr_net(torch.tensor(processed).to(device).view(1, len(voice_map) * num_samples * num_frames, -1))
  loss = loss_fn(torch.floor(pointers / num_samples / num_frames), labels)
  loss.requires_grad = True
  # display(torch.floor(pointers * len(voice_map) / num_voices / num_frames), labels)
  losses.append(loss.item())
  loss.backward()
  optimizer.step()
  if step % 100 == 0:
    print(f"{step}/{NUM_STEPS}:\t{np.mean(losses):.4f}")
    losses = []
print("Training Ended!")

100/1000:	11.5129
200/1000:	11.5129
300/1000:	11.5129
400/1000:	11.5129
500/1000:	11.5129
600/1000:	11.5129
700/1000:	11.5129
800/1000:	11.5129
900/1000:	11.5129
1000/1000:	11.5129
Training Ended!


In [125]:
ptr_net.eval()
with torch.no_grad():
  processed = []
  labels = torch.zeros((1, len(voice_map) * num_samples * num_frames)).to(device)
  for i, voices in enumerate(voice_map.values()):
    for voice in random.choices(voices, k=num_samples):
      start = np.random.randint(0, len(voice) - int(MAX_DURATION * SAMPLE_RATE / 1000))
      processed.extend(decompose_voice(voice[start:], EMBEDDING_DIM, int(MAX_DURATION * SAMPLE_RATE / 1000)))
    labels[:,i * num_frames * num_samples:] += 1
  labels -= 1
  outputs, pointers = ptr_net(torch.tensor(processed).to(device).view(1, len(voice_map) * num_samples * num_frames, -1))
  display(num_frames * num_samples)
  display(pointers, labels)

5

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')

tensor([[0., 0., 0., 0., 0., 1., 1., 1., 1., 1.]], device='cuda:0')

In [8]:
def generate(basename, raw_prompt, index):
  basepath = "./data/bark/"
  text_prompt = re.sub(r"\s\s+", " ", re.sub(r"[\t\n]", "", raw_prompt)).strip()
  print(f"{basename}: prompt-{index}")
  audio_array, generation_data = generate_audio(text_prompt, basename)
  if basename not in os.listdir(basepath): os.mkdir(f"{basepath}/{basename}")
  save_history(f"{basepath}/{basename}/prompt-{index}", audio_array, text_prompt, generation_data)

In [9]:
raw_prompts = [
  "A pessimist is one who makes difficulties of his opportunities and an optimist is one who makes opportunities of his difficulties.",
  "Don't judge each day by the harvest you reap but by the seeds that you plant.",
  "Challenges are what make life interesting and overcoming them is what makes life meaningful.",
  "Happiness lies not in the mere possession of money; it lies in the joy of achievement, in the thrill of creative effort.",
  "I disapprove of what you say, but I will defend to the death your right to say it.",
  "If I looked compared to others far, is because I stand on giant's shoulder.",
  "Never argue with stupid people, they will drag you down to their level and then beat you with experience.",
  "The greatest glory in living lies not in never falling, but in rising every time we fall.",
  "When you look into the abyss, the abyss also looks into you.",
  "Whoever fights monsters should see to it that in the process he does not become a monster."
]

basename = "woman-1"

for i, raw_prompt in enumerate(raw_prompts):
  generate(basename, raw_prompt, i)

woman-1: prompt-0
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:04<00:00, 24.21it/s]
100%|██████████| 22/22 [00:21<00:00,  1.04it/s]


woman-1: prompt-1
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:03<00:00, 30.84it/s] 
100%|██████████| 19/19 [00:18<00:00,  1.05it/s]


woman-1: prompt-2
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:03<00:00, 32.42it/s]
100%|██████████| 18/18 [00:17<00:00,  1.03it/s]


woman-1: prompt-3
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:04<00:00, 20.84it/s]
100%|██████████| 24/24 [00:23<00:00,  1.02it/s]


woman-1: prompt-4
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:02<00:00, 44.22it/s] 
100%|██████████| 14/14 [00:14<00:00,  1.02s/it]


woman-1: prompt-5
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:01<00:00, 59.81it/s] 
100%|██████████| 11/11 [00:11<00:00,  1.01s/it]


woman-1: prompt-6
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:03<00:00, 28.94it/s] 
100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


woman-1: prompt-7
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:04<00:00, 22.16it/s]
100%|██████████| 23/23 [00:23<00:00,  1.00s/it]


woman-1: prompt-8
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:02<00:00, 36.20it/s]
100%|██████████| 16/16 [00:16<00:00,  1.00s/it]


woman-1: prompt-9
history_prompt in gen: woman-1
woman-1
aa


100%|██████████| 100/100 [00:01<00:00, 59.00it/s]
100%|██████████| 12/12 [00:11<00:00,  1.06it/s]
