# Импорт

In [31]:
import json
import os
import torch
from tqdm import tqdm
import scipy.io.wavfile as wav

from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter
from boltons.fileutils import iter_find_files
import numpy as np

In [32]:
import transformers
import json
from torch.utils.data import Dataset, DataLoader, Sampler
import os

In [33]:
from models import ConvFeatureEncoder, SegmentsRepr, SegmentsEncoder, NegativeSampler, SegmentPredictor, FinModel
from utils import ConstrativeLoss, sample_negatives
from collections import OrderedDict
import shutil

In [34]:
from model_transformers import SegmentTransformer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
torch.multiprocessing.set_sharing_strategy('file_system')
from tqdm import tqdm
import numpy as np
import os
from os.path import join, basename
from boltons.fileutils import iter_find_files
import soundfile as sf
import librosa
import pickle
from multiprocessing import Pool
import random
import torchaudio
import math
from torchaudio.datasets import LIBRISPEECH

In [36]:
torch.manual_seed(0)

<torch._C.Generator at 0x7fd0a43590f0>

In [37]:
transformers.__version__

'4.14.0'

In [38]:
# %pip list

In [39]:
# Данный класс основан на https://github.com/felixkreuk/UnsupSeg/blob/master/dataloader.py

class WavPhnDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.data = list(iter_find_files(self.path, "*.wav"))
        super(WavPhnDataset, self).__init__()

    @staticmethod
    def get_datasets(path):
        raise NotImplementedError

    def process_file(self, wav_path):
        phn_path = wav_path.replace(".wav", ".txt")
        filetext_id = phn_path.split('/')[-1]

        # load audio
        audio, sr = torchaudio.load(wav_path)
        audio = audio[0]
        audio_len = len(audio)

        # load labels -- segmentation and phonemes
        with open(phn_path, "r") as f:
            lines = [i.strip() for i in f.readlines()]
            times = torch.FloatTensor([eval(i.split()[0]) for i in lines])[:-1]
            phonemes = torch.FloatTensor([eval(i.split()[0])*16000 for i in lines])[:-1]

        return audio, times, phonemes, wav_path, filetext_id, phn_path

    def __getitem__(self, idx):
        signal, seg, phonemes, fname, filetext_id, segment_file = self.process_file(self.data[idx])
        
        return {'audio_file':fname, 
                'segment_file':segment_file, 
                'id':filetext_id, 
                'sample': signal, 
                'length': len(signal), 'boundaries': seg}
        
    def __len__(self):
        return len(self.data)

In [40]:
def collate_fn(samples):
    
    max_length = max([sample['length'] for sample in samples])
    boundaries = [sample['boundaries'] for sample in samples]
    samples1 = []
    lengths = []
    samplings = []
    attentions = []
    ids = []
    audio_files = []
    segment_files = []
    for sample in samples:
        to_add_l = max_length-sample['length']
        sample1 = list(sample['sample'])+[0]*to_add_l
        samples1.append(torch.Tensor(sample1).unsqueeze(0))
        lengths.append(sample['length'])
        ids.append(sample['id'])
        audio_files.append(sample['audio_file'])
        segment_files.append(sample['segment_file'])
        att_norm = torch.ones(size = (1, sample['length']))
        att_add = torch.zeros(size = (1, to_add_l))
        att = torch.cat([att_norm, att_add], dim = -1)
        attentions.append(att)
        
    batch = torch.cat(samples1)
    lengths = torch.Tensor(lengths)
    attention_mask = torch.cat(attentions, dim = 0)
    
    return dict(batch=batch, lengths=lengths, 
                attention_mask=attention_mask, 
                boundaries=boundaries, ids=ids, 
                audio_files=audio_files, 
                segment_files=segment_files)

# Загрузка данных

In [13]:
train_dataset = WavPhnDataset('Buckeye_fin/Train')
val_dataset = WavPhnDataset('Buckeye_fin/Valid')
test_dataset = WavPhnDataset('Buckeye_fin/Test')

Buckeye_fin/Train
38656
Buckeye_fin/Valid
4267
Buckeye_fin/Test
4015


In [14]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)

In [15]:
path_to_save = 'save_results_path_buckeye'
path_results = os.path.join(path_to_save, 'Results')

In [16]:
os.makedirs(path_to_save, exist_ok=True) 
os.makedirs(path_results, exist_ok=True) 

In [18]:
import warnings
warnings.filterwarnings("ignore")

In [20]:
from transformers_f.src.transformers.activations import ACT2FN
from transformers_f.src.transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureExtractor, Wav2Vec2Model

from modeling_segmentation import (Wav2Vec2ModelForSegmentation, SegmentsRepr)

In [21]:
segment_paths = ['buckeye_r_val.ckpt']
thres = [0.05]

In [23]:
for model_path, thre in zip(segment_paths, thres):
    
    wav2vec_segm = Wav2Vec2ModelForSegmentation.from_pretrained("facebook/wav2vec2-base-960h")
    wav2vec_segm.segment_mean = SegmentsRepr(thres = thre)
    
    checkpoint = torch.load(model_path)
#     checkpoint = torch.load(model_path, map_location='cpu')
    
    state_dicts = OrderedDict()
    for key, value in checkpoint['state_dict'].items():
        state_dicts[key.replace('wav2vec_segm.', '')] = value
    wav2vec_segm.load_state_dict(state_dicts)
    
    wav2vec_segm.eval()
    wav2vec_segm=wav2vec_segm.to('cuda')
    
    for batch in tqdm(train_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'train'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))

    for batch in tqdm(val_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'valid'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))
                
    for batch in tqdm(test_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'test'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1596.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=377667514.0), HTML(value='')))





Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ModelForSegmentation: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ModelForSegmentation were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 4832/4832 [1:01:45<00:00,  1.30it/s

# Get results

In [41]:
# Данная функция основана на https://github.com/felixkreuk/UnsupSeg/blob/master/utils.py

class RMetrics1(nn.Module):
    def __init__(self, eps = 1e-5, tolerance = 2, sampling_rate = 16000):
        super(RMetrics1, self).__init__()
        self.tolerance = tolerance
        self.eps = eps
        self.sampling_rate = sampling_rate
    
    def calculate_stride(self, isz, config):
        pad = 0
        insize = isz
        totstride = 1
        sec_per_frame = 1/self.sampling_rate

        for kernel, stride in zip(config.conv_kernel, config.conv_stride):
            outsize = (insize + 2*pad - 1*(kernel-1)-1) / stride + 1
            insize = outsize
            totstride = totstride * stride

        RFsize = isz - (outsize - 1) * totstride

        ms_per_frame = sec_per_frame*RFsize*1000
        ms_stride = sec_per_frame*totstride*1000
        return outsize, totstride, RFsize, ms_per_frame, ms_stride
        
    def get_frames(self, secs, stride):
        frames = [[int(i*self.sampling_rate/stride) for i in sec] for sec in secs]
        return frames
        
    def make_true_boundaries(self, secs, boundaries, stride):
        frames = self.get_frames(secs, stride)
        true_boundaries = torch.zeros(size = boundaries.shape)
        for num_frame, frame in enumerate(frames):
            for i in frame:
                true_boundaries[num_frame, i] = 1
        return true_boundaries.long().detach().numpy()
    
    def get_sec_bounds(self, b, stride, attention_mask = None):
        if type(b)==torch.Tensor:
            b1 = b.long().detach().cpu().numpy()
        else:
            b1 = b
        
        if attention_mask is not None:
            b1 = b1*attention_mask.long().detach().cpu().numpy()
            
        frames_pred = []
        secs_pred = []
        for i in range(b1.shape[0]):
            frames = np.where(b1[i, :] == 1)[0]
            secs = [i*stride/self.sampling_rate for i in frames]
            frames_pred.append(frames)
            secs_pred.append(secs)
        return frames_pred, secs_pred
    
    def get_precision_recall_frames(self, true_boundaries, b, attention_mask = None):
        if type(b)==torch.Tensor:
            b1 = b.long().detach().numpy()
        else:
            b1 = b
            
        if attention_mask is not None:
            b1 = b1*attention_mask.long().detach().cpu().numpy()
            
        recall = recall_score(true_boundaries.flatten(), b1.flatten())
        pre = precision_score(true_boundaries.flatten(), b1.flatten())
        f_score = f1_score(true_boundaries.flatten(), b1.flatten())
        return recall, pre, f_score
    
    def get_stats(self, frames_true, frames_pred):
        
        # Утащено отсюда: https://github.com/felixkreuk/UnsupSeg/blob/68c2c7b9bd49f3fb8f51c5c2f4d5aa85f251eaa8/utils.py#L69
        precision_counter = 0 
        recall_counter = 0
        pred_counter = 0 
        gt_counter = 0

        for (y, yhat) in zip(frames_true, frames_pred):
            for yhat_i in yhat:
                min_dist = np.abs(np.array(y) - yhat_i).min()
                precision_counter += (min_dist <= self.tolerance)
            for y_i in y:
                min_dist = np.abs(np.array(yhat) - y_i).min()
                recall_counter += (min_dist <= self.tolerance)
            pred_counter += len(yhat)
            gt_counter += len(y)

        return precision_counter, recall_counter, pred_counter, gt_counter
    
    def calc_metr(self, precision_counter, recall_counter, pred_counter, gt_counter):

        # Утащено отсюда: https://github.com/felixkreuk/UnsupSeg/blob/68c2c7b9bd49f3fb8f51c5c2f4d5aa85f251eaa8/utils.py#L69
        EPS = 1e-7

        precision = precision_counter / (pred_counter + self.eps)
        recall = recall_counter / (gt_counter + self.eps)
        f1 = 2 * (precision * recall) / (precision + recall + self.eps)

        os = recall / (precision + EPS) - 1
        r1 = np.sqrt((1 - recall) ** 2 + os ** 2)
        r2 = (-os + recall - 1) / (np.sqrt(2))
        rval = 1 - (np.abs(r1) + np.abs(r2)) / 2

        return precision, recall, f1, rval
    
    def get_metrics(self, true_secs, b, seq_len, config, attention_mask = None, 
                    return_secs=False):
        
        outsize, totstride, RFsize, ms_per_frame, ms_stride = self.calculate_stride(seq_len, config)
#         print(seq_len, outsize, totstride, RFsize, ms_per_frame, ms_stride)
        frames_true = self.get_frames(true_secs, totstride)
        frames_pred, secs_pred = self.get_sec_bounds(b, totstride, attention_mask = attention_mask)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(frames_true, frames_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        if return_secs:
            return precision, recall, f1, rval, secs_pred
        else:
            return precision, recall, f1, rval
        
    def get_metrics_secs(self, true_secs, secs_pred, totstride = 320):
        
        frames_true = self.get_frames(true_secs, totstride)
        frames_pred = self.get_frames(secs_pred, totstride)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(frames_true, frames_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        return precision, recall, f1, rval
    
    

In [66]:
def read_true_file(segment_file):
    with open(segment_file, 'r', encoding="cp1251") as file:
        tt = file.read()

    boundaries = set()
    mm = [i for i in tt.split('\n') if len(i)>0]
    
    for i in mm:
        boundaries.add(eval(i.split()[0]))
    boundaries = sorted(list(boundaries))
    return boundaries

In [67]:
def read_pred_file(segment_file):
    with open(segment_file, 'r', encoding="cp1251") as file:
        tt = file.read()
    boundaries = eval(tt)
    if len(boundaries)>1:
        return boundaries[1:]
    else:
        return boundaries

In [81]:
import os
import pandas as pd

In [82]:
path_to_save = 'save_results_path_buckeye'
path_results = os.path.join(path_to_save, 'Results')

folders = ['train', 'valid', 'test']
true_files = [i.replace('.wav', '.txt') for i in list(iter_find_files('Buckeye_fin/Train', "*.wav"))]
idss = [i.split('/')[-1] for i in true_files]

In [83]:
result_dataframes = []
metr = RMetrics1()

for folder, true_folder in zip(folders, ['Buckeye_fin/Train', 'Buckeye_fin/Valid', 'Buckeye_fin/Test']):

    true_files = [i.replace('.wav', '.txt') for i in list(iter_find_files(true_folder, "*.wav"))]
    idss = [i.split('/')[-1] for i in true_files]
    
    secs_preds = []
    secs_trues = []

    for num in tqdm(range(len(true_files))):

        bound_true = read_true_file(true_files[num])
        bound_pred = read_pred_file(os.path.join(os.path.join(path_results, folder), idss[num]))

        secs_trues.append(bound_true)
        secs_preds.append(bound_pred)
        
    precision, recall, f1, rval = metr.get_metrics_secs(secs_trues, secs_preds)
    
    datafr = pd.DataFrame([folder, 
                           precision, recall, 
                           f1, rval]).T.rename(columns = {0:'type', 
                                                          1:'precision',
                                                          2:'recall', 
                                                          3:'f1', 
                                                          4:'rval'})
    
    result_dataframes.append(datafr)
    
result_df = pd.concat(result_dataframes, ignore_index = True)
result_df.to_csv('results_buckeye.csv')

100%|██████████| 38656/38656 [02:07<00:00, 302.62it/s]
100%|██████████| 4267/4267 [00:09<00:00, 438.94it/s]
100%|██████████| 4015/4015 [00:09<00:00, 438.15it/s]


In [84]:
result_df

Unnamed: 0,type,precision,recall,f1,rval
0,train,0.822638,0.681589,0.745495,0.767224
1,valid,0.856086,0.738185,0.792771,0.808212
2,test,0.8249,0.690123,0.751512,0.773051
