In [None]:
cd Prosody2Vec/

In [None]:
import torch
tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16')
tacotron2 = tacotron2.to('cuda')

In [None]:
tacotron2.embedding

In [None]:
tacotron2.encoder

In [None]:
1408 - 1024 

In [None]:
256 + 128

In [None]:
tacotron2.decoder

In [None]:
# import clearml
# clearml.browser_login()

In [None]:
# from clearml import Task
# task = Task.init(project_name="my project", task_name="my task")

In [None]:
from torch import nn
import torch 
import glob
from IPython.display import clear_output, display, Audio
import copy

In [None]:
import torch

In [None]:
import model

In [None]:

class Attention(nn.Module):
    def __init__(self, hidden_dim, vec_a_size, vec_b_size):
        super(Attention, self).__init__()
        self.attn_a = nn.Linear(vec_a_size, hidden_dim)  # Project vec_a to hidden_dim
        self.attn_b = nn.Linear(vec_b_size, hidden_dim)  # Project vec_b to hidden_dim
        self.attn_score = nn.Linear(hidden_dim, 1)  # Compute attention scores
    
    def forward(self, matrix, vec_a, vec_b):
        # matrix: (batch_size, time, hidden_dim)
        # vec_a: (batch_size, vec_a_size)
        # vec_b: (batch_size, vec_b_size)
        
        batch_size, time, hidden_dim = matrix.shape
        
        # Project vectors into hidden space
        a_proj = self.attn_a(vec_a).unsqueeze(1).expand(-1, time, -1)  # (batch_size, time, hidden_dim)
        b_proj = self.attn_b(vec_b).unsqueeze(1).expand(-1, time, -1)  # (batch_size, time, hidden_dim)
        
        # Compute attention scores
        attn_input = torch.tanh(matrix + a_proj + b_proj)  # Combine information
        attn_scores = self.attn_score(attn_input).squeeze(-1)  # (batch_size, time)
        attn_weights = torch.softmax(attn_scores, dim=-1).unsqueeze(-1)  # (batch_size, time, 1)
        
        # Apply attention to the matrix
        updated_matrix = matrix * attn_weights  # Element-wise weighting
        
        return updated_matrix  # (batch_size, time, hidden_dim)

# Example Usage
batch_size, time, hidden_dim, vec_a_size, vec_b_size = 32, 10, 64, 16, 16
matrix = torch.randn(batch_size, time, hidden_dim)
vec_a = torch.randn(batch_size, vec_a_size)
vec_b = torch.randn(batch_size, vec_b_size)

attn = Attention(hidden_dim, vec_a_size, vec_b_size)
output_matrix = attn(matrix, vec_a, vec_b)
print(output_matrix.shape)  # Should be (batch_size, time, hidden_dim)


In [None]:
# acoustic = torch.hub.load("bshall/acoustic-model:main", "hubert_soft", trust_repo=True).cuda()
acoustic = model.AcousticModel().cuda()

In [None]:
hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft", trust_repo=True).cuda()

In [None]:
# data_dir = './Emotion Speech Dataset/'
data_dir = '/home/dcor/niskhizov/Prosody2Vec/IEMOCAP_full_release/'
# scan recursively for all .wav files in the data_dir
wav_files = glob.glob(data_dir + '/**/*.wav', recursive=True)



In [None]:
embeddings_dir = 'iemocap_embeddings'

In [None]:
# create pytorch dataset that loads pairs of wav a and embeddings from iemocap_embeddings
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import pickle
import torchaudio

class IemocapDataset(Dataset):
    def __init__(self, audio_files):
        self.audio_files = []
        self.embeddings_file = []

        for audio_file in audio_files:
            out_file = f"iemocap_embeddings/{audio_file.split('/')[-1].replace('.wav', '.pkl')}"
            if os.path.exists(out_file):                
                self.embeddings_file.append(out_file)
                self.audio_files.append(audio_file)

    def __len__(self):
        return len(self.embeddings_file)
    
    def __getitem__(self, idx):

        wav_path = self.audio_files[idx]

        out_file = self.embeddings_file[idx]

        with open(out_file, 'rb') as f:
            embd = pickle.load(f)
        
        wav,sr = torchaudio.load(wav_path)

        # take the first 3 seconds of the audio

        wav = wav[:, :3*sr]

        
        
        return wav, embd

In [None]:
acoustic.train()


class Decoder(nn.Module):
    def __init__(self, hidden_dim, acoustic):
        super(Decoder, self).__init__()
        self.attn = Attention(hidden_dim, 1024, 192)
        self.decoder_rnn = copy.deepcopy(acoustic)

    def forward(self, units, emo_vecs, spk_vecs, logmels):
        # units: (batch_size, time, hidden_dim)
        # emo_vecs: (batch_size, emo_vec_size)
        # spk_vecs: (batch_size, spk_vec_size)
        # logmels: (batch_size, time, n_mels)
        
        batch_size, time, _ = units.shape
        
        # Apply attention
        units_attn = self.attn(units, emo_vecs, spk_vecs)  # (batch_size, time, hidden_dim)
        return self.decoder_rnn.generate(units_attn)
        
decoder = Decoder(256, acoustic).cuda()

In [None]:
# load the latest decoder
decoders = glob.glob('decoder_*.pth')
decoders.sort()
decoder.load_state_dict(torch.load(decoders[-1]))


In [None]:
ds = IemocapDataset(wav_files)

In [None]:
train_ds, test_ds = torch.utils.data.random_split(ds, [int(0.8*len(ds)), len(ds) - int(0.8*len(ds))])

In [None]:
# acoustic(units.cuda().unsqueeze(0), logmel.unsqueeze(0).transpose(1,2).cuda())

In [None]:
# acoustic.decoder(enc,logmel.unsqueeze(0).cuda().transpose(1,2))

In [None]:
ds[0][1]['logmel'].shape

In [None]:
# create collate function that will pad the sequences to the same length
def collate_fn(batch):
    wavs = [item[0][0] for item in batch]
    
    units, emo_vecs, spk_vecs, logmels = [], [], [], []
    for item in batch:
        u = item[1]['units'][:150,:]
        mel = item[1]['logmel'][:,:300].T
        units.append(u)
        emo_vecs.append(torch.tensor((item[1]['emo_vec'])))
        spk_vecs.append(item[1]['spk_vec'])

        mel  = mel[:u.size(0)*2,:]
        # print(mel.shape)
        mel = torch.nn.functional.pad(mel, (0,0,1,0))
        # print(mel.shape)

        logmels.append(mel)

    
    mels_lengths = torch.tensor([x.size(0) - 1 for x in logmels])
    units_lengths = torch.tensor([x.size(0) for x in units])

    units_padded = nn.utils.rnn.pad_sequence(units, batch_first=True)
    logmels_padded = nn.utils.rnn.pad_sequence(logmels, batch_first=True)
    
    _,T,_ = units_padded.shape
    # pad the sequences

    wavs = nn.utils.rnn.pad_sequence(wavs, batch_first=True)

    
    return wavs, units_padded, torch.stack(emo_vecs), torch.stack(spk_vecs), logmels_padded, mels_lengths, units_lengths

In [None]:
ds[200][1]['logmel'].shape

In [None]:
ds[0][1]['units'].T.shape

In [None]:
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=10)
test_dl = DataLoader(test_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=10)

In [None]:
# it = iter(dl)

In [None]:
from torch.optim import Adam
from torch.nn.functional import l1_loss

optimizer = Adam(decoder.parameters(), lr=1e-4)


In [None]:
from tqdm import tqdm_notebook,tqdm

In [None]:
decoder.attn.attn_a.weight.data.norm()

In [None]:
wavs.shape

In [None]:
units_padded.shape

In [None]:
for epoch in range(1000):  
    decoder.train()
    for idx,batch in tqdm(enumerate(dl),total=len(dl)):
        wavs, units_padded, emo_vecs, spk_vecs, logmels_padded, mels_lengths, units_lengths  = batch
        
        optimizer.zero_grad()

        out = decoder(units_padded.cuda(), emo_vecs.cuda(), spk_vecs.cuda(), logmels_padded[:, :-1, :].cuda())
        # out = acoustic(units_padded.cuda(), logmels_padded[:, :-1, :].cuda())

        # target = hifigan(out[:1,:,:].transpose(1, 2))
        loss = l1_loss(out, logmels_padded[:, 1:, :].cuda(), reduction="none")
        loss = torch.sum(loss, dim=(1, 2)) / (out.size(-1) * mels_lengths.cuda())
        loss = torch.mean(loss)
        loss.backward()

        optimizer.step()

        if idx % 100 == 0:
            print('Epoch:', epoch, 'Batch:', idx)
            print('Loss:', loss.item())

    if epoch % 10 == 0:
        torch.save(decoder.state_dict(), f"decoder_{epoch}.pth")

In [None]:
import plotly.express as px

In [None]:
px.imshow(out[0].detach().cpu().numpy().T)

## Inference

In [None]:
# # load the latest decoder
# decoders = glob.glob('decoder_*.pth')
# decoders.sort()
# decoder.load_state_dict(torch.load(decoders[-1]))


In [None]:
it = iter(dl)

In [None]:
batch = next(it)

In [None]:
wavs, units_padded, emo_vecs, spk_vecs, logmels_padded, mels_lengths, units_lengths  = batch


In [None]:
decoder.eval()

In [None]:
with torch.no_grad():
    units_attn = decoder.attn(units_padded[1].unsqueeze(0).cuda(), emo_vecs[10].unsqueeze(0).cuda(), spk_vecs[10].unsqueeze(0).cuda())

In [None]:
with torch.no_grad():
    out = decoder.decoder_rnn.generate(units_attn.cuda())
    # out = acoustic.generate(units_padded.cuda())

In [None]:
import plotly.express as px


In [None]:
with torch.no_grad():
    rec= hifigan(out[0].T.unsqueeze(0))[0][0]

In [None]:
px.imshow(out[0].detach().cpu().numpy().T,origin='lower')

In [None]:
Audio(rec.squeeze().cpu().numpy(), rate=16000)

In [None]:
r = rec.cpu()

In [None]:
torchaudio.save('./rec.wav', r.unsqueeze(0).float(),sample_rate=16000)

In [None]:
import numpy as np

In [None]:
import torch

In [None]:
torch.ones()