In [25]:
import torch
from core.model import GPT, GPTConfig
from core.tokenizer import Tokenizer, AudioTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
import soundfile as sf
from IPython.display import Audio


In [2]:
model_config = GPTConfig.from_pretrained('EleutherAI/pythia-410m')
model = GPT(model_config)
tokenizer = Tokenizer()
audio_tokenizer = AudioTokenizer()

number of parameters: 409.55M
Loading Audio Encoder


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


Loaded Audio Encoder, size: 1029


In [3]:
# Step 1, load the model

model_name = 'alexedw/audio-clean-100-run-2'
model_checkpoint = '750'

state_dict = torch.load(hf_hub_download(model_name, "model_state.pt", revision=model_checkpoint), map_location='cpu')
new_state_dict = {key.replace("_orig_mod.", ""): value for key, value in state_dict.items()}

model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [4]:
# prepare tokens
input_text = 'my name is alexander and I am so great'
tokenized_text = tokenizer.encode(input_text)

full_text_tokens = np.concatenate([np.array([tokenizer.start_text_id]), tokenized_text, np.array([tokenizer.end_text_id])])
full_audio_tokens = np.concatenate([np.array([audio_tokenizer.start_text_id]), np.array([audio_tokenizer.text_id] * len(tokenized_text)), np.array([audio_tokenizer.end_text_id])])

full_text_tokens = torch.tensor(full_text_tokens).unsqueeze(0)
full_audio_tokens_1 = torch.tensor(full_audio_tokens).unsqueeze(0)
full_audio_tokens_2 = torch.tensor(full_audio_tokens).unsqueeze(0)

In [58]:
# run generation
output_text, output_audio_1, output_audio_2 = model.generate(full_text_tokens, full_audio_tokens_1, full_audio_tokens_2, 200, temperature=1.0, top_p=0.95)

100%|██████████| 200/200 [00:35<00:00,  5.59it/s]


In [59]:
start_audio = len(full_audio_tokens) + 1
end_audio = len(output_text[0])

audio_tokens_1 = output_audio_1[0, start_audio:end_audio]
audio_tokens_2 = output_audio_2[0, start_audio:end_audio]

audio_tokens_stacked = torch.stack([audio_tokens_1, audio_tokens_2], dim=0)

In [60]:
sound = audio_tokenizer.decode(audio_tokens_stacked)

In [61]:

wav_filename = "temp_audio2.wav"
sf.write(wav_filename, sound.detach().numpy(), 24000)
Audio(filename=wav_filename)

In [62]:
# random audio tokens test
audio_tokens_stacked_random = torch.randint(0, 1000, (2, 500))
sound2 = audio_tokenizer.decode(audio_tokens_stacked_random)
wav_filename2 = "temp_audio22.wav"
sf.write(wav_filename2, sound2.detach().numpy(), 24000)
Audio(filename=wav_filename2)