In [1]:
cd Prosody2Vec/

/home/dcor/niskhizov/Prosody2Vec


In [2]:
import torch

In [3]:
# import clearml
# clearml.browser_login()

In [4]:
# from clearml import Task
# task = Task.init(project_name="my project", task_name="my task")

In [5]:
from torch import nn
import torch 
import glob
from IPython.display import clear_output, display, Audio
import copy

In [6]:
import torch

In [7]:
import model

In [8]:

class Attention(nn.Module):
    def __init__(self, hidden_dim, vec_a_size, vec_b_size):
        super(Attention, self).__init__()
        self.attn_a = nn.Linear(vec_a_size, hidden_dim)  # Project vec_a to hidden_dim
        self.attn_b = nn.Linear(vec_b_size, hidden_dim)  # Project vec_b to hidden_dim
        self.attn_score = nn.Linear(hidden_dim, 1)  # Compute attention scores
    
    def forward(self, matrix, vec_a, vec_b):
        # matrix: (batch_size, time, hidden_dim)
        # vec_a: (batch_size, vec_a_size)
        # vec_b: (batch_size, vec_b_size)
        
        batch_size, time, hidden_dim = matrix.shape
        
        # Project vectors into hidden space
        a_proj = self.attn_a(vec_a).unsqueeze(1).expand(-1, time, -1)  # (batch_size, time, hidden_dim)
        b_proj = self.attn_b(vec_b).unsqueeze(1).expand(-1, time, -1)  # (batch_size, time, hidden_dim)
        
        # Compute attention scores
        attn_input = torch.tanh(matrix + a_proj + b_proj)  # Combine information
        attn_scores = self.attn_score(attn_input).squeeze(-1)  # (batch_size, time)
        attn_weights = torch.softmax(attn_scores, dim=-1).unsqueeze(-1)  # (batch_size, time, 1)
        
        # Apply attention to the matrix
        updated_matrix = matrix * attn_weights  # Element-wise weighting
        
        return updated_matrix  # (batch_size, time, hidden_dim)

# Example Usage
batch_size, time, hidden_dim, vec_a_size, vec_b_size = 32, 10, 64, 16, 16
matrix = torch.randn(batch_size, time, hidden_dim)
vec_a = torch.randn(batch_size, vec_a_size)
vec_b = torch.randn(batch_size, vec_b_size)

attn = Attention(hidden_dim, vec_a_size, vec_b_size)
output_matrix = attn(matrix, vec_a, vec_b)
print(output_matrix.shape)  # Should be (batch_size, time, hidden_dim)


torch.Size([32, 10, 64])


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionFusion(nn.Module):
    def __init__(self, hubert_dim, speaker_dim, prosody_dim, hidden_dim, num_heads=8):
        super(AttentionFusion, self).__init__()
        
        # Linear projections to match hidden_dim
        self.speaker_proj = nn.Linear(speaker_dim, hidden_dim // 2)
        self.prosody_proj = nn.Linear(prosody_dim, hidden_dim // 2)
        
        # Multi-head attention layer
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)

        # Final feedforward layer (optional)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, hubert_features, prosody_embedding, speaker_embedding):
        """
        Inputs:
        - hubert_features: (batch, seq_len, hubert_dim)
        - speaker_embedding: (batch, speaker_dim)
        - prosody_embedding: (batch, prosody_dim)

        Output:
        - attended_features: (batch, seq_len, hidden_dim)
        """

        # Project speaker and prosody to hidden_dim
        S_aligned = self.speaker_proj(speaker_embedding)  # (batch, hidden_dim)
        P_aligned = self.prosody_proj(prosody_embedding)  # (batch, hidden_dim)

        # Combine speaker & prosody information
        # Q = S_aligned + P_aligned  # Alternative: torch.cat([S_aligned, P_aligned], dim=-1)
        Q = torch.cat([S_aligned, P_aligned], dim=-1)

        # Expand Q to match seq_len for attention
        Q = Q.unsqueeze(1).expand(-1, hubert_features.shape[1], -1)  # (batch, seq_len, hidden_dim)

        # Apply cross-attention: Q attends to H (HuBERT features)
        attn_output, _ = self.cross_attn(Q, hubert_features, hubert_features)  # (batch, seq_len, hidden_dim)

        # Pass through FFN for further feature refinement
        attended_features = self.ffn(attn_output)

        return attended_features

# Define input dimensions
hubert_dim = 512   # Example size from HuBERT output
speaker_dim = 192  # Example ECAPA-TDNN speaker embedding size
prosody_dim = 1024   # Example emotion recognition prosody size
hidden_dim = 512   # Common dimension for fusion
seq_len = 100      # Example sequence length
batch_size = 16    # Example batch size

# Initialize model
fusion_model = AttentionFusion(hubert_dim, speaker_dim, prosody_dim, hidden_dim)

# Dummy Inputs
hubert_features = torch.randn(batch_size, seq_len, hubert_dim)
speaker_embedding = torch.randn(batch_size, speaker_dim)
prosody_embedding = torch.randn(batch_size, prosody_dim)

# Forward pass
output = fusion_model(hubert_features, prosody_embedding, speaker_embedding)
print(output.shape)  # Expected: (batch_size, seq_len, hidden_dim)


torch.Size([16, 100, 512])


In [None]:
# acoustic = torch.hub.load("bshall/acoustic-model:main", "hubert_soft", trust_repo=True).cuda()
# acoustic = model.AcousticModel(discrete=True).cuda()

In [12]:
# data_dir = './Emotion Speech Dataset/'
data_dir = '/home/dcor/niskhizov/Prosody2Vec/IEMOCAP_full_release/'
# scan recursively for all .wav files in the data_dir
wav_files = glob.glob(data_dir + '/**/*.wav', recursive=True)



In [13]:
embeddings_dir = 'iemocap_embeddings_3sec'

In [14]:
# create pytorch dataset that loads pairs of wav a and embeddings from iemocap_embeddings
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import pickle
import torchaudio

class IemocapDataset(Dataset):
    def __init__(self, audio_files):
        self.audio_files = []
        self.embeddings_file = []

        for audio_file in audio_files:
            out_file = f"{embeddings_dir}/{audio_file.split('/')[-1].replace('.wav', '.pkl')}"
            if os.path.exists(out_file):                
                self.embeddings_file.append(out_file)
                self.audio_files.append(audio_file)

    def __len__(self):
        return len(self.embeddings_file)
    
    def __getitem__(self, idx):

        wav_path = self.audio_files[idx]

        out_file = self.embeddings_file[idx]

        with open(out_file, 'rb') as f:
            embd = pickle.load(f)

        wav,sr = torchaudio.load(wav_path)

        # take the first 3 seconds of the audio

        wav = wav[:, :3*sr]

        
        
        return wav, embd

In [15]:
acoustic.train()


class Decoder(nn.Module):
    def __init__(self, hidden_dim, acoustic):
        super(Decoder, self).__init__()
        # self.attn = Attention(hidden_dim, 1024, 192)
        self.attn = AttentionFusion(hubert_dim, speaker_dim, prosody_dim, hidden_dim)

        self.decoder_rnn = copy.deepcopy(acoustic)

    def forward(self, units, emo_vecs, spk_vecs, logmels, teacher_forcing_ratio=0.5):
        # units: (batch_size, time, hidden_dim)
        # emo_vecs: (batch_size, emo_vec_size)
        # spk_vecs: (batch_size, spk_vec_size)
        # logmels: (batch_size, time, n_mels)
        
        # batch_size, time, _ = units.shape
        
        # Apply attention
        cont_units = self.decoder_rnn.encoder(units)

        units_attn = self.attn(cont_units, emo_vecs, spk_vecs)  # (batch_size, time, hidden_dim)
        
        to_techer_force = torch.rand(1)[0] < teacher_forcing_ratio
        if to_techer_force:
            return self.decoder_rnn.decoder(units_attn, logmels)
        else:
            return self.decoder_rnn.decoder.generate(units_attn)
        
decoder = Decoder(512, acoustic).cuda()

In [16]:
# load the latest decoder
# decoders = glob.glob('decoder_*.pth')
# decoders.sort()
# decoder.load_state_dict(torch.load(decoders[-1]))


In [17]:
ds = IemocapDataset(wav_files)

In [18]:
train_ds, test_ds = torch.utils.data.random_split(ds, [int(0.8*len(ds)), len(ds) - int(0.8*len(ds))])

In [19]:
# acoustic(units.cuda().unsqueeze(0), logmel.unsqueeze(0).transpose(1,2).cuda())

In [20]:
# acoustic.decoder(enc,logmel.unsqueeze(0).cuda().transpose(1,2))

In [21]:
ds[0][1]['logmel'].shape

torch.Size([128, 300])

In [22]:
ds[0][1]

{'discrite_units': tensor([ 6,  6,  6,  6,  6,  6,  6, 69, 69, 69, 69, 69, 69, 69, 74, 69, 69, 69,
         74, 69, 69, 74, 69, 19, 74, 19, 74, 90, 94, 59, 78, 63, 47, 47, 35, 35,
          0,  0, 73, 29, 76, 76, 28, 28, 28, 20, 20, 17, 17, 95, 85, 85, 85, 41,
         41, 29, 30, 92, 34, 34, 84, 84, 18, 82, 82, 43, 50, 50, 50, 51, 51, 51,
         20, 20, 78, 89, 79, 79, 85, 85, 75, 75, 75, 24, 56, 56, 95, 95,  0,  0,
         27, 33, 78, 36, 30, 34, 34, 78, 63, 43, 43, 43, 50, 50, 50, 50, 50, 50,
         51, 51, 51, 91,  5,  5,  5, 32, 32, 32, 37, 37, 37, 37, 37, 37, 37,  3,
          3,  3,  3,  3,  3, 45, 45, 45, 45, 74, 45, 45,  4, 69, 69,  4, 69, 69,
         69, 69, 69,  6,  6,  6]),
 'units': tensor([[ 0.1531,  0.0203, -0.5402,  ...,  0.2662, -0.0722,  0.3767],
         [ 0.1603,  0.0284, -0.5288,  ...,  0.2461, -0.0886,  0.3676],
         [ 0.1830,  0.0294, -0.5550,  ...,  0.2188, -0.1136,  0.3648],
         ...,
         [-0.2571, -0.1120, -0.5201,  ...,  0.8423, -0.0786,  0

In [23]:
# create collate function that will pad the sequences to the same length
def collate_fn(batch):
    wavs = [item[0][0] for item in batch]
    
    d_units, units, emo_vecs, spk_vecs, logmels = [], [], [], [], []
    for item in batch:
        d = item[1]['discrite_units']
        u = item[1]['units']
        mel = item[1]['logmel'].T

        d_units.append(d)
        units.append(u)
        emo_vecs.append(torch.tensor((item[1]['emo_vec'])))
        spk_vecs.append(item[1]['spk_vec'])

        mel  = mel[:u.size(0)*2,:]
        # print(mel.shape)
        mel = torch.nn.functional.pad(mel, (0,0,1,0))
        # print(mel.shape)

        logmels.append(mel)

    
    mels_lengths = torch.tensor([x.size(0) - 1 for x in logmels])
    units_lengths = torch.tensor([x.size(0) for x in units])

    d_units_padded = nn.utils.rnn.pad_sequence(d_units, batch_first=True, padding_value=-1)
    units_padded = nn.utils.rnn.pad_sequence(units, batch_first=True)
    logmels_padded = nn.utils.rnn.pad_sequence(logmels, batch_first=True)
    
    _,T,_ = units_padded.shape
    # pad the sequences

    wavs = nn.utils.rnn.pad_sequence(wavs, batch_first=True)

    
    return wavs, d_units_padded, units_padded, torch.stack(emo_vecs), torch.stack(spk_vecs), logmels_padded, mels_lengths, units_lengths

In [24]:
ds[1][1]['logmel'].shape

torch.Size([128, 300])

In [25]:
ds[0][1]['emo_vec'].shape

(1024,)

In [26]:
ds[0][1]['spk_vec'].shape

torch.Size([192])

In [27]:
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers=10)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False, collate_fn=collate_fn, num_workers=10)

In [28]:
it = iter(train_dl)

In [29]:
b= next(it)

In [30]:
b[1].shape

torch.Size([32, 150])

In [166]:
acoustic.encoder(b[1].cuda()).shape

In [167]:
from torch.optim import Adam
from torch.nn.functional import l1_loss

optimizer = Adam(decoder.parameters(), lr=1e-4)


In [168]:
from tqdm import tqdm_notebook,tqdm

In [None]:
for epoch in range(1000):  
    decoder.train()
    for idx,batch in tqdm(enumerate(train_dl),total=len(train_dl)):
        wavs, d_units_padded, units_padded, emo_vecs, spk_vecs, logmels_padded, mels_lengths, units_lengths  = batch
        
        optimizer.zero_grad()

        out = decoder(d_units_padded.cuda(), emo_vecs.cuda(), spk_vecs.cuda(), logmels_padded[:, :-1, :].cuda())
        # out = acoustic(units_padded.cuda(), logmels_padded[:, :-1, :].cuda())

        # target = hifigan(out[:1,:,:].transpose(1, 2))
        loss = l1_loss(out, logmels_padded[:, 1:, :].cuda(), reduction="none")
        loss = torch.sum(loss, dim=(1, 2)) / (out.size(-1) * mels_lengths.cuda())
        loss = torch.mean(loss)
        loss.backward()

        optimizer.step()

        if idx % 100 == 0:
            print('Epoch:', epoch, 'Batch:', idx)
            print('Loss:', loss.item())

    if epoch % 10 == 0:
        torch.save(decoder.state_dict(), f"decoder_{epoch}.pth")

In [65]:
import plotly.express as px

In [66]:
px.imshow(out[0].detach().cpu().numpy().T)

## Inference

In [36]:
ls -lash --sort time | grep decoder

 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 23:01 decoder_20.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:50 decoder_10.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:39 decoder_0.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:18 decoder_260.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:16 decoder_250.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:13 decoder_240.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:11 decoder_230.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:08 decoder_220.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:05 decoder_210.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:02 decoder_200.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:00 decoder_190.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 21:57 decoder_180.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 21:54 decoder_170.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 21:52 decoder_160.pth
 80M -rw-r

In [38]:
# load the latest decoder
decoders = glob.glob('decoder_*.pth')
# sort by last modification time
decoders.sort(key=os.path.getmtime)
decoder.load_state_dict(torch.load(decoders[-1]))
print(decoders[-1])
decoder.eval()


decoder_20.pth


  decoder.load_state_dict(torch.load(decoders[-1]))


Decoder(
  (attn): AttentionFusion(
    (speaker_proj): Linear(in_features=192, out_features=256, bias=True)
    (prosody_proj): Linear(in_features=1024, out_features=256, bias=True)
    (cross_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (ffn): Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (decoder_rnn): AcousticModel(
    (encoder): Encoder(
      (embedding): Embedding(101, 256)
      (prenet): PreNet(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=256, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.5, inplace=False)
          (3): Linear(in_features=256, out_features=256, bias=True)
          (4): ReLU()
          (5): Dropout(p=0.5, inplace=False)
        )
      )
      (convs): Sequential(
        (0): Conv1d(256, 512, kernel_

In [39]:
it = iter(test_dl)

In [40]:
batch = next(it)

In [41]:
wavs,d_units, units_padded, emo_vecs, spk_vecs, logmels_padded, mels_lengths, units_lengths  = batch


In [62]:
decoder = decoder.eval()

In [46]:
with torch.no_grad():
        
        cont_units = decoder.decoder_rnn.encoder(d_units.cuda())

        units_attn = decoder.attn(cont_units.cuda(), emo_vecs.cuda(), spk_vecs.cuda())  # (batch_size, time, hidden_dim)
        
   
        o = decoder.decoder_rnn.decoder.generate(units_attn)
        

In [52]:
acoustic = torch.hub.load("bshall/acoustic-model:main", "hubert_discrete", trust_repo=True).cuda()

with torch.no_grad():
    o = acoustic.generate(d_units.cuda())

Using cache found in /home/dcor/niskhizov/cache/hub/bshall_acoustic-model_main
Downloading: "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt" to /home/dcor/niskhizov/cache/hub/checkpoints/hubert-discrete-d49e1c77.pt
100%|██████████| 71.9M/71.9M [00:19<00:00, 3.97MB/s]


In [57]:
import plotly.express as px
px.imshow(o[10].detach().cpu().numpy().T)

In [54]:
hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_discrete", trust_repo=True).cuda()

Using cache found in /home/dcor/niskhizov/cache/hub/bshall_hifigan_main

`torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.



In [67]:
with torch.no_grad():
    target = hifigan(o[10,:,:].unsqueeze(0).transpose(1, 2))

In [68]:
Audio(target[0].detach().cpu().numpy(), rate=16000)

In [69]:
Audio(wavs[10],rate=16000)

In [None]:
with torch.no_grad():
    units_attn = decoder.attn(units_padded[1].unsqueeze(0).cuda(), emo_vecs[10].unsqueeze(0).cuda(), spk_vecs[10].unsqueeze(0).cuda())

In [None]:
with torch.no_grad():
    out = decoder.decoder_rnn.generate(units_attn.cuda())
    # out = acoustic.generate(units_padded.cuda())

In [None]:
import plotly.express as px


In [None]:
with torch.no_grad():
    rec= hifigan(out[0].T.unsqueeze(0))[0][0]

In [None]:
out.shape

In [None]:
px.imshow(out[0].detach().cpu().numpy().T,origin='lower')

In [None]:
Audio(rec.squeeze().cpu().numpy(), rate=16000)

In [None]:
r = rec.cpu()

In [None]:
torchaudio.save('./rec.wav', r.unsqueeze(0).float(),sample_rate=16000)

In [None]:
import numpy as np

In [None]:
import torch

In [None]:
torch.ones()