In [1]:
#!g1.1
import json
import os
import torch
from tqdm import tqdm
import scipy.io.wavfile as wav

from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.tensorboard import SummaryWriter

import numpy as np
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import math

import torchaudio
from boltons.fileutils import iter_find_files
from collections import OrderedDict

In [2]:
#!g1.1
# %pip install pytorch_lightning --upgrade

In [3]:
#!g1.1
from models import ConvFeatureEncoder, SegmentsRepr, SegmentsEncoder, NegativeSampler, SegmentPredictor, FinModel, FinModel1
from utils import ConstrativeLoss, sample_negatives
# from trainer import Trainer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

In [4]:
#!g1.1
# import wandb

In [5]:
#!g1.1
def spectral_size(wav_len):
    layers = [(10,5,0), (8,4,0), (4,2,0), (4,2,0), (4,2,0)]
    for kernel, stride, padding in layers:
        wav_len = math.floor((wav_len + 2*padding - 1*(kernel-1) - 1)/stride + 1)
    return wav_len

In [6]:
#!g1.1
# Данный класс основан на https://github.com/felixkreuk/UnsupSeg/blob/master/dataloader.py

class WavPhnDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.data = list(iter_find_files(self.path, "*.wav"))
        super(WavPhnDataset, self).__init__()

    @staticmethod
    def get_datasets(path):
        raise NotImplementedError

    def process_file(self, wav_path):
        phn_path = wav_path.replace(".wav", ".txt")
        filetext_id = phn_path.split('/')[-1]

        # load audio
        audio, sr = torchaudio.load(wav_path)
        audio = audio[0]
        audio_len = len(audio)

        # load labels -- segmentation and phonemes
        with open(phn_path, "r") as f:
            lines = [i.strip() for i in f.readlines()]
            times = torch.FloatTensor([eval(i.split()[0]) for i in lines])[:-1]
            phonemes = torch.FloatTensor([eval(i.split()[0])*16000 for i in lines])[:-1]

        return audio, times, phonemes, wav_path, filetext_id, phn_path

    def spectral_size(self, wav_len):
        layers = [(10,5,0), (8,4,0), (4,2,0), (4,2,0), (4,2,0)]
        for kernel, stride, padding in layers:
            wav_len = math.floor((wav_len + 2*padding - 1*(kernel-1) - 1)/stride + 1)
        return wav_len
    
    def __getitem__(self, idx):
        signal, seg, phonemes, fname, filetext_id, segment_file = self.process_file(self.data[idx])
        
        return {'audio_file':fname, 
                'segment_file':segment_file, 
                'id':filetext_id, 
                'sample': signal, 
                'length': len(signal), 
                'spectral_size': self.spectral_size(len(signal)),
                'boundaries': seg}
        
    def __len__(self):
        return len(self.data)

In [7]:
#!g1.1
# Данный класс основан на https://github.com/felixkreuk/UnsupSeg/blob/master/dataloader.py

class TrainTestDataset(WavPhnDataset):
    def __init__(self, path):
        super(TrainTestDataset, self).__init__(path)

    @staticmethod
    def get_datasets(path, val_ratio=0.1):
        train_dataset = TrainTestDataset(os.path.join(path, 'TRAIN'))
        test_dataset  = TrainTestDataset(os.path.join(path, 'TEST'))

        train_len   = len(train_dataset)
        train_split = int(train_len * (1 - val_ratio))
        val_split   = train_len - train_split
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_split, val_split])

        train_dataset.path = os.path.join(path, 'train')
        val_dataset.path = os.path.join(path, 'train')

        return train_dataset, val_dataset, test_dataset

In [8]:
#!g1.1
def collate_fn(samples):
    
    max_length = max([sample['length'] for sample in samples])
    boundaries = [sample['boundaries'] for sample in samples]
    spectral_sizes = [sample['spectral_size'] for sample in samples]
    samples1 = []
    lengths = []
    samplings = []
    attentions = []
    ids = []
    audio_files = []
    segment_files = []
    for sample in samples:
        to_add_l = max_length-sample['length']
        sample1 = list(sample['sample'])+[0]*to_add_l
        samples1.append(torch.Tensor(sample1).unsqueeze(0))
        lengths.append(sample['length'])
        ids.append(sample['id'])
        audio_files.append(sample['audio_file'])
        segment_files.append(sample['segment_file'])
        att_norm = torch.ones(size = (1, sample['length']))
        att_add = torch.zeros(size = (1, to_add_l))
        att = torch.cat([att_norm, att_add], dim = -1)
        attentions.append(att)
        
    batch = torch.cat(samples1)
    lengths = torch.Tensor(lengths)
    attention_mask = torch.cat(attentions, dim = 0)
    spectral_size = torch.Tensor(spectral_sizes)
    
    return dict(batch=batch, lengths=lengths, attention_mask=attention_mask, 
                boundaries=boundaries, ids=ids, 
                audio_files=audio_files, 
                segment_files=segment_files, 
                spectral_size=spectral_size)

# Загрузка Данных

In [9]:
#!g1.1
train_dataset = WavPhnDataset('Buckeye_fin/Train')
val_dataset = WavPhnDataset('Buckeye_fin/Valid')
test_dataset = WavPhnDataset('Buckeye_fin/Test')

In [10]:
#!g1.1
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=8, collate_fn = collate_fn)

In [11]:
#!g1.1
path_results = 'save_results_path_compare_buckeye'

In [12]:
#!g1.1


# Segment Model

In [24]:
#!g1.1
model_path = 'golos_model_segment_r_val_acc_200_edges_train_buckeye_model_segment-v2.ckpt'

In [25]:
#!g1.1
accumulate_grad_batches = 1
cfg = {'optimizer': "adam",
'momentum': 0.9,
'learning_rate': 0.0001*accumulate_grad_batches,
'lr_anneal_gamma': 1.0,
'lr_anneal_step': 1000,
# 'epochs': 500,
'grad_clip': 0.5,
'batch_size': 8,

'conv_args': {},
'mask_args': {"segment": "first", "add_one": False},
'segm_enc_args': {},
'segm_predictor_args': {},
'loss_args': {"n_negatives": 1, "loss_args": {"reduction": "mean"}},
'num_epoch': 2}

In [26]:
#!g1.1
class Conf:
    def __init__(self, my_dict):
        for key, value in my_dict.items():
            setattr(self, key, value)
            
config = Conf(cfg)

In [27]:
#!g1.1
model = FinModel(config)
checkpoint = torch.load(model_path)
#     checkpoint = torch.load(model_path, map_location='cpu')

state_dicts = OrderedDict()
for key, value in checkpoint['state_dict'].items():
    state_dicts[key.replace('wav2vec_segm.', '')] = value
model.load_state_dict(state_dicts)

model.eval()
model=model.to('cuda')

for batch in tqdm(test_loader):
    x = batch['batch']
    lengths = batch['lengths']
    attention_mask = batch['attention_mask']
    secs = batch['boundaries']
    ids = batch['ids']
    spectral_sizes = batch['spectral_size']
    name_path = 'model_segment'
    os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 


    rr = model.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), return_secs=True)
#         rr = model.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, return_secs=True)
    secs_preds = rr[1]['secs_pred']
    for idd, secs in zip(ids, secs_preds):
        with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
            file.write(str(secs))

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
100%|██████████| 692/692 [07:41<00:00,  1.50it/s]


In [28]:
#!g1.1


# Peak Detection Model

In [29]:
#!g1.1
model_path = 'golos_model_segment_r_val_acc_200_edges_train_buckeye_peak_detection-v1.ckpt'

In [30]:
#!g1.1
accumulate_grad_batches = 1
cfg = {'optimizer': "adam",
'momentum': 0.9,
'learning_rate': 0.0001*accumulate_grad_batches,
'lr_anneal_gamma': 1.0,
'lr_anneal_step': 1000,
# 'epochs': 500,
'grad_clip': 0.5,
'batch_size': 8,

'conv_args': {},
'mask_args': {"segment": "first", "add_one": False},
'segm_enc_args': {},
'segm_predictor_args': {},
'loss_args': {"n_negatives": 1, "loss_args": {"reduction": "mean"}},
'num_epoch': 2,
'use_projection': False
      }

In [31]:
#!g1.1
class Conf:
    def __init__(self, my_dict):
        for key, value in my_dict.items():
            setattr(self, key, value)
            
config = Conf(cfg)

In [32]:
#!g1.1
model = FinModel1(config) 
checkpoint = torch.load(model_path)
#     checkpoint = torch.load(model_path, map_location='cpu')

state_dicts = OrderedDict()
for key, value in checkpoint['state_dict'].items():
    state_dicts[key.replace('wav2vec_segm.', '')] = value
model.load_state_dict(state_dicts)

model.eval()
model=model.to('cuda')

for batch in tqdm(test_loader):
    x = batch['batch']
    lengths = batch['lengths']
    attention_mask = batch['attention_mask']
    secs = batch['boundaries']
    ids = batch['ids']
    spectral_sizes = batch['spectral_size']
    name_path = 'peak_detection'
    os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 


    rr = model.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), 
                           spectral_size = spectral_sizes,
                           return_secs=True)
#         rr = model.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, 
#                            spectral_size = spectral_sizes,
#                             return_secs=True)
    secs_preds = rr[1]['secs_pred']
    for idd, secs in zip(ids, secs_preds):
        with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
            file.write(str(list(secs)))

100%|██████████| 692/692 [07:12<00:00,  1.60it/s]


In [33]:
#!g1.1


# Wav2Vec2 Model

In [14]:
#!g1.1
from modeling_segmentation import Wav2Vec2ModelForSegmentation

In [18]:
#!g1.1
model_path = 'buckeye_r_val_negs1-v2.ckpt'

In [19]:
#!g1.1


In [20]:
#!g1.1

model = Wav2Vec2ModelForSegmentation.from_pretrained("facebook/wav2vec2-base-960h") 
checkpoint = torch.load(model_path)
#     checkpoint = torch.load(model_path, map_location='cpu')

state_dicts = OrderedDict()
for key, value in checkpoint['state_dict'].items():
    state_dicts[key.replace('wav2vec_segm.', '')] = value
model.load_state_dict(state_dicts)

model.eval()
model=model.to('cuda')

for batch in tqdm(test_loader):
    x = batch['batch']
    lengths = batch['lengths']
    attention_mask = batch['attention_mask']
    secs = batch['boundaries']
    ids = batch['ids']
    spectral_sizes = batch['spectral_size']
    name_path = 'wav2vec_model'
    os.makedirs(os.path.join(path_results, name_path), exist_ok=True) 


    rr = model.compute_all(x.to('cuda'), secs, num_epoch=0, attention_mask=attention_mask.to('cuda'), 
                           return_secs=True)
#         rr = model.compute_all(x, secs, num_epoch=0, attention_mask=attention_mask, 
#                             return_secs=True)
    secs_preds = rr[1]['secs_pred']
    for idd, secs in zip(ids, secs_preds):
        with open(os.path.join(os.path.join(path_results, name_path), idd), 'w', encoding="cp1251") as file:
            file.write(str(secs))

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ModelForSegmentation: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ModelForSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ModelForSegmentation were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
To keep the current behavior, use torch.div(a, b, ro

In [21]:
#!g1.1


# Агрегация данных

In [22]:
#!g1.1
import pandas as pd

In [34]:
#!g1.1
# Данная функция основана на https://github.com/felixkreuk/UnsupSeg/blob/master/utils.py

class RMetrics1(nn.Module):
    def __init__(self, eps = 1e-5, tolerance = 2, sampling_rate = 16000):
        super(RMetrics1, self).__init__()
        self.tolerance = tolerance
        self.eps = eps
        self.sampling_rate = sampling_rate
    
    def calculate_stride(self, isz, conv_layers):
        pad = 0
        insize = isz
        totstride = 1
        sec_per_frame = 1/self.sampling_rate

        for layer in conv_layers:
            kernel, stride = layer
            outsize = (insize + 2*pad - 1*(kernel-1)-1) / stride + 1
            insize = outsize
            totstride = totstride * stride

        RFsize = isz - (outsize - 1) * totstride

        ms_per_frame = sec_per_frame*RFsize*1000
        ms_stride = sec_per_frame*totstride*1000
        return outsize, totstride, RFsize, ms_per_frame, ms_stride
        
    def get_frames(self, secs, stride):
        frames = [[int(i*self.sampling_rate/stride) for i in sec] for sec in secs]
        return frames
        
    def make_true_boundaries(self, secs, boundaries, stride):
        frames = self.get_frames(secs, stride)
        true_boundaries = torch.zeros(size = boundaries.shape)
        for num_frame, frame in enumerate(frames):
            for i in frame:
                true_boundaries[num_frame, i] = 1
        return true_boundaries.long().detach().numpy()
    
    def get_sec_bounds(self, b, stride, attention_mask = None):
        if type(b)==torch.Tensor:
            b1 = b.long().detach().cpu().numpy()
        else:
            b1 = b
        
        if attention_mask is not None:
            b1 = b1*attention_mask.long().detach().cpu().numpy()
            
        frames_pred = []
        secs_pred = []
        for i in range(b1.shape[0]):
            frames = np.where(b1[i, :] == 1)[0]
            secs = [i*stride/self.sampling_rate for i in frames]
            frames_pred.append(frames)
            secs_pred.append(secs)
        return frames_pred, secs_pred
    
    def get_precision_recall_frames(self, true_boundaries, b, attention_mask = None):
        if type(b)==torch.Tensor:
            b1 = b.long().detach().numpy()
        else:
            b1 = b
            
        if attention_mask is not None:
            b1 = b1*attention_mask.long().detach().cpu().numpy()
            
        recall = recall_score(true_boundaries.flatten(), b1.flatten())
        pre = precision_score(true_boundaries.flatten(), b1.flatten())
        f_score = f1_score(true_boundaries.flatten(), b1.flatten())
        return recall, pre, f_score
    
    def get_stats(self, frames_true, frames_pred):
        
        # Утащено отсюда: https://github.com/felixkreuk/UnsupSeg/blob/68c2c7b9bd49f3fb8f51c5c2f4d5aa85f251eaa8/utils.py#L69
        precision_counter = 0 
        recall_counter = 0
        pred_counter = 0 
        gt_counter = 0

        for (y, yhat) in zip(frames_true, frames_pred):
            for yhat_i in yhat:
                min_dist = np.abs(np.array(y) - yhat_i).min()
                precision_counter += (min_dist <= self.tolerance)
            for y_i in y:
                if len(yhat) > 0:
                    min_dist = np.abs(np.array(yhat) - y_i).min()
                    recall_counter += (min_dist <= self.tolerance)
                else:
                    recall_counter += 0
            pred_counter += len(yhat)
            gt_counter += len(y)

        return precision_counter, recall_counter, pred_counter, gt_counter
    
    def calc_metr(self, precision_counter, recall_counter, pred_counter, gt_counter):

        # Утащено отсюда: https://github.com/felixkreuk/UnsupSeg/blob/68c2c7b9bd49f3fb8f51c5c2f4d5aa85f251eaa8/utils.py#L69
        EPS = 1e-7

        precision = precision_counter / (pred_counter + self.eps)
        recall = recall_counter / (gt_counter + self.eps)
        f1 = 2 * (precision * recall) / (precision + recall + self.eps)

        os = recall / (precision + EPS) - 1
        r1 = np.sqrt((1 - recall) ** 2 + os ** 2)
        r2 = (-os + recall - 1) / (np.sqrt(2))
        rval = 1 - (np.abs(r1) + np.abs(r2)) / 2

        return precision, recall, f1, rval
    
    def get_metrics(self, true_secs, b, seq_len, config, attention_mask = None, 
                    return_secs=False):
        
        outsize, totstride, RFsize, ms_per_frame, ms_stride = self.calculate_stride(seq_len, config)
#         print(seq_len, outsize, totstride, RFsize, ms_per_frame, ms_stride)
        frames_true = self.get_frames(true_secs, totstride)
        frames_pred, secs_pred = self.get_sec_bounds(b, totstride, attention_mask = attention_mask)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(frames_true, frames_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        if return_secs:
            return precision, recall, f1, rval, secs_pred
        else:
            return precision, recall, f1, rval
        
    def get_metrics_secs(self, true_secs, secs_pred, totstride = 160):
        
        frames_true = self.get_frames(true_secs, totstride)
        frames_pred = self.get_frames(secs_pred, totstride)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(frames_true, frames_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        return precision, recall, f1, rval
    
    def get_metrics_secs1(self, true_secs, secs_pred, totstride = 160):
        
#         frames_true = self.get_frames(true_secs, totstride)
#         frames_pred = self.get_frames(secs_pred, totstride)
        precision_counter, recall_counter, pred_counter, gt_counter = self.get_stats(true_secs, secs_pred)
        precision, recall, f1, rval = self.calc_metr(precision_counter, recall_counter, pred_counter, gt_counter)
        return precision, recall, f1, rval
    
    

In [35]:
#!g1.1
def read_true_file(segment_file):
    with open(segment_file, 'r', encoding="cp1251") as file:
        tt = file.read()

    boundaries = set()
    mm = [i for i in tt.split('\n') if len(i)>0]
    
    for i in mm:
        boundaries.add(eval(i.split()[0]))
    boundaries = sorted(list(boundaries))
    return boundaries

In [36]:
#!g1.1
def read_pred_file(segment_file):
    with open(segment_file, 'r', encoding="cp1251") as file:
        tt = file.read()
    boundaries = eval(tt)
    return boundaries

In [37]:
#!g1.1
path_results = 'save_results_path_compare_buckeye'

folders = ['model_segment', 'peak_detection', 'wav2vec_model']

In [64]:
#!g1.1
result_dataframes = []
metr = RMetrics1(tolerance=2)
totstride = 160
true_folder = 'Buckeye_fin/Test'

for folder in folders:
    true_files = [i.replace('.wav', '.txt') for i in list(iter_find_files(true_folder, "*.wav"))]
    idss = os.listdir(os.path.join(path_results, folder))

    secs_preds = []
    secs_trues = []

    for num in tqdm(range(len(true_files))):

        bound_true = read_true_file(true_files[num])
        bound_pred = read_pred_file(os.path.join(os.path.join(path_results, folder), idss[num]))

        secs_trues.append(bound_true)
        secs_preds.append(bound_pred)
   
    precision, recall, f1, rval = metr.get_metrics_secs(secs_trues, secs_preds, totstride = totstride)
    
    datafr = pd.DataFrame([folder, 
                           precision, recall, 
                           f1, rval]).T.rename(columns = {0:'type', 
                                                          1:'precision',
                                                          2:'recall', 
                                                          3:'f1', 
                                                          4:'rval'})
    
    result_dataframes.append(datafr)

100%|██████████| 5536/5536 [00:03<00:00, 1576.05it/s]
100%|██████████| 5536/5536 [00:03<00:00, 1416.49it/s]
100%|██████████| 5536/5536 [00:03<00:00, 1395.00it/s]


In [65]:
#!g1.1
result_df = pd.concat(result_dataframes, ignore_index = True)

In [66]:
#!g1.1
result_df

Unnamed: 0,type,precision,recall,f1,rval
0,model_segment,0.586108,0.588653,0.587373,0.647347
1,peak_detection,0.577641,0.553406,0.565259,0.632658
2,wav2vec_model,0.590504,0.460075,0.517188,0.595521


In [67]:
#!g1.1
result_df.to_csv('results_compare_buckeye.csv')

In [63]:
#!g1.1
