# Импорты

In [17]:
import json
import os
import torch
from tqdm import tqdm
import scipy.io.wavfile as wav

from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter

import numpy as np

In [18]:
import transformers
import json
from torch.utils.data import Dataset, DataLoader, Sampler
import os

In [19]:
from models import ConvFeatureEncoder, SegmentsRepr, SegmentsEncoder, NegativeSampler, SegmentPredictor, FinModel
from utils import ConstrativeLoss, sample_negatives
from collections import OrderedDict
import shutil

In [20]:
from model_transformers import SegmentTransformer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

In [21]:
transformers.__version__

'4.14.0'

In [22]:
class Dataset:
    
    def __init__(self, path, segment_path, manifest_path, edges_path = None, chars = 'segments_chars/', frames = 'secs/'):
        with open(manifest_path, 'r') as json_file:
            manifest = json.load(json_file)
        self.manifest = manifest
#         self.manifest = manifest[:1000]
        self.path = path
        self.segment_path = segment_path
        self.frames = os.path.join(chars, frames)
        self.edges_path = edges_path
    
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, ind):
#         print(self.manifest[ind])
        
        # Загрузка аудио-сигнала
        
        audio_filepath = self.manifest[ind]['audio_filepath']
#         print(audio_filepath)
        audio_file = os.path.join(self.path, audio_filepath)
#         print(audio_file)
        sampling_rate, signal = wav.read(audio_file)
#         signal, sampling_rate = torchaudio.load(path)
        
        # Обрезать тишину
        filetext_id = self.manifest[ind]['id']+'.txt'
        silero_filepath = audio_filepath.replace('.wav', '.txt')
        silero_file = os.path.join(self.edges_path, silero_filepath)
#         print(silero_file)
        
        with open(silero_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        start = []
        end = []
        for line in tt.split('\n'):
            if len(line) > 0:
                start.append(eval(line.split()[0]))
                end.append(eval(line.split()[1]))
        start = min(start)
        end = max(end)
#         print(start)
#         print(end)
        signal = signal[start:end]
        
        # Загрузка разметки
        filetext_id = self.manifest[ind]['id']+'.txt'
        segment_filepath = audio_filepath.replace('.wav', '.txt').replace(filetext_id, '')
        
        segment_file = os.path.join(os.path.join(os.path.join(self.segment_path, segment_filepath), self.frames), filetext_id)
#         print(segment_file)
        with open(segment_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        boundaries = set()
        mm = [i for i in tt.split('\n') if len(i)>0]
        for i in mm:
            boundaries.add(eval(i.split()[0]))
            boundaries.add(eval(i.split()[1]))
        boundaries = sorted(list(boundaries))
        
        return {'audio_file':os.path.join(self.path, audio_filepath), 
                'segment_file':segment_file, 
                'id':filetext_id, 'sample': signal, 
                'length': len(signal), 'boundaries': boundaries}
        

In [23]:
class Dataset_test:
    
    def __init__(self, path, segment_path, manifest_path, edges_path = None, chars = 'segments_chars/', frames = 'secs/'):
        with open(manifest_path, 'r') as json_file:
            manifest = json.load(json_file)
        self.manifest = manifest
#         self.manifest = manifest[:1000]
        self.path = path
        self.segment_path = segment_path
        self.frames = os.path.join(chars, frames)
        self.edges_path = edges_path
    
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, ind):
#         print(self.manifest[ind])
        
        # Загрузка аудио-сигнала
        
        audio_filepath = self.manifest[ind]['audio_filepath']
#         print(audio_filepath)
        audio_file = os.path.join(self.path, audio_filepath)
#         print(audio_file)
        sampling_rate, signal = wav.read(audio_file)
#         signal, sampling_rate = torchaudio.load(path)
        
        # Обрезать тишину
        filetext_id = self.manifest[ind]['id']+'.txt'
        silero_filepath = audio_filepath.replace('.wav', '.txt')
        silero_file = os.path.join(self.edges_path, silero_filepath).replace('files/', '')
#         print(silero_file)
        
        with open(silero_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        start = []
        end = []
        for line in tt.split('\n'):
            if len(line) > 0:
                start.append(eval(line.split()[0]))
                end.append(eval(line.split()[1]))
        start = min(start)
        end = max(end)
#         print(start)
#         print(end)
        signal = signal[start:end]
        
        # Загрузка разметки
        filetext_id = self.manifest[ind]['id']+'.txt'
        segment_filepath = audio_filepath.replace('.wav', '.txt').replace(filetext_id, '')
        
        segment_file = os.path.join(os.path.join(os.path.join(self.segment_path, segment_filepath), self.frames), filetext_id).replace('files/', '')
#         print(segment_file)
        with open(segment_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        boundaries = set()
        mm = [i for i in tt.split('\n') if len(i)>0]
        for i in mm:
            boundaries.add(eval(i.split()[0]))
            boundaries.add(eval(i.split()[1]))
        boundaries = sorted(list(boundaries))
        
        return {'audio_file':os.path.join(self.path, audio_filepath), 
                'segment_file':segment_file, 
                'id':filetext_id, 'sample': signal, 
                'length': len(signal), 'boundaries': boundaries}
        
#         return {'id':filetext_id, 'sample': signal, 'length': len(signal), 'boundaries': boundaries}

In [24]:
def collate_fn(samples):
    
    max_length = max([sample['length'] for sample in samples])
    boundaries = [sample['boundaries'] for sample in samples]
    samples1 = []
    lengths = []
    samplings = []
    attentions = []
    ids = []
    audio_files = []
    segment_files = []
    for sample in samples:
        to_add_l = max_length-sample['length']
        sample1 = list(sample['sample'])+[0]*to_add_l
        samples1.append(torch.Tensor(sample1).unsqueeze(0))
        lengths.append(sample['length'])
        ids.append(sample['id'])
        audio_files.append(sample['audio_file'])
        segment_files.append(sample['segment_file'])
        att_norm = torch.ones(size = (1, sample['length']))
        att_add = torch.zeros(size = (1, to_add_l))
        att = torch.cat([att_norm, att_add], dim = -1)
        attentions.append(att)
        
    batch = torch.cat(samples1)
    lengths = torch.Tensor(lengths)
    attention_mask = torch.cat(attentions, dim = 0)
    
    return dict(batch=batch, lengths=lengths, attention_mask=attention_mask, 
                boundaries=boundaries, ids=ids, 
                audio_files=audio_files, 
                segment_files=segment_files)

# Predict Train

In [17]:
train_dataset = Dataset('train/', 'segments_edges_init/', 'manifest_silero_edges_train1.json', 
                        edges_path = 'silero_edges',
                              chars = 'segments_chars/', frames = 'secs/')
val_dataset = Dataset('train/', 'segments_edges_init/', 'manifest_silero_edges_val1.json', 
                        edges_path = 'silero_edges',
                              chars = 'segments_chars/', frames = 'secs/')
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)

In [20]:
path_to_save = 'save_results_path'
path_audio = os.path.join(path_to_save, 'Audio')
path_segments = os.path.join(path_to_save, 'Segment')
path_results = os.path.join(path_to_save, 'Results')

In [21]:
os.makedirs(path_to_save, exist_ok=True) 
os.makedirs(path_audio, exist_ok=True) 
os.makedirs(path_segments, exist_ok=True) 
os.makedirs(path_results, exist_ok=True) 

In [22]:
os.listdir(path_to_save)

['Audio', 'Segment', 'Results']

In [23]:
import warnings
warnings.filterwarnings("ignore")

In [24]:
for batch in val_loader:
    break

In [28]:
from transformers_f.src.transformers.activations import ACT2FN
from transformers_f.src.transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureExtractor, Wav2Vec2Model

from transformers_f.src.transformers.models.wav2vec2.modeling_segmentation import (Wav2Vec2ModelForSegmentation,
                                                                                   SegmentsRepr)

In [29]:
segment_paths = ['golos_model_segment_r_val_transformers_acc_10_ep_500_r_val_edges_train.ckpt']
names_paths = ['edges']
thres = [0.05]

In [32]:
for model_path, name_path, thre in zip(segment_paths, names_paths, thres):
    os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 
    
    wav2vec_segm = Wav2Vec2ModelForSegmentation.from_pretrained("facebook/wav2vec2-base-960h")
    wav2vec_segm.segment_mean = SegmentsRepr(thres = thre)
    
    checkpoint = torch.load(model_path)
#     checkpoint = torch.load(model_path, map_location='cpu')
    
    state_dicts = OrderedDict()
    for key, value in checkpoint['state_dict'].items():
        state_dicts[key.replace('wav2vec_segm.', '')] = value
    wav2vec_segm.load_state_dict(state_dicts)
    
    wav2vec_segm.eval()
    wav2vec_segm=wav2vec_segm.to('cuda')
    
    for batch in tqdm(train_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'train'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))

    for batch in tqdm(val_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'val'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))
                

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1596.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=377667514.0), HTML(value='')))





100%|██████████| 13623/13623 [1:26:31<00:00,  2.62it/s]


# Predict Test

In [9]:
class Dataset_test:
    
    def __init__(self, path, segment_path, manifest_path, edges_path = None, chars = 'segments_chars/', frames = 'secs/'):
        with open(manifest_path, 'r') as json_file:
            manifest = json.load(json_file)
        self.manifest = manifest
#         self.manifest = manifest[:1000]
        self.path = path
        self.segment_path = segment_path
        self.frames = os.path.join(chars, frames)
        self.edges_path = edges_path
    
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, ind):
#         print(self.manifest[ind])
        
        # Загрузка аудио-сигнала
        
        audio_filepath = self.manifest[ind]['audio_filepath']
#         print(audio_filepath)
        audio_file = os.path.join(self.path, audio_filepath)
#         print(audio_file)
        sampling_rate, signal = wav.read(audio_file)
#         signal, sampling_rate = torchaudio.load(path)
        
        # Обрезать тишину
        filetext_id = self.manifest[ind]['id']+'.txt'
        silero_filepath = audio_filepath.replace('.wav', '.txt')
        silero_file = os.path.join(self.edges_path, silero_filepath).replace('files/', '')
#         print(silero_file)
        
        with open(silero_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        start = []
        end = []
        for line in tt.split('\n'):
            if len(line) > 0:
                start.append(eval(line.split()[0]))
                end.append(eval(line.split()[1]))
        start = min(start)
        end = max(end)
#         print(start)
#         print(end)
        signal = signal[start:end]
        
        # Загрузка разметки
        filetext_id = self.manifest[ind]['id']+'.txt'
        segment_filepath = audio_filepath.replace('.wav', '.txt').replace(filetext_id, '')
        
        segment_file = os.path.join(os.path.join(os.path.join(self.segment_path, segment_filepath), self.frames), filetext_id).replace('files/', '')
#         print(segment_file)
        with open(segment_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        boundaries = set()
        mm = [i for i in tt.split('\n') if len(i)>0]
        for i in mm:
            boundaries.add(eval(i.split()[0]))
            boundaries.add(eval(i.split()[1]))
        boundaries = sorted(list(boundaries))
        
        return {'audio_file':os.path.join(self.path, audio_filepath), 
                'segment_file':segment_file, 
                'id':filetext_id, 'sample': signal, 
                'length': len(signal), 'boundaries': boundaries}
        
#         return {'id':filetext_id, 'sample': signal, 'length': len(signal), 'boundaries': boundaries}

In [10]:
test_dataset = Dataset_test('test/', 'segments_edges_init_test/', 'manifest_test.json', 
                        edges_path = 'silero_edges_test',
                              chars = 'segments_chars/', frames = 'secs/')

test_loader = DataLoader(test_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)

In [13]:
from transformers_f.src.transformers.activations import ACT2FN
from transformers_f.src.transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureExtractor, Wav2Vec2Model

from transformers_f.src.transformers.models.wav2vec2.modeling_segmentation import (Wav2Vec2ModelForSegmentation,
                                                                                   SegmentsRepr)

In [14]:
for batch in test_loader:
    break

In [16]:

segment_paths = ['golos_model_segment_r_val_transformers_acc_10_ep_500_r_val_edges_train.ckpt']
names_paths = ['edges']
thres = [0.05]

In [17]:
path_to_save = 'save_results_path'
path_results = os.path.join(path_to_save, 'Results')

In [18]:
os.makedirs(path_to_save, exist_ok=True) 
os.makedirs(path_results, exist_ok=True) 

In [20]:
import warnings
warnings.filterwarnings("ignore")

In [22]:
for model_path, name_path, thre in zip(segment_paths, names_paths, thres):
    os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 
    
    wav2vec_segm = Wav2Vec2ModelForSegmentation.from_pretrained("facebook/wav2vec2-base-960h")
    wav2vec_segm.segment_mean = SegmentsRepr(thres = thre)
    
    checkpoint = torch.load(model_path)
#     checkpoint = torch.load(model_path, map_location='cpu')
    
    state_dicts = OrderedDict()
    for key, value in checkpoint['state_dict'].items():
        state_dicts[key.replace('wav2vec_segm.', '')] = value
    wav2vec_segm.load_state_dict(state_dicts)
    
    wav2vec_segm.eval()
    wav2vec_segm=wav2vec_segm.to('cuda')
    
    for batch in tqdm(test_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'test2'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1596.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=377667514.0), HTML(value='')))





Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ModelForSegmentation: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ModelForSegmentation were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 1476/1476 [08:07<00:00,  3.03it/s]


# Get results

In [25]:
import pandas as pd
import os

In [26]:
class RMetrics1(nn.Module):
    def __init__(self, eps = 1e-5, tolerance = 2, sampling_rate = 16000):
        super(RMetrics1, self).__init__()
        self.tolerance = tolerance
        self.eps = eps
        self.sampling_rate = sampling_rate
    
    def calculate_stride(self, isz, config):
        pad = 0
        insize = isz
        totstride = 1
        sec_per_frame = 1/self.sampling_rate

        for kernel, stride in zip(config.conv_kernel, config.conv_stride):
            outsize = (insize + 2*pad - 1*(kernel-1)-1) / stride + 1
            insize = outsize
            totstride = totstride * stride

        RFsize = isz - (outsize - 1) * totstride

        ms_per_frame = sec_per_frame*RFsize*1000
        ms_stride = sec_per_frame*totstride*1000
        return outsize, totstride, RFsize, ms_per_frame, ms_stride
        
    def get_frames(self, secs, stride):
        frames = [[int(i*self.sampling_rate/stride) for i in sec] for sec in secs]
        return frames
        
    def make_true_boundaries(self, secs, boundaries, stride):
        frames = self.get_frames(secs, stride)
        true_boundaries = torch.zeros(size = boundaries.shape)
        for num_frame, frame in enumerate(frames):
            for i in frame:
                true_boundaries[num_frame, i] = 1
        return true_boundaries.long().detach().numpy()
    
    def get_sec_bounds(self, b, stride, attention_mask = None):
        if type(b)==torch.Tensor:
            b1 = b.long().detach().cpu().numpy()
        else:
            b1 = b
        
        if attention_mask is not None:
            b1 = b1*attention_mask.long().detach().cpu().numpy()
            
        frames_pred = []
        secs_pred = []
        for i in range(b1.shape[0]):
            frames = np.where(b1[i, :] == 1)[0]
            secs = [i*stride/self.sampling_rate for i in frames]
            frames_pred.append(frames)
            secs_pred.append(secs)
        return frames_pred, secs_pred
    
    def get_precision_recall_frames(self, true_boundaries, b, attention_mask = None):
        if type(b)==torch.Tensor:
            b1 = b.long().detach().numpy()
        else:
            b1 = b
            
        if attention_mask is not None:
            b1 = b1*attention_mask.long().detach().cpu().numpy()
            
        recall = recall_score(true_boundaries.flatten(), b1.flatten())
        pre = precision_score(true_boundaries.flatten(), b1.flatten())
        f_score = f1_score(true_boundaries.flatten(), b1.flatten())
        return recall, pre, f_score
    
    def get_stats(self, frames_true, frames_pred):
        
        # Утащено отсюда: https://github.com/felixkreuk/UnsupSeg/blob/68c2c7b9bd49f3fb8f51c5c2f4d5aa85f251eaa8/utils.py#L69
        precision_counter = 0 
        recall_counter = 0
        pred_counter = 0 
        gt_counter = 0

        for (y, yhat) in zip(frames_true, frames_pred):
            for yhat_i in yhat:
                min_dist = np.abs(np.array(y) - yhat_i).min()
                precision_counter += (min_dist <= self.tolerance)
            for y_i in y:
                min_dist = np.abs(np.array(yhat) - y_i).min()
                recall_counter += (min_dist <= self.tolerance)
            pred_counter += len(yhat)
            gt_counter += len(y)

        return precision_counter, recall_counter, pred_counter, gt_counter
    
    def calc_metr(self, precision_counter, recall_counter, pred_counter, gt_counter):

        # Утащено отсюда: https://github.com/felixkreuk/UnsupSeg/blob/68c2c7b9bd49f3fb8f51c5c2f4d5aa85f251eaa8/utils.py#L69
        EPS = 1e-7

        precision = precision_counter / (pred_counter + self.eps)
        recall = recall_counter / (gt_counter + self.eps)
        f1 = 2 * (precision * recall) / (precision + recall + self.eps)

        os = recall / (precision + EPS) - 1
        r1 = np.sqrt((1 - recall) ** 2 + os ** 2)
        r2 = (-os + recall - 1) / (np.sqrt(2))
        rval = 1 - (np.abs(r1) + np.abs(r2)) / 2

        return precision, recall, f1, rval
    
    def get_metrics(self, true_secs, b, seq_len, config, attention_mask = None, 
                    return_secs=False):
        
        outsize, totstride, RFsize, ms_per_frame, ms_stride = self.calculate_stride(seq_len, config)
#         print(seq_len, outsize, totstride, RFsize, ms_per_frame, ms_stride)
        frames_true = self.get_frames(true_secs, totstride)
        frames_pred, secs_pred = self.get_sec_bounds(b, totstride, attention_mask = attention_mask)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(frames_true, frames_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        if return_secs:
            return precision, recall, f1, rval, secs_pred
        else:
            return precision, recall, f1, rval
        
    def get_metrics_secs(self, true_secs, secs_pred, totstride = 160):
        
        frames_true = self.get_frames(true_secs, totstride)
        frames_pred = self.get_frames(secs_pred, totstride)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(frames_true, frames_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        return precision, recall, f1, rval
    
    

In [40]:
def read_true_file(segment_file):
    with open(segment_file, 'r', encoding="cp1251") as file:
        tt = file.read()

    boundaries = set()
    mm = [i for i in tt.split('\n') if len(i)>0]
    for i in mm:
        boundaries.add(eval(i.split()[0]))
        boundaries.add(eval(i.split()[1]))
    boundaries = sorted(list(boundaries))
    return boundaries

In [41]:
def read_pred_file(segment_file):
    with open(segment_file, 'r', encoding="cp1251") as file:
        tt = file.read()
    boundaries = eval(tt)
    if len(boundaries)>1:
        return boundaries[1:]
    else:
        return boundaries

In [55]:
class Dataset_res:
    
    def __init__(self, path, segment_path, manifest_path, chars = 'segments_chars/', frames = 'secs/'):
        with open(manifest_path, 'r') as json_file:
            manifest = json.load(json_file)
        self.manifest = manifest
        self.path = path
        self.segment_path = segment_path
        self.frames = os.path.join(chars, frames)
    
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, ind):
#         print(self.manifest[ind])
        
        # Загрузка аудио-сигнала
        
        audio_filepath = self.manifest[ind]['audio_filepath']
#         print(audio_filepath)
        audio_file = os.path.join(self.path, audio_filepath)
    
        filetext_id = self.manifest[ind]['id']+'.txt'
        segment_filepath = audio_filepath.replace('.wav', '.txt').replace(filetext_id, '')
        
        segment_file = os.path.join(os.path.join(os.path.join(self.segment_path, segment_filepath), self.frames), filetext_id).replace('files/', '')
#         print(segment_file)
        with open(segment_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        boundaries = set()
        mm = [i for i in tt.split('\n') if len(i)>0]
        for i in mm:
            boundaries.add(eval(i.split()[0]))
            boundaries.add(eval(i.split()[1]))
        boundaries = sorted(list(boundaries))
        
        return {'segment_file':segment_file, 
                'id':filetext_id,
                'boundaries': boundaries}
        

In [56]:
def collate_fn_res(samples):

    boundaries = [sample['boundaries'] for sample in samples]
    samples1 = []
    lengths = []
    samplings = []
    attentions = []
    ids = []
    audio_files = []
    segment_files = []
    for sample in samples:
        ids.append(sample['id'])
        audio_files.append(sample['audio_file'])
        segment_files.append(sample['segment_file'])
        
    
    return dict(batch=batch, lengths=lengths, 
                boundaries=boundaries, ids=ids, 
                audio_files=audio_files, 
                segment_files=segment_files)

In [57]:
train_dataset = Dataset_res('train/', 'segments_edges_init/', 'manifest_silero_edges_train3.json', 
                              chars = 'segments_chars/', frames = 'secs/')
val_dataset = Dataset_res('train/', 'segments_edges_init/', 'manifest_silero_edges_val1.json', 
                              chars = 'segments_chars/', frames = 'secs/')
test_dataset = Dataset_res('test/', 'segments_edges_init_test/', 'manifest_test.json', 
                              chars = 'segments_chars/', frames = 'secs/')

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn_res)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn_res)
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn_res)

In [58]:
train_dataset[0]

{'segment_file': 'segments_edges_init/crowd/0/segments_chars/secs/77b380796d242cf5bc09cfb551cffecd.txt',
 'id': '77b380796d242cf5bc09cfb551cffecd.txt',
 'boundaries': [0.07375355450236964,
  0.3201990521327014,
  0.39602843601895743,
  0.43394312796208534,
  0.5287298578199053,
  0.5666445497630332]}

In [64]:
path_to_save = 'save_results_path'
path_results = os.path.join(path_to_save, 'Results')
folders = ['train', 'val', 'test2']

In [65]:
result_dataframes = []
metr = RMetrics1()
for folder, dataset in zip(folders, [train_dataset, val_dataset, test_dataset]):

    secs_preds = []
    secs_trues = []

    for i in tqdm(range(len(dataset))):
        
        idd = dataset[i]
        true_file = idd['segment_file']
        bound_true = idd['boundaries']
        ids = idd['id']
        
#         bound_true = read_true_file(true_files[num])
        bound_pred = read_pred_file(os.path.join(os.path.join(path_results, folder), ids))

        secs_trues.append(bound_true)
        secs_preds.append(bound_pred)
   

    precision, recall, f1, rval = metr.get_metrics_secs(secs_trues, secs_preds)
    
    datafr = pd.DataFrame([folder, 
                           precision, recall, 
                           f1, rval]).T.rename(columns = {0:'type', 
                                                          1:'precision',
                                                          2:'recall', 
                                                          3:'f1', 
                                                          4:'rval'})
    
    result_dataframes.append(datafr)

100%|██████████| 980538/980538 [1:30:42<00:00, 180.16it/s]
100%|██████████| 108957/108957 [09:00<00:00, 201.47it/s]
100%|██████████| 11804/11804 [00:24<00:00, 475.87it/s]


In [66]:
result_df = pd.concat(result_dataframes, ignore_index = True)

In [67]:
result_df

Unnamed: 0,type,precision,recall,f1,rval
0,train,0.843228,0.804486,0.823397,0.846699
1,val,0.843317,0.809125,0.825862,0.849283
2,test2,0.805539,0.794868,0.800163,0.829379


In [68]:
result_df.to_csv('results_golos.csv')