In [None]:
cd Prosody2Vec/

/home/dcor/niskhizov/Prosody2Vec


In [2]:
import torch

In [3]:
# import clearml
# clearml.browser_login()

In [4]:
# from clearml import Task
# task = Task.init(project_name="my project", task_name="my task")

In [5]:
from torch import nn
import torch 
import glob
from IPython.display import clear_output, display, Audio
import copy

In [6]:
import torch

In [7]:
import model

In [None]:
# acoustic = torch.hub.load("bshall/acoustic-model:main", "hubert_soft", trust_repo=True).cuda()
# acoustic = model.AcousticModel(discrete=True).cuda()

In [8]:
# data_dir = './Emotion Speech Dataset/'
# data_dir = '/home/dcor/niskhizov/Prosody2Vec/IEMOCAP_full_release/'
data_dir = '/home/dcor/niskhizov/Prosody2Vec/Emotion Speech Dataset/0018/'
# scan recursively for all .wav files in the data_dir
wav_files = glob.glob(data_dir + '/**/*.wav', recursive=True)



In [9]:
embeddings_dir = 'esd_female_018'

In [20]:
# create pytorch dataset that loads pairs of wav a and embeddings from iemocap_embeddings
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import pickle
import torchaudio

class IemocapDataset(Dataset):
    def __init__(self, audio_files):
        self.audio_files = []
        self.embeddings_file = []

        for audio_file in audio_files:
            out_file = f"{embeddings_dir}/{audio_file.split('/')[-1].replace('.wav', '.pkl')}"
            if os.path.exists(out_file):                
                self.embeddings_file.append(out_file)
                self.audio_files.append(audio_file)

    def __len__(self):
        return len(self.embeddings_file)
    
    def __getitem__(self, idx):

        wav_path = self.audio_files[idx]

        out_file = self.embeddings_file[idx]

        with open(out_file, 'rb') as f:
            embd = pickle.load(f)

        wav,sr = torchaudio.load(wav_path)

        # take the first 3 seconds of the audio

        wav = wav[:, :3*sr]

        
        
        return wav, embd

In [183]:
def sinusoidal_positional_encoding(seq_len, hidden_dim):
    position = torch.arange(seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-torch.log(torch.tensor(10000.0)) / hidden_dim))
    pe = torch.zeros(seq_len, hidden_dim)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0) 

In [136]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionFusion(nn.Module):
    def __init__(self, prosody_dim, hidden_dim, num_heads=8):
        super(AttentionFusion, self).__init__()
        
        # Linear projections to match hidden_dim
        self.prosody_proj = nn.Linear(prosody_dim, hidden_dim)
        
        # Multi-head attention layer
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)

        # Final feedforward layer (optional)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, hubert_features, prosody_embedding):
        """
        Inputs:
        - hubert_features: (batch, seq_len, hubert_dim)
        - speaker_embedding: (batch, speaker_dim)
        - prosody_embedding: (batch, prosody_dim)

        Output:
        - attended_features: (batch, seq_len, hidden_dim)
        """

        # Project speaker and prosody to hidden_dim
        P_aligned = self.prosody_proj(prosody_embedding)  # (batch, hidden_dim)

        # Combine speaker & prosody information
        # Q = S_aligned + P_aligned  # Alternative: torch.cat([S_aligned, P_aligned], dim=-1)
        Q = P_aligned

        # Expand Q to match seq_len for attention
        Q = Q.unsqueeze(1).expand(-1, hubert_features.shape[1], -1)  # (batch, seq_len, hidden_dim)

        # Apply cross-attention: Q attends to H (HuBERT features)
        attn_output, _ = self.cross_attn(Q, hubert_features, hubert_features)  # (batch, seq_len, hidden_dim)

        # Pass through FFN for further feature refinement
        attended_features = self.ffn(attn_output)

        return attended_features

# Define input dimensions
prosody_dim = 1024   # Example emotion recognition prosody size
hidden_dim = 512   # Common dimension for fusion
seq_len = 100      # Example sequence length
batch_size = 16    # Example batch size

# Initialize model
fusion_model = AttentionFusion(prosody_dim, hidden_dim)

# Dummy Inputs
hubert_features = torch.randn(batch_size, seq_len, hidden_dim)
prosody_embedding = torch.randn(batch_size, prosody_dim)

# Forward pass
output = fusion_model(hubert_features, prosody_embedding)
print(output.shape)  # Expected: (batch_size, seq_len, hidden_dim)


torch.Size([16, 100, 512])


In [228]:


class FusionDecoder(nn.Module):
    def __init__(self, hidden_dim, acoustic):
        super(FusionDecoder, self).__init__()
        # self.attn = AttentionFusion(prosody_dim, hidden_dim)
        
        self.ff1 = nn.Linear(1024, 512).cuda()
        self.ff2 = nn.Linear(1024, 512).cuda()

        self.base_model = copy.deepcopy(acoustic)

    def forward(self, units, emo_vecs,  logmels):
        # units: (batch_size, time, hidden_dim)
        # emo_vecs: (batch_size, emo_vec_size)
        # spk_vecs: (batch_size, spk_vec_size)
        # logmels: (batch_size, time, n_mels)
        
        # batch_size, time, _ = units.shape
        
        # Apply attention
        o = self.base_model.encoder(units.cuda())

        o2 = self.ff1(emo_vecs.cuda()).unsqueeze(1).expand(-1,o.shape[1] , -1)
        # units_attn = self.attn(o, emo_vecs)  # (batch_size, time, hidden_dim)

        o3 = self.ff2(torch.cat([o, o2], dim=-1))

        d = self.base_model.decoder(o3, logmels)

        return d
    
    def generate(self, units, emo_vecs):
        # units: (batch_size, time, hidden_dim)
        # emo_vecs: (batch_size, emo_vec_size)
        # spk_vecs: (batch_size, spk_vec_size)
                
        # Apply attention
        o = self.base_model.encoder(units.cuda())

        # units_attn = self.attn(o, emo_vecs)
        o2 = self.ff1(emo_vecs.cuda()).unsqueeze(1).expand(-1,o.shape[1] , -1)
        # units_attn = self.attn(o, emo_vecs)  # (batch_size, time, hidden_dim)

        o3 = self.ff2(torch.cat([o, o2], dim=-1))


        d = self.base_model.decoder.generate(o3)
        
        return d


In [229]:
# load the latest decoder
# decoders = glob.glob('decoder_*.pth')
# decoders.sort()
# decoder.load_state_dict(torch.load(decoders[-1]))


In [230]:
ds = IemocapDataset(wav_files)

In [231]:
train_ds, test_ds = torch.utils.data.random_split(ds, [int(0.8*len(ds)), len(ds) - int(0.8*len(ds))])

In [232]:
# acoustic(units.cuda().unsqueeze(0), logmel.unsqueeze(0).transpose(1,2).cuda())

In [233]:
# acoustic.decoder(enc,logmel.unsqueeze(0).cuda().transpose(1,2))

In [234]:
ds[0][1]['logmel'].shape

torch.Size([128, 300])

In [235]:
ds[0][1]

{'discrite_units': tensor([ 6,  6,  6,  6,  6,  6, 96, 96, 96, 22, 22,  2,  2, 90, 90, 90, 56, 56,
         95, 97, 97, 20, 20,  1, 89, 89, 89, 89, 79, 79,  0,  0,  0,  0, 61, 61,
         53, 53, 10, 10, 12, 89, 89, 89, 89, 79, 79, 95, 85, 85, 85, 85, 85, 75,
         75, 75, 75, 75, 31, 31, 63, 66, 50, 50, 51, 51, 51, 20, 12, 58, 14, 59,
         13, 13,  8,  8,  0,  0,  0, 20, 33, 47, 28, 28, 86, 86, 18, 24, 17, 17,
         17, 18, 53, 53, 10, 10, 10, 13,  8,  8, 95, 98, 98, 98, 98,  9, 52, 52,
          1, 79, 31, 31, 66, 18, 18, 13, 13, 13, 84, 84, 95,  0,  0, 98, 98, 48,
         48, 48, 91, 91, 91, 91,  5,  5, 21, 21, 21, 21, 21, 45, 74, 74, 87, 45,
         87, 87, 87, 87,  6,  6]),
 'units': tensor([[ 0.2208, -0.0573, -0.3418,  ..., -0.1308, -0.3788,  0.4143],
         [ 0.2211, -0.0582, -0.3429,  ..., -0.1238, -0.3976,  0.4143],
         [ 0.2381, -0.0453, -0.3322,  ..., -0.2084, -0.3901,  0.4268],
         ...,
         [-0.0519, -0.2640, -0.7900,  ...,  0.4541, -0.7379,  0

In [236]:
# create collate function that will pad the sequences to the same length
def collate_fn(batch):
    wavs = [item[0][0] for item in batch]
    
    d_units, units, emo_vecs, spk_vecs, logmels = [], [], [], [], []
    for item in batch:
        d = item[1]['discrite_units']
        u = item[1]['units']
        mel = item[1]['logmel'].T

        d_units.append(d)
        units.append(u)
        emo_vecs.append(torch.tensor((item[1]['emo_vec'])))
        spk_vecs.append(item[1]['spk_vec'])

        mel  = mel[:u.size(0)*2,:]
        # print(mel.shape)
        mel = torch.nn.functional.pad(mel, (0,0,1,0))
        # print(mel.shape)

        logmels.append(mel)

    
    mels_lengths = torch.tensor([x.size(0) - 1 for x in logmels])
    units_lengths = torch.tensor([x.size(0) for x in units])

    d_units_padded = nn.utils.rnn.pad_sequence(d_units, batch_first=True, padding_value=-1)
    units_padded = nn.utils.rnn.pad_sequence(units, batch_first=True)
    logmels_padded = nn.utils.rnn.pad_sequence(logmels, batch_first=True)
    
    _,T,_ = units_padded.shape
    # pad the sequences

    wavs = nn.utils.rnn.pad_sequence(wavs, batch_first=True)

    
    return wavs, d_units_padded, units_padded, torch.stack(emo_vecs), torch.stack(spk_vecs), logmels_padded, mels_lengths, units_lengths

In [237]:
ds[1][1]['logmel'].shape

torch.Size([128, 300])

In [238]:
ds[0][1]['emo_vec'].shape

(1024,)

In [239]:
ds[0][1]['spk_vec'].shape

torch.Size([192])

In [255]:
train_dl = DataLoader(train_ds, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=10)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False, collate_fn=collate_fn, num_workers=10)

In [256]:
it = iter(train_dl)

In [257]:
b= next(it)

In [258]:
wavs,d_units, units_padded, emo_vecs, spk_vecs, logmels_padded, mels_lengths, units_lengths  = b


In [259]:
b[1].shape

torch.Size([1, 150])

In [260]:
acoustic = torch.hub.load("bshall/acoustic-model:main", "hubert_discrete", trust_repo=True).cuda()
decoder = FusionDecoder(512, acoustic).cuda()


Using cache found in /home/dcor/niskhizov/cache/hub/bshall_acoustic-model_main


In [261]:
from torch.optim import Adam
from torch.nn.functional import l1_loss

optimizer = Adam(decoder.parameters(), lr=1e-6)


In [262]:
from tqdm import tqdm_notebook,tqdm

In [263]:
decoder

FusionDecoder(
  (ff1): Linear(in_features=1024, out_features=512, bias=True)
  (ff2): Linear(in_features=1024, out_features=512, bias=True)
  (base_model): AcousticModel(
    (encoder): Encoder(
      (embedding): Embedding(101, 256)
      (prenet): PreNet(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=256, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.5, inplace=False)
          (3): Linear(in_features=256, out_features=256, bias=True)
          (4): ReLU()
          (5): Dropout(p=0.5, inplace=False)
        )
      )
      (convs): Sequential(
        (0): Conv1d(256, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        (1): ReLU()
        (2): InstanceNorm1d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ConvTranspose1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
        (4): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        (5): ReLU()
        (6): InstanceNorm1

In [314]:
torch.save(decoder.state_dict(), f"decoder_simple2_working.pth")

In [316]:
for epoch in range(1000,300000):  
    decoder.train()
    for idx,batch in tqdm(enumerate(train_dl),total=len(train_dl)):
        wavs, d_units_padded, units_padded, emo_vecs, spk_vecs, logmels_padded, mels_lengths, units_lengths  = batch
        
        optimizer.zero_grad()

        out =  decoder(d_units_padded.cuda(),emo_vecs.cuda(),logmels_padded[:, 1:, :].cuda())
        # out = decoder(d_units_padded.cuda(), emo_vecs.cuda(), spk_vecs.cuda(), logmels_padded[:, :-1, :].cuda())
        # out = acoustic(units_padded.cuda(), logmels_padded[:, :-1, :].cuda())

        # target = hifigan(out[:1,:,:].transpose(1, 2))
        loss = l1_loss(out, logmels_padded[:, 1:, :].cuda(), reduction="none")
        loss = torch.sum(loss, dim=(1, 2)) / (out.size(-1) * mels_lengths.cuda())
        loss = torch.mean(loss)
        loss.backward()

        optimizer.step()

    if epoch % 100 == 0:
        print('Epoch:', epoch, 'Batch:', idx)
        print('Loss:', loss.item())

    if epoch % 1000 == 0:
        torch.save(decoder.state_dict(), f"decoder_simple2_{epoch}.pth")

100%|██████████| 331/331 [00:19<00:00, 16.61it/s]


Epoch: 1000 Batch: 330
Loss: 0.35373571515083313


100%|██████████| 331/331 [00:19<00:00, 16.84it/s]
100%|██████████| 331/331 [00:20<00:00, 16.17it/s]
100%|██████████| 331/331 [00:20<00:00, 16.37it/s]
100%|██████████| 331/331 [00:20<00:00, 16.47it/s]
100%|██████████| 331/331 [00:19<00:00, 16.88it/s]
100%|██████████| 331/331 [00:20<00:00, 16.45it/s]
100%|██████████| 331/331 [00:19<00:00, 17.16it/s]
100%|██████████| 331/331 [00:19<00:00, 16.71it/s]
100%|██████████| 331/331 [00:21<00:00, 15.65it/s]
100%|██████████| 331/331 [00:20<00:00, 16.10it/s]
100%|██████████| 331/331 [00:19<00:00, 17.07it/s]
100%|██████████| 331/331 [00:20<00:00, 16.32it/s]
100%|██████████| 331/331 [00:20<00:00, 16.54it/s]
100%|██████████| 331/331 [00:19<00:00, 17.03it/s]
100%|██████████| 331/331 [00:19<00:00, 16.94it/s]
100%|██████████| 331/331 [00:19<00:00, 16.56it/s]
100%|██████████| 331/331 [00:19<00:00, 16.75it/s]
100%|██████████| 331/331 [00:19<00:00, 17.38it/s]
100%|██████████| 331/331 [00:20<00:00, 16.21it/s]
100%|██████████| 331/331 [00:19<00:00, 16.68it/s]


Epoch: 1100 Batch: 330
Loss: 0.3617890775203705



100%|██████████| 331/331 [00:20<00:00, 16.44it/s]
100%|██████████| 331/331 [00:19<00:00, 16.86it/s]
100%|██████████| 331/331 [00:19<00:00, 16.73it/s]
100%|██████████| 331/331 [00:19<00:00, 16.91it/s]
100%|██████████| 331/331 [00:20<00:00, 15.96it/s]
100%|██████████| 331/331 [00:20<00:00, 16.24it/s]
100%|██████████| 331/331 [00:19<00:00, 17.21it/s]
100%|██████████| 331/331 [00:19<00:00, 16.90it/s]
100%|██████████| 331/331 [00:19<00:00, 17.30it/s]
100%|██████████| 331/331 [00:19<00:00, 17.05it/s]
100%|██████████| 331/331 [00:19<00:00, 16.72it/s]
100%|██████████| 331/331 [00:19<00:00, 17.03it/s]
100%|██████████| 331/331 [00:19<00:00, 16.84it/s]
100%|██████████| 331/331 [00:21<00:00, 15.75it/s]
100%|██████████| 331/331 [00:20<00:00, 16.26it/s]
100%|██████████| 331/331 [00:19<00:00, 16.87it/s]
100%|██████████| 331/331 [00:19<00:00, 16.90it/s]
100%|██████████| 331/331 [00:20<00:00, 16.45it/s]
100%|██████████| 331/331 [00:20<00:00, 16.11it/s]
100%|██████████| 331/331 [00:20<00:00, 16.38it/s]

Epoch: 1200 Batch: 330
Loss: 0.3385249972343445



100%|██████████| 331/331 [00:19<00:00, 16.62it/s]
100%|██████████| 331/331 [00:19<00:00, 16.75it/s]
100%|██████████| 331/331 [00:20<00:00, 16.11it/s]
100%|██████████| 331/331 [00:19<00:00, 16.80it/s]
100%|██████████| 331/331 [00:19<00:00, 16.96it/s]
100%|██████████| 331/331 [00:20<00:00, 16.42it/s]
100%|██████████| 331/331 [00:20<00:00, 15.97it/s]
100%|██████████| 331/331 [00:18<00:00, 18.31it/s]
100%|██████████| 331/331 [00:19<00:00, 16.71it/s]
100%|██████████| 331/331 [00:20<00:00, 16.30it/s]
100%|██████████| 331/331 [00:20<00:00, 16.20it/s]
100%|██████████| 331/331 [00:20<00:00, 16.55it/s]
100%|██████████| 331/331 [00:19<00:00, 16.81it/s]
100%|██████████| 331/331 [00:19<00:00, 16.91it/s]
100%|██████████| 331/331 [00:19<00:00, 16.81it/s]
100%|██████████| 331/331 [00:19<00:00, 17.13it/s]
100%|██████████| 331/331 [00:19<00:00, 16.76it/s]
100%|██████████| 331/331 [00:19<00:00, 16.57it/s]
100%|██████████| 331/331 [00:20<00:00, 16.47it/s]
100%|██████████| 331/331 [00:20<00:00, 16.05it/s]

Epoch: 1300 Batch: 330
Loss: 0.33405330777168274



100%|██████████| 331/331 [00:19<00:00, 16.98it/s]
100%|██████████| 331/331 [00:20<00:00, 16.16it/s]
100%|██████████| 331/331 [00:19<00:00, 17.20it/s]
100%|██████████| 331/331 [00:20<00:00, 16.39it/s]
100%|██████████| 331/331 [00:20<00:00, 16.47it/s]
100%|██████████| 331/331 [00:19<00:00, 16.89it/s]
100%|██████████| 331/331 [00:20<00:00, 16.35it/s]
100%|██████████| 331/331 [00:18<00:00, 18.06it/s]
100%|██████████| 331/331 [00:19<00:00, 16.69it/s]
100%|██████████| 331/331 [00:19<00:00, 16.63it/s]
100%|██████████| 331/331 [00:19<00:00, 16.61it/s]
100%|██████████| 331/331 [00:20<00:00, 16.20it/s]
100%|██████████| 331/331 [00:19<00:00, 16.73it/s]
100%|██████████| 331/331 [00:19<00:00, 16.75it/s]
100%|██████████| 331/331 [00:19<00:00, 16.58it/s]
100%|██████████| 331/331 [00:20<00:00, 16.33it/s]
100%|██████████| 331/331 [00:18<00:00, 18.01it/s]
100%|██████████| 331/331 [00:19<00:00, 16.60it/s]
100%|██████████| 331/331 [00:20<00:00, 15.93it/s]
100%|██████████| 331/331 [00:20<00:00, 16.28it/s]

KeyboardInterrupt: 

In [430]:
torch.save(decoder.state_dict(), f"decoder_simple2_{epoch}.pth")

In [65]:
import plotly.express as px

In [66]:
px.imshow(out[0].detach().cpu().numpy().T)

## Inference

In [36]:
ls -lash --sort time | grep decoder

 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 23:01 decoder_20.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:50 decoder_10.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:39 decoder_0.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:18 decoder_260.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:16 decoder_250.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:13 decoder_240.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:11 decoder_230.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:08 decoder_220.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:05 decoder_210.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:02 decoder_200.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 22:00 decoder_190.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 21:57 decoder_180.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 21:54 decoder_170.pth
 80M -rw-r--r--  1 niskhizov cs_dcor  80M Feb 26 21:52 decoder_160.pth
 80M -rw-r

In [38]:
# load the latest decoder
decoders = glob.glob('decoder_*.pth')
# sort by last modification time
decoders.sort(key=os.path.getmtime)
decoder.load_state_dict(torch.load(decoders[-1]))
print(decoders[-1])
decoder.eval()


decoder_20.pth


  decoder.load_state_dict(torch.load(decoders[-1]))


Decoder(
  (attn): AttentionFusion(
    (speaker_proj): Linear(in_features=192, out_features=256, bias=True)
    (prosody_proj): Linear(in_features=1024, out_features=256, bias=True)
    (cross_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (ffn): Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (decoder_rnn): AcousticModel(
    (encoder): Encoder(
      (embedding): Embedding(101, 256)
      (prenet): PreNet(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=256, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.5, inplace=False)
          (3): Linear(in_features=256, out_features=256, bias=True)
          (4): ReLU()
          (5): Dropout(p=0.5, inplace=False)
        )
      )
      (convs): Sequential(
        (0): Conv1d(256, 512, kernel_

In [317]:
it = iter(test_dl)

In [318]:
batch = next(it)

In [320]:
hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_discrete", trust_repo=True).cuda()

Using cache found in /home/dcor/niskhizov/cache/hub/bshall_hifigan_main

`torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.



In [332]:
# with torch.no_grad():
        
#         cont_units = decoder.decoder_rnn.encoder(d_units.cuda())

#         units_attn = decoder.attn(cont_units.cuda(), emo_vecs.cuda(), spk_vecs.cuda())  # (batch_size, time, hidden_dim)
        
   
#         o = decoder.decoder_rnn.decoder.generate(units_attn)
wavs,d_units, units_padded, emo_vecs, spk_vecs, logmels_padded, mels_lengths, units_lengths  = batch

decoder = decoder.eval()

with torch.no_grad():
        

        o = decoder.generate(d_units[4].unsqueeze(0).cuda(), emo_vecs[12].unsqueeze(0).cuda())
        

In [333]:
# acoustic = torch.hub.load("bshall/acoustic-model:main", "hubert_discrete", trust_repo=True).cuda()

# with torch.no_grad():
#     o = acoustic.generate(d_units.cuda())

In [334]:
import plotly.express as px
px.imshow(o[0].detach().cpu().numpy().T)

In [335]:
idx = 0
with torch.no_grad():
    target = hifigan(o[idx,:,:].unsqueeze(0).transpose(1, 2))

In [336]:
Audio(target[0].detach().cpu().numpy(), rate=16000)

In [345]:
Audio(wavs[22],rate=16000)

In [340]:
Audio(wavs[4],rate=16000)

In [346]:
hubert_discrete = torch.hub.load("bshall/hubert:main", "hubert_discrete", trust_repo=True).cuda()


Using cache found in /home/dcor/niskhizov/cache/hub/bshall_hubert_main


In [348]:
from funasr import AutoModel


In [349]:
model_id = "iic/emotion2vec_plus_large"

sed_model = AutoModel(
    model=model_id,
    hub="ms",  # "ms" or "modelscope" for China mainland users; "hf" or "huggingface" for other overseas users
)

Downloading Model to directory: /home/dcor/niskhizov/.cache/modelscope/hub/models/iic/emotion2vec_plus_large




Detect model requirements, begin to install it: /home/dcor/niskhizov/.cache/modelscope/hub/models/iic/emotion2vec_plus_large/requirements.txt
install model requirements successfully
ckpt: /home/dcor/niskhizov/.cache/modelscope/hub/models/iic/emotion2vec_plus_large/model.pt



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



init param, map: modality_encoders.AUDIO.extra_tokens from d2v_model.modality_encoders.AUDIO.extra_tokens in ckpt
init param, map: modality_encoders.AUDIO.alibi_scale from d2v_model.modality_encoders.AUDIO.alibi_scale in ckpt
init param, map: modality_encoders.AUDIO.local_encoder.conv_layers.0.0.weight from d2v_model.modality_encoders.AUDIO.local_encoder.conv_layers.0.0.weight in ckpt
init param, map: modality_encoders.AUDIO.local_encoder.conv_layers.0.2.1.weight from d2v_model.modality_encoders.AUDIO.local_encoder.conv_layers.0.2.1.weight in ckpt
init param, map: modality_encoders.AUDIO.local_encoder.conv_layers.0.2.1.bias from d2v_model.modality_encoders.AUDIO.local_encoder.conv_layers.0.2.1.bias in ckpt
init param, map: modality_encoders.AUDIO.local_encoder.conv_layers.1.0.weight from d2v_model.modality_encoders.AUDIO.local_encoder.conv_layers.1.0.weight in ckpt
init param, map: modality_encoders.AUDIO.local_encoder.conv_layers.1.2.1.weight from d2v_model.modality_encoders.AUDIO.loc

In [381]:
def extract_embedding(wav_path):
    wav, sr = torchaudio.load(wav_path)

    # take 3 seconds of audio

    with torch.inference_mode():
        # Extract speech units
        discrite_units = hubert_discrete.units(wav.unsqueeze(0).cuda())
        
        emo_vec = torch.tensor(sed_model.generate(wav, granularity="utterance", extract_embedding=True, disable_pbar =True)[0]['feats'])

    return discrite_units, emo_vec, wav



In [None]:
neutral_wavs = glob.glob('Emotion Speech Dataset/0018/Neutral/*.wav')

In [None]:
angry_wavs = glob.glob('Emotion Speech Dataset/0018/Angry/*.wav')

In [456]:
happy_wavs = glob.glob('Emotion Speech Dataset/0018/Happy/*.wav')

In [550]:
wav_a = "/home/dcor/niskhizov/Prosody2Vec/Emotion Speech Dataset/0018/Sad/0018_001305.wav"
wav_b = 'Emotion Speech Dataset/0018/Angry/0018_000578.wav'
wav_c = 'Emotion Speech Dataset/0018/Happy/0018_000752.wav'
wav_d = "/home/dcor/niskhizov/Prosody2Vec/Emotion Speech Dataset/0018/Surprise/0018_001674.wav"

embed_a = extract_embedding(wav_a)
embed_b = extract_embedding(wav_b)
embed_c = extract_embedding(wav_c)
embed_d = extract_embedding(wav_d)


In [551]:

decoder = decoder.eval()

with torch.no_grad():
        

        o = decoder.generate(embed_a[0].unsqueeze(0).cuda(), embed_d[1].unsqueeze(0).cuda())
        

In [552]:
with torch.no_grad():
    target = hifigan(o.transpose(1, 2)).cpu()[0][0]

In [553]:
Audio(target,rate = 16000)

In [554]:
Audio(embed_a[-1],rate = 16000)

In [555]:
Audio(embed_b[-1],rate = 16000)

In [493]:
with torch.no_grad():
        

        o = decoder.generate(embed_a[0].unsqueeze(0).cuda(), embed_c[1].unsqueeze(0).cuda())
        

In [494]:
with torch.no_grad():
    target = hifigan(o.transpose(1, 2)).cpu()[0][0]

In [495]:
Audio(target,rate = 16000)


In [496]:
Audio(embed_a[-1],rate = 16000)

In [497]:
Audio(embed_b[-1],rate = 16000)