In [1]:
import torch, sys, time, os, math, tqdm, numpy, soundfile, time, pickle
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast

from model.layers import TacotronSTFT
from model.ast_model import ASTModel
from model.wavenet import WaveNet
from model.JCU_MSD import JCU_MSD
from model.losses import DAMSoftmax, PredictionLoss, generator_loss, \
                         discriminator_loss, get_mel_loss, get_fm_loss

from Dataset.dataset import Fbank_DataLoader
from scipy.io.wavfile import read

from tools import *
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)

In [2]:
train_loader = Fbank_DataLoader('Dataset/filelist_Vox1.txt')  #Batch = 32

In [3]:
len(train_loader)

2704

In [2]:
scaler = torch.cuda.amp.GradScaler()

In [7]:
class ASVframework(nn.Module):
    def __init__(self,            
                 input_tdim = 500, 
                 emb_dim = 768,
                 class_nums = 1299,
                 hidden_nums = 256
                ):
        
        super(ASVframework, self).__init__()
        # optim 
        self.grad_acc_step = 1
        self.grad_clip_thresh = 1
        
        # FT
        self.target_lenght = input_tdim
        self.stft = TacotronSTFT(filter_length=1024,
                                 hop_length=256,
                                 win_length=1024,
                                 sampling_rate=22050,
                                 mel_fmin=0, mel_fmax=8000)      
        
        # ASV   
        self.encoder = ASTModel(label_dim=1, fshape=128, tshape=2, fstride=128, tstride=1,
                       input_fdim=128, input_tdim=input_tdim, model_size='base',
                       pretrain_stage=False, load_pretrained_mdl_path='save_model/SSAST-Base-Frame-400.pth').cuda()
        self.closs = DAMSoftmax(emb_dim, class_nums).cuda() 
        self.ploss = PredictionLoss(emb_dim, hidden_nums).cuda() 
        self.asv_opt  = torch.optim.Adam([{'params': self.encoder.parameters()},
                                          # {'params': self.closs.parameters()},
                                          # {'params': self.ploss.parameters()},
                                          ], 
                                          lr=1e-4, 
                                          betas=(0.95, 0.99), weight_decay = 2e-5)
   
        self.ft_opt = torch.optim.Adam(self.ploss.parameters(), lr=1e-4, 
                                       betas=(0.9, 0.99), weight_decay = 2e-5)
        # self.ft_opt = torch.optim.Adam([{'params': self.encoder.parameters(),'lr': 1e-3},
        #                                 {'params': self.ploss.parameters()}], lr=1e-4, 
        #                                 betas=(0.9, 0.99), weight_decay = 2e-5)
        # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, step_size = test_step, gamma=lr_decay)
        
        # FRN 
        # self.gen = WaveNet(gin_channels=1, upsample_conditional_features=True).cuda()
        # self.JCUMSD = JCU_MSD().cuda()
        # self.genloss = generator_loss
        # self.disloss = discriminator_loss
        # self.specloss = get_mel_loss
        # self.fmloss = get_fm_loss
        # self.g_opt = torch.optim.Adam(self.gen.parameters(), lr=1e-4, 
        #                               betas=(0.9, 0.99), weight_decay = 2e-5)
        # self.d_opt = torch.optim.Adam(self.JCUMSD.parameters(), lr=1e-3, 
        #                               betas=(0.9, 0.99), weight_decay = 2e-5) 

    def load_wav_to_torch(self, full_path):
        """
        Loads wavdata into torch array
        """
        sampling_rate, data = read(full_path)
        return torch.from_numpy(data).float(), sampling_rate

    def get_mel(self, audio):
        audio_norm = audio / 32768.0
        audio_norm = audio_norm.unsqueeze(0)
        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = torch.squeeze(melspec, 0).T
        '''
         mel = (D, T)
        '''
        #cut and pad
        n_frames = melspec.shape[0]
        p = self.target_lenght - n_frames
        if p > 0:
            m = torch.nn.ZeroPad2d((0, 0, 0, p))
            melspec = m(melspec)
        elif p < 0:
            melspec = melspec[0:self.target_lenght, :]       
            
        return melspec
     
        
    def model_update(self, model, step, loss, optimizer):
        # Backward
        loss = (loss / self.grad_acc_step).backward()
        if step % self.grad_acc_step == 0:
            # Clipping gradients to avoid gradient explosion
            nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip_thresh)

            # Update weights
            optimizer.step()
            optimizer.zero_grad()
            
            

    def save_parameters(self, path):
        torch.save(self.state_dict(), path)
        
    def load_parameters(self, path):
        self_state = self.state_dict()
        loaded_state = torch.load(path)
        for name, param in loaded_state.items():
            origname = name
            if name not in self_state:
                name = name.replace("module.", "")
                if name not in self_state:
                    print("%s is not in the model."%origname)
                    continue
            if self_state[name].size() != loaded_state[origname].size():
                print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size()))
                continue
            self_state[name].copy_(param)   
            
            
    def train_network(self, epoch, train_loader):
        self.train()
        
        ## Update the learning rate based on the current epcoh
        # self.scheduler.step(epoch - 1)
        index, top1, top2, spk_loss, estimation_loss, \
        G_loss, D_loss = 0, 0, 0, 0, 0, 0, 0
        lr = self.asv_opt.param_groups[0]['lr']        
            
        for i, (fbank_s, fbank_o, label, a) in enumerate(train_loader):
            num = i+1
            fbank_s = fbank_s.cuda()
            fbank_o = fbank_o.cuda()
            label   = label.cuda()
            a       = a.unsqueeze(-1).cuda()
            recon_a = torch.zeros_like(a).cuda()
            
            # train ASV system
            self.asv_opt.zero_grad()
            # fp16
            with autocast():      
                spk_emb = self.encoder(fbank_s, task='ft_emb')
                cls_loss, acc = self.closs(spk_emb, a, label)           
                # loss_asv_total = cls_loss + pred_loss      
                
            # optm ASV
            scaler.scale(cls_loss).backward()
            scaler.step(self.asv_opt)
            scaler.update()
            
    
    
    
#             # optm prediction  
#             self.ft_opt.zero_grad()
#             with autocast(): 
#                 spk_emb = self.encoder(fbank_s, task='ft_emb')
#                 cls_loss, acc, _ = self.closs(spk_emb, label)  
#                 pred_loss, _     = self.ploss(spk_emb, a, label)
#                 # loss_asv_total = cls_loss + pred_loss  
            
#             scaler.scale(pred_loss).backward()
#             scaler.step(self.ft_opt)
#             scaler.update()

            
            
            
            # optm prediction  
#             self.asv_opt.zero_grad()
#             with autocast(): 
#                 spk_emb = self.encoder(fbank_s, task='ft_emb')
#                 cls_loss , acc     = self.closs(spk_emb, label)  
#                 pred_loss, acc2, _ = self.ploss(spk_emb, a)
#                 loss_asv_total = 10*cls_loss + pred_loss  
            
#             scaler.scale(loss_asv_total).backward()
#             scaler.step(self.asv_opt)
#             scaler.update()
    

        
#             # train FRN
#             # train D
#             fbank_recon = self.gen(fbank_s.transpose(2,1), g=-pred_a).transpose(2,1)
#             D_real_cond, D_real_uncond = self.JCUMSD(fbank_o, fbank_s, label)
#             D_fake_cond, D_fake_uncond = self.JCUMSD(fbank_recon.detach(), fbank_s, label)
        
#             loss_D = self.disloss(D_real_cond, D_real_uncond, D_fake_cond, D_fake_uncond)
#             # optm JCU-MSD
#             self.JCUMSD.zero_grad()
#             loss_D.backward()
#             self.d_opt.step()
                
#             # train G
#             fbank_recon = self.gen(fbank_s.transpose(2,1), g=-pred_a).transpose(2,1)
#             recon_emb = self.encoder(fbank_recon, task='ft_emb')
#             # fool asv system
#             loss_asv, _ = self.ploss(recon_emb, recon_a)
        
#             D_real_cond, D_real_uncond = self.JCUMSD(fbank_o, fbank_s, label)
#             D_fake_cond, D_fake_uncond = self.JCUMSD(fbank_recon, fbank_s, label)

#             loss_adv   = self.genloss(D_fake_cond, D_fake_uncond) 
#             loss_recon = self.specloss(fbank_recon, fbank_o)          
#             loss_FM    = self.fmloss(D_real_cond, D_real_uncond, D_fake_cond, D_fake_uncond)
#             lambda_FM  = loss_recon.item() / loss_FM.item()     
#             loss_total = loss_asv + loss_adv + loss_recon + lambda_FM*loss_FM   
#             # optm G
#             self.gen.zero_grad()
#             loss_recon.backward()
#             self.g_opt.step()      
        
            # viusalizations     
            index += len(label)
            top1 += acc
            # top2 += acc2
            spk_loss += cls_loss.detach().cpu().numpy()
            # estimation_loss += pred_loss.detach().cpu().numpy()
#             G_loss += loss_total.detach().cpu().numpy()
#             D_loss += loss_D.detach().cpu().numpy()
            
            # %.nf 保留n位小数
            sys.stderr.write(time.strftime("%m-%d %H:%M:%S") + \
            " [%2d] Lr: %5f, Training: %.2f%%, "  %(epoch, lr, 100 * (num / train_loader.__len__())) + \
            # " Spk_loss: %.5f, pred_loss: %.5f, ACC: %2.2f%%\r"     % (spk_loss/(num), estimation_loss/num, top1/index*len(label))) # + \
            " Spk_loss: %.5f,  ACC: %2.2f%%\r"      % (spk_loss/(num), top1/index*len(label)))          
            # " pred_loss: %.5f, ACC2: %2.2f%%\r"  %(estimation_loss/num, top2/index*len(label))) # + \ 
            # "G_loss: %.5f, D_loss: %.5f\r"   %( G_loss/(num), D_loss/(num)))  # + \ 
            sys.stderr.flush()
            
        sys.stdout.write("\n") 
        # return spk_loss/num, estimation_loss/num, G_loss/num, D_loss/num, lr, top1/index*len(label)
        
       
    
#     def eval_prediction_a(self, path, a):
#         self.eval()
#         file_list = os.listdir(path)
#         mse_err = []
#         with torch.no_grad():
#             for file in file_list:
#                 audio, sr = self.load_wav_to_torch(os.path.join(path, file))
#                 mel = self.get_mel(audio).unsqueeze(0).cuda()
#                 spk_emb = self.encoder(mel, task='ft_emb')
#                 a = torch.tensor([a]).unsqueeze(0).cuda()
#                 _, pred_alpha = self.ploss(spk_emb, a)
#                 loss = F.l1_loss(pred_alpha, a)    
#                 mse_err.append(loss)
                
#         return sum(mse_err)/len(file_list)
    
    
    
    def eval_prediction(self, path, a):
        self.eval()
        file_list = os.listdir(path)
        mse_err = []
        with torch.no_grad():
            for file in file_list:
                try:
                    audio, sr = self.load_wav_to_torch(os.path.join(path, file))
                    mel = self.get_mel(audio).unsqueeze(0).cuda()
                    spk_emb = self.encoder(mel, task='ft_emb')
                
                    if len(file.split('_'))>1:
                        a = torch.tensor([float(file.split('_')[1])]).unsqueeze(0).cuda()
                        _, pred_alpha = self.ploss(spk_emb, a)
                    else:
                        a = torch.tensor([a]).unsqueeze(0).cuda()
                        _, pred_alpha = self.ploss(spk_emb, a)
                    
                    loss = F.l1_loss(pred_alpha, a)    
                    mse_err.append(loss)
                    
                except:
                    print(file)
        
        
        if len(file_list)>0:
            return sum(mse_err)/len(file_list)
        
        else:
            print(path)
            
            
    def eval_network(self, eval_list, eval_path):
        self.eval()
        files = []
        embeddings = {}
        lines = open(eval_list).read().splitlines()
        for line in lines:
            files.append(line.split()[1])
            files.append(line.split()[2])
        setfiles = list(set(files))
        setfiles.sort()

        for idx, file in tqdm.tqdm(enumerate(setfiles), total = len(setfiles)):
            audio, _  = self.load_wav_to_torch(os.path.join(eval_path, file))
            # Full utterance
            # data_1 = torch.FloatTensor(numpy.stack([audio],axis=0)).cuda()
            mel_1 = self.get_mel(audio).unsqueeze(0).cuda()
            # Spliited utterance matrix
            # max_audio = 300 * 160 + 240
            # if audio.shape[0] <= max_audio:
            #     shortage = max_audio - audio.shape[0]
            #     audio = numpy.pad(audio, (0, shortage), 'wrap')
            # feats = []
            # startframe = numpy.linspace(0, audio.shape[0]-max_audio, num=5)
            # for asf in startframe:
            #     feats.append(audio[int(asf):int(asf)+max_audio])
            # feats = numpy.stack(feats, axis = 0).astype(numpy.float)
            # # data_2 = torch.FloatTensor(feats).cuda()
            # mel_2 = self.get_mel(feats).unsqueeze(0).cuda()
            # Speaker embeddings
            with torch.no_grad():
                embedding_1 = self.encoder.forward(mel_1, task='ft_emb')
                embedding_1 = F.normalize(embedding_1, p=2, dim=1)
                embedding_2 = self.encoder.forward(mel_1, task='ft_emb')
                embedding_2 = F.normalize(embedding_2, p=2, dim=1)
            embeddings[file] = [embedding_1, embedding_2]
        scores, labels  = [], []

        for line in lines:
            embedding_11, embedding_12 = embeddings[line.split()[1]]
            embedding_21, embedding_22 = embeddings[line.split()[2]]
            # Compute the scores
            score_1 = torch.mean(torch.matmul(embedding_11, embedding_21.T)) # higher is positive
            score_2 = torch.mean(torch.matmul(embedding_12, embedding_22.T))
            score = (score_1 + score_2) / 2
            score = score.detach().cpu().numpy()
            scores.append(score)
            labels.append(int(line.split()[0]))

        # Coumpute EER and minDCF
        EER = tuneThresholdfromScore(scores, labels, [1, 0.1])[1]
        fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
        minDCF, _ = ComputeMinDcf(fnrs, fprs, thresholds, 0.05, 1, 1)

        return EER, minDCF

    
    
#     def eval_network(self, test_loader): 
#         self.eval()
#         with torch.no_grad():
#             for i, (fbank_s, _, label, a) in enumerate(test_loader):
#                 num = i+1
#                 fbank_s = fbank_s.cuda()
#                 label   = label.cuda()
#                 a       = a.unsqueeze(-1).cuda()

#                 spk_emb = self.encoder(fbank_s, task='ft_emb')
#                 cls_loss, acc = self.closs(spk_emb, a, label)      
#                 pred_loss, _ = self.ploss(spk_emb.detach(), a, mode='eval')         

#                 index += len(label)
#                 top1 += acc
#                 spk_loss += cls_loss.detach().cpu().numpy()
#                 estimation_loss += pred_loss.detach().cpu().numpy()
                
#                 sys.stderr.write(time.strftime("%m-%d %H:%M:%S") + \
#                 " [%2d] Lr: %5f, Training: %.2f%%, "  %(epoch, lr, 100 * (num / train_loader.__len__())) + \
#                 " Spk_loss: %.5f, pred_loss: %.5f, ACC: %2.2f%%\r"     % (spk_loss/(num), estimation_loss/num, top1/index*len(label))) # + \     
            
#                 sys.stderr.flush()       
#             sys.stdout.write("\n") 
            
        

In [8]:
framework = ASVframework()

now load a SSL pretrained models from save_model/SSAST-Base-Frame-400.pth
pretraining patch split stride: frequency=128, time=2
pretraining patch shape: frequency=128, time=2
pretraining patch array dimension: frequency=1, time=512
pretraining number of patches=512
fine-tuning patch split stride: frequncey=128, time=1
fine-tuning number of patches=499


In [9]:
framework.load_parameters('save_model/model_Vox1_k128_damsoft0030.model')

In [10]:
EER, minDCF = framework.eval_network('Dataset/filelist_Vox1_test.txt', '/root/autodl-tmp/vox1_test/wav')

100%|██████████| 4556/4556 [47:29<00:00,  1.60it/s]


In [11]:
EER

36.8359454508447

In [12]:
minDCF

0.8783329940972939

In [None]:
for i in range(100):
    # i = 0
    framework.train_network(i, train_loader)
    if i % 3 == 0:
        framework.save_parameters('save_model' + "/model_Vox1_k128_damsoft%04d.model"%(i+21))
    
    # i += 1

01-03 22:59:46 [ 0] Lr: 0.000100, Training: 100.00%,  Spk_loss: 36.75592,  ACC: 1.36%




01-04 00:36:19 [ 1] Lr: 0.000100, Training: 100.00%,  Spk_loss: 29.57802,  ACC: 9.75%




01-04 02:12:32 [ 2] Lr: 0.000100, Training: 100.00%,  Spk_loss: 21.78222,  ACC: 23.45%




01-04 03:48:32 [ 3] Lr: 0.000100, Training: 100.00%,  Spk_loss: 15.47819,  ACC: 39.02%




01-04 05:25:16 [ 4] Lr: 0.000100, Training: 100.00%,  Spk_loss: 11.00468,  ACC: 52.96%




01-04 07:02:15 [ 5] Lr: 0.000100, Training: 100.00%,  Spk_loss: 7.89179,  ACC: 63.99%




01-04 08:38:19 [ 6] Lr: 0.000100, Training: 100.00%,  Spk_loss: 5.75701,  ACC: 72.18%




01-04 10:14:17 [ 7] Lr: 0.000100, Training: 100.00%,  Spk_loss: 4.27659,  ACC: 78.18%




01-04 11:45:28 [ 8] Lr: 0.000100, Training: 100.00%,  Spk_loss: 3.27719,  ACC: 82.40%




01-04 13:13:51 [ 9] Lr: 0.000100, Training: 100.00%,  Spk_loss: 2.54629,  ACC: 85.68%




01-04 14:42:08 [10] Lr: 0.000100, Training: 100.00%,  Spk_loss: 2.06091,  ACC: 87.94%




01-04 16:00:19 [11] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.69758,  ACC: 89.71%




01-04 17:21:08 [12] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.42103,  ACC: 91.06%




01-04 18:41:27 [13] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.21485,  ACC: 92.09%




01-04 20:02:03 [14] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.05968,  ACC: 92.85%




01-04 21:22:11 [15] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.91572,  ACC: 93.69%




01-04 22:42:20 [16] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.81393,  ACC: 94.20%




01-04 23:57:54 [17] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.72627,  ACC: 94.65%




01-05 01:17:48 [18] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.65320,  ACC: 95.02%




01-05 02:36:52 [19] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.58648,  ACC: 95.44%




01-05 03:56:57 [20] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.53774,  ACC: 95.71%




01-05 05:16:50 [21] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.49248,  ACC: 95.95%




01-05 05:37:15 [22] Lr: 0.000100, Training: 25.41%,  Spk_loss: 0.46330,  ACC: 96.18%

In [None]:
for i in range(100):
    # i = 0
    framework.train_network(i, train_loader)
    if i % 3 == 0:
        framework.save_parameters('save_model' + "/model_Vox1_k128_damsoft%04d.model"%(i+21))
    
    # i += 1

01-05 09:02:15 [ 0] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.44813,  ACC: 96.22%




01-05 10:22:28 [ 1] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.40733,  ACC: 96.48%




01-05 11:43:24 [ 2] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.38142,  ACC: 96.69%




01-05 13:04:03 [ 3] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.35788,  ACC: 96.87%




01-05 14:23:06 [ 4] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.33918,  ACC: 97.04%




01-05 15:37:45 [ 5] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.30923,  ACC: 97.30%




01-05 16:52:25 [ 6] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.29039,  ACC: 97.43%




01-05 18:07:25 [ 7] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.27598,  ACC: 97.56%




01-05 19:17:52 [ 8] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.26592,  ACC: 97.65%




01-05 20:27:10 [ 9] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.26302,  ACC: 97.71%




01-05 20:41:07 [10] Lr: 0.000100, Training: 19.79%,  Spk_loss: 0.26099,  ACC: 97.73%

In [None]:
for i in range(100):
    # i = 0
    framework.train_network(i, train_loader)
    if i % 3 == 0:
        framework.save_parameters('save_model' + "/model_Vox1_k128_aamsoft_impact_ref%04d.model"%(i))
    
    # i += 1

12-18 23:18:03 [ 0] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.23555, pred_loss: 4.82882, ACC: 97.80%




12-19 00:26:55 [ 1] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.23555, pred_loss: 3.93283, ACC: 97.80%




12-19 01:33:31 [ 2] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.23555, pred_loss: 3.37557, ACC: 97.80%




12-19 07:18:04 [ 7] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.23555, pred_loss: 2.10564, ACC: 97.80%




12-19 08:26:17 [ 8] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.23555, pred_loss: 1.97830, ACC: 97.80%




12-19 09:33:09 [ 9] Lr: 0.000100, Training: 99.86%,  Spk_loss: 0.23580, pred_loss: 1.86810, ACC: 97.79%

In [None]:
1

In [None]:
for i in range(100):
    # i = 0
    framework.train_network(i, train_loader)
    if i % 3 == 0:
        framework.save_parameters('save_model' + "/model_Vox1_k128_jft_aamsoft%04d.model"%(i))
    
    # i += 1

12-17 20:31:47 [ 0] Lr: 0.000100, Training: 100.00%,  Spk_loss: 36.70421, pred_loss: 0.66339, ACC: 1.09%




12-17 21:38:46 [ 1] Lr: 0.000100, Training: 100.00%,  Spk_loss: 29.73396, pred_loss: 0.57681, ACC: 8.57%




12-17 22:45:03 [ 2] Lr: 0.000100, Training: 100.00%,  Spk_loss: 21.86984, pred_loss: 0.59950, ACC: 21.87%




12-17 23:51:26 [ 3] Lr: 0.000100, Training: 100.00%,  Spk_loss: 15.24692, pred_loss: 0.57982, ACC: 37.70%




12-18 00:58:16 [ 4] Lr: 0.000100, Training: 100.00%,  Spk_loss: 10.47455, pred_loss: 0.56114, ACC: 52.32%




12-18 02:06:19 [ 5] Lr: 0.000100, Training: 100.00%,  Spk_loss: 7.22164, pred_loss: 0.54357, ACC: 64.15%




12-18 03:12:44 [ 6] Lr: 0.000100, Training: 100.00%,  Spk_loss: 5.07079, pred_loss: 0.52807, ACC: 72.88%




IOPub message rate exceeded.0100, Training: 28.81%,  Spk_loss: 4.02351, pred_loss: 0.51706, ACC: 77.39%
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

12-18 07:39:47 [10] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.52531, pred_loss: 0.47363, ACC: 89.67%




12-18 08:46:19 [11] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.21968, pred_loss: 0.46459, ACC: 91.40%




12-18 09:52:48 [12] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.00200, pred_loss: 0.45941, ACC: 92.67%




12-18 10:59:07 [13] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.85502, pred_loss: 0.45211, ACC: 93.61%




12-18 12:05:39 [14] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.74324, pred_loss: 0.44568, ACC: 94.26%




12-18 13:12:17 [15] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.65091, pred_loss: 0.44202, ACC: 94.90%




12-18 13:39:32 [16] Lr: 0.000100, Training: 41.05%,  Spk_loss: 0.59690, pred_loss: 0.43926, ACC: 95.27%

In [None]:
1

In [None]:
for i in range(100):
    # i = 0
    framework.train_network(i, train_loader)
    if i % 3 == 0:
        framework.save_parameters('save_model' + "/model_Vox1_k128_jft_aamsoft%04d.model"%(i))
    
    # i += 1

12-17 18:21:34 [ 0] Lr: 0.000100, Training: 100.00%,  Spk_loss: 14.16132, pred_loss: 0.29113, ACC: 2.85%




12-17 19:22:11 [ 1] Lr: 0.000100, Training: 91.16%,  Spk_loss: 6.78183, pred_loss: 0.12628, ACC: 9.20%

In [None]:
for i in range(100):
    # i = 0
    framework.train_network(i, train_loader)
    if i % 3 == 0:
        framework.save_parameters('save_model' + "/model_Vox1_k128_aamsoft%04d.model"%(i))
    
    # i += 1

12-13 23:50:44 [ 0] Lr: 0.000100, Training: 100.00%,  Spk_loss: 35.90853, ACC: 1.75%




12-14 00:57:51 [ 1] Lr: 0.000100, Training: 100.00%,  Spk_loss: 28.21510, ACC: 11.27%




12-14 02:04:17 [ 2] Lr: 0.000100, Training: 100.00%,  Spk_loss: 20.47209, ACC: 25.51%




12-14 03:10:49 [ 3] Lr: 0.000100, Training: 100.00%,  Spk_loss: 14.30644, ACC: 41.32%




12-14 04:17:16 [ 4] Lr: 0.000100, Training: 100.00%,  Spk_loss: 9.91363, ACC: 55.42%




12-14 05:23:40 [ 5] Lr: 0.000100, Training: 100.00%,  Spk_loss: 6.92692, ACC: 66.48%




12-14 06:30:11 [ 6] Lr: 0.000100, Training: 100.00%,  Spk_loss: 4.94606, ACC: 74.47%




12-14 07:36:29 [ 7] Lr: 0.000100, Training: 100.00%,  Spk_loss: 3.62408, ACC: 80.20%




12-14 08:42:48 [ 8] Lr: 0.000100, Training: 100.00%,  Spk_loss: 2.76512, ACC: 84.13%




12-14 09:49:13 [ 9] Lr: 0.000100, Training: 100.00%,  Spk_loss: 2.14160, ACC: 87.14%




12-14 10:55:29 [10] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.71462, ACC: 89.24%




12-14 12:01:47 [11] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.42189, ACC: 90.68%




12-14 12:01:59 [12] Lr: 0.000100, Training: 0.15%,  Spk_loss: 1.38264, ACC: 91.32%

In [None]:
for i in range(100):
    # i = 0
    framework.train_network(i, train_loader)
    if i % 4 == 0:
        framework.save_parameters('save_model' + "/model_Vox1_k128_aamsoft%04d.model"%(i))
    
    # i += 1

12-14 22:28:51 [ 0] Lr: 0.000100, Training: 100.00%,  Spk_loss: 1.01007, ACC: 92.88%




12-14 23:36:16 [ 1] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.87538, ACC: 93.55%




12-15 00:43:27 [ 2] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.76428, ACC: 94.20%




12-15 01:49:45 [ 3] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.65908, ACC: 94.81%




12-15 02:56:11 [ 4] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.59127, ACC: 95.20%




12-15 04:03:00 [ 5] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.50913, ACC: 95.68%




12-15 05:09:33 [ 6] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.46425, ACC: 95.88%




12-15 06:15:48 [ 7] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.41565, ACC: 96.19%




12-15 07:22:41 [ 8] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.37885, ACC: 96.47%




12-15 08:29:31 [ 9] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.34227, ACC: 96.73%




12-15 09:35:52 [10] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.31826, ACC: 96.93%




12-15 10:42:07 [11] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.28970, ACC: 97.23%




12-15 11:49:03 [12] Lr: 0.000100, Training: 100.00%,  Spk_loss: 0.28362, ACC: 97.32%




12-15 11:49:26 [13] Lr: 0.000100, Training: 0.44%,  Spk_loss: 0.29220, ACC: 97.40%

In [8]:
# opt pred
for i in range(100):
    # i = 0
    framework.train_network(i, train_loader)
    if i % 4 == 0:
        framework.save_parameters('save_model' + "/model_Vox1_k128_aamsoft%04d.model"%(i))
    
    # i += 1

12-15 18:10:55 [ 0] Lr: 0.000100, Training: 100.00%,  pred_loss: 3.89200, ACC: 97.80%




12-15 19:17:32 [ 1] Lr: 0.000100, Training: 100.00%,  pred_loss: 3.07436, ACC: 97.80%




12-15 20:24:06 [ 2] Lr: 0.000100, Training: 100.00%,  pred_loss: 2.77466, ACC: 97.80%




12-15 20:38:04 [ 3] Lr: 0.000100, Training: 20.15%,  pred_loss: 2.62424, ACC: 97.40%

KeyboardInterrupt: 

#### DAMSoftmax

In [283]:
s = 64
m = 0.5
c = 2
k = 17
eps = 1e-6
in_features = 768
out_features = 999

In [264]:
weight = nn.Parameter(torch.FloatTensor(k, in_features, out_features))
nn.init.xavier_uniform_(weight)
loss_fn = nn.CrossEntropyLoss()
factor = torch.tensor([8.1, 3.9, -8.1]).float().unsqueeze(-1)

In [265]:
# func_a = (m - 0.1*torch.pow(c, (factor/12)))
func_a = m
threshold = math.pi - func_a

In [266]:
func_a

0.35

In [267]:
threshold

2.791592653589793

In [268]:
inputs = torch.randn([3, 768]).float()

In [269]:
label = torch.empty(3, dtype = torch.long).random_(100)

In [270]:
label

tensor([45, 60, 46])

In [271]:
cos_theta = torch.bmm(
            F.normalize(inputs).unsqueeze(0).expand(k, *inputs.shape),  # k*b*f
            F.normalize(weight, dim=1),  # normalize in_features dim   # k*f*c
        )  # k*b*c

In [272]:
cos_theta = torch.max(cos_theta, dim=0)[0]  # b*c

In [273]:
cos_theta

tensor([[0.0897, 0.0811, 0.0832,  ..., 0.0789, 0.1000, 0.0571],
        [0.0481, 0.0675, 0.0630,  ..., 0.0747, 0.0484, 0.0512],
        [0.1124, 0.1120, 0.0558,  ..., 0.0740, 0.1463, 0.0433]],
       grad_fn=<MaxBackward0>)

In [274]:
theta = torch.acos(torch.clamp(cos_theta, -1.0 + eps, 1.0 - eps))

In [275]:
theta

tensor([[1.4809, 1.4896, 1.4875,  ..., 1.4918, 1.4706, 1.5137],
        [1.5227, 1.5032, 1.5078,  ..., 1.4960, 1.5223, 1.5195],
        [1.4582, 1.4586, 1.5150,  ..., 1.4967, 1.4240, 1.5274]],
       grad_fn=<AcosBackward0>)

In [276]:
one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, label.unsqueeze(0).T.long(), 1)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [277]:
selected = torch.where(theta > threshold, torch.zeros_like(one_hot), one_hot)

In [278]:
selected

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [279]:
logits = torch.cos(torch.where(selected.bool(), theta + func_a, theta))

In [280]:
logits

tensor([[0.0897, 0.0811, 0.0832,  ..., 0.0789, 0.1000, 0.0571],
        [0.0481, 0.0675, 0.0630,  ..., 0.0747, 0.0484, 0.0512],
        [0.1124, 0.1120, 0.0558,  ..., 0.0740, 0.1463, 0.0433]],
       grad_fn=<CosBackward0>)

In [281]:
logits *= s

In [282]:
logits[0][45]

tensor(-17.4749, grad_fn=<SelectBackward0>)

In [159]:

loss = loss_fn(logits, label)
        
# prec1 = accuracy(logits.detach(), label.detach(), topk=(1,))[0]

    

In [160]:
loss

tensor(46.8266, grad_fn=<NllLossBackward0>)

### Eval

In [84]:
audio1, sr1 = framework.load_wav_to_torch('/root/autodl-tmp/Dataset/ASV/Vox1/train_scaled/id10044/tEiC2TFawPM/00006.wav')
audio2, sr2 = framework.load_wav_to_torch('/root/autodl-tmp/Dataset/ASV/Vox1/train_scaled/id10044/tEiC2TFawPM/00005_11.591_.wav')

In [85]:
mel1 = framework.get_mel(audio1).unsqueeze(0).cuda()
mel2 = framework.get_mel(audio2).unsqueeze(0).cuda()

In [86]:
spk_emb1 = framework.encoder(mel1, task='ft_emb')
spk_emb2 = framework.encoder(mel2, task='ft_emb')

In [87]:
simi = nn.CosineSimilarity(dim=-1)

In [88]:
cos_sim = simi(spk_emb1, spk_emb2)

In [89]:
cos_sim

tensor([0.0779], device='cuda:0', grad_fn=<SumBackward1>)

In [12]:
a1 = torch.tensor([-11.512]).float().unsqueeze(0).cuda()
a2 = torch.tensor([0.0]).float().unsqueeze(0).cuda()

In [13]:
pred_loss1, pred_a1 = framework.ploss(spk_emb1, a1)     
pred_loss2, pred_a2 = framework.ploss(spk_emb2, a2)     

In [14]:
print(pred_loss1, pred_a1)

tensor(1.6396, device='cuda:0', grad_fn=<SmoothL1LossBackward0>) tensor([[-9.3724]], device='cuda:0', grad_fn=<MulBackward0>)


In [15]:
print(pred_loss2, pred_a2 )

tensor(3.1937, device='cuda:0', grad_fn=<SmoothL1LossBackward0>) tensor([[3.6937]], device='cuda:0', grad_fn=<MulBackward0>)


In [90]:
spk_id1=torch.tensor([44]).long().cuda()
spk_id2=torch.tensor([44]).long().cuda()

In [91]:
closs1, acc1, logit1 = framework.closs(spk_emb1, spk_id1)
theta_n1 = framework.closs.forward_(spk_emb1)
closs2, acc2, logit2 = framework.closs(spk_emb2, spk_id2)
theta_n2 = framework.closs.forward_(spk_emb2)

In [93]:
acc1

tensor([100.], device='cuda:0')

In [37]:
logit1[:,27]

tensor([6.4253], device='cuda:0', grad_fn=<SelectBackward0>)

In [13]:
theta_n1[:, 6]

tensor([-1.1842], device='cuda:0', grad_fn=<SelectBackward0>)

In [14]:
theta_n2[:,6]

tensor([-1.1029], device='cuda:0', grad_fn=<SelectBackward0>)

In [23]:
logit[:,38]

tensor([38.4305], device='cuda:0', grad_fn=<SelectBackward0>)

In [24]:
label = F.softmax(logit, dim=-1)

In [26]:
label[:, 38]

tensor([1.0000], device='cuda:0', grad_fn=<SelectBackward0>)

In [22]:
theta_n

tensor([[-0.1277, -0.2082, -0.1075, -0.1524, -0.1981, -1.1253, -0.1961, -0.1244,
         -0.1468, -0.1429, -0.2088, -0.2758, -0.2269, -0.1990, -0.2318, -0.1773,
         -0.2411, -0.1388, -0.1311, -0.1614, -0.1783, -0.1951, -0.1518, -0.1650,
         -0.1961, -0.1863, -0.1781, -0.2543, -0.1587, -0.1361, -0.1094, -0.1182,
         -0.2540, -0.2285, -0.1894, -0.1634, -0.1488, -0.2225, -0.2157, -0.2639,
         -0.1566, -0.1488, -0.1758, -0.1054, -0.2924, -0.2266, -0.3419, -0.1714,
         -0.2621, -0.1709, -0.2408, -0.1810, -0.1785, -0.1999, -0.1231, -0.2149,
         -0.1401, -0.1771, -0.1794, -0.2929, -0.1307, -0.1406, -0.1548, -0.2056,
         -0.1716, -0.1192, -0.1645, -0.1361, -0.1850, -0.1706, -0.1471, -0.2651,
         -0.1673, -0.1561, -0.2333, -0.1830, -0.2413, -0.1460, -0.1618, -0.1394,
         -0.2380, -0.2417, -0.1471, -0.2011, -0.3049, -0.2327, -0.1736, -0.1780,
         -0.2123, -0.1151, -0.1630, -0.1581, -0.1308, -0.2459, -0.2080, -0.2723,
         -0.1845, -0.1451, -

In [None]:
filepath = 'F:/datasets/project3/AISHELL-3/scaled'
spk_list = os.listdir(filepath)
loss = []
for spk in spk_list:
    pred_loss = framework.eval_prediction(os.path.join(filepath, spk), spk)
    loss.append(pred_loss)

loss = sum(loss)
loss /= len(spk_list)

In [None]:
loss

In [None]:
loss

In [None]:
sum(loss) / 60

tensor(2.2414, device='cuda:0')

In [None]:
pred_loss = framework.eval_prediction('F:/datasets/project3/AISHELL-1/unseen_scaled/S0736/')

In [None]:
pred_loss

tensor(2.0710, device='cuda:0')

In [None]:
filepath = 'F:/datasets/project3/AISHELL-1/unseen_scaled/S0736/'

def eval_prediction(path):
    framework.eval()
    file_list = os.listdir(path)
    mse_err = []
    for file in file_list:
        audio, sr = framework.load_wav_to_torch(os.path.join(path, file))
        mel = framework.get_mel(audio).unsqueeze(0).cuda()
        spk_emb = framework.encoder(mel, task='ft_emb')
        a = torch.tensor(float(file.split('_')[1])).cuda()
        loss, pred_alpha = framework.ploss(spk_emb, a)
        mse_err.append(pow(loss, 0.5))
     
    return sum(mse_err)/len(file_list)

In [None]:
loss = eval_prediction(filepath)

  return F.mse_loss(input, target, reduction=self.reduction)


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.99 GiB total capacity; 22.55 GiB already allocated; 0 bytes free; 23.00 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
audio, sr = framework.load_wav_to_torch('F:/datasets/project3/AISHELL-1/unseen_scaled/S0736/BAC009S0736W0122_-6.952_.wav')
mel = framework.get_mel(audio).unsqueeze(0).cuda()

In [None]:
file = 'BAC009S0736W0122_-6.952_.wav'

In [None]:
a = torch.tensor(float(file.split('_')[1])).cuda()

In [None]:
a.type()

'torch.cuda.FloatTensor'

In [None]:
float(file.split('_')[1])

-6.952

In [None]:
framework.eval()
spk_emb = framework.encoder(mel, task='ft_emb')

In [None]:
a = torch.tensor(-7).float().cuda()

In [None]:
a

tensor(-7., device='cuda:0')

In [None]:
# spk_id = framework.closs(spk_emb)
loss, pred_alpha = framework.ploss(spk_emb, a)

In [None]:
loss

tensor(0.0955, device='cuda:0', grad_fn=<MseLossBackward0>)

In [None]:
pred_alpha

tensor([[-7.3091]], device='cuda:0', grad_fn=<MulBackward0>)

In [6]:
framework.save_parameters('save_model' + "/model_%04d.model"%13)

In [None]:
while(1):
	## Training for one epoch
	loss, lr, acc = s.train_network(epoch = epoch, loader = trainLoader)

	## Evaluation every [test_step] epochs
	if epoch % args.test_step == 0:
		s.save_parameters(args.model_save_path + "/model_%04d.model"%epoch)
		EERs.append(s.eval_network(eval_list = args.eval_list, eval_path = args.eval_path)[0])
		print(time.strftime("%Y-%m-%d %H:%M:%S"), "%d epoch, ACC %2.2f%%, EER %2.2f%%, bestEER %2.2f%%"%(epoch, acc, EERs[-1], min(EERs)))
		score_file.write("%d epoch, LR %f, LOSS %f, ACC %2.2f%%, EER %2.2f%%, bestEER %2.2f%%\n"%(epoch, lr, loss, acc, EERs[-1], min(EERs)))
		score_file.flush()

	if epoch >= args.max_epoch:
		quit()

	epoch += 1

In [None]:
epoch=40
lr=1e-04
num=30
top1=90
index=30
spk_loss=26
estimation_loss=48
G_loss=32
D_loss=29

In [None]:
sys.stderr.write(time.strftime("%m-%d %H:%M:%S") + \
" [%2d] Lr: %5f, Training: %.2f%%, "    %(epoch, lr, 100 * (num / 7481)) + \
" Spk_loss: %.5f, ACC: %2.2f%%, "      %(spk_loss/(num), top1/index*32) + \
" pred_loss: %.5f, G_loss: %.5f, D_loss: %.5f\r"   %(estimation_loss/num, G_loss/(num), D_loss/(num)))
sys.stderr.flush()

12-03 23:36:11 [40] Lr: 0.000100, Training: 0.40%,  Spk_loss: 0.86667, ACC: 96.00%,  pred_loss: 1.60000, G_loss: 1.06667, D_loss: 0.96667