# Импорты

In [1]:
#!g1.1
import json
import os
import torch
from tqdm import tqdm
import scipy.io.wavfile as wav

from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter

import numpy as np

In [2]:
#!g1.1
import transformers
import json
from torch.utils.data import Dataset, DataLoader, Sampler
import os

In [3]:
#!g1.1
from models import ConvFeatureEncoder, SegmentsRepr, SegmentsEncoder, NegativeSampler, SegmentPredictor, FinModel
from utils import ConstrativeLoss, sample_negatives
from collections import OrderedDict
import shutil

In [4]:
#!g1.1
from model_transformers import SegmentTransformer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

In [5]:
#!g1.1
transformers.__version__

'4.14.0'

In [6]:
#!g1.1
class Dataset:
    
    def __init__(self, path, segment_path, manifest_path, edges_path = None, chars = 'segments_chars/', frames = 'secs/'):
        with open(manifest_path, 'r') as json_file:
            manifest = json.load(json_file)
        self.manifest = manifest
#         self.manifest = manifest[:1000]
        self.path = path
        self.segment_path = segment_path
        self.frames = os.path.join(chars, frames)
        self.edges_path = edges_path
    
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, ind):
#         print(self.manifest[ind])
        
        # Загрузка аудио-сигнала
        
        audio_filepath = self.manifest[ind]['audio_filepath']
#         print(audio_filepath)
        audio_file = os.path.join(self.path, audio_filepath)
#         print(audio_file)
        sampling_rate, signal = wav.read(audio_file)
#         signal, sampling_rate = torchaudio.load(path)
        
        # Обрезать тишину
        filetext_id = self.manifest[ind]['id']+'.txt'
        silero_filepath = audio_filepath.replace('.wav', '.txt')
        silero_file = os.path.join(self.edges_path, silero_filepath)
#         print(silero_file)
        
        with open(silero_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        start = []
        end = []
        for line in tt.split('\n'):
            if len(line) > 0:
                start.append(eval(line.split()[0]))
                end.append(eval(line.split()[1]))
        start = min(start)
        end = max(end)
#         print(start)
#         print(end)
        signal = signal[start:end]
        
        # Загрузка разметки
        filetext_id = self.manifest[ind]['id']+'.txt'
        segment_filepath = audio_filepath.replace('.wav', '.txt').replace(filetext_id, '')
        
        segment_file = os.path.join(os.path.join(os.path.join(self.segment_path, segment_filepath), self.frames), filetext_id)
#         print(segment_file)
        with open(segment_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        boundaries = set()
        mm = [i for i in tt.split('\n') if len(i)>0]
        for i in mm:
            boundaries.add(eval(i.split()[0]))
            boundaries.add(eval(i.split()[1]))
        boundaries = sorted(list(boundaries))
        
        return {'audio_file':os.path.join(self.path, audio_filepath), 
                'segment_file':segment_file, 
                'id':filetext_id, 'sample': signal, 
                'length': len(signal), 'boundaries': boundaries}
        

In [7]:
#!g1.1
class Dataset_test:
    
    def __init__(self, path, segment_path, manifest_path, edges_path = None, chars = 'segments_chars/', frames = 'secs/'):
        with open(manifest_path, 'r') as json_file:
            manifest = json.load(json_file)
        self.manifest = manifest
#         self.manifest = manifest[:1000]
        self.path = path
        self.segment_path = segment_path
        self.frames = os.path.join(chars, frames)
        self.edges_path = edges_path
    
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, ind):
#         print(self.manifest[ind])
        
        # Загрузка аудио-сигнала
        
        audio_filepath = self.manifest[ind]['audio_filepath']
#         print(audio_filepath)
        audio_file = os.path.join(self.path, audio_filepath)
#         print(audio_file)
        sampling_rate, signal = wav.read(audio_file)
#         signal, sampling_rate = torchaudio.load(path)
        
        # Обрезать тишину
        filetext_id = self.manifest[ind]['id']+'.txt'
        silero_filepath = audio_filepath.replace('.wav', '.txt')
        silero_file = os.path.join(self.edges_path, silero_filepath).replace('files/', '')
#         print(silero_file)
        
        with open(silero_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        start = []
        end = []
        for line in tt.split('\n'):
            if len(line) > 0:
                start.append(eval(line.split()[0]))
                end.append(eval(line.split()[1]))
        start = min(start)
        end = max(end)
#         print(start)
#         print(end)
        signal = signal[start:end]
        
        # Загрузка разметки
        filetext_id = self.manifest[ind]['id']+'.txt'
        segment_filepath = audio_filepath.replace('.wav', '.txt').replace(filetext_id, '')
        
        segment_file = os.path.join(os.path.join(os.path.join(self.segment_path, segment_filepath), self.frames), filetext_id).replace('files/', '')
#         print(segment_file)
        with open(segment_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        boundaries = set()
        mm = [i for i in tt.split('\n') if len(i)>0]
        for i in mm:
            boundaries.add(eval(i.split()[0]))
            boundaries.add(eval(i.split()[1]))
        boundaries = sorted(list(boundaries))
        
        return {'audio_file':os.path.join(self.path, audio_filepath), 
                'segment_file':segment_file, 
                'id':filetext_id, 'sample': signal, 
                'length': len(signal), 'boundaries': boundaries}
        
#         return {'id':filetext_id, 'sample': signal, 'length': len(signal), 'boundaries': boundaries}

In [8]:
#!g1.1
def collate_fn(samples):
    
    max_length = max([sample['length'] for sample in samples])
    boundaries = [sample['boundaries'] for sample in samples]
    samples1 = []
    lengths = []
    samplings = []
    attentions = []
    ids = []
    audio_files = []
    segment_files = []
    for sample in samples:
        to_add_l = max_length-sample['length']
        sample1 = list(sample['sample'])+[0]*to_add_l
        samples1.append(torch.Tensor(sample1).unsqueeze(0))
        lengths.append(sample['length'])
        ids.append(sample['id'])
        audio_files.append(sample['audio_file'])
        segment_files.append(sample['segment_file'])
        att_norm = torch.ones(size = (1, sample['length']))
        att_add = torch.zeros(size = (1, to_add_l))
        att = torch.cat([att_norm, att_add], dim = -1)
        attentions.append(att)
        
    batch = torch.cat(samples1)
    lengths = torch.Tensor(lengths)
    attention_mask = torch.cat(attentions, dim = 0)
    
    return dict(batch=batch, lengths=lengths, attention_mask=attention_mask, 
                boundaries=boundaries, ids=ids, 
                audio_files=audio_files, 
                segment_files=segment_files)

# Predict Train

In [9]:
#!g1.1
train_dataset = Dataset('train/', 'segments_edges_init/', '10hours_train.json', 
                        edges_path = 'silero_edges',
                              chars = 'segments_chars/', frames = 'secs/')
val_dataset = Dataset('train/', 'segments_edges_init/', '10hours_val.json', 
                        edges_path = 'silero_edges',
                              chars = 'segments_chars/', frames = 'secs/')
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)

In [10]:
#!g1.1


In [11]:
#!g1.1


In [12]:
#!g1.1
path_to_save = 'save_results_path_10_hours_conv'
path_results = os.path.join(path_to_save, 'Results')

In [13]:
#!g1.1
os.makedirs(path_to_save, exist_ok=True) 
os.makedirs(path_results, exist_ok=True) 

In [14]:
#!g1.1
os.listdir(path_to_save)

['Audio', 'Segment', 'Results']

In [15]:
#!g1.1
import warnings
warnings.filterwarnings("ignore")

In [16]:
#!g1.1
for batch in val_loader:
    break

In [17]:
#!g1.1


In [18]:
#!g1.1
from transformers_f.src.transformers.activations import ACT2FN
from transformers_f.src.transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureExtractor, Wav2Vec2Model

from transformers_f.src.transformers.models.wav2vec2.modeling_segmentation import (Wav2Vec2ModelForSegmentation,
                                                                                   SegmentsRepr)

In [19]:
#!g1.1
segment_paths = ['golos_model_segment_r_val_acc_200_edges_train_10_hours.ckpt']
names_paths = ['edges']
thres = [0.05]

In [20]:
#!g1.1


In [21]:
#!g1.1
accumulate_grad_batches = 10
cfg = {'optimizer': "adam",
'momentum': 0.9,
'learning_rate': 0.0001*accumulate_grad_batches,
'lr_anneal_gamma': 1.0,
'lr_anneal_step': 1000,
# 'epochs': 500,
'grad_clip': 0.5,
'batch_size': 8,

'conv_args': {},
'mask_args': {},
'segm_enc_args': {},
'segm_predictor_args': {},
'loss_args': {"n_negatives": 10, "loss_args": {"reduction": "mean"}},
'num_epoch': 2}

class Conf:
    def __init__(self, my_dict):
        for key, value in my_dict.items():
            setattr(self, key, value)
            
config = Conf(cfg)

In [22]:
#!g1.1


In [23]:
#!g1.1
for model_path, name_path, thre in zip(segment_paths, names_paths, thres):
    os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 
    
    wav2vec_segm = FinModel(config)
    
    checkpoint = torch.load(model_path)
#     checkpoint = torch.load(model_path, map_location='cpu')
    
    state_dicts = OrderedDict()
    for key, value in checkpoint['state_dict'].items():
        state_dicts[key.replace('wav2vec_segm.', '')] = value
    wav2vec_segm.load_state_dict(state_dicts)
    
    wav2vec_segm.eval()
    wav2vec_segm=wav2vec_segm.to('cuda')
    
    for batch in tqdm(train_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'train'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))

    for batch in tqdm(val_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'val'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))
                

100%|██████████| 1050/1050 [04:09<00:00,  4.20it/s]
100%|██████████| 63/63 [00:13<00:00,  4.60it/s]


In [24]:
#!g1.1


# Predict Test

In [25]:
#!g1.1
class Dataset_test:
    
    def __init__(self, path, segment_path, manifest_path, edges_path = None, chars = 'segments_chars/', frames = 'secs/'):
        with open(manifest_path, 'r') as json_file:
            manifest = json.load(json_file)
        self.manifest = manifest
#         self.manifest = manifest[:1000]
        self.path = path
        self.segment_path = segment_path
        self.frames = os.path.join(chars, frames)
        self.edges_path = edges_path
    
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, ind):
#         print(self.manifest[ind])
        
        # Загрузка аудио-сигнала
        
        audio_filepath = self.manifest[ind]['audio_filepath']
#         print(audio_filepath)
        audio_file = os.path.join(self.path, audio_filepath)
#         print(audio_file)
        sampling_rate, signal = wav.read(audio_file)
#         signal, sampling_rate = torchaudio.load(path)
        
        # Обрезать тишину
        filetext_id = self.manifest[ind]['id']+'.txt'
        silero_filepath = audio_filepath.replace('.wav', '.txt')
        silero_file = os.path.join(self.edges_path, silero_filepath).replace('files/', '')
#         print(silero_file)
        
        with open(silero_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        start = []
        end = []
        for line in tt.split('\n'):
            if len(line) > 0:
                start.append(eval(line.split()[0]))
                end.append(eval(line.split()[1]))
        start = min(start)
        end = max(end)
#         print(start)
#         print(end)
        signal = signal[start:end]
        
        # Загрузка разметки
        filetext_id = self.manifest[ind]['id']+'.txt'
        segment_filepath = audio_filepath.replace('.wav', '.txt').replace(filetext_id, '')
        
        segment_file = os.path.join(os.path.join(os.path.join(self.segment_path, segment_filepath), self.frames), filetext_id).replace('files/', '')
#         print(segment_file)
        with open(segment_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        boundaries = set()
        mm = [i for i in tt.split('\n') if len(i)>0]
        for i in mm:
            boundaries.add(eval(i.split()[0]))
            boundaries.add(eval(i.split()[1]))
        boundaries = sorted(list(boundaries))
        
        return {'audio_file':os.path.join(self.path, audio_filepath), 
                'segment_file':segment_file, 
                'id':filetext_id, 'sample': signal, 
                'length': len(signal), 'boundaries': boundaries}
        
#         return {'id':filetext_id, 'sample': signal, 'length': len(signal), 'boundaries': boundaries}

In [26]:
#!g1.1
test_farfield_dataset = Dataset_test('test/', 'segments_edges_init_test/', 'manifest_test_farfield.json', 
                        edges_path = 'silero_edges_test',
                              chars = 'segments_chars/', frames = 'secs/')
test_crowd_dataset = Dataset_test('test/', 'segments_edges_init_test/', 'manifest_test_crowd.json', 
                        edges_path = 'silero_edges_test',
                              chars = 'segments_chars/', frames = 'secs/')

test_farfield_loader = DataLoader(test_farfield_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
test_crowd_loader = DataLoader(test_crowd_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)

In [27]:
#!g1.1


In [28]:
#!g1.1


In [29]:
#!g1.1
from transformers_f.src.transformers.activations import ACT2FN
from transformers_f.src.transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureExtractor, Wav2Vec2Model

from transformers_f.src.transformers.models.wav2vec2.modeling_segmentation import (Wav2Vec2ModelForSegmentation,
                                                                                   SegmentsRepr)

In [30]:
#!g1.1
for batch in test_farfield_loader:
    break

In [31]:
#!g1.1


In [32]:
#!g1.1

segment_paths = ['golos_model_segment_r_val_acc_200_edges_train_10_hours.ckpt']
names_paths = ['edges']
thres = [0.05]

In [33]:
#!g1.1
path_to_save = 'save_results_path_10_hours_conv'
path_results = os.path.join(path_to_save, 'Results')

In [34]:
#!g1.1
os.makedirs(path_to_save, exist_ok=True) 
os.makedirs(path_results, exist_ok=True) 

In [89]:
#!g1.1
import warnings
warnings.filterwarnings("ignore")

In [36]:
#!g1.1
accumulate_grad_batches = 10
cfg = {'optimizer': "adam",
'momentum': 0.9,
'learning_rate': 0.0001*accumulate_grad_batches,
'lr_anneal_gamma': 1.0,
'lr_anneal_step': 1000,
# 'epochs': 500,
'grad_clip': 0.5,
'batch_size': 8,

'conv_args': {},
'mask_args': {},
'segm_enc_args': {},
'segm_predictor_args': {},
'loss_args': {"n_negatives": 10, "loss_args": {"reduction": "mean"}},
'num_epoch': 2}

class Conf:
    def __init__(self, my_dict):
        for key, value in my_dict.items():
            setattr(self, key, value)
            
config = Conf(cfg)

In [37]:
#!g1.1
for model_path, name_path, thre in zip(segment_paths, names_paths, thres):
    os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 
    
    wav2vec_segm = FinModel(config)
    
    checkpoint = torch.load(model_path)
#     checkpoint = torch.load(model_path, map_location='cpu')
    
    state_dicts = OrderedDict()
    for key, value in checkpoint['state_dict'].items():
        state_dicts[key.replace('wav2vec_segm.', '')] = value
    wav2vec_segm.load_state_dict(state_dicts)
    
    wav2vec_segm.eval()
    wav2vec_segm=wav2vec_segm.to('cuda')
    
    for batch in tqdm(test_farfield_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'farfield'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))
                
    for batch in tqdm(test_crowd_loader):
        x = batch['batch']
        lengths = batch['lengths']
        attention_mask = batch['attention_mask']
        secs = batch['boundaries']
        ids = batch['ids']
        name_path = 'crowd'
        os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 

        rr = wav2vec_segm.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = wav2vec_segm.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
        secs_preds = rr[1]['secs_pred']
        for idd, secs in zip(ids, secs_preds):
            with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
                file.write(str(secs))

100%|██████████| 240/240 [00:46<00:00,  5.19it/s]
100%|██████████| 1237/1237 [04:14<00:00,  4.87it/s]


# Get results

In [78]:
#!g1.1
import pandas as pd
import os

In [79]:
#!g1.1
class RMetrics1(nn.Module):
    def __init__(self, eps = 1e-5, tolerance = 2, sampling_rate = 16000):
        super(RMetrics1, self).__init__()
        self.tolerance = tolerance
        self.eps = eps
        self.sampling_rate = sampling_rate
    
    def calculate_stride(self, isz, conv_layers):
        pad = 0
        insize = isz
        totstride = 1
        sec_per_frame = 1/self.sampling_rate

        for layer in conv_layers:
            kernel, stride = layer
            outsize = (insize + 2*pad - 1*(kernel-1)-1) / stride + 1
            insize = outsize
            totstride = totstride * stride

        RFsize = isz - (outsize - 1) * totstride

        ms_per_frame = sec_per_frame*RFsize*1000
        ms_stride = sec_per_frame*totstride*1000
        return outsize, totstride, RFsize, ms_per_frame, ms_stride
        
    def get_frames(self, secs, stride):
        frames = [[int(i*self.sampling_rate/stride) for i in sec] for sec in secs]
        return frames
        
    def make_true_boundaries(self, secs, boundaries, stride):
        frames = self.get_frames(secs, stride)
        true_boundaries = torch.zeros(size = boundaries.shape)
        for num_frame, frame in enumerate(frames):
            for i in frame:
                true_boundaries[num_frame, i] = 1
        return true_boundaries.long().detach().numpy()
    
    def get_sec_bounds(self, b, stride, attention_mask = None):
        if type(b)==torch.Tensor:
            b1 = b.long().detach().cpu().numpy()
        else:
            b1 = b
        
        if attention_mask is not None:
            b1 = b1*attention_mask.long().detach().cpu().numpy()
            
        frames_pred = []
        secs_pred = []
        for i in range(b1.shape[0]):
            frames = np.where(b1[i, :] == 1)[0]
            secs = [i*stride/self.sampling_rate for i in frames]
            frames_pred.append(frames)
            secs_pred.append(secs)
        return frames_pred, secs_pred
    
    def get_precision_recall_frames(self, true_boundaries, b, attention_mask = None):
        if type(b)==torch.Tensor:
            b1 = b.long().detach().numpy()
        else:
            b1 = b
            
        if attention_mask is not None:
            b1 = b1*attention_mask.long().detach().cpu().numpy()
            
        recall = recall_score(true_boundaries.flatten(), b1.flatten())
        pre = precision_score(true_boundaries.flatten(), b1.flatten())
        f_score = f1_score(true_boundaries.flatten(), b1.flatten())
        return recall, pre, f_score
    
    def get_stats(self, frames_true, frames_pred):
        
        # Утащено отсюда: https://github.com/felixkreuk/UnsupSeg/blob/68c2c7b9bd49f3fb8f51c5c2f4d5aa85f251eaa8/utils.py#L69
        precision_counter = 0 
        recall_counter = 0
        pred_counter = 0 
        gt_counter = 0

        for (y, yhat) in zip(frames_true, frames_pred):
            for yhat_i in yhat:
                min_dist = np.abs(np.array(y) - yhat_i).min()
                precision_counter += (min_dist <= self.tolerance)
            for y_i in y:
                min_dist = np.abs(np.array(yhat) - y_i).min()
                recall_counter += (min_dist <= self.tolerance)
            pred_counter += len(yhat)
            gt_counter += len(y)

        return precision_counter, recall_counter, pred_counter, gt_counter
    
    def calc_metr(self, precision_counter, recall_counter, pred_counter, gt_counter):

        # Утащено отсюда: https://github.com/felixkreuk/UnsupSeg/blob/68c2c7b9bd49f3fb8f51c5c2f4d5aa85f251eaa8/utils.py#L69
        EPS = 1e-7

        precision = precision_counter / (pred_counter + self.eps)
        recall = recall_counter / (gt_counter + self.eps)
        f1 = 2 * (precision * recall) / (precision + recall + self.eps)

        os = recall / (precision + EPS) - 1
        r1 = np.sqrt((1 - recall) ** 2 + os ** 2)
        r2 = (-os + recall - 1) / (np.sqrt(2))
        rval = 1 - (np.abs(r1) + np.abs(r2)) / 2

        return precision, recall, f1, rval
    
    def get_metrics(self, true_secs, b, seq_len, config, attention_mask = None, 
                    return_secs=False):
        
        outsize, totstride, RFsize, ms_per_frame, ms_stride = self.calculate_stride(seq_len, config)
#         print(seq_len, outsize, totstride, RFsize, ms_per_frame, ms_stride)
        frames_true = self.get_frames(true_secs, totstride)
        frames_pred, secs_pred = self.get_sec_bounds(b, totstride, attention_mask = attention_mask)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(frames_true, frames_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        if return_secs:
            return precision, recall, f1, rval, secs_pred
        else:
            return precision, recall, f1, rval
        
    def get_metrics_secs(self, true_secs, secs_pred, totstride = 160):
        
        frames_true = self.get_frames(true_secs, totstride)
        frames_pred = self.get_frames(secs_pred, totstride)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(frames_true, frames_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        return precision, recall, f1, rval
    
    

In [80]:
#!g1.1
def read_true_file(segment_file):
    with open(segment_file, 'r', encoding="cp1251") as file:
        tt = file.read()

    boundaries = set()
    mm = [i for i in tt.split('\n') if len(i)>0]
    for i in mm:
        boundaries.add(eval(i.split()[0]))
        boundaries.add(eval(i.split()[1]))
    boundaries = sorted(list(boundaries))
    return boundaries

In [81]:
#!g1.1
def read_pred_file(segment_file):
    with open(segment_file, 'r', encoding="cp1251") as file:
        tt = file.read()
    boundaries = eval(tt)
    if len(boundaries)>1:
        return boundaries[1:]
    else:
        return boundaries

In [82]:
#!g1.1
class Dataset_res:
    
    def __init__(self, path, segment_path, manifest_path, chars = 'segments_chars/', frames = 'secs/'):
        with open(manifest_path, 'r') as json_file:
            manifest = json.load(json_file)
        self.manifest = manifest
        self.path = path
        self.segment_path = segment_path
        self.frames = os.path.join(chars, frames)
    
    def __len__(self):
        return len(self.manifest)
    
    def __getitem__(self, ind):
#         print(self.manifest[ind])
        
        # Загрузка аудио-сигнала
        
        audio_filepath = self.manifest[ind]['audio_filepath']
#         print(audio_filepath)
        audio_file = os.path.join(self.path, audio_filepath)
    
        filetext_id = self.manifest[ind]['id']+'.txt'
        segment_filepath = audio_filepath.replace('.wav', '.txt').replace(filetext_id, '')
        
        segment_file = os.path.join(os.path.join(os.path.join(self.segment_path, segment_filepath), self.frames), filetext_id).replace('files/', '')
#         print(segment_file)
        with open(segment_file, 'r', encoding="cp1251") as file:
            tt = file.read()
        
        boundaries = set()
        mm = [i for i in tt.split('\n') if len(i)>0]
        for i in mm:
            boundaries.add(eval(i.split()[0]))
            boundaries.add(eval(i.split()[1]))
        boundaries = sorted(list(boundaries))
        
        return {'segment_file':segment_file, 
                'id':filetext_id,
                'boundaries': boundaries}
        

In [83]:
#!g1.1
def collate_fn_res(samples):

    boundaries = [sample['boundaries'] for sample in samples]
    samples1 = []
    lengths = []
    samplings = []
    attentions = []
    ids = []
    audio_files = []
    segment_files = []
    for sample in samples:
        ids.append(sample['id'])
        audio_files.append(sample['audio_file'])
        segment_files.append(sample['segment_file'])
        
    
    return dict(batch=batch, lengths=lengths, 
                boundaries=boundaries, ids=ids, 
                audio_files=audio_files, 
                segment_files=segment_files)

In [84]:
#!g1.1
train_dataset = Dataset('train/', 'segments_edges_init/', '10hours_train.json', 
                        edges_path = 'silero_edges',
                              chars = 'segments_chars/', frames = 'secs/')
val_dataset = Dataset('train/', 'segments_edges_init/', '10hours_val.json', 
                        edges_path = 'silero_edges',
                              chars = 'segments_chars/', frames = 'secs/')

# test_dataset = Dataset_res('test/', 'segments_edges_init_test/', 'manifest_test.json', 
#                               chars = 'segments_chars/', frames = 'secs/')
test_farfield_dataset = Dataset_test('test/', 'segments_edges_init_test/', 'manifest_test_farfield.json', 
                        edges_path = 'silero_edges_test',
                              chars = 'segments_chars/', frames = 'secs/')
test_crowd_dataset = Dataset_test('test/', 'segments_edges_init_test/', 'manifest_test_crowd.json', 
                        edges_path = 'silero_edges_test',
                              chars = 'segments_chars/', frames = 'secs/')

test_farfield_loader = DataLoader(test_farfield_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
test_crowd_loader = DataLoader(test_crowd_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn_res)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn_res)
# test_loader = DataLoader(test_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn_res)

In [90]:
#!g1.1
path_to_save = 'save_results_path_10_hours_conv'
path_results = os.path.join(path_to_save, 'Results')
folders = ['train', 'val', 'farfield', 'crowd']

In [91]:
#!g1.1
result_dataframes = []
metr = RMetrics1()
outsize, totstride, RFsize, ms_per_frame, ms_stride = metr.calculate_stride(100000, 
                                                                            wav2vec_segm.conv_layers_list)
for folder, dataset in zip(folders, [train_dataset, val_dataset, test_farfield_dataset, test_crowd_dataset]):

    secs_preds = []
    secs_trues = []

    for i in tqdm(range(len(dataset))):
        
        idd = dataset[i]
        true_file = idd['segment_file']
        bound_true = idd['boundaries']
        ids = idd['id']
        
#         bound_true = read_true_file(true_files[num])
        bound_pred = read_pred_file(os.path.join(os.path.join(path_results, folder), ids))

        secs_trues.append(bound_true)
        secs_preds.append(bound_pred)
   

    precision, recall, f1, rval = metr.get_metrics_secs(secs_trues, secs_preds, totstride = totstride)
    
    datafr = pd.DataFrame([folder, 
                           precision, recall, 
                           f1, rval]).T.rename(columns = {0:'type', 
                                                          1:'precision',
                                                          2:'recall', 
                                                          3:'f1', 
                                                          4:'rval'})
    
    result_dataframes.append(datafr)

100%|██████████| 8398/8398 [00:32<00:00, 257.57it/s]
100%|██████████| 500/500 [00:06<00:00, 79.87it/s]
100%|██████████| 1915/1915 [00:18<00:00, 102.41it/s]
100%|██████████| 9889/9889 [00:53<00:00, 185.97it/s]


In [92]:
#!g1.1
result_df = pd.concat(result_dataframes, ignore_index = True)

In [93]:
#!g1.1
result_df

Unnamed: 0,type,precision,recall,f1,rval
0,train,0.634034,0.64674,0.640319,0.691105
1,val,0.627283,0.646142,0.636568,0.686697
2,farfield,0.569982,0.688024,0.62346,0.629251
3,crowd,0.566495,0.639342,0.600713,0.635576


In [94]:
#!g1.1
result_df.to_csv('results_golos_sep_test_conv1.csv')

In [None]:
#!g1.1
